├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── Setup.hs ├── bft0.sh ├── bft1.sh ├── bft2.sh ├── bft3.sh ├── bftclient.sh ├── bftservers.sh ├── bin ├── Byzantine │ ├── Client.hs │ ├── Command.hs │ └── Server.hs ├── GenerateKeys.hs ├── Simple │ ├── Client.hs │ ├── Command.hs │ └── Server.hs ├── udprecv.hs └── udpsend.hs ├── client.sh ├── servers.sh ├── src └── Network │ └── Tangaroa │ ├── Byzantine │ ├── Client.hs │ ├── Handler.hs │ ├── Role.hs │ ├── Sender.hs │ ├── Server.hs │ ├── Spec │ │ └── Simple.hs │ ├── Timer.hs │ ├── Types.hs │ └── Util.hs │ ├── Client.hs │ ├── Combinator.hs │ ├── Handler.hs │ ├── Role.hs │ ├── Sender.hs │ ├── Server.hs │ ├── Spec │ └── Simple.hs │ ├── Timer.hs │ ├── Types.hs │ └── Util.hs └── tangaroa.cabal /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | .cabal-sandbox 3 | cabal.sandbox.config 4 | public_keys* 5 | private_keys* 6 | client_public_keys* 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrisnc/tangaroa/5dc68d9de0bdc23506e58f86357576f2c4a0c178/.gitmodules -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Chris Copeland 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the author nor the names of its contributors may be used 15 | to endorse or promote products derived from this software without specific 16 | prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. 22 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY 23 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 24 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 25 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 26 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 28 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Tangaroa 2 | ======== 3 | Raft with Byzantine Fault Tolerance in Haskell 4 | ---------------------------------------------- 5 | 6 | To build: 7 | - Install GHC and Cabal: http://www.haskell.org/ghc/ https://www.haskell.org/cabal/ (or use a package manager) 8 | - `cabal update` 9 | - `cabal install --only-dependencies` 10 | - `cabal configure` 11 | - `cabal build` 12 | 13 | See the `bin` directory for example server and client implementations. 14 | The `bft*.sh` or `bftservers.sh` scripts will launch BFT Raft nodes, and `bftclient.sh` 15 | will launch a client to connect to them. 16 | 17 | For standard Raft, use `client.sh` and `servers.sh`. 18 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /bft0.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/bftserver/bftserver -d -p public_keys.txt -c client_public_keys.txt -k private_keys/10000.txt -s 10000 10001 10002 10003 4 | -------------------------------------------------------------------------------- /bft1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/bftserver/bftserver -d -p public_keys.txt -c client_public_keys.txt -k private_keys/10001.txt -s 10001 10000 10002 10003 4 | -------------------------------------------------------------------------------- /bft2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/bftserver/bftserver -d -p public_keys.txt -c client_public_keys.txt -k private_keys/10002.txt -s 10002 10001 10000 10003 4 | -------------------------------------------------------------------------------- /bft3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/bftserver/bftserver -d -p public_keys.txt -c client_public_keys.txt -k private_keys/10003.txt -s 10003 10001 10002 10000 4 | -------------------------------------------------------------------------------- /bftclient.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/bftclient/bftclient -d -p public_keys.txt -k private_keys/10008.txt -s 10008 10000 10001 10002 10003 4 | -------------------------------------------------------------------------------- /bftservers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/bftserver/bftserver -p public_keys.txt -c client_public_keys.txt -k private_keys/10000.txt -s 10000 10001 10002 10003 & 4 | dist/build/bftserver/bftserver -p public_keys.txt -c client_public_keys.txt -k private_keys/10001.txt -s 10001 10000 10002 10003 & 5 | dist/build/bftserver/bftserver -p public_keys.txt -c client_public_keys.txt -k private_keys/10002.txt -s 10002 10001 10000 10003 & 6 | dist/build/bftserver/bftserver -p public_keys.txt -c client_public_keys.txt -k private_keys/10003.txt -s 10003 10001 10002 10000 7 | -------------------------------------------------------------------------------- /bin/Byzantine/Client.hs: -------------------------------------------------------------------------------- 1 | module Main 2 | ( main 3 | ) where 4 | 5 | import Network.Tangaroa.Byzantine.Spec.Simple 6 | 7 | import Command 8 | 9 | getCommand :: IO CommandType 10 | getCommand = do 11 | cmd <- getLine 12 | case words cmd of 13 | ["insert", k, v] -> return (Insert k v) 14 | ["delete", k] -> return (Delete k) 15 | ["set", k, v] -> return (Set k v) 16 | ["get", k] -> return (Get k) 17 | _ -> do 18 | putStrLn "Not a recognized command." 19 | getCommand 20 | 21 | showResult :: ResultType -> IO () 22 | showResult r = 23 | case r of 24 | Value v -> putStrLn v 25 | _ -> print r 26 | 27 | main :: IO () 28 | main = do 29 | runClient (\_ -> return Failure) getCommand showResult 30 | -------------------------------------------------------------------------------- /bin/Byzantine/Command.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveGeneric #-} 2 | 3 | module Command 4 | ( CommandType(..) 5 | , ResultType(..) 6 | ) where 7 | 8 | import Data.Binary 9 | import GHC.Generics 10 | 11 | data CommandType = Insert String String 12 | | Delete String 13 | | Set String String 14 | | Get String 15 | deriving (Show, Read, Generic) 16 | 17 | instance Binary CommandType 18 | 19 | data ResultType = Value String -- for successful Get 20 | | Success -- for successful Insert, Delete, Set 21 | | Failure 22 | deriving (Show, Read, Generic) 23 | 24 | instance Binary ResultType 25 | -------------------------------------------------------------------------------- /bin/Byzantine/Server.hs: -------------------------------------------------------------------------------- 1 | module Main 2 | ( main 3 | ) where 4 | 5 | import Network.Tangaroa.Byzantine.Spec.Simple 6 | 7 | import Data.IORef 8 | import Data.Map (Map) 9 | import qualified Data.Map as Map 10 | 11 | import Command 12 | 13 | applyCommand :: IORef (Map String String) -> CommandType -> IO ResultType 14 | applyCommand ref cmd = do 15 | case cmd of 16 | Insert k v -> runInsert ref k v 17 | Delete k -> runDelete ref k 18 | Set k v -> runSet ref k v 19 | Get k -> runGet ref k 20 | 21 | member :: IORef (Map String String) 22 | -> String 23 | -> (IORef (Map String String) -> IO ResultType) 24 | -> (IORef (Map String String) -> IO ResultType) 25 | -> IO ResultType 26 | member ref k memberFn notMemberFn = do 27 | isMember <- Map.member k <$> readIORef ref 28 | if isMember 29 | then memberFn ref 30 | else notMemberFn ref 31 | 32 | -- adds a new mapping, and fails if a mapping already exists 33 | runInsert :: IORef (Map String String) -> String -> String -> IO ResultType 34 | runInsert ref k v = member ref k doFail (doInsert k v) 35 | 36 | -- like insert, but instead fails if a mapping doesn't exist 37 | runSet :: IORef (Map String String) -> String -> String -> IO ResultType 38 | runSet ref k v = member ref k (doInsert k v) doFail 39 | 40 | -- gets the value for a key, fails if it doesn't exist 41 | runGet :: IORef (Map String String) -> String -> IO ResultType 42 | runGet ref k = do 43 | mv <- Map.lookup k <$> readIORef ref 44 | case mv of 45 | Just v -> return (Value v) 46 | Nothing -> return Failure 47 | 48 | -- removes the mapping for a key, fails if it doesn't exist 49 | runDelete :: IORef (Map String String) -> String -> IO ResultType 50 | runDelete ref k = member ref k (doDelete k) doFail 51 | 52 | doFail :: IORef (Map String String) -> IO ResultType 53 | doFail = return . const Failure 54 | 55 | doInsert :: String -> String -> IORef (Map String String) -> IO ResultType 56 | doInsert k v ref = modifyIORef ref (Map.insert k v) >> return Success 57 | 58 | doDelete :: String -> IORef (Map String String) -> IO ResultType 59 | doDelete k ref = modifyIORef ref (Map.delete k) >> return Success 60 | 61 | main :: IO () 62 | main = do 63 | stateVariable <- newIORef Map.empty 64 | runServer (applyCommand stateVariable) 65 | -------------------------------------------------------------------------------- /bin/GenerateKeys.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Codec.Crypto.RSA 4 | import Crypto.Random 5 | import Text.Read 6 | import Network.Socket 7 | import Data.Word 8 | import System.Directory 9 | import System.IO 10 | import System.FilePath 11 | 12 | import qualified Data.Map as Map 13 | 14 | portnums :: [Word16] 15 | portnums = iterate (+ 1) 10000 16 | 17 | main :: IO () 18 | main = do 19 | g <- newGenIO :: IO SystemRandom 20 | putStrLn "Number of keys to generate? " 21 | hFlush stdout 22 | mn <- fmap readMaybe getLine 23 | case mn of 24 | Just n -> do 25 | let keys = generateKeys g 1024 n 26 | writePublicKeys $ map fst keys 27 | writePrivateKeys $ zip portnums $ map snd keys 28 | Nothing -> putStrLn "Please specify a number of keys to generate." 29 | 30 | generateKeys :: CryptoRandomGen g => g -> Int -> Int -> [(PublicKey, PrivateKey)] 31 | generateKeys g nbits nkeys = case nkeys of 32 | 0 -> [] 33 | n -> (pubkey, privkey) : generateKeys ng nbits (n - 1) where 34 | (pubkey, privkey, ng) = generateKeyPair g nbits 35 | 36 | localhost :: HostAddress 37 | localhost = 0x0100007f 38 | 39 | writePublicKeys :: Show a => [a] -> IO () 40 | writePublicKeys xs = do 41 | putStrLn "Filename for public keys? " 42 | hFlush stdout 43 | filename <- getLine 44 | writeFile filename $ 45 | show $ Map.fromList $ 46 | zip 47 | (zip 48 | (repeat localhost) 49 | portnums) 50 | xs 51 | 52 | writePrivateKeys :: (Show name, Show a) => [(name,a)] -> IO () 53 | writePrivateKeys xs = do 54 | putStrLn "Folder for private keys? " 55 | hFlush stdout 56 | dirname <- getLine 57 | createDirectory dirname 58 | mapM_ (\(fn, x) -> writeFile (dirname show fn ++ ".txt") (show x)) xs 59 | -------------------------------------------------------------------------------- /bin/Simple/Client.hs: -------------------------------------------------------------------------------- 1 | module Main 2 | ( main 3 | ) where 4 | 5 | import Network.Tangaroa.Spec.Simple 6 | 7 | import Command 8 | 9 | getCommand :: IO CommandType 10 | getCommand = do 11 | cmd <- getLine 12 | case words cmd of 13 | ["insert", k, v] -> return (Insert k v) 14 | ["delete", k] -> return (Delete k) 15 | ["set", k, v] -> return (Set k v) 16 | ["get", k] -> return (Get k) 17 | _ -> do 18 | putStrLn "Not a recognized command." 19 | getCommand 20 | 21 | showResult :: ResultType -> IO () 22 | showResult r = 23 | case r of 24 | Value v -> putStrLn v 25 | _ -> print r 26 | 27 | main :: IO () 28 | main = do 29 | runClient (\_ -> return Failure) getCommand showResult 30 | -------------------------------------------------------------------------------- /bin/Simple/Command.hs: -------------------------------------------------------------------------------- 1 | module Command 2 | ( CommandType(..) 3 | , ResultType(..) 4 | ) where 5 | 6 | data CommandType = Insert String String 7 | | Delete String 8 | | Set String String 9 | | Get String 10 | deriving (Show, Read) 11 | 12 | data ResultType = Value String -- for successful Get 13 | | Success -- for successful Insert, Delete, Set 14 | | Failure 15 | deriving (Show, Read) 16 | -------------------------------------------------------------------------------- /bin/Simple/Server.hs: -------------------------------------------------------------------------------- 1 | module Main 2 | ( main 3 | ) where 4 | 5 | import Network.Tangaroa.Spec.Simple 6 | 7 | import Data.IORef 8 | import Data.Map (Map) 9 | import qualified Data.Map as Map 10 | 11 | import Command 12 | 13 | applyCommand :: IORef (Map String String) -> CommandType -> IO ResultType 14 | applyCommand ref cmd = do 15 | case cmd of 16 | Insert k v -> runInsert ref k v 17 | Delete k -> runDelete ref k 18 | Set k v -> runSet ref k v 19 | Get k -> runGet ref k 20 | 21 | member :: IORef (Map String String) 22 | -> String 23 | -> (IORef (Map String String) -> IO ResultType) 24 | -> (IORef (Map String String) -> IO ResultType) 25 | -> IO ResultType 26 | member ref k memberFn notMemberFn = do 27 | isMember <- Map.member k <$> readIORef ref 28 | if isMember 29 | then memberFn ref 30 | else notMemberFn ref 31 | 32 | -- adds a new mapping, and fails if a mapping already exists 33 | runInsert :: IORef (Map String String) -> String -> String -> IO ResultType 34 | runInsert ref k v = member ref k doFail (doInsert k v) 35 | 36 | -- like insert, but instead fails if a mapping doesn't exist 37 | runSet :: IORef (Map String String) -> String -> String -> IO ResultType 38 | runSet ref k v = member ref k (doInsert k v) doFail 39 | 40 | -- gets the value for a key, fails if it doesn't exist 41 | runGet :: IORef (Map String String) -> String -> IO ResultType 42 | runGet ref k = do 43 | mv <- Map.lookup k <$> readIORef ref 44 | case mv of 45 | Just v -> return (Value v) 46 | Nothing -> return Failure 47 | 48 | -- removes the mapping for a key, fails if it doesn't exist 49 | runDelete :: IORef (Map String String) -> String -> IO ResultType 50 | runDelete ref k = member ref k (doDelete k) doFail 51 | 52 | doFail :: IORef (Map String String) -> IO ResultType 53 | doFail = return . const Failure 54 | 55 | doInsert :: String -> String -> IORef (Map String String) -> IO ResultType 56 | doInsert k v ref = modifyIORef ref (Map.insert k v) >> return Success 57 | 58 | doDelete :: String -> IORef (Map String String) -> IO ResultType 59 | doDelete k ref = modifyIORef ref (Map.delete k) >> return Success 60 | 61 | main :: IO () 62 | main = do 63 | stateVariable <- newIORef Map.empty 64 | runServer (applyCommand stateVariable) 65 | -------------------------------------------------------------------------------- /bin/udprecv.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Control.Monad (forever) 4 | import Network.Socket hiding (recv, recvFrom, send, sendTo) 5 | import Network.Socket.ByteString 6 | import Data.ByteString.Char8 as B 7 | 8 | localhost :: HostAddress 9 | localhost = 0x0100007f 10 | 11 | port :: PortNumber 12 | port = 10000 13 | 14 | maxlen :: Int 15 | maxlen = 8192 16 | 17 | main :: IO () 18 | main = do 19 | sock <- socket AF_INET Datagram defaultProtocol 20 | bind sock (SockAddrInet port localhost) 21 | forever $ do 22 | msg <- recv sock maxlen 23 | B.putStrLn msg 24 | -------------------------------------------------------------------------------- /bin/udpsend.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Control.Monad (forever) 4 | import Network.Socket hiding (recv, recvFrom, send, sendTo) 5 | import Network.Socket.ByteString 6 | import Data.ByteString.Char8 as B 7 | 8 | localhost :: HostAddress 9 | localhost = 0x0100007f 10 | 11 | port :: PortNumber 12 | port = 10000 13 | 14 | main :: IO () 15 | main = do 16 | sock <- socket AF_INET Datagram defaultProtocol 17 | let target = SockAddrInet port localhost 18 | forever $ do 19 | l <- B.getLine 20 | sendAllTo sock l target 21 | -------------------------------------------------------------------------------- /client.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/simpleclient/simpleclient -s 10010 10000 10001 10002 10003 10004 4 | -------------------------------------------------------------------------------- /servers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dist/build/simpleserver/simpleserver -s 10000 10001 10002 10003 10004 & 4 | dist/build/simpleserver/simpleserver -s 10001 10000 10002 10003 10004 & 5 | dist/build/simpleserver/simpleserver -s 10002 10001 10000 10003 10004 & 6 | dist/build/simpleserver/simpleserver -s 10003 10001 10002 10000 10004 & 7 | dist/build/simpleserver/simpleserver -s 10004 10001 10002 10003 10000 8 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Client.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | 4 | module Network.Tangaroa.Byzantine.Client 5 | ( runRaftClient 6 | ) where 7 | 8 | import Control.Concurrent.Chan.Unagi 9 | import Control.Lens hiding (Index) 10 | import Control.Monad.RWS 11 | import Data.Binary 12 | import Data.Foldable (traverse_) 13 | import qualified Data.ByteString.Lazy as B 14 | import qualified Data.Map as Map 15 | import qualified Data.Set as Set 16 | 17 | import Network.Tangaroa.Byzantine.Timer 18 | import Network.Tangaroa.Byzantine.Types 19 | import Network.Tangaroa.Byzantine.Util 20 | import Network.Tangaroa.Byzantine.Sender (sendSignedRPC) 21 | 22 | runRaftClient :: (Binary nt, Binary et, Binary rt, Ord nt) => IO et -> (rt -> IO ()) -> Config nt -> RaftSpec nt et rt mt -> IO () 23 | runRaftClient getEntry useResult rconf spec@RaftSpec{..} = do 24 | let qsize = getQuorumSize $ Set.size $ rconf ^. otherNodes 25 | (ein, eout) <- newChan 26 | runRWS_ 27 | (raftClient (lift getEntry) (lift . useResult)) 28 | (RaftEnv rconf qsize ein eout (liftRaftSpec spec)) 29 | initialRaftState -- only use currentLeader and logEntries 30 | 31 | raftClient :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt et -> (rt -> Raft nt et rt mt ()) -> Raft nt et rt mt () 32 | raftClient getEntry useResult = do 33 | nodes <- view (cfg.otherNodes) 34 | when (Set.null nodes) $ error "The client has no nodes to send requests to." 35 | currentLeader .= (Just $ Set.findMin nodes) 36 | fork_ messageReceiver 37 | fork_ $ commandGetter getEntry 38 | pendingRequests .= Map.empty 39 | clientHandleEvents useResult 40 | 41 | -- get commands with getEntry and put them on the event queue to be sent 42 | commandGetter :: Raft nt et rt mt et -> Raft nt et rt mt () 43 | commandGetter getEntry = do 44 | nid <- view (cfg.nodeId) 45 | forever $ do 46 | entry <- getEntry 47 | rid <- nextRequestId 48 | enqueueEvent $ ERPC $ CMD $ Command entry nid rid B.empty 49 | 50 | nextRequestId :: Raft nt et rt mt RequestId 51 | nextRequestId = do 52 | currentRequestId += 1 53 | use currentRequestId 54 | 55 | clientHandleEvents :: (Binary nt, Binary et, Binary rt, Ord nt) => (rt -> Raft nt et rt mt ()) -> Raft nt et rt mt () 56 | clientHandleEvents useResult = forever $ do 57 | e <- dequeueEvent 58 | case e of 59 | ERPC (CMD cmd) -> clientSendCommand cmd -- these are commands coming from the commandGetter thread 60 | ERPC (CMDR cmdr) -> clientHandleCommandResponse useResult cmdr 61 | HeartbeatTimeout _ -> do 62 | timeouts <- use numTimeouts 63 | limit <- view (cfg.clientTimeoutLimit) 64 | if timeouts < limit 65 | then do 66 | debug "choosing a new leader and resending commands" 67 | setLeaderToNext 68 | reqs <- use pendingRequests 69 | pendingRequests .= Map.empty -- this will reset the timer on resend 70 | traverse_ clientSendCommand reqs 71 | numTimeouts += 1 72 | else do 73 | debug "starting a revolution" 74 | nid <- view (cfg.nodeId) 75 | mlid <- use currentLeader 76 | case mlid of 77 | Just lid -> do 78 | rid <- nextRequestId 79 | view (cfg.otherNodes) >>= 80 | traverse_ (\n -> sendSignedRPC n (REVOLUTION (Revolution nid lid rid B.empty))) 81 | numTimeouts .= 0 82 | resetHeartbeatTimer 83 | _ -> do 84 | setLeaderToFirst 85 | resetHeartbeatTimer 86 | _ -> return () 87 | 88 | setLeaderToFirst :: Raft nt et rt mt () 89 | setLeaderToFirst = do 90 | nodes <- view (cfg.otherNodes) 91 | when (Set.null nodes) $ error "the client has no nodes to send requests to" 92 | currentLeader .= (Just $ Set.findMin nodes) 93 | 94 | setLeaderToNext :: Ord nt => Raft nt et rt mt () 95 | setLeaderToNext = do 96 | mlid <- use currentLeader 97 | nodes <- view (cfg.otherNodes) 98 | case mlid of 99 | Just lid -> case Set.lookupGT lid nodes of 100 | Just nlid -> currentLeader .= Just nlid 101 | Nothing -> setLeaderToFirst 102 | Nothing -> setLeaderToFirst 103 | 104 | clientSendCommand :: (Binary nt, Binary et, Binary rt) => Command nt et -> Raft nt et rt mt () 105 | clientSendCommand cmd@Command{..} = do 106 | mlid <- use currentLeader 107 | case mlid of 108 | Just lid -> do 109 | sendSignedRPC lid $ CMD cmd 110 | prcount <- fmap Map.size (use pendingRequests) 111 | -- if this will be our only pending request, start the timer 112 | -- otherwise, it should already be running 113 | when (prcount == 0) resetHeartbeatTimer 114 | pendingRequests %= Map.insert _cmdRequestId cmd 115 | Nothing -> do 116 | setLeaderToFirst 117 | clientSendCommand cmd 118 | 119 | clientHandleCommandResponse :: (Binary nt, Binary et, Binary rt, Ord nt) 120 | => (rt -> Raft nt et rt mt ()) 121 | -> CommandResponse nt rt 122 | -> Raft nt et rt mt () 123 | clientHandleCommandResponse useResult cmdr@CommandResponse{..} = do 124 | prs <- use pendingRequests 125 | valid <- verifyRPCWithKey (CMDR cmdr) 126 | when (valid && Map.member _cmdrRequestId prs) $ do 127 | useResult _cmdrResult 128 | currentLeader .= Just _cmdrLeaderId 129 | pendingRequests %= Map.delete _cmdrRequestId 130 | numTimeouts .= 0 131 | prcount <- fmap Map.size (use pendingRequests) 132 | -- if we still have pending requests, reset the timer 133 | -- otherwise cancel it 134 | if (prcount > 0) 135 | then resetHeartbeatTimer 136 | else cancelTimer 137 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Handler.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | 3 | module Network.Tangaroa.Byzantine.Handler 4 | ( handleEvents 5 | ) where 6 | 7 | import Control.Lens 8 | import Control.Monad hiding (mapM) 9 | import Control.Monad.Loops 10 | import Crypto.Hash.SHA256 11 | import Data.Binary 12 | import Data.Sequence (Seq) 13 | import Data.Set (Set) 14 | import Data.Foldable (all, traverse_) 15 | import Data.Traversable (mapM) 16 | import Prelude hiding (mapM, all) 17 | import qualified Data.ByteString.Lazy as LB 18 | import qualified Data.ByteString as B 19 | import qualified Data.Map as Map 20 | import qualified Data.Sequence as Seq 21 | import qualified Data.Set as Set 22 | 23 | import Network.Tangaroa.Byzantine.Types 24 | import Network.Tangaroa.Byzantine.Sender 25 | import Network.Tangaroa.Byzantine.Util 26 | import Network.Tangaroa.Byzantine.Role 27 | import Network.Tangaroa.Byzantine.Timer 28 | 29 | handleEvents :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 30 | handleEvents = forever $ do 31 | e <- dequeueEvent 32 | case e of 33 | ERPC rpc -> handleRPC rpc 34 | ElectionTimeout s -> handleElectionTimeout s 35 | HeartbeatTimeout s -> handleHeartbeatTimeout s 36 | 37 | whenM :: Monad m => m Bool -> m () -> m () 38 | whenM mb ma = do 39 | b <- mb 40 | when b ma 41 | 42 | handleRPC :: (Binary nt, Binary et, Binary rt, Ord nt) => RPC nt et rt -> Raft nt et rt mt () 43 | handleRPC rpc = case rpc of 44 | AE ae -> whenM (verifyRPCWithKey rpc) $ handleAppendEntries ae 45 | AER aer -> whenM (verifyRPCWithKey rpc) $ handleAppendEntriesResponse aer 46 | RV rv -> whenM (verifyRPCWithKey rpc) $ handleRequestVote rv 47 | RVR rvr -> whenM (verifyRPCWithKey rpc) $ handleRequestVoteResponse rvr 48 | CMD cmd -> whenM (verifyRPCWithClientKey rpc) $ handleCommand cmd 49 | CMDR _ -> whenM (verifyRPCWithKey rpc) $ debug "got a command response RPC" 50 | DBG s -> debug $ "got a debug RPC: " ++ s 51 | REVOLUTION rev -> whenM (verifyRPCWithClientKey rpc) $ handleRevolution rev 52 | 53 | handleElectionTimeout :: (Binary nt, Binary et, Binary rt, Ord nt) => String -> Raft nt et rt mt () 54 | handleElectionTimeout s = do 55 | debug $ "election timeout: " ++ s 56 | r <- use role 57 | when (r /= Leader) $ do 58 | lv <- use lazyVote 59 | case lv of 60 | Just (t, c) -> do 61 | updateTerm t 62 | setVotedFor (Just c) 63 | lazyVote .= Nothing 64 | ignoreLeader .= False 65 | currentLeader .= Nothing 66 | fork_ $ sendRequestVoteResponse c True 67 | resetElectionTimer 68 | Nothing -> becomeCandidate 69 | 70 | handleHeartbeatTimeout :: (Binary nt, Binary et, Binary rt, Ord nt) => String -> Raft nt et rt mt () 71 | handleHeartbeatTimeout s = do 72 | debug $ "heartbeat timeout: " ++ s 73 | r <- use role 74 | when (r == Leader) $ do 75 | fork_ sendAllAppendEntries 76 | resetHeartbeatTimer 77 | 78 | checkForNewLeader :: (Binary nt, Binary et, Binary rt, Ord nt) => AppendEntries nt et -> Raft nt et rt mt () 79 | checkForNewLeader AppendEntries{..} = do 80 | ct <- use term 81 | cl <- use currentLeader 82 | if (_aeTerm == ct && cl == Just _leaderId) || 83 | _aeTerm < ct || 84 | Set.size _aeQuorumVotes == 0 85 | then return () 86 | else do 87 | votesValid <- confirmElection _leaderId _aeTerm _aeQuorumVotes 88 | when votesValid $ do 89 | updateTerm _aeTerm 90 | ignoreLeader .= False 91 | currentLeader .= Just _leaderId 92 | 93 | confirmElection :: (Binary nt, Binary et, Binary rt, Ord nt) => nt -> Term -> Set (RequestVoteResponse nt) -> Raft nt et rt mt Bool 94 | confirmElection l t votes = do 95 | debug "confirming election of a new leader" 96 | qsize <- view quorumSize 97 | if Set.size votes >= qsize 98 | then allM (validateVote l t) (Set.toList votes) 99 | else return False 100 | 101 | validateVote :: (Binary nt, Binary et, Binary rt, Ord nt) => nt -> Term -> RequestVoteResponse nt -> Raft nt et rt mt Bool 102 | validateVote l t vote@RequestVoteResponse{..} = do 103 | sigOkay <- verifyRPCWithKey (RVR vote) 104 | return (sigOkay && _rvrCandidateId == l && _rvrTerm == t) 105 | 106 | handleAppendEntries :: (Binary nt, Binary et, Binary rt, Ord nt) => AppendEntries nt et -> Raft nt et rt mt () 107 | handleAppendEntries ae@AppendEntries{..} = do 108 | debug $ "got an appendEntries RPC: prev log entry: Index " ++ show _prevLogIndex ++ " " ++ show _prevLogTerm 109 | checkForNewLeader ae 110 | cl <- use currentLeader 111 | ig <- use ignoreLeader 112 | ct <- use term 113 | case cl of 114 | Just l | not ig && l == _leaderId && _aeTerm == ct -> do 115 | resetElectionTimer 116 | lazyVote .= Nothing 117 | plmatch <- prevLogEntryMatches _prevLogIndex _prevLogTerm 118 | if not plmatch 119 | then fork_ $ sendAppendEntriesResponse _leaderId False True 120 | else do 121 | appendLogEntries _prevLogIndex _aeEntries 122 | doCommit 123 | {-| 124 | if (not (Seq.null _aeEntries)) 125 | -- only broadcast when there are new entries 126 | -- this has the downside that recovering nodes won't update 127 | -- their commit index until new entries come along 128 | -- not sure if this is okay or not 129 | -- committed entries by definition have already been externalized 130 | -- so if a particular node missed it, there were already 2f+1 nodes 131 | -- that didn't 132 | then fork_ sendAllAppendEntriesResponse 133 | else fork_ $ sendAppendEntriesResponse _leaderId True True 134 | --} 135 | fork_ sendAllAppendEntriesResponse 136 | _ | not ig && _aeTerm >= ct -> do 137 | debug "sending unconvinced response" 138 | fork_ $ sendAppendEntriesResponse _leaderId False False 139 | _ -> return () 140 | 141 | mergeCommitProof :: Ord nt => AppendEntriesResponse nt -> Raft nt et rt mt () 142 | mergeCommitProof aer@AppendEntriesResponse{..} = do 143 | ci <- use commitIndex 144 | debug $ "merging commit proof for index: " ++ show _aerIndex 145 | when (_aerIndex > ci) $ 146 | commitProof.at _aerIndex %= maybe (Just (Set.singleton aer)) (Just . Set.insert aer) 147 | 148 | prevLogEntryMatches :: LogIndex -> Term -> Raft nt et rt mt Bool 149 | prevLogEntryMatches pli plt = do 150 | es <- use logEntries 151 | case seqIndex es pli of 152 | -- if we don't have the entry, only return true if pli is startIndex 153 | Nothing -> return (pli == startIndex) 154 | -- if we do have the entry, return true if the terms match 155 | Just LogEntry{..} -> return (_leTerm == plt) 156 | 157 | appendLogEntries :: (Binary nt, Binary et, Ord nt) => LogIndex -> Seq (LogEntry nt et) -> Raft nt et rt mt () 158 | appendLogEntries pli es = do 159 | logEntries %= (Seq.>< es) . Seq.take (pli + 1) 160 | traverse_ (\LogEntry{_leCommand = Command{..}} -> replayMap %= Map.insert (_cmdClientId, _cmdSig) Nothing) es 161 | updateLogHashesFromIndex (pli + 1) 162 | 163 | hashLogEntry :: (Binary nt, Binary et) => Maybe (LogEntry nt et) -> LogEntry nt et -> LogEntry nt et 164 | hashLogEntry (Just LogEntry{ _leHash = prevHash}) le = 165 | le { _leHash = hashlazy (encode (le { _leHash = prevHash }))} 166 | hashLogEntry Nothing le = 167 | le { _leHash = hashlazy (encode (le { _leHash = B.empty }))} 168 | 169 | updateLogHashesFromIndex :: (Binary nt, Binary et) => LogIndex -> Raft nt et rt mt () 170 | updateLogHashesFromIndex i = do 171 | es <- use logEntries 172 | case seqIndex es i of 173 | Just _ -> do 174 | logEntries %= Seq.adjust (hashLogEntry (seqIndex es (i - 1))) i 175 | updateLogHashesFromIndex (i + 1) 176 | Nothing -> return () 177 | 178 | addLogEntryAndHash :: (Binary nt, Binary et) => LogEntry nt et -> Seq (LogEntry nt et) -> Seq (LogEntry nt et) 179 | addLogEntryAndHash le es = 180 | case Seq.viewr es of 181 | _ Seq.:> ple -> es Seq.|> hashLogEntry (Just ple) le 182 | Seq.EmptyR -> Seq.singleton (hashLogEntry Nothing le) 183 | 184 | handleAppendEntriesResponse :: (Binary nt, Binary et, Binary rt, Ord nt) => AppendEntriesResponse nt -> Raft nt et rt mt () 185 | handleAppendEntriesResponse aer@AppendEntriesResponse{..} = do 186 | debug "got an appendEntriesResponse RPC" 187 | mergeCommitProof aer 188 | doCommit 189 | r <- use role 190 | ct <- use term 191 | when (r == Leader) $ do 192 | when (not _aerConvinced && _aerTerm <= ct) $ -- implies not _aerSuccess 193 | lConvinced %= Set.delete _aerNodeId 194 | when (_aerTerm == ct) $ do 195 | when (_aerConvinced && not _aerSuccess) $ 196 | lNextIndex %= Map.adjust (subtract 1) _aerNodeId 197 | when (_aerConvinced && _aerSuccess) $ do 198 | lNextIndex .at _aerNodeId .= Just (_aerIndex + 1) 199 | lConvinced %= Set.insert _aerNodeId 200 | when (not _aerConvinced || not _aerSuccess) $ 201 | fork_ $ sendAppendEntries _aerNodeId 202 | 203 | applyCommand :: Ord nt => Command nt et -> Raft nt et rt mt (nt, CommandResponse nt rt) 204 | applyCommand cmd@Command{..} = do 205 | apply <- view (rs.applyLogEntry) 206 | result <- apply _cmdEntry 207 | replayMap %= Map.insert (_cmdClientId, _cmdSig) (Just result) 208 | ((,) _cmdClientId) <$> makeCommandResponse cmd result 209 | 210 | makeCommandResponse :: Command nt et -> rt -> Raft nt et rt mt (CommandResponse nt rt) 211 | makeCommandResponse Command{..} result = do 212 | nid <- view (cfg.nodeId) 213 | mlid <- use currentLeader 214 | return $ CommandResponse 215 | result 216 | (maybe nid id mlid) 217 | nid 218 | _cmdRequestId 219 | LB.empty 220 | 221 | doCommit :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 222 | doCommit = do 223 | commitUpdate <- updateCommitIndex 224 | when commitUpdate applyLogEntries 225 | 226 | -- apply the un-applied log entries up through commitIndex 227 | -- and send results to the client if you are the leader 228 | -- TODO: have this done on a separate thread via event passing 229 | applyLogEntries :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 230 | applyLogEntries = do 231 | la <- use lastApplied 232 | ci <- use commitIndex 233 | le <- use logEntries 234 | let leToApply = Seq.drop (la + 1) . Seq.take (ci + 1) $ le 235 | results <- mapM (applyCommand . _leCommand) leToApply 236 | r <- use role 237 | when (r == Leader) $ fork_ $ sendResults results 238 | lastApplied .= ci 239 | 240 | 241 | -- checks to see what the largest N where a quorum of nodes 242 | -- has sent us proof of a commit up to that index 243 | updateCommitIndex :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt Bool 244 | updateCommitIndex = do 245 | ci <- use commitIndex 246 | proof <- use commitProof 247 | qsize <- view quorumSize 248 | es <- use logEntries 249 | 250 | -- get all indices in the log past commitIndex 251 | let inds = [(ci + 1)..(Seq.length es - 1)] 252 | 253 | -- get the prefix of these indices where a quorum of nodes have 254 | -- provided proof of having replicated that entry 255 | let qcinds = takeWhile (\i -> (not . Map.null) (Map.filterWithKey (\k s -> k >= i && Set.size s + 1 >= qsize) proof)) inds 256 | 257 | case qcinds of 258 | [] -> return False 259 | _ -> do 260 | let qci = last qcinds 261 | case Map.lookup qci proof of 262 | Just s -> do 263 | let lhash = _leHash (Seq.index es qci) 264 | valid <- checkCommitProof qci lhash s 265 | if valid 266 | then do 267 | commitIndex .= qci 268 | commitProof %= Map.filterWithKey (\k _ -> k >= qci) 269 | debug $ "commit index is now: " ++ show qci 270 | return True 271 | else 272 | return False 273 | Nothing -> return False 274 | 275 | checkCommitProof :: (Binary nt, Binary et, Binary rt, Ord nt) 276 | => LogIndex -> B.ByteString -> Set (AppendEntriesResponse nt) -> Raft nt et rt mt Bool 277 | checkCommitProof ci lhash aers = do 278 | sigsOkay <- allM (verifyRPCWithKey . AER) (Set.toList aers) 279 | return $ sigsOkay && all (\AppendEntriesResponse{..} -> _aerHash == lhash && _aerIndex == ci) aers 280 | 281 | handleRequestVote :: (Binary nt, Binary et, Binary rt, Eq nt) => RequestVote nt -> Raft nt et rt mt () 282 | handleRequestVote RequestVote{..} = do 283 | debug $ "got a requestVote RPC for " ++ show _rvTerm 284 | mvote <- use votedFor 285 | es <- use logEntries 286 | ct <- use term 287 | cl <- use currentLeader 288 | ig <- use ignoreLeader 289 | case mvote of 290 | _ | ig && cl == Just _rvCandidateId -> return () 291 | -- don't respond to a candidate if they were leader and a client 292 | -- asked us to ignore them 293 | 294 | _ | _rvTerm < ct -> do 295 | -- this is an old candidate 296 | debug "this is for an old term" 297 | fork_ $ sendRequestVoteResponse _rvCandidateId False 298 | 299 | Just c | c == _rvCandidateId && _rvTerm == ct -> do 300 | -- already voted for this candidate in this term 301 | debug "already voted for this candidate" 302 | fork_ $ sendRequestVoteResponse _rvCandidateId True 303 | 304 | Just _ | _rvTerm == ct -> do 305 | -- already voted for a different candidate in this term 306 | debug "already voted for a different candidate" 307 | fork_ $ sendRequestVoteResponse _rvCandidateId False 308 | 309 | _ -> if (_lastLogTerm, _lastLogIndex) >= let (llt, lli, _) = lastLogInfo es in (llt, lli) 310 | -- we have no recorded vote, or this request is for a higher term 311 | -- (we don't externalize votes without updating our own term, so we 312 | -- haven't voted in the higher term before) 313 | -- lazily vote for the candidate if its log is at least as 314 | -- up to date as ours, use the Ord instance of (Term, Index) to prefer 315 | -- higher terms, and then higher last indices for equal terms 316 | then do 317 | lv <- use lazyVote 318 | case lv of 319 | Just (t, _) | t >= _rvTerm -> 320 | debug "would vote lazily, but already voted lazily for candidate in same or higher term" 321 | Just _ -> do 322 | debug "replacing lazy vote" 323 | lazyVote .= Just (_rvTerm, _rvCandidateId) 324 | Nothing -> do 325 | debug "haven't voted, (lazily) voting for this candidate" 326 | lazyVote .= Just (_rvTerm, _rvCandidateId) 327 | else do 328 | debug "haven't voted, but my log is better than this candidate's" 329 | fork_ $ sendRequestVoteResponse _rvCandidateId False 330 | 331 | handleRequestVoteResponse :: (Binary nt, Binary et, Binary rt, Ord nt) => RequestVoteResponse nt -> Raft nt et rt mt () 332 | handleRequestVoteResponse rvr@RequestVoteResponse{..} = do 333 | debug $ "got a requestVoteResponse RPC for " ++ show _rvrTerm ++ ": " ++ show _voteGranted 334 | r <- use role 335 | ct <- use term 336 | when (r == Candidate && ct == _rvrTerm) $ 337 | if _voteGranted 338 | then do 339 | cYesVotes %= Set.insert rvr 340 | checkElection 341 | else 342 | cPotentialVotes %= Set.delete _rvrNodeId 343 | 344 | handleCommand :: (Binary nt, Binary et, Binary rt, Ord nt) => Command nt et -> Raft nt et rt mt () 345 | handleCommand cmd@Command{..} = do 346 | debug "got a command RPC" 347 | r <- use role 348 | ct <- use term 349 | mlid <- use currentLeader 350 | replays <- use replayMap 351 | case (Map.lookup (_cmdClientId, _cmdSig) replays, r, mlid) of 352 | (Just (Just result), _, _) -> do 353 | cmdr <- makeCommandResponse cmd result 354 | sendSignedRPC _cmdClientId $ CMDR cmdr 355 | -- we have already committed this request, so send the result to the client 356 | (Just Nothing, _, _) -> 357 | -- we have already seen this request, but have not yet committed it 358 | -- nothing to do 359 | return () 360 | (_, Leader, _) -> do 361 | -- we're the leader, so append this to our log with the current term 362 | -- and propagate it to replicas 363 | logEntries %= addLogEntryAndHash (LogEntry ct cmd B.empty) 364 | replayMap %= Map.insert (_cmdClientId, _cmdSig) Nothing 365 | fork_ sendAllAppendEntries 366 | fork_ sendAllAppendEntriesResponse 367 | doCommit 368 | (_, _, Just lid) -> 369 | -- we're not the leader, but we know who the leader is, so forward this 370 | -- command (don't sign it ourselves, as it comes from the client) 371 | fork_ $ sendRPC lid $ CMD cmd 372 | (_, _, Nothing) -> 373 | -- we're not the leader, and we don't know who the leader is, so can't do 374 | -- anything 375 | return () 376 | 377 | handleRevolution :: Ord nt => Revolution nt -> Raft nt et rt mt () 378 | handleRevolution Revolution{..} = do 379 | cl <- use currentLeader 380 | whenM (Map.notMember (_revClientId, _revSig) <$> use replayMap) $ 381 | case cl of 382 | Just l | l == _revLeaderId -> do 383 | replayMap %= Map.insert (_revClientId, _revSig) Nothing 384 | -- clear our lazy vote if it was for this leader 385 | lv <- use lazyVote 386 | case lv of 387 | Just (_, lvid) | lvid == _revLeaderId -> lazyVote .= Nothing 388 | _ -> return () 389 | ignoreLeader .= True 390 | _ -> return () 391 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Role.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Byzantine.Role 2 | ( becomeFollower 3 | , becomeLeader 4 | , becomeCandidate 5 | , checkElection 6 | , setVotedFor 7 | ) where 8 | 9 | import Control.Lens hiding (Index) 10 | import Control.Monad 11 | import Data.Binary 12 | import qualified Data.ByteString.Lazy as B 13 | import qualified Data.Map as Map 14 | import qualified Data.Sequence as Seq 15 | import qualified Data.Set as Set 16 | 17 | import Network.Tangaroa.Byzantine.Timer 18 | import Network.Tangaroa.Byzantine.Types 19 | import Network.Tangaroa.Byzantine.Util 20 | import Network.Tangaroa.Byzantine.Sender 21 | import Network.Tangaroa.Combinator 22 | 23 | -- count the yes votes and become leader if you have reached a quorum 24 | checkElection :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 25 | checkElection = do 26 | nyes <- Set.size <$> use cYesVotes 27 | qsize <- view quorumSize 28 | debug $ "yes votes: " ++ show nyes ++ " quorum size: " ++ show qsize 29 | when (nyes >= qsize) $ becomeLeader 30 | 31 | setVotedFor :: Maybe nt -> Raft nt et rt mt () 32 | setVotedFor mvote = do 33 | _ <- rs.writeVotedFor ^$ mvote 34 | votedFor .= mvote 35 | 36 | becomeFollower :: Raft nt et rt mt () 37 | becomeFollower = do 38 | debug "becoming follower" 39 | role .= Follower 40 | resetElectionTimer 41 | 42 | becomeCandidate :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 43 | becomeCandidate = do 44 | debug "becoming candidate" 45 | role .= Candidate 46 | use term >>= updateTerm . (+ 1) 47 | nid <- view (cfg.nodeId) 48 | setVotedFor (Just nid) 49 | ct <- use term 50 | selfVote <- signRPCWithKey $ RequestVoteResponse ct nid True nid B.empty 51 | cYesVotes .= Set.singleton selfVote 52 | (cPotentialVotes .=) =<< view (cfg.otherNodes) 53 | resetElectionTimer 54 | -- this is necessary for a single-node cluster, as we have already won the 55 | -- election in that case. otherwise we will wait for more votes to check again 56 | checkElection -- can possibly transition to leader 57 | r <- use role 58 | when (r == Candidate) $ fork_ sendAllRequestVotes 59 | 60 | becomeLeader :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 61 | becomeLeader = do 62 | debug "becoming leader" 63 | role .= Leader 64 | (currentLeader .=) . Just =<< view (cfg.nodeId) 65 | ni <- Seq.length <$> use logEntries 66 | (lNextIndex .=) =<< Map.fromSet (const ni) <$> view (cfg.otherNodes) 67 | (lMatchIndex .=) =<< Map.fromSet (const startIndex) <$> view (cfg.otherNodes) 68 | lConvinced .= Set.empty 69 | fork_ sendAllAppendEntries 70 | resetHeartbeatTimer 71 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Sender.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | 3 | module Network.Tangaroa.Byzantine.Sender 4 | ( sendAppendEntries 5 | , sendAppendEntriesResponse 6 | , sendRequestVote 7 | , sendRequestVoteResponse 8 | , sendAllAppendEntries 9 | , sendAllRequestVotes 10 | , sendAllAppendEntriesResponse 11 | , sendResults 12 | , sendRPC 13 | , sendSignedRPC 14 | ) where 15 | 16 | import Control.Lens 17 | import Data.Binary 18 | import Data.Foldable (traverse_) 19 | import Data.Sequence (Seq) 20 | import Data.Set (Set) 21 | import qualified Data.ByteString.Lazy as B 22 | import qualified Data.Sequence as Seq 23 | import qualified Data.Set as Set 24 | 25 | import Network.Tangaroa.Byzantine.Util 26 | import Network.Tangaroa.Byzantine.Types 27 | 28 | sendAppendEntries :: (Binary nt, Binary et, Binary rt, Ord nt) => nt -> Raft nt et rt mt () 29 | sendAppendEntries target = do 30 | mni <- use $ lNextIndex.at target 31 | es <- use logEntries 32 | let (pli,plt) = logInfoForNextIndex mni es 33 | ct <- use term 34 | nid <- view (cfg.nodeId) 35 | debug $ "sendAppendEntries: " ++ show ct 36 | qVoteList <- getVotesForNode target 37 | sendSignedRPC target $ AE $ 38 | AppendEntries ct nid pli plt (Seq.drop (pli + 1) es) qVoteList B.empty 39 | 40 | getVotesForNode :: Ord nt => nt -> Raft nt et rt mt (Set (RequestVoteResponse nt)) 41 | getVotesForNode target = do 42 | convinced <- Set.member target <$> use lConvinced 43 | if convinced 44 | then return Set.empty 45 | else use cYesVotes 46 | 47 | sendAppendEntriesResponse :: (Binary nt, Binary et, Binary rt) => nt -> Bool -> Bool -> Raft nt et rt mt () 48 | sendAppendEntriesResponse target success convinced = do 49 | ct <- use term 50 | nid <- view (cfg.nodeId) 51 | debug $ "sendAppendEntriesResponse: " ++ show ct 52 | (_, lindex, lhash) <- lastLogInfo <$> use logEntries 53 | sendSignedRPC target $ AER $ AppendEntriesResponse ct nid success convinced lindex lhash B.empty 54 | 55 | sendAllAppendEntriesResponse :: (Binary nt, Binary et, Binary rt) => Raft nt et rt mt () 56 | sendAllAppendEntriesResponse = 57 | traverse_ (\n -> sendAppendEntriesResponse n True True) =<< view (cfg.otherNodes) 58 | 59 | sendRequestVote :: (Binary nt, Binary et, Binary rt) => nt -> Raft nt et rt mt () 60 | sendRequestVote target = do 61 | ct <- use term 62 | nid <- view (cfg.nodeId) 63 | (llt, lli, _) <- lastLogInfo <$> use logEntries 64 | debug $ "sendRequestVote: " ++ show ct 65 | sendSignedRPC target $ RV $ RequestVote ct nid lli llt B.empty 66 | 67 | sendRequestVoteResponse :: (Binary nt, Binary et, Binary rt) => nt -> Bool -> Raft nt et rt mt () 68 | sendRequestVoteResponse target vote = do 69 | ct <- use term 70 | nid <- view (cfg.nodeId) 71 | debug $ "sendRequestVoteResponse: " ++ show ct 72 | sendSignedRPC target $ RVR $ RequestVoteResponse ct nid vote target B.empty 73 | 74 | sendAllAppendEntries :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 75 | sendAllAppendEntries = traverse_ sendAppendEntries =<< view (cfg.otherNodes) 76 | 77 | sendAllRequestVotes :: (Binary nt, Binary et, Binary rt) => Raft nt et rt mt () 78 | sendAllRequestVotes = traverse_ sendRequestVote =<< use cPotentialVotes 79 | 80 | sendResults :: (Binary nt, Binary et, Binary rt) => Seq (nt, CommandResponse nt rt) -> Raft nt et rt mt () 81 | sendResults results = do 82 | traverse_ (\(target,cmdr) -> sendSignedRPC target $ CMDR cmdr) results 83 | 84 | -- called by leaders sending appendEntries. 85 | -- given a replica's nextIndex, get the index and term to send as 86 | -- prevLog(Index/Term) 87 | logInfoForNextIndex :: Maybe LogIndex -> Seq (LogEntry nt et) -> (LogIndex,Term) 88 | logInfoForNextIndex mni es = 89 | case mni of 90 | Just ni -> let pli = ni - 1 in 91 | case seqIndex es pli of 92 | Just LogEntry{..} -> (pli, _leTerm) 93 | -- this shouldn't happen, because nextIndex - 1 should always be at 94 | -- most our last entry 95 | Nothing -> (startIndex, startTerm) 96 | Nothing -> (startIndex, startTerm) 97 | 98 | sendRPC :: nt -> RPC nt et rt -> Raft nt et rt mt () 99 | sendRPC target rpc = do 100 | send <- view (rs.sendMessage) 101 | ser <- view (rs.serializeRPC) 102 | send target $ ser rpc 103 | 104 | sendSignedRPC :: (Binary nt, Binary et, Binary rt) => nt -> RPC nt et rt -> Raft nt et rt mt () 105 | sendSignedRPC target rpc = do 106 | pk <- view (cfg.privateKey) 107 | sendRPC target $ case rpc of 108 | AE ae -> AE $ signRPC pk ae 109 | AER aer -> AER $ signRPC pk aer 110 | RV rv -> RV $ signRPC pk rv 111 | RVR rvr -> RVR $ signRPC pk rvr 112 | CMD cmd -> CMD $ signRPC pk cmd 113 | CMDR cmdr -> CMDR $ signRPC pk cmdr 114 | REVOLUTION rev -> REVOLUTION $ signRPC pk rev 115 | _ -> rpc 116 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Server.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Byzantine.Server 2 | ( runRaftServer 3 | ) where 4 | 5 | import Data.Binary 6 | import Control.Concurrent.Chan.Unagi 7 | import Control.Lens 8 | import qualified Data.Set as Set 9 | 10 | import Network.Tangaroa.Byzantine.Handler 11 | import Network.Tangaroa.Byzantine.Types 12 | import Network.Tangaroa.Byzantine.Util 13 | import Network.Tangaroa.Byzantine.Timer 14 | 15 | runRaftServer :: (Binary nt, Binary et, Binary rt, Ord nt) => Config nt -> RaftSpec nt et rt mt -> IO () 16 | runRaftServer rconf spec = do 17 | let qsize = getQuorumSize $ 1 + (Set.size $ rconf ^. otherNodes) 18 | (ein, eout) <- newChan 19 | runRWS_ 20 | raft 21 | (RaftEnv rconf qsize ein eout (liftRaftSpec spec)) 22 | initialRaftState 23 | 24 | raft :: (Binary nt, Binary et, Binary rt, Ord nt) => Raft nt et rt mt () 25 | raft = do 26 | fork_ messageReceiver 27 | resetElectionTimer 28 | handleEvents 29 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Spec/Simple.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Byzantine.Spec.Simple 2 | ( runServer 3 | , runClient 4 | ) where 5 | 6 | import Network.Tangaroa.Byzantine.Server 7 | import Network.Tangaroa.Byzantine.Client 8 | import Network.Tangaroa.Byzantine.Types 9 | 10 | import Control.Lens 11 | import Data.Word 12 | import Data.Binary 13 | import Network.Socket 14 | import System.Console.GetOpt 15 | import System.Environment 16 | import System.Exit 17 | import Text.Read 18 | import qualified Data.Set as Set 19 | import qualified Data.Map as Map 20 | import Codec.Crypto.RSA 21 | 22 | options :: [OptDescr (Config NodeType -> IO (Config NodeType))] 23 | options = 24 | [ Option ['s'] ["self"] 25 | (ReqArg setThisNode "SELF_PORT_NUMBER") 26 | "The port number of this node." 27 | , Option ['d'] ["debug"] 28 | (NoArg (return . (enableDebug .~ True))) 29 | "Enable debugging info (show RPCs and timeouts)." 30 | , Option ['p'] ["public-keys"] 31 | (ReqArg getPublicKeys "NODE_PUBLIC_KEY_FILE") 32 | "A file containing a map of nodes to their public key." 33 | , Option ['c'] ["client-keys"] 34 | (ReqArg getClientPublicKeys "CLIENT_PUBLIC_KEY_FILE") 35 | "A file containing a map of clients to their public key." 36 | , Option ['k'] ["private-key"] 37 | (ReqArg getPrivateKey "PRIVATE_KEY_FILE") 38 | "A file containing the node's private key." 39 | ] 40 | 41 | cfgFold :: [a -> IO a] -> a -> IO a 42 | cfgFold [] x = return x 43 | cfgFold (f:fs) x = do 44 | fx <- f x 45 | cfgFold fs fx 46 | 47 | getConfig :: IO (Config NodeType) 48 | getConfig = do 49 | argv <- getArgs 50 | case getOpt Permute options argv of 51 | (opts,args,[]) -> cfgFold (opts ++ map addOtherNode args) defaultConfig 52 | (_,_,errs) -> mapM_ putStrLn errs >> exitFailure 53 | 54 | type NodeType = (HostAddress, Word16) 55 | localhost :: HostAddress 56 | localhost = 0x0100007f 57 | 58 | defaultPortNum :: PortNumber 59 | defaultPortNum = 10000 60 | 61 | defaultConfig :: Config NodeType 62 | defaultConfig = 63 | Config 64 | Set.empty -- other nodes 65 | (localhost,fromIntegral defaultPortNum) -- self address 66 | Map.empty -- publicKeys 67 | Map.empty -- clientPublicKeys 68 | (PrivateKey (PublicKey 0 0 0) 0 0 0 0 0 0) -- empty public key 69 | (3000000,6000000) -- election timeout range 70 | 1500000 -- heartbeat timeout 71 | False -- no debug 72 | 5 -- client timeouts before revolution 73 | 74 | nodeSockAddr :: NodeType -> SockAddr 75 | nodeSockAddr (host,port) = SockAddrInet (fromIntegral port) host 76 | 77 | setThisNode :: String -> Config NodeType -> IO (Config NodeType) 78 | setThisNode s = 79 | return . maybe id (\p -> nodeId .~ (localhost, p)) (readMaybe s) 80 | 81 | addOtherNode :: String -> Config NodeType -> IO (Config NodeType) 82 | addOtherNode s = 83 | return . maybe id (\p -> otherNodes %~ Set.insert (localhost, p)) (readMaybe s) 84 | 85 | getPublicKeys :: FilePath -> Config NodeType -> IO (Config NodeType) 86 | getPublicKeys filename conf = do 87 | contents <- readFile filename 88 | return $ case readMaybe contents of 89 | Just pkm -> conf & publicKeys .~ pkm 90 | Nothing -> conf 91 | 92 | getClientPublicKeys :: FilePath -> Config NodeType -> IO (Config NodeType) 93 | getClientPublicKeys filename conf = do 94 | contents <- readFile filename 95 | return $ case readMaybe contents of 96 | Just pkm -> conf & clientPublicKeys .~ pkm 97 | Nothing -> conf 98 | 99 | getPrivateKey :: FilePath -> Config NodeType -> IO (Config NodeType) 100 | getPrivateKey filename conf = do 101 | contents <- readFile filename 102 | return $ case readMaybe contents of 103 | Just pk -> conf & privateKey .~ pk 104 | Nothing -> conf 105 | 106 | getMsg :: Socket -> IO String 107 | getMsg sock = recv sock 8192 108 | 109 | msgSend :: Socket -> NodeType -> String -> IO () 110 | msgSend sock node s = 111 | sendTo sock s (nodeSockAddr node) >> return () 112 | 113 | showDebug :: NodeType -> String -> IO () 114 | showDebug node msg = putStrLn $ show (snd node) ++ " " ++ msg 115 | 116 | noDebug :: NodeType -> String -> IO () 117 | noDebug _ _ = return () 118 | 119 | simpleRaftSpec :: (Show et, Read et, Show rt, Read rt) 120 | => Socket 121 | -> (et -> IO rt) 122 | -> (NodeType -> String -> IO ()) 123 | -> RaftSpec NodeType et rt String 124 | simpleRaftSpec sock applyFn debugFn = RaftSpec 125 | { 126 | -- TODO don't read log entries 127 | __readLogEntry = return . const Nothing 128 | -- TODO don't write log entries 129 | , __writeLogEntry = \_ _ -> return () 130 | -- TODO always read startTerm 131 | , __readTermNumber = return startTerm 132 | -- TODO don't write term numbers 133 | , __writeTermNumber = return . const () 134 | -- TODO never voted for anyone 135 | , __readVotedFor = return Nothing 136 | -- TODO don't record votes 137 | , __writeVotedFor = return . const () 138 | -- apply log entries to the state machine, given by caller 139 | , __applyLogEntry = applyFn 140 | -- serialize with show 141 | , __serializeRPC = show 142 | -- deserialize with readMaybe 143 | , __deserializeRPC = readMaybe 144 | -- send messages using msgSend 145 | , __sendMessage = msgSend sock 146 | -- get messages using getMsg 147 | , __getMessage = getMsg sock 148 | -- use the debug function given by the caller 149 | , __debugPrint = debugFn 150 | } 151 | 152 | runServer :: (Binary et, Binary rt, Show et, Read et, Show rt, Read rt) => 153 | (et -> IO rt) -> IO () 154 | runServer applyFn = do 155 | rconf <- getConfig 156 | sock <- socket AF_INET Datagram defaultProtocol 157 | bind sock $ nodeSockAddr $ rconf ^. nodeId 158 | let debugFn = if (rconf ^. enableDebug) then showDebug else noDebug 159 | runRaftServer rconf $ simpleRaftSpec sock applyFn debugFn 160 | 161 | runClient :: (Binary et, Binary rt, Show et, Read et, Show rt, Read rt) => 162 | (et -> IO rt) -> IO et -> (rt -> IO ()) -> IO () 163 | runClient applyFn getEntry useResult = do 164 | rconf <- getConfig 165 | sock <- socket AF_INET Datagram defaultProtocol 166 | bind sock $ nodeSockAddr $ rconf ^. nodeId 167 | let debugFn = if (rconf ^. enableDebug) then showDebug else noDebug 168 | runRaftClient getEntry useResult rconf (simpleRaftSpec sock applyFn debugFn) 169 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Timer.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Byzantine.Timer 2 | ( resetElectionTimer 3 | , resetHeartbeatTimer 4 | , cancelTimer 5 | ) where 6 | 7 | import Control.Lens hiding (Index) 8 | import Control.Monad.Trans (lift) 9 | import System.Random 10 | import Control.Concurrent.Lifted 11 | 12 | import Network.Tangaroa.Byzantine.Types 13 | import Network.Tangaroa.Byzantine.Util 14 | 15 | getNewElectionTimeout :: Raft nt et rt mt Int 16 | getNewElectionTimeout = view (cfg.electionTimeoutRange) >>= lift . randomRIO 17 | 18 | resetElectionTimer :: Raft nt et rt mt () 19 | resetElectionTimer = do 20 | timeout <- getNewElectionTimeout 21 | setTimedEvent (ElectionTimeout $ show (timeout `div` 1000) ++ "ms") timeout 22 | 23 | resetHeartbeatTimer :: Raft nt et rt mt () 24 | resetHeartbeatTimer = do 25 | timeout <- view (cfg.heartbeatTimeout) 26 | setTimedEvent (HeartbeatTimeout $ show (timeout `div` 1000) ++ "ms") timeout 27 | 28 | -- | Cancel any existing timer. 29 | cancelTimer :: Raft nt et rt mt () 30 | cancelTimer = do 31 | use timerThread >>= maybe (return ()) killThread 32 | timerThread .= Nothing 33 | 34 | -- | Cancels any pending timer and sets a new timer to trigger an event after t 35 | -- microseconds. 36 | setTimedEvent :: Event nt et rt -> Int -> Raft nt et rt mt () 37 | setTimedEvent e t = do 38 | cancelTimer 39 | tmr <- fork $ wait t >> enqueueEvent e 40 | timerThread .= Just tmr 41 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Types.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveGeneric #-} 2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 3 | {-# LANGUAGE Rank2Types #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | {-# LANGUAGE TemplateHaskell #-} 6 | 7 | module Network.Tangaroa.Byzantine.Types 8 | ( Raft 9 | , RaftSpec(..) 10 | , LiftedRaftSpec(..) 11 | , readLogEntry, writeLogEntry, readTermNumber, writeTermNumber 12 | , readVotedFor, writeVotedFor, applyLogEntry, serializeRPC 13 | , deserializeRPC, sendMessage, getMessage, debugPrint 14 | , liftRaftSpec 15 | , Term, startTerm 16 | , LogIndex, startIndex 17 | , RequestId, startRequestId 18 | , Config(..), otherNodes, nodeId, electionTimeoutRange, heartbeatTimeout 19 | , enableDebug, publicKeys, clientPublicKeys, privateKey, clientTimeoutLimit 20 | , Role(..) 21 | , RaftEnv(..), cfg, quorumSize, eventIn, eventOut, rs 22 | , LogEntry(..) 23 | , RaftState(..), role, term, votedFor, lazyVote, currentLeader, ignoreLeader 24 | , logEntries, commitIndex, commitProof, lastApplied, timerThread, replayMap 25 | , cYesVotes, cPotentialVotes, lNextIndex, lMatchIndex, lConvinced 26 | , numTimeouts, pendingRequests, currentRequestId 27 | , initialRaftState 28 | , AppendEntries(..) 29 | , AppendEntriesResponse(..) 30 | , RequestVote(..) 31 | , RequestVoteResponse(..) 32 | , Command(..) 33 | , CommandResponse(..) 34 | , Revolution(..) 35 | , RPC(..) 36 | , Event(..) 37 | , SigRPC 38 | , signRPC 39 | , verifyRPC 40 | ) where 41 | 42 | import Control.Concurrent (ThreadId) 43 | import Control.Concurrent.Chan.Unagi 44 | import Control.Lens hiding (Index) 45 | import Control.Monad.RWS 46 | import Codec.Crypto.RSA 47 | import Data.Binary 48 | import Data.Sequence (Seq) 49 | import qualified Data.Sequence as Seq 50 | import Data.Map (Map) 51 | import qualified Data.Map as Map 52 | import Data.Set (Set) 53 | import qualified Data.Set as Set 54 | import qualified Data.ByteString.Lazy as LB 55 | import qualified Data.ByteString as B 56 | 57 | import GHC.Generics 58 | 59 | newtype Term = Term Int 60 | deriving (Show, Read, Eq, Ord, Generic, Num) 61 | 62 | startTerm :: Term 63 | startTerm = Term (-1) 64 | 65 | type LogIndex = Int 66 | 67 | startIndex :: LogIndex 68 | startIndex = (-1) 69 | 70 | newtype RequestId = RequestId Int 71 | deriving (Show, Read, Eq, Ord, Generic, Num) 72 | 73 | startRequestId :: RequestId 74 | startRequestId = RequestId 0 75 | 76 | data Config nt = Config 77 | { _otherNodes :: Set nt 78 | , _nodeId :: nt 79 | , _publicKeys :: Map nt PublicKey 80 | , _clientPublicKeys :: Map nt PublicKey 81 | , _privateKey :: PrivateKey 82 | , _electionTimeoutRange :: (Int,Int) -- in microseconds 83 | , _heartbeatTimeout :: Int -- in microseconds 84 | , _enableDebug :: Bool 85 | , _clientTimeoutLimit :: Int 86 | } 87 | deriving (Show, Generic) 88 | makeLenses ''Config 89 | 90 | data Command nt et = Command 91 | { _cmdEntry :: et 92 | , _cmdClientId :: nt 93 | , _cmdRequestId :: RequestId 94 | , _cmdSig :: LB.ByteString 95 | } 96 | deriving (Show, Read, Generic) 97 | 98 | data CommandResponse nt rt = CommandResponse 99 | { _cmdrResult :: rt 100 | , _cmdrLeaderId :: nt 101 | , _cmdrNodeId :: nt 102 | , _cmdrRequestId :: RequestId 103 | , _cmdrSig :: LB.ByteString 104 | } 105 | deriving (Show, Read, Generic) 106 | 107 | data LogEntry nt et = LogEntry 108 | { _leTerm :: Term 109 | , _leCommand :: Command nt et 110 | , _leHash :: B.ByteString 111 | } 112 | deriving (Show, Read, Generic) 113 | 114 | data AppendEntries nt et = AppendEntries 115 | { _aeTerm :: Term 116 | , _leaderId :: nt 117 | , _prevLogIndex :: LogIndex 118 | , _prevLogTerm :: Term 119 | , _aeEntries :: Seq (LogEntry nt et) 120 | , _aeQuorumVotes :: Set (RequestVoteResponse nt) 121 | , _aeSig :: LB.ByteString 122 | } 123 | deriving (Show, Read, Generic) 124 | 125 | data AppendEntriesResponse nt = AppendEntriesResponse 126 | { _aerTerm :: Term 127 | , _aerNodeId :: nt 128 | , _aerSuccess :: Bool 129 | , _aerConvinced :: Bool 130 | , _aerIndex :: LogIndex 131 | , _aerHash :: B.ByteString 132 | , _aerSig :: LB.ByteString 133 | } 134 | deriving (Show, Read, Generic, Eq, Ord) 135 | 136 | data RequestVote nt = RequestVote 137 | { _rvTerm :: Term 138 | , _rvCandidateId :: nt 139 | , _lastLogIndex :: LogIndex 140 | , _lastLogTerm :: Term 141 | , _rvSig :: LB.ByteString 142 | } 143 | deriving (Show, Read, Generic) 144 | 145 | data RequestVoteResponse nt = RequestVoteResponse 146 | { _rvrTerm :: Term 147 | , _rvrNodeId :: nt 148 | , _voteGranted :: Bool 149 | , _rvrCandidateId :: nt 150 | , _rvrSig :: LB.ByteString 151 | } 152 | deriving (Show, Read, Generic, Eq, Ord) 153 | 154 | data Revolution nt = Revolution 155 | { _revClientId :: nt 156 | , _revLeaderId :: nt 157 | , _revRequestId :: RequestId 158 | , _revSig :: LB.ByteString 159 | } 160 | deriving (Show, Read, Generic) 161 | 162 | data RPC nt et rt = AE (AppendEntries nt et) 163 | | AER (AppendEntriesResponse nt) 164 | | RV (RequestVote nt) 165 | | RVR (RequestVoteResponse nt) 166 | | CMD (Command nt et) 167 | | CMDR (CommandResponse nt rt) 168 | | REVOLUTION (Revolution nt) 169 | | DBG String 170 | deriving (Show, Read, Generic) 171 | 172 | class SigRPC rpc where 173 | signRPC :: PrivateKey -> rpc -> rpc 174 | verifyRPC :: PublicKey -> rpc -> Bool 175 | 176 | instance (Binary nt, Binary et) => SigRPC (Command nt et) where 177 | signRPC k rpc = rpc { _cmdSig = sign k (encode (rpc { _cmdSig = LB.empty })) } 178 | verifyRPC k rpc = verify k (encode (rpc { _cmdSig = LB.empty })) (_cmdSig rpc) 179 | 180 | instance (Binary nt, Binary rt) => SigRPC (CommandResponse nt rt) where 181 | signRPC k rpc = rpc { _cmdrSig = sign k (encode (rpc { _cmdrSig = LB.empty })) } 182 | verifyRPC k rpc = verify k (encode (rpc { _cmdrSig = LB.empty })) (_cmdrSig rpc) 183 | 184 | instance (Binary nt, Binary et) => SigRPC (AppendEntries nt et) where 185 | signRPC k rpc = rpc { _aeSig = sign k (encode (rpc { _aeSig = LB.empty })) } 186 | verifyRPC k rpc = verify k (encode (rpc { _aeSig = LB.empty })) (_aeSig rpc) 187 | 188 | instance Binary nt => SigRPC (AppendEntriesResponse nt) where 189 | signRPC k rpc = rpc { _aerSig = sign k (encode (rpc { _aerSig = LB.empty })) } 190 | verifyRPC k rpc = verify k (encode (rpc { _aerSig = LB.empty })) (_aerSig rpc) 191 | 192 | instance Binary nt => SigRPC (RequestVote nt) where 193 | signRPC k rpc = rpc { _rvSig = sign k (encode (rpc { _rvSig = LB.empty })) } 194 | verifyRPC k rpc = verify k (encode (rpc { _rvSig = LB.empty })) (_rvSig rpc) 195 | 196 | instance Binary nt => SigRPC (RequestVoteResponse nt) where 197 | signRPC k rpc = rpc { _rvrSig = sign k (encode (rpc { _rvrSig = LB.empty })) } 198 | verifyRPC k rpc = verify k (encode (rpc { _rvrSig = LB.empty })) (_rvrSig rpc) 199 | 200 | instance Binary nt => SigRPC (Revolution nt) where 201 | signRPC k rpc = rpc { _revSig = sign k (encode (rpc { _revSig = LB.empty })) } 202 | verifyRPC k rpc = verify k (encode (rpc { _revSig = LB.empty })) (_revSig rpc) 203 | 204 | -- | A structure containing all the implementation details for running 205 | -- the raft protocol. 206 | data RaftSpec nt et rt mt = RaftSpec 207 | { 208 | -- ^ Function to get a log entry from persistent storage. 209 | __readLogEntry :: LogIndex -> IO (Maybe et) 210 | 211 | -- ^ Function to write a log entry to persistent storage. 212 | , __writeLogEntry :: LogIndex -> (Term,et) -> IO () 213 | 214 | -- ^ Function to get the term number from persistent storage. 215 | , __readTermNumber :: IO Term 216 | 217 | -- ^ Function to write the term number to persistent storage. 218 | , __writeTermNumber :: Term -> IO () 219 | 220 | -- ^ Function to read the node voted for from persistent storage. 221 | , __readVotedFor :: IO (Maybe nt) 222 | 223 | -- ^ Function to write the node voted for to persistent storage. 224 | , __writeVotedFor :: Maybe nt -> IO () 225 | 226 | -- ^ Function to apply a log entry to the state machine. 227 | , __applyLogEntry :: et -> IO rt 228 | 229 | -- ^ Function to serialize an RPC. 230 | , __serializeRPC :: RPC nt et rt -> mt 231 | 232 | -- ^ Function to deserialize an RPC. 233 | , __deserializeRPC :: mt -> Maybe (RPC nt et rt) 234 | 235 | -- ^ Function to send a message to a node. 236 | , __sendMessage :: nt -> mt -> IO () 237 | 238 | -- ^ Function to get the next message. 239 | , __getMessage :: IO mt 240 | 241 | -- ^ Function to log a debug message (no newline). 242 | , __debugPrint :: nt -> String -> IO () 243 | } 244 | 245 | data Role = Follower 246 | | Candidate 247 | | Leader 248 | deriving (Show, Generic, Eq) 249 | 250 | data Event nt et rt = ERPC (RPC nt et rt) 251 | | ElectionTimeout String 252 | | HeartbeatTimeout String 253 | deriving (Show) 254 | 255 | -- | A version of RaftSpec where all IO functions are lifted 256 | -- into the Raft monad. 257 | data LiftedRaftSpec nt et rt mt t = LiftedRaftSpec 258 | { 259 | -- ^ Function to get a log entry from persistent storage. 260 | _readLogEntry :: MonadTrans t => LogIndex -> t IO (Maybe et) 261 | 262 | -- ^ Function to write a log entry to persistent storage. 263 | , _writeLogEntry :: MonadTrans t => LogIndex -> (Term,et) -> t IO () 264 | 265 | -- ^ Function to get the term number from persistent storage. 266 | , _readTermNumber :: MonadTrans t => t IO Term 267 | 268 | -- ^ Function to write the term number to persistent storage. 269 | , _writeTermNumber :: MonadTrans t => Term -> t IO () 270 | 271 | -- ^ Function to read the node voted for from persistent storage. 272 | , _readVotedFor :: MonadTrans t => t IO (Maybe nt) 273 | 274 | -- ^ Function to write the node voted for to persistent storage. 275 | , _writeVotedFor :: MonadTrans t => Maybe nt -> t IO () 276 | 277 | -- ^ Function to apply a log entry to the state machine. 278 | , _applyLogEntry :: MonadTrans t => et -> t IO rt 279 | 280 | -- ^ Function to serialize an RPC. 281 | , _serializeRPC :: RPC nt et rt -> mt 282 | 283 | -- ^ Function to deserialize an RPC. 284 | , _deserializeRPC :: mt -> Maybe (RPC nt et rt) 285 | 286 | -- ^ Function to send a message to a node. 287 | , _sendMessage :: MonadTrans t => nt -> mt -> t IO () 288 | 289 | -- ^ Function to get the next message. 290 | , _getMessage :: MonadTrans t => t IO mt 291 | 292 | -- ^ Function to log a debug message (no newline). 293 | , _debugPrint :: nt -> String -> t IO () 294 | } 295 | makeLenses ''LiftedRaftSpec 296 | 297 | liftRaftSpec :: MonadTrans t => RaftSpec nt et rt mt -> LiftedRaftSpec nt et rt mt t 298 | liftRaftSpec RaftSpec{..} = 299 | LiftedRaftSpec 300 | { _readLogEntry = lift . __readLogEntry 301 | , _writeLogEntry = \i et -> lift (__writeLogEntry i et) 302 | , _readTermNumber = lift __readTermNumber 303 | , _writeTermNumber = lift . __writeTermNumber 304 | , _readVotedFor = lift __readVotedFor 305 | , _writeVotedFor = lift . __writeVotedFor 306 | , _applyLogEntry = lift . __applyLogEntry 307 | , _serializeRPC = __serializeRPC 308 | , _deserializeRPC = __deserializeRPC 309 | , _sendMessage = \n m -> lift (__sendMessage n m) 310 | , _getMessage = lift __getMessage 311 | , _debugPrint = \n s -> lift (__debugPrint n s) 312 | } 313 | 314 | data RaftState nt et rt = RaftState 315 | { _role :: Role 316 | , _term :: Term 317 | , _votedFor :: Maybe nt 318 | , _lazyVote :: Maybe (Term, nt) 319 | , _currentLeader :: Maybe nt 320 | , _ignoreLeader :: Bool 321 | , _logEntries :: Seq (LogEntry nt et) 322 | , _commitIndex :: LogIndex 323 | , _lastApplied :: LogIndex 324 | , _commitProof :: Map LogIndex (Set (AppendEntriesResponse nt)) 325 | , _timerThread :: Maybe ThreadId 326 | , _replayMap :: Map (nt, LB.ByteString) (Maybe rt) 327 | , _cYesVotes :: Set (RequestVoteResponse nt) 328 | , _cPotentialVotes :: Set nt 329 | , _lNextIndex :: Map nt LogIndex 330 | , _lMatchIndex :: Map nt LogIndex 331 | , _lConvinced :: Set nt 332 | 333 | -- used by clients 334 | , _pendingRequests :: Map RequestId (Command nt et) 335 | , _currentRequestId :: RequestId 336 | , _numTimeouts :: Int 337 | } 338 | makeLenses ''RaftState 339 | 340 | initialRaftState :: RaftState nt et rt 341 | initialRaftState = RaftState 342 | Follower -- role 343 | startTerm -- term 344 | Nothing -- votedFor 345 | Nothing -- lazyVote 346 | Nothing -- currentLeader 347 | False -- ignoreLeader 348 | Seq.empty -- log 349 | startIndex -- commitIndex 350 | startIndex -- lastApplied 351 | Map.empty -- commitProof 352 | Nothing -- timerThread 353 | Map.empty -- replayMap 354 | Set.empty -- cYesVotes 355 | Set.empty -- cPotentialVotes 356 | Map.empty -- lNextIndex 357 | Map.empty -- lMatchIndex 358 | Set.empty -- lConvinced 359 | Map.empty -- pendingRequests 360 | 0 -- nextRequestId 361 | 0 -- numTimeouts 362 | 363 | data RaftEnv nt et rt mt = RaftEnv 364 | { _cfg :: Config nt 365 | , _quorumSize :: Int 366 | , _eventIn :: InChan (Event nt et rt) 367 | , _eventOut :: OutChan (Event nt et rt) 368 | , _rs :: LiftedRaftSpec nt et rt mt (RWST (RaftEnv nt et rt mt) () (RaftState nt et rt)) 369 | } 370 | makeLenses ''RaftEnv 371 | 372 | type Raft nt et rt mt a = RWST (RaftEnv nt et rt mt) () (RaftState nt et rt) IO a 373 | 374 | instance Binary Term 375 | instance Binary RequestId 376 | 377 | instance (Binary nt, Binary et) => Binary (LogEntry nt et) 378 | instance (Binary nt, Binary et) => Binary (AppendEntries nt et) 379 | instance Binary nt => Binary (AppendEntriesResponse nt) 380 | instance Binary nt => Binary (RequestVote nt) 381 | instance Binary nt => Binary (RequestVoteResponse nt) 382 | instance (Binary nt, Binary et) => Binary (Command nt et) 383 | instance (Binary nt, Binary rt) => Binary (CommandResponse nt rt) 384 | instance Binary nt => Binary (Revolution nt) 385 | 386 | instance (Binary nt, Binary et, Binary rt) => Binary (RPC nt et rt) 387 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Byzantine/Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | 4 | module Network.Tangaroa.Byzantine.Util 5 | ( seqIndex 6 | , lastLogInfo 7 | , getQuorumSize 8 | , debug 9 | , fork_ 10 | , wait 11 | , runRWS_ 12 | , enqueueEvent 13 | , dequeueEvent 14 | , messageReceiver 15 | , verifyRPCWithKey 16 | , verifyRPCWithClientKey 17 | , signRPCWithKey 18 | , updateTerm 19 | ) where 20 | 21 | import Network.Tangaroa.Byzantine.Types 22 | import Network.Tangaroa.Combinator 23 | 24 | import Control.Lens 25 | import Data.Binary 26 | import Codec.Crypto.RSA 27 | import Data.Sequence (Seq) 28 | import Control.Monad.RWS 29 | import Control.Concurrent.Lifted (fork, threadDelay) 30 | import Control.Monad.Trans.Control (MonadBaseControl) 31 | import Control.Concurrent.Chan.Unagi (readChan, writeChan) 32 | import qualified Data.ByteString as B 33 | import qualified Data.Sequence as Seq 34 | import qualified Data.Map as Map 35 | 36 | seqIndex :: Seq a -> Int -> Maybe a 37 | seqIndex s i = 38 | if i >= 0 && i < Seq.length s 39 | then Just (Seq.index s i) 40 | else Nothing 41 | 42 | getQuorumSize :: Int -> Int 43 | getQuorumSize n = minimum [n - f | f <- [0..n], n >= 3*f + 1] 44 | 45 | -- get the last term and index of a log 46 | lastLogInfo :: Seq (LogEntry nt et) -> (Term, LogIndex, B.ByteString) 47 | lastLogInfo es = 48 | case Seq.viewr es of 49 | _ Seq.:> LogEntry{..} -> (_leTerm, Seq.length es - 1, _leHash) 50 | Seq.EmptyR -> (startTerm, startIndex, B.empty) 51 | 52 | debug :: String -> Raft nt et rt mt () 53 | debug s = do 54 | dbg <- view (rs.debugPrint) 55 | nid <- view (cfg.nodeId) 56 | dbg nid s 57 | 58 | fork_ :: MonadBaseControl IO m => m () -> m () 59 | fork_ a = fork a >> return () 60 | 61 | wait :: Int -> Raft nt et rt mt () 62 | wait t = threadDelay t 63 | 64 | runRWS_ :: Monad m => RWST r w s m a -> r -> s -> m () 65 | runRWS_ ma r s = runRWST ma r s >> return () 66 | 67 | enqueueEvent :: Event nt et rt -> Raft nt et rt mt () 68 | enqueueEvent event = do 69 | ein <- view eventIn 70 | lift $ writeChan ein event 71 | 72 | dequeueEvent :: Raft nt et rt mt (Event nt et rt) 73 | dequeueEvent = lift . readChan =<< view eventOut 74 | 75 | -- | Thread to take incoming messages and write them to the event queue. 76 | messageReceiver :: Raft nt et rt mt () 77 | messageReceiver = do 78 | gm <- view (rs.getMessage) 79 | deser <- view (rs.deserializeRPC) 80 | forever $ 81 | gm >>= maybe 82 | (debug "failed to deserialize RPC") 83 | (enqueueEvent . ERPC) 84 | . deser 85 | 86 | verifyWrappedRPC :: (Binary nt, Binary et, Binary rt) => PublicKey -> RPC nt et rt -> Bool 87 | verifyWrappedRPC k rpc = case rpc of 88 | AE ae -> verifyRPC k ae 89 | AER aer -> verifyRPC k aer 90 | RV rv -> verifyRPC k rv 91 | RVR rvr -> verifyRPC k rvr 92 | CMD cmd -> verifyRPC k cmd 93 | CMDR cmdr -> verifyRPC k cmdr 94 | REVOLUTION rev -> verifyRPC k rev 95 | DBG _ -> True 96 | 97 | senderId :: RPC nt et rt -> Maybe nt 98 | senderId rpc = case rpc of 99 | AE ae -> Just (_leaderId ae) 100 | AER aer -> Just (_aerNodeId aer) 101 | RV rv -> Just (_rvCandidateId rv) 102 | RVR rvr -> Just (_rvrNodeId rvr) 103 | CMD cmd -> Just (_cmdClientId cmd) 104 | CMDR cmdr -> Just (_cmdrNodeId cmdr) 105 | REVOLUTION rev -> Just (_revClientId rev) 106 | DBG _ -> Nothing 107 | 108 | verifyRPCWithKey :: (Binary nt, Binary et, Binary rt, Ord nt) => RPC nt et rt -> Raft nt et rt mt Bool 109 | verifyRPCWithKey rpc = 110 | case rpc of 111 | AE _ -> doVerify rpc 112 | AER _ -> doVerify rpc 113 | RV _ -> doVerify rpc 114 | RVR _ -> doVerify rpc 115 | CMDR _ -> doVerify rpc 116 | _ -> return False 117 | where 118 | doVerify rpc' = do 119 | pks <- view (cfg.publicKeys) 120 | let mk = (\k -> Map.lookup k pks) =<< senderId rpc' 121 | maybe 122 | (debug "RPC has invalid signature" >> return False) 123 | (\k -> return (verifyWrappedRPC k rpc')) 124 | mk 125 | 126 | verifyRPCWithClientKey :: (Binary nt, Binary et, Binary rt, Ord nt) => RPC nt et rt -> Raft nt et rt mt Bool 127 | verifyRPCWithClientKey rpc = 128 | case rpc of 129 | CMD _ -> doVerify rpc 130 | REVOLUTION _ -> doVerify rpc 131 | _ -> return False 132 | where 133 | doVerify rpc' = do 134 | pks <- view (cfg.clientPublicKeys) 135 | let mk = (\k -> Map.lookup k pks) =<< senderId rpc' 136 | maybe 137 | (debug "RPC has invalid signature" >> return False) 138 | (\k -> return (verifyWrappedRPC k rpc')) 139 | mk 140 | 141 | signRPCWithKey :: SigRPC rpc => rpc -> Raft nt et rt mt rpc 142 | signRPCWithKey rpc = do 143 | pk <- view (cfg.privateKey) 144 | return (signRPC pk rpc) 145 | 146 | updateTerm :: Term -> Raft nt et rt mt () 147 | updateTerm t = do 148 | _ <- rs.writeTermNumber ^$ t 149 | term .= t 150 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Client.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | 4 | module Network.Tangaroa.Client 5 | ( runRaftClient 6 | ) where 7 | 8 | import Network.Tangaroa.Timer 9 | import Network.Tangaroa.Types 10 | import Network.Tangaroa.Util 11 | import Network.Tangaroa.Sender (sendRPC) 12 | 13 | import Control.Concurrent.Chan.Unagi 14 | import Control.Lens hiding (Index) 15 | import Control.Monad.RWS 16 | import qualified Data.Set as Set 17 | import qualified Data.Map as Map 18 | import Data.Foldable (traverse_) 19 | 20 | runRaftClient :: Ord nt => IO et -> (rt -> IO ()) -> Config nt -> RaftSpec nt et rt mt -> IO () 21 | runRaftClient getEntry useResult rconf spec@RaftSpec{..} = do 22 | let qsize = getQuorumSize $ Set.size $ rconf ^. otherNodes 23 | (ein, eout) <- newChan 24 | runRWS_ 25 | (raftClient (lift getEntry) (lift . useResult)) 26 | (RaftEnv rconf qsize ein eout (liftRaftSpec spec)) 27 | initialRaftState -- only use currentLeader and logEntries 28 | 29 | raftClient :: Ord nt => Raft nt et rt mt et -> (rt -> Raft nt et rt mt ()) -> Raft nt et rt mt () 30 | raftClient getEntry useResult = do 31 | nodes <- view (cfg.otherNodes) 32 | when (Set.null nodes) $ error "The client has no nodes to send requests to." 33 | currentLeader .= (Just $ Set.findMin nodes) 34 | fork_ messageReceiver 35 | fork_ $ commandGetter getEntry 36 | pendingRequests .= Map.empty 37 | clientHandleEvents useResult 38 | 39 | -- get commands with getEntry and put them on the event queue to be sent 40 | commandGetter :: Raft nt et rt mt et -> Raft nt et rt mt () 41 | commandGetter getEntry = do 42 | nid <- view (cfg.nodeId) 43 | forever $ do 44 | entry <- getEntry 45 | rid <- use nextRequestId 46 | nextRequestId += 1 47 | enqueueEvent $ ERPC $ CMD $ Command entry nid rid 48 | 49 | clientHandleEvents :: Ord nt => (rt -> Raft nt et rt mt ()) -> Raft nt et rt mt () 50 | clientHandleEvents useResult = forever $ do 51 | e <- dequeueEvent 52 | case e of 53 | ERPC (CMD cmd) -> clientSendCommand cmd -- these are commands coming from the commandGetter thread 54 | ERPC (CMDR cmdr) -> clientHandleCommandResponse useResult cmdr 55 | HeartbeatTimeout _ -> do 56 | debug "choosing a new leader and resending commands" 57 | setLeaderToNext 58 | traverse_ clientSendCommand =<< use pendingRequests 59 | _ -> return () 60 | 61 | setLeaderToFirst :: Raft nt et rt mt () 62 | setLeaderToFirst = do 63 | nodes <- view (cfg.otherNodes) 64 | when (Set.null nodes) $ error "the client has no nodes to send requests to" 65 | currentLeader .= (Just $ Set.findMin nodes) 66 | 67 | setLeaderToNext :: Ord nt => Raft nt et rt mt () 68 | setLeaderToNext = do 69 | mlid <- use currentLeader 70 | nodes <- view (cfg.otherNodes) 71 | case mlid of 72 | Just lid -> case Set.lookupGT lid nodes of 73 | Just nlid -> currentLeader .= Just nlid 74 | Nothing -> setLeaderToFirst 75 | Nothing -> setLeaderToFirst 76 | 77 | clientSendCommand :: Command nt et -> Raft nt et rt mt () 78 | clientSendCommand cmd@Command{..} = do 79 | mlid <- use currentLeader 80 | case mlid of 81 | Just lid -> do 82 | sendRPC lid $ CMD cmd 83 | prcount <- fmap Map.size (use pendingRequests) 84 | -- if this will be our only pending request, start the timer 85 | -- otherwise, it should already be running 86 | when (prcount == 0) resetHeartbeatTimer 87 | pendingRequests %= Map.insert _cmdRequestId cmd 88 | Nothing -> do 89 | setLeaderToFirst 90 | clientSendCommand cmd 91 | 92 | clientHandleCommandResponse :: (rt -> Raft nt et rt mt ()) 93 | -> CommandResponse nt rt 94 | -> Raft nt et rt mt () 95 | clientHandleCommandResponse useResult CommandResponse{..} = do 96 | prs <- use pendingRequests 97 | when (Map.member _cmdrRequestId prs) $ do 98 | useResult _cmdrResult 99 | currentLeader .= Just _cmdrLeaderId 100 | pendingRequests %= Map.delete _cmdrRequestId 101 | prcount <- fmap Map.size (use pendingRequests) 102 | -- if we still have pending requests, reset the timer 103 | -- otherwise cancel it 104 | if (prcount > 0) 105 | then resetHeartbeatTimer 106 | else cancelTimer 107 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Combinator.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RankNTypes #-} 2 | {-# LANGUAGE KindSignatures #-} 3 | 4 | module Network.Tangaroa.Combinator 5 | ( (^$) 6 | , (^=<<.) 7 | ) where 8 | 9 | import Control.Lens 10 | import Control.Monad.RWS 11 | 12 | -- like $, but the function is a lens from the reader environment with a 13 | -- pure function as its target 14 | infixr 0 ^$ 15 | (^$) :: forall (m :: * -> *) b r a. MonadReader r m => 16 | Getting (a -> b) r (a -> b) -> a -> m b 17 | lf ^$ a = fmap ($ a) (view lf) 18 | 19 | infixr 0 ^=<<. 20 | (^=<<.) :: forall a (m :: * -> *) b r s. 21 | (MonadReader r m, MonadState s m) => 22 | Getting (a -> m b) r (a -> m b) -> Getting a s a -> m b 23 | lf ^=<<. la = view lf >>= (use la >>=) 24 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Handler.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | 3 | module Network.Tangaroa.Handler 4 | ( handleEvents 5 | ) where 6 | 7 | import Network.Tangaroa.Types 8 | import Network.Tangaroa.Sender 9 | import Network.Tangaroa.Util 10 | import Network.Tangaroa.Combinator 11 | import Network.Tangaroa.Role 12 | import Network.Tangaroa.Timer 13 | 14 | import Control.Monad hiding (mapM) 15 | import Control.Lens 16 | import Data.Sequence (Seq) 17 | import qualified Data.Sequence as Seq 18 | import qualified Data.Map as Map 19 | import qualified Data.Set as Set 20 | 21 | import Prelude hiding (mapM) 22 | import Data.Traversable (mapM) 23 | 24 | handleEvents :: Ord nt => Raft nt et rt mt () 25 | handleEvents = forever $ do 26 | e <- dequeueEvent 27 | case e of 28 | ERPC rpc -> handleRPC rpc 29 | ElectionTimeout s -> handleElectionTimeout s 30 | HeartbeatTimeout s -> handleHeartbeatTimeout s 31 | 32 | handleRPC :: Ord nt => RPC nt et rt -> Raft nt et rt mt () 33 | handleRPC rpc = do 34 | case rpc of 35 | AE ae -> handleAppendEntries ae 36 | AER aer -> handleAppendEntriesResponse aer 37 | RV rv -> handleRequestVote rv 38 | RVR rvr -> handleRequestVoteResponse rvr 39 | CMD cmd -> handleCommand cmd 40 | CMDR _ -> debug "got a command response RPC" 41 | DBG s -> debug $ "got a debug RPC: " ++ s 42 | 43 | handleElectionTimeout :: Ord nt => String -> Raft nt et rt mt () 44 | handleElectionTimeout s = do 45 | debug $ "election timeout: " ++ s 46 | r <- use role 47 | when (r /= Leader) becomeCandidate 48 | 49 | handleHeartbeatTimeout :: Ord nt => String -> Raft nt et rt mt () 50 | handleHeartbeatTimeout s = do 51 | debug $ "heartbeat timeout: " ++ s 52 | r <- use role 53 | when (r == Leader) $ do 54 | -- heartbeat timeouts are used to control appendEntries heartbeats 55 | fork_ sendAllAppendEntries 56 | resetHeartbeatTimer 57 | 58 | handleTermNumber :: Term -> Raft nt et rt mt () 59 | handleTermNumber rpcTerm = do 60 | ct <- use term 61 | when (rpcTerm > ct) $ do 62 | _ <- rs.writeTermNumber ^$ rpcTerm 63 | setVotedFor Nothing 64 | term .= rpcTerm 65 | becomeFollower 66 | 67 | handleAppendEntries :: AppendEntries nt et -> Raft nt et rt mt () 68 | handleAppendEntries AppendEntries{..} = do 69 | debug $ "got an appendEntries RPC: prev log entry: Index " ++ show _prevLogIndex ++ " " ++ show _prevLogTerm 70 | handleTermNumber _aeTerm 71 | r <- use role 72 | when (r == Follower) $ do 73 | ct <- use term 74 | when (_aeTerm == ct) $ do 75 | resetElectionTimer 76 | currentLeader .= Just _leaderId 77 | plmatch <- prevLogEntryMatches _prevLogIndex _prevLogTerm 78 | es <- use logEntries 79 | let oldLastEntry = Seq.length es - 1 80 | let newLastEntry = _prevLogIndex + Seq.length _aeEntries 81 | if _aeTerm < ct || not plmatch 82 | then fork_ $ sendAppendEntriesResponse _leaderId False oldLastEntry 83 | else do 84 | appendLogEntries _prevLogIndex _aeEntries 85 | fork_ $ sendAppendEntriesResponse _leaderId True newLastEntry 86 | nc <- use commitIndex 87 | when (_leaderCommit > nc) $ do 88 | commitIndex .= min _leaderCommit newLastEntry 89 | applyLogEntries 90 | 91 | prevLogEntryMatches :: LogIndex -> Term -> Raft nt et rt mt Bool 92 | prevLogEntryMatches pli plt = do 93 | es <- use logEntries 94 | case seqIndex es pli of 95 | -- if we don't have the entry, only return true if pli is startIndex 96 | Nothing -> return (pli == startIndex) 97 | -- if we do have the entry, return true if the terms match 98 | Just (t,_) -> return (t == plt) 99 | 100 | -- TODO: check this 101 | appendLogEntries :: LogIndex -> Seq (Term, Command nt et) -> Raft nt et rt mt () 102 | appendLogEntries pli es = 103 | logEntries %= (Seq.>< es) . Seq.take (pli + 1) 104 | 105 | handleAppendEntriesResponse :: Ord nt => AppendEntriesResponse nt -> Raft nt et rt mt () 106 | handleAppendEntriesResponse AppendEntriesResponse{..} = do 107 | debug "got an appendEntriesResponse RPC" 108 | handleTermNumber _aerTerm 109 | r <- use role 110 | ct <- use term 111 | when (r == Leader && _aerTerm == ct) $ 112 | if _aerSuccess 113 | then do 114 | lMatchIndex.at _aerNodeId .= Just _aerIndex 115 | lNextIndex .at _aerNodeId .= Just (_aerIndex + 1) 116 | leaderDoCommit 117 | else do 118 | lNextIndex %= Map.adjust (subtract 1) _aerNodeId 119 | fork_ $ sendAppendEntries _aerNodeId 120 | 121 | applyCommand :: Command nt et -> Raft nt et rt mt (nt, CommandResponse nt rt) 122 | applyCommand Command{..} = do 123 | apply <- view (rs.applyLogEntry) 124 | result <- apply _cmdEntry 125 | mlid <- use currentLeader 126 | nid <- view (cfg.nodeId) 127 | return $ 128 | (_cmdClientId, 129 | CommandResponse 130 | result 131 | (case mlid of Just lid -> lid; Nothing -> nid) 132 | _cmdRequestId) 133 | 134 | leaderDoCommit :: Raft nt et rt mt () 135 | leaderDoCommit = do 136 | commitUpdate <- leaderUpdateCommitIndex 137 | when commitUpdate applyLogEntries 138 | 139 | -- apply the un-applied log entries up through commitIndex 140 | -- and send results to the client if you are the leader 141 | -- TODO: have this done on a separate thread via event passing 142 | applyLogEntries :: Raft nt et rt mt () 143 | applyLogEntries = do 144 | la <- use lastApplied 145 | ci <- use commitIndex 146 | le <- use logEntries 147 | let leToApply = fmap (^. _2) . Seq.drop (la + 1) . Seq.take (ci + 1) $ le 148 | results <- mapM applyCommand leToApply 149 | r <- use role 150 | when (r == Leader) $ fork_ $ sendResults results 151 | lastApplied .= ci 152 | 153 | 154 | -- called only as leader 155 | -- checks to see what the largest N where a majority of 156 | -- the lMatchIndex set is >= N 157 | leaderUpdateCommitIndex :: Raft nt et rt mt Bool 158 | leaderUpdateCommitIndex = do 159 | ci <- use commitIndex 160 | lmi <- use lMatchIndex 161 | qsize <- view quorumSize 162 | ct <- use term 163 | es <- use logEntries 164 | 165 | -- get all indices in the log past commitIndex and take the ones where the entry's 166 | -- term is equal to the current term 167 | let ctinds = filter (\i -> maybe False ((== ct) . fst) (seqIndex es i)) 168 | [(ci + 1)..(Seq.length es - 1)] 169 | 170 | -- get the prefix of these indices where a quorum of nodes have matching 171 | -- indices for that entry. lMatchIndex doesn't include the leader, so add 172 | -- one to the size 173 | let qcinds = takeWhile (\i -> 1 + Map.size (Map.filter (>= i) lmi) >= qsize) ctinds 174 | 175 | case qcinds of 176 | [] -> return False 177 | _ -> do 178 | commitIndex .= last qcinds 179 | debug $ "commit index is now: " ++ show (last qcinds) 180 | return True 181 | 182 | handleRequestVote :: Eq nt => RequestVote nt -> Raft nt et rt mt () 183 | handleRequestVote RequestVote{..} = do 184 | debug $ "got a requestVote RPC for " ++ show _rvTerm 185 | handleTermNumber _rvTerm 186 | mvote <- use votedFor 187 | es <- use logEntries 188 | ct <- use term 189 | case mvote of 190 | _ | _rvTerm < ct -> do 191 | -- this is an old candidate 192 | debug "this is for an old term" 193 | fork_ $ sendRequestVoteResponse _candidateId False 194 | 195 | Just c | c == _candidateId -> do 196 | -- already voted for this candidate 197 | debug "already voted for this candidate" 198 | fork_ $ sendRequestVoteResponse _candidateId True 199 | 200 | Just _ -> do 201 | -- already voted for a different candidate this term 202 | debug "already voted for a different candidate" 203 | fork_ $ sendRequestVoteResponse _candidateId False 204 | 205 | Nothing -> if (_lastLogTerm, _lastLogIndex) >= lastLogInfo es 206 | -- haven't voted yet, so vote for the candidate if its log is at least as 207 | -- up to date as ours, use the Ord instance of (Term, Index) to prefer 208 | -- higher terms, and then higher last indices for equal terms 209 | then do 210 | debug "haven't voted, voting for this candidate" 211 | setVotedFor (Just _candidateId) 212 | fork_ $ sendRequestVoteResponse _candidateId True 213 | else do 214 | debug "haven't voted, but my log is better than this candidate's" 215 | fork_ $ sendRequestVoteResponse _candidateId False 216 | 217 | handleRequestVoteResponse :: Ord nt => RequestVoteResponse nt -> Raft nt et rt mt () 218 | handleRequestVoteResponse RequestVoteResponse{..} = do 219 | debug $ "got a requestVoteResponse RPC for " ++ show _rvrTerm ++ ": " ++ show _voteGranted 220 | handleTermNumber _rvrTerm 221 | r <- use role 222 | when (r == Candidate) $ 223 | if _voteGranted 224 | then do 225 | cYesVotes %= Set.insert _rvrNodeId 226 | checkElection 227 | else 228 | cPotentialVotes %= Set.delete _rvrNodeId 229 | 230 | handleCommand :: Ord nt => Command nt et -> Raft nt et rt mt () 231 | handleCommand cmd = do 232 | debug "got a command RPC" 233 | r <- use role 234 | ct <- use term 235 | mlid <- use currentLeader 236 | case (r, mlid) of 237 | (Leader, _) -> do 238 | -- we're the leader, so append this to our log with the current term 239 | -- and propagate it to replicas 240 | logEntries %= (Seq.|> (ct, cmd)) 241 | fork_ sendAllAppendEntries 242 | leaderDoCommit 243 | (_, Just lid) -> 244 | -- we're not the leader, but we know who the leader is, so forward this 245 | -- command 246 | fork_ $ sendRPC lid $ CMD cmd 247 | (_, Nothing) -> 248 | -- we're not the leader, and we don't know who the leader is, so can't do 249 | -- anything (TODO) 250 | return () 251 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Role.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Role 2 | ( becomeFollower 3 | , becomeLeader 4 | , becomeCandidate 5 | , checkElection 6 | , setVotedFor 7 | ) where 8 | 9 | import Network.Tangaroa.Timer 10 | import Network.Tangaroa.Types 11 | import Network.Tangaroa.Combinator 12 | import Network.Tangaroa.Util 13 | import Network.Tangaroa.Sender 14 | 15 | import Control.Lens hiding (Index) 16 | import Control.Monad 17 | import qualified Data.Map as Map 18 | import qualified Data.Sequence as Seq 19 | import qualified Data.Set as Set 20 | 21 | -- count the yes votes and become leader if you have reached a quorum 22 | checkElection :: Ord nt => Raft nt et rt mt () 23 | checkElection = do 24 | nyes <- Set.size <$> use cYesVotes 25 | qsize <- view quorumSize 26 | debug $ "yes votes: " ++ show nyes ++ " quorum size: " ++ show qsize 27 | when (nyes >= qsize) $ becomeLeader 28 | 29 | setVotedFor :: Maybe nt -> Raft nt et rt mt () 30 | setVotedFor mvote = do 31 | _ <- rs.writeVotedFor ^$ mvote 32 | votedFor .= mvote 33 | 34 | becomeFollower :: Raft nt et rt mt () 35 | becomeFollower = do 36 | debug "becoming follower" 37 | role .= Follower 38 | resetElectionTimer 39 | 40 | becomeCandidate :: Ord nt => Raft nt et rt mt () 41 | becomeCandidate = do 42 | debug "becoming candidate" 43 | role .= Candidate 44 | term += 1 45 | rs.writeTermNumber ^=<<. term 46 | nid <- view (cfg.nodeId) 47 | setVotedFor $ Just nid 48 | cYesVotes .= Set.singleton nid -- vote for yourself 49 | (cPotentialVotes .=) =<< view (cfg.otherNodes) 50 | resetElectionTimer 51 | -- this is necessary for a single-node cluster, as we have already won the 52 | -- election in that case. otherwise we will wait for more votes to check again 53 | checkElection -- can possibly transition to leader 54 | r <- use role 55 | when (r == Candidate) $ fork_ sendAllRequestVotes 56 | 57 | becomeLeader :: Ord nt => Raft nt et rt mt () 58 | becomeLeader = do 59 | debug "becoming leader" 60 | role .= Leader 61 | (currentLeader .=) . Just =<< view (cfg.nodeId) 62 | ni <- Seq.length <$> use logEntries 63 | (lNextIndex .=) =<< Map.fromSet (const ni) <$> view (cfg.otherNodes) 64 | (lMatchIndex .=) =<< Map.fromSet (const startIndex) <$> view (cfg.otherNodes) 65 | fork_ sendAllAppendEntries 66 | resetHeartbeatTimer 67 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Sender.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Sender 2 | ( sendAppendEntries 3 | , sendAppendEntriesResponse 4 | , sendRequestVote 5 | , sendRequestVoteResponse 6 | , sendAllAppendEntries 7 | , sendAllRequestVotes 8 | , sendResults 9 | , sendRPC 10 | ) where 11 | 12 | import Control.Lens 13 | import Data.Foldable (traverse_) 14 | import Data.Sequence (Seq) 15 | import qualified Data.Sequence as Seq 16 | 17 | import Network.Tangaroa.Util 18 | import Network.Tangaroa.Types 19 | 20 | sendAppendEntries :: Ord nt => nt -> Raft nt et rt mt () 21 | sendAppendEntries target = do 22 | mni <- use $ lNextIndex.at target 23 | es <- use logEntries 24 | let (pli,plt) = logInfoForNextIndex mni es 25 | ct <- use term 26 | nid <- view (cfg.nodeId) 27 | ci <- use commitIndex 28 | debug $ "sendAppendEntries: " ++ show ct 29 | sendRPC target $ AE $ 30 | AppendEntries ct nid pli plt (Seq.drop (pli + 1) es) ci 31 | 32 | sendAppendEntriesResponse :: nt -> Bool -> LogIndex -> Raft nt et rt mt () 33 | sendAppendEntriesResponse target success lindex = do 34 | ct <- use term 35 | nid <- view (cfg.nodeId) 36 | debug $ "sendAppendEntriesResponse: " ++ show ct 37 | sendRPC target $ AER $ AppendEntriesResponse ct nid success lindex 38 | 39 | sendRequestVote :: nt -> Raft nt et rt mt () 40 | sendRequestVote target = do 41 | ct <- use term 42 | nid <- view (cfg.nodeId) 43 | es <- use logEntries 44 | let (llt, lli) = lastLogInfo es 45 | debug $ "sendRequestVote: " ++ show ct 46 | sendRPC target $ RV $ RequestVote ct nid lli llt 47 | 48 | sendRequestVoteResponse :: nt -> Bool -> Raft nt et rt mt () 49 | sendRequestVoteResponse target vote = do 50 | ct <- use term 51 | nid <- view (cfg.nodeId) 52 | debug $ "sendRequestVoteResponse: " ++ show ct 53 | sendRPC target $ RVR $ RequestVoteResponse ct nid vote 54 | 55 | sendAllAppendEntries :: Ord nt => Raft nt et rt mt () 56 | sendAllAppendEntries = traverse_ sendAppendEntries =<< view (cfg.otherNodes) 57 | 58 | sendAllRequestVotes :: Raft nt et rt mt () 59 | sendAllRequestVotes = traverse_ sendRequestVote =<< use cPotentialVotes 60 | 61 | sendResults :: Seq (nt, CommandResponse nt rt) -> Raft nt et rt mt () 62 | sendResults results = do 63 | traverse_ (\(target,cmdr) -> sendRPC target $ CMDR cmdr) results 64 | 65 | -- called by leaders sending appendEntries. 66 | -- given a replica's nextIndex, get the index and term to send as 67 | -- prevLog(Index/Term) 68 | logInfoForNextIndex :: Maybe LogIndex -> Seq (Term,et) -> (LogIndex,Term) 69 | logInfoForNextIndex mni es = 70 | case mni of 71 | Just ni -> let pli = ni - 1 in 72 | case seqIndex es pli of 73 | Just (t,_) -> (pli, t) 74 | -- this shouldn't happen, because nextIndex - 1 should always be at 75 | -- most our last entry 76 | Nothing -> (startIndex, startTerm) 77 | Nothing -> (startIndex, startTerm) 78 | 79 | sendRPC :: nt -> RPC nt et rt -> Raft nt et rt mt () 80 | sendRPC target rpc = do 81 | send <- view (rs.sendMessage) 82 | ser <- view (rs.serializeRPC) 83 | send target $ ser rpc 84 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Server.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecordWildCards #-} 2 | 3 | module Network.Tangaroa.Server 4 | ( runRaft 5 | , RaftSpec(..) 6 | , Config(..), otherNodes, nodeId, electionTimeoutRange, heartbeatTimeout, enableDebug 7 | , Term, startTerm 8 | ) where 9 | 10 | import Control.Concurrent.Chan.Unagi 11 | import Control.Lens 12 | import qualified Data.Set as Set 13 | 14 | import Network.Tangaroa.Handler 15 | import Network.Tangaroa.Types 16 | import Network.Tangaroa.Util 17 | import Network.Tangaroa.Timer 18 | 19 | runRaft :: Ord nt => Config nt -> RaftSpec nt et rt mt -> IO () 20 | runRaft rconf spec@RaftSpec{..} = do 21 | let qsize = getQuorumSize $ 1 + (Set.size $ rconf ^. otherNodes) 22 | (ein, eout) <- newChan 23 | runRWS_ 24 | raft 25 | (RaftEnv rconf qsize ein eout (liftRaftSpec spec)) 26 | initialRaftState 27 | 28 | raft :: Ord nt => Raft nt et rt mt () 29 | raft = do 30 | fork_ messageReceiver 31 | resetElectionTimer 32 | handleEvents 33 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Spec/Simple.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Spec.Simple 2 | ( runServer 3 | , runClient 4 | ) where 5 | 6 | import Network.Tangaroa.Server 7 | import Network.Tangaroa.Client 8 | 9 | import Control.Lens 10 | import Data.Word 11 | import Network.Socket 12 | import System.Console.GetOpt 13 | import System.Environment 14 | import System.Exit 15 | import Text.Read 16 | import qualified Data.Set as Set 17 | 18 | options :: [OptDescr (Config NodeType -> Config NodeType)] 19 | options = 20 | [ Option ['s'] ["self"] 21 | (ReqArg setThisNode "SELF_PORT_NUMBER") 22 | "The port number of this node." 23 | , Option ['d'] ["debug"] 24 | (NoArg (enableDebug .~ True)) 25 | "Enable debugging info (show RPCs and timeouts)." 26 | ] 27 | 28 | getConfig :: IO (Config NodeType) 29 | getConfig = do 30 | argv <- getArgs 31 | case getOpt Permute options argv of 32 | (opts,args,[]) -> return $ foldr addOtherNode (foldr ($) defaultConfig opts) args 33 | (_,_,_) -> exitFailure -- TODO, print errors 34 | 35 | type NodeType = (HostAddress, Word16) 36 | localhost :: HostAddress 37 | localhost = 0x0100007f 38 | 39 | defaultPortNum :: Word16 40 | defaultPortNum = 10000 41 | 42 | defaultConfig :: Config NodeType 43 | defaultConfig = 44 | Config 45 | Set.empty -- other nodes 46 | (localhost,defaultPortNum) -- self address 47 | (3000000,6000000) -- election timeout range 48 | 1500000 -- heartbeat timeout 49 | False -- no debug 50 | 51 | nodeSockAddr :: NodeType -> SockAddr 52 | nodeSockAddr (host,port) = SockAddrInet (fromIntegral port) host 53 | 54 | setThisNode :: String -> Config NodeType -> Config NodeType 55 | setThisNode = 56 | maybe id (\p -> nodeId .~ (localhost, p)) . readMaybe 57 | 58 | addOtherNode :: String -> Config NodeType -> Config NodeType 59 | addOtherNode = 60 | maybe id (\p -> otherNodes %~ Set.insert (localhost, p)) . readMaybe 61 | 62 | getMsg :: Socket -> IO String 63 | getMsg sock = recv sock 8192 64 | 65 | msgSend :: Socket -> NodeType -> String -> IO () 66 | msgSend sock node s = 67 | sendTo sock s (nodeSockAddr node) >> return () 68 | 69 | showDebug :: NodeType -> String -> IO () 70 | showDebug node msg = putStrLn $ show (snd node) ++ " " ++ msg 71 | 72 | noDebug :: NodeType -> String -> IO () 73 | noDebug _ _ = return () 74 | 75 | simpleRaftSpec :: (Show et, Read et, Show rt, Read rt) 76 | => Socket 77 | -> (et -> IO rt) 78 | -> (NodeType -> String -> IO ()) 79 | -> RaftSpec NodeType et rt String 80 | simpleRaftSpec sock applyFn debugFn = RaftSpec 81 | { 82 | -- TODO don't read log entries 83 | __readLogEntry = return . const Nothing 84 | -- TODO don't write log entries 85 | , __writeLogEntry = \_ _ -> return () 86 | -- TODO always read startTerm 87 | , __readTermNumber = return startTerm 88 | -- TODO don't write term numbers 89 | , __writeTermNumber = return . const () 90 | -- TODO never voted for anyone 91 | , __readVotedFor = return Nothing 92 | -- TODO don't record votes 93 | , __writeVotedFor = return . const () 94 | -- apply log entries to the state machine, given by caller 95 | , __applyLogEntry = applyFn 96 | -- serialize with show 97 | , __serializeRPC = show 98 | -- deserialize with readMaybe 99 | , __deserializeRPC = readMaybe 100 | -- send messages using msgSend 101 | , __sendMessage = msgSend sock 102 | -- get messages using getMsg 103 | , __getMessage = getMsg sock 104 | -- use the debug function given by the caller 105 | , __debugPrint = debugFn 106 | } 107 | 108 | runServer :: (Show et, Read et, Show rt, Read rt) => 109 | (et -> IO rt) -> IO () 110 | runServer applyFn = do 111 | rconf <- getConfig 112 | sock <- socket AF_INET Datagram defaultProtocol 113 | bind sock $ nodeSockAddr $ rconf ^. nodeId 114 | let debugFn = if (rconf ^. enableDebug) then showDebug else noDebug 115 | runRaft rconf $ simpleRaftSpec sock applyFn debugFn 116 | 117 | runClient :: (Show et, Read et, Show rt, Read rt) => 118 | (et -> IO rt) -> IO et -> (rt -> IO ()) -> IO () 119 | runClient applyFn getEntry useResult = do 120 | rconf <- getConfig 121 | sock <- socket AF_INET Datagram defaultProtocol 122 | bind sock $ nodeSockAddr $ rconf ^. nodeId 123 | let debugFn = if (rconf ^. enableDebug) then showDebug else noDebug 124 | runRaftClient getEntry useResult rconf (simpleRaftSpec sock applyFn debugFn) 125 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Timer.hs: -------------------------------------------------------------------------------- 1 | module Network.Tangaroa.Timer 2 | ( resetElectionTimer 3 | , resetHeartbeatTimer 4 | , cancelTimer 5 | ) where 6 | 7 | import Control.Lens hiding (Index) 8 | import Control.Monad.Trans (lift) 9 | import System.Random 10 | import Control.Concurrent.Lifted 11 | 12 | import Network.Tangaroa.Types 13 | import Network.Tangaroa.Util 14 | 15 | getNewElectionTimeout :: Raft nt et rt mt Int 16 | getNewElectionTimeout = view (cfg.electionTimeoutRange) >>= lift . randomRIO 17 | 18 | resetElectionTimer :: Raft nt et rt mt () 19 | resetElectionTimer = do 20 | timeout <- getNewElectionTimeout 21 | setTimedEvent (ElectionTimeout $ show (timeout `div` 1000) ++ "ms") timeout 22 | 23 | resetHeartbeatTimer :: Raft nt et rt mt () 24 | resetHeartbeatTimer = do 25 | timeout <- view (cfg.heartbeatTimeout) 26 | setTimedEvent (HeartbeatTimeout $ show (timeout `div` 1000) ++ "ms") timeout 27 | 28 | -- | Cancel any existing timer. 29 | cancelTimer :: Raft nt et rt mt () 30 | cancelTimer = do 31 | use timerThread >>= maybe (return ()) killThread 32 | timerThread .= Nothing 33 | 34 | -- | Cancels any pending timer and sets a new timer to trigger an event after t 35 | -- microseconds. 36 | setTimedEvent :: Event nt et rt -> Int -> Raft nt et rt mt () 37 | setTimedEvent e t = do 38 | cancelTimer 39 | tmr <- fork $ wait t >> enqueueEvent e 40 | timerThread .= Just tmr 41 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Types.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveGeneric #-} 2 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 3 | {-# LANGUAGE Rank2Types #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | {-# LANGUAGE TemplateHaskell #-} 6 | 7 | module Network.Tangaroa.Types 8 | ( Raft 9 | , RaftSpec(..) 10 | , LiftedRaftSpec(..) 11 | , readLogEntry, writeLogEntry, readTermNumber, writeTermNumber 12 | , readVotedFor, writeVotedFor, applyLogEntry, serializeRPC 13 | , deserializeRPC, sendMessage, getMessage, debugPrint 14 | , liftRaftSpec 15 | , Term, startTerm 16 | , LogIndex, startIndex 17 | , RequestId, startRequestId 18 | , Config(..), otherNodes, nodeId, electionTimeoutRange, heartbeatTimeout, enableDebug 19 | , Role(..) 20 | , RaftEnv(..), cfg, quorumSize, eventIn, eventOut, rs 21 | , RaftState(..), role, votedFor, currentLeader, logEntries, commitIndex, lastApplied, timerThread 22 | , pendingRequests, nextRequestId 23 | , initialRaftState 24 | , cYesVotes, cPotentialVotes, lNextIndex, lMatchIndex 25 | , AppendEntries(..) 26 | , AppendEntriesResponse(..) 27 | , RequestVote(..) 28 | , RequestVoteResponse(..) 29 | , Command(..) 30 | , CommandResponse(..) 31 | , RPC(..) 32 | , term 33 | , Event(..) 34 | ) where 35 | 36 | import Control.Concurrent (ThreadId) 37 | import Control.Concurrent.Chan.Unagi 38 | import Control.Lens hiding (Index) 39 | import Control.Monad.RWS 40 | import Data.Binary 41 | import Data.Sequence (Seq) 42 | import qualified Data.Sequence as Seq 43 | import Data.Map (Map) 44 | import qualified Data.Map as Map 45 | import Data.Set (Set) 46 | import qualified Data.Set as Set 47 | 48 | import GHC.Generics 49 | 50 | newtype Term = Term Int 51 | deriving (Show, Read, Eq, Ord, Generic, Num) 52 | 53 | startTerm :: Term 54 | startTerm = Term (-1) 55 | 56 | type LogIndex = Int 57 | 58 | startIndex :: LogIndex 59 | startIndex = (-1) 60 | 61 | newtype RequestId = RequestId Int 62 | deriving (Show, Read, Eq, Ord, Generic, Num) 63 | 64 | startRequestId :: RequestId 65 | startRequestId = RequestId 0 66 | 67 | data Config nt = Config 68 | { _otherNodes :: Set nt 69 | , _nodeId :: nt 70 | , _electionTimeoutRange :: (Int,Int) -- in microseconds 71 | , _heartbeatTimeout :: Int -- in microseconds 72 | , _enableDebug :: Bool 73 | } 74 | deriving (Show, Generic) 75 | makeLenses ''Config 76 | 77 | data Command nt et = Command 78 | { _cmdEntry :: et 79 | , _cmdClientId :: nt 80 | , _cmdRequestId :: RequestId 81 | } 82 | deriving (Show, Read, Generic) 83 | 84 | data CommandResponse nt rt = CommandResponse 85 | { _cmdrResult :: rt 86 | , _cmdrLeaderId :: nt 87 | , _cmdrRequestId :: RequestId 88 | } 89 | deriving (Show, Read, Generic) 90 | 91 | data AppendEntries nt et = AppendEntries 92 | { _aeTerm :: Term 93 | , _leaderId :: nt 94 | , _prevLogIndex :: LogIndex 95 | , _prevLogTerm :: Term 96 | , _aeEntries :: Seq (Term, Command nt et) 97 | , _leaderCommit :: LogIndex 98 | } 99 | deriving (Show, Read, Generic) 100 | 101 | data AppendEntriesResponse nt = AppendEntriesResponse 102 | { _aerTerm :: Term 103 | , _aerNodeId :: nt 104 | , _aerSuccess :: Bool 105 | , _aerIndex :: LogIndex 106 | } 107 | deriving (Show, Read, Generic) 108 | 109 | data RequestVote nt = RequestVote 110 | { _rvTerm :: Term 111 | , _candidateId :: nt 112 | , _lastLogIndex :: LogIndex 113 | , _lastLogTerm :: Term 114 | } 115 | deriving (Show, Read, Generic) 116 | 117 | data RequestVoteResponse nt = RequestVoteResponse 118 | { _rvrTerm :: Term 119 | , _rvrNodeId :: nt 120 | , _voteGranted :: Bool 121 | } 122 | deriving (Show, Read, Generic) 123 | 124 | data RPC nt et rt = AE (AppendEntries nt et) 125 | | AER (AppendEntriesResponse nt) 126 | | RV (RequestVote nt) 127 | | RVR (RequestVoteResponse nt) 128 | | CMD (Command nt et) 129 | | CMDR (CommandResponse nt rt) 130 | | DBG String 131 | deriving (Show, Read, Generic) 132 | 133 | -- | A structure containing all the implementation details for running 134 | -- the raft protocol. 135 | data RaftSpec nt et rt mt = RaftSpec 136 | { 137 | -- ^ Function to get a log entry from persistent storage. 138 | __readLogEntry :: LogIndex -> IO (Maybe et) 139 | 140 | -- ^ Function to write a log entry to persistent storage. 141 | , __writeLogEntry :: LogIndex -> (Term,et) -> IO () 142 | 143 | -- ^ Function to get the term number from persistent storage. 144 | , __readTermNumber :: IO Term 145 | 146 | -- ^ Function to write the term number to persistent storage. 147 | , __writeTermNumber :: Term -> IO () 148 | 149 | -- ^ Function to read the node voted for from persistent storage. 150 | , __readVotedFor :: IO (Maybe nt) 151 | 152 | -- ^ Function to write the node voted for to persistent storage. 153 | , __writeVotedFor :: Maybe nt -> IO () 154 | 155 | -- ^ Function to apply a log entry to the state machine. 156 | , __applyLogEntry :: et -> IO rt 157 | 158 | -- ^ Function to serialize an RPC. 159 | , __serializeRPC :: RPC nt et rt -> mt 160 | 161 | -- ^ Function to deserialize an RPC. 162 | , __deserializeRPC :: mt -> Maybe (RPC nt et rt) 163 | 164 | -- ^ Function to send a message to a node. 165 | , __sendMessage :: nt -> mt -> IO () 166 | 167 | -- ^ Function to get the next message. 168 | , __getMessage :: IO mt 169 | 170 | -- ^ Function to log a debug message (no newline). 171 | , __debugPrint :: nt -> String -> IO () 172 | } 173 | 174 | data Role = Follower 175 | | Candidate 176 | | Leader 177 | deriving (Show, Generic, Eq) 178 | 179 | data Event nt et rt = ERPC (RPC nt et rt) 180 | | ElectionTimeout String 181 | | HeartbeatTimeout String 182 | deriving (Show) 183 | 184 | -- | A version of RaftSpec where all IO functions are lifted 185 | -- into the Raft monad. 186 | data LiftedRaftSpec nt et rt mt t = LiftedRaftSpec 187 | { 188 | -- ^ Function to get a log entry from persistent storage. 189 | _readLogEntry :: MonadTrans t => LogIndex -> t IO (Maybe et) 190 | 191 | -- ^ Function to write a log entry to persistent storage. 192 | , _writeLogEntry :: MonadTrans t => LogIndex -> (Term,et) -> t IO () 193 | 194 | -- ^ Function to get the term number from persistent storage. 195 | , _readTermNumber :: MonadTrans t => t IO Term 196 | 197 | -- ^ Function to write the term number to persistent storage. 198 | , _writeTermNumber :: MonadTrans t => Term -> t IO () 199 | 200 | -- ^ Function to read the node voted for from persistent storage. 201 | , _readVotedFor :: MonadTrans t => t IO (Maybe nt) 202 | 203 | -- ^ Function to write the node voted for to persistent storage. 204 | , _writeVotedFor :: MonadTrans t => Maybe nt -> t IO () 205 | 206 | -- ^ Function to apply a log entry to the state machine. 207 | , _applyLogEntry :: MonadTrans t => et -> t IO rt 208 | 209 | -- ^ Function to serialize an RPC. 210 | , _serializeRPC :: RPC nt et rt -> mt 211 | 212 | -- ^ Function to deserialize an RPC. 213 | , _deserializeRPC :: mt -> Maybe (RPC nt et rt) 214 | 215 | -- ^ Function to send a message to a node. 216 | , _sendMessage :: MonadTrans t => nt -> mt -> t IO () 217 | 218 | -- ^ Function to get the next message. 219 | , _getMessage :: MonadTrans t => t IO mt 220 | 221 | -- ^ Function to log a debug message (no newline). 222 | , _debugPrint :: nt -> String -> t IO () 223 | } 224 | makeLenses ''LiftedRaftSpec 225 | 226 | liftRaftSpec :: MonadTrans t => RaftSpec nt et rt mt -> LiftedRaftSpec nt et rt mt t 227 | liftRaftSpec RaftSpec{..} = 228 | LiftedRaftSpec 229 | { _readLogEntry = lift . __readLogEntry 230 | , _writeLogEntry = \i et -> lift (__writeLogEntry i et) 231 | , _readTermNumber = lift __readTermNumber 232 | , _writeTermNumber = lift . __writeTermNumber 233 | , _readVotedFor = lift __readVotedFor 234 | , _writeVotedFor = lift . __writeVotedFor 235 | , _applyLogEntry = lift . __applyLogEntry 236 | , _serializeRPC = __serializeRPC 237 | , _deserializeRPC = __deserializeRPC 238 | , _sendMessage = \n m -> lift (__sendMessage n m) 239 | , _getMessage = lift __getMessage 240 | , _debugPrint = \n s -> lift (__debugPrint n s) 241 | } 242 | 243 | data RaftState nt et = RaftState 244 | { _role :: Role 245 | , _term :: Term 246 | , _votedFor :: Maybe nt 247 | , _currentLeader :: Maybe nt 248 | , _logEntries :: Seq (Term, Command nt et) 249 | , _commitIndex :: LogIndex 250 | , _lastApplied :: LogIndex 251 | , _timerThread :: Maybe ThreadId 252 | , _cYesVotes :: Set nt 253 | , _cPotentialVotes :: Set nt 254 | , _lNextIndex :: Map nt LogIndex 255 | , _lMatchIndex :: Map nt LogIndex 256 | , _pendingRequests :: Map RequestId (Command nt et) -- used by clients 257 | , _nextRequestId :: RequestId -- used by clients 258 | } 259 | makeLenses ''RaftState 260 | 261 | initialRaftState :: RaftState nt et 262 | initialRaftState = RaftState 263 | Follower -- role 264 | startTerm -- term 265 | Nothing -- votedFor 266 | Nothing -- currentLeader 267 | Seq.empty -- log 268 | startIndex -- commitIndex 269 | startIndex -- lastApplied 270 | Nothing -- timerThread 271 | Set.empty -- cYesVotes 272 | Set.empty -- cPotentialVotes 273 | Map.empty -- lNextIndex 274 | Map.empty -- lMatchIndex 275 | Map.empty -- pendingRequests 276 | 0 -- nextRequestId 277 | 278 | data RaftEnv nt et rt mt = RaftEnv 279 | { _cfg :: Config nt 280 | , _quorumSize :: Int 281 | , _eventIn :: InChan (Event nt et rt) 282 | , _eventOut :: OutChan (Event nt et rt) 283 | , _rs :: LiftedRaftSpec nt et rt mt (RWST (RaftEnv nt et rt mt) () (RaftState nt et)) 284 | } 285 | makeLenses ''RaftEnv 286 | 287 | type Raft nt et rt mt a = RWST (RaftEnv nt et rt mt) () (RaftState nt et) IO a 288 | 289 | instance Binary Term 290 | instance Binary RequestId 291 | 292 | instance (Binary nt, Binary et) => Binary (AppendEntries nt et) 293 | instance Binary nt => Binary (AppendEntriesResponse nt) 294 | instance Binary nt => Binary (RequestVote nt) 295 | instance Binary nt => Binary (RequestVoteResponse nt) 296 | instance (Binary nt, Binary et) => Binary (Command nt et) 297 | instance (Binary nt, Binary rt) => Binary (CommandResponse nt rt) 298 | 299 | instance (Binary nt, Binary et, Binary rt) => Binary (RPC nt et rt) 300 | -------------------------------------------------------------------------------- /src/Network/Tangaroa/Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | 3 | module Network.Tangaroa.Util 4 | ( seqIndex 5 | , lastLogInfo 6 | , getQuorumSize 7 | , debug 8 | , fork_ 9 | , wait 10 | , runRWS_ 11 | , enqueueEvent 12 | , dequeueEvent 13 | , messageReceiver 14 | ) where 15 | 16 | import Network.Tangaroa.Types 17 | import Control.Lens 18 | import Data.Sequence (Seq) 19 | import Control.Monad.RWS 20 | import Control.Concurrent.Lifted (fork, threadDelay) 21 | import Control.Monad.Trans.Control (MonadBaseControl) 22 | import Control.Concurrent.Chan.Unagi (readChan, writeChan) 23 | import qualified Data.Sequence as Seq 24 | 25 | seqIndex :: Seq a -> Int -> Maybe a 26 | seqIndex s i = 27 | if i >= 0 && i < Seq.length s 28 | then Just (Seq.index s i) 29 | else Nothing 30 | 31 | getQuorumSize :: Int -> Int 32 | getQuorumSize n = 33 | if even n 34 | then n `div` 2 + 1 35 | else (n - 1) `div` 2 + 1 36 | 37 | -- get the last term and index of a log 38 | lastLogInfo :: Seq (Term, et) -> (Term, LogIndex) 39 | lastLogInfo es = 40 | case Seq.viewr es of 41 | _ Seq.:> (t,_) -> (t, Seq.length es - 1) 42 | Seq.EmptyR -> (startTerm, startIndex) 43 | 44 | debug :: String -> Raft nt et rt mt () 45 | debug s = do 46 | dbg <- view (rs.debugPrint) 47 | nid <- view (cfg.nodeId) 48 | dbg nid s 49 | 50 | fork_ :: MonadBaseControl IO m => m () -> m () 51 | fork_ a = fork a >> return () 52 | 53 | wait :: Int -> Raft nt et rt mt () 54 | wait t = threadDelay t 55 | 56 | runRWS_ :: Monad m => RWST r w s m a -> r -> s -> m () 57 | runRWS_ ma r s = runRWST ma r s >> return () 58 | 59 | enqueueEvent :: Event nt et rt -> Raft nt et rt mt () 60 | enqueueEvent event = do 61 | ein <- view eventIn 62 | lift $ writeChan ein event 63 | 64 | dequeueEvent :: Raft nt et rt mt (Event nt et rt) 65 | dequeueEvent = lift . readChan =<< view eventOut 66 | 67 | -- | Thread to take incoming messages and write them to the event queue. 68 | messageReceiver :: Raft nt et rt mt () 69 | messageReceiver = do 70 | gm <- view (rs.getMessage) 71 | deser <- view (rs.deserializeRPC) 72 | forever $ 73 | gm >>= maybe 74 | (debug "failed to deserialize RPC") 75 | (enqueueEvent . ERPC) 76 | . deser 77 | -------------------------------------------------------------------------------- /tangaroa.cabal: -------------------------------------------------------------------------------- 1 | name: tangaroa 2 | version: 0.0.0.1 3 | synopsis: Bynzantine Fault Tolerant Raft 4 | description: An implementation of a Byzantine Fault Tolerant Raft protocol. 5 | homepage: https://github.com/chrisnc/tangaroa 6 | author: Chris Copeland 7 | maintainer: chrisnc@cs.stanford.edu 8 | copyright: Copyright (C) 2014-2015, Chris Copeland 9 | 10 | license: BSD3 11 | license-file: LICENSE 12 | 13 | category: Network 14 | build-type: Simple 15 | cabal-version: >=1.20 16 | 17 | source-repository head 18 | type: git 19 | location: git@github.com:chrisnc/tangaroa.git 20 | 21 | library 22 | exposed-modules: Network.Tangaroa.Client 23 | , Network.Tangaroa.Combinator 24 | , Network.Tangaroa.Handler 25 | , Network.Tangaroa.Role 26 | , Network.Tangaroa.Sender 27 | , Network.Tangaroa.Server 28 | , Network.Tangaroa.Spec.Simple 29 | , Network.Tangaroa.Timer 30 | , Network.Tangaroa.Types 31 | , Network.Tangaroa.Util 32 | , Network.Tangaroa.Byzantine.Client 33 | , Network.Tangaroa.Byzantine.Handler 34 | , Network.Tangaroa.Byzantine.Role 35 | , Network.Tangaroa.Byzantine.Sender 36 | , Network.Tangaroa.Byzantine.Server 37 | , Network.Tangaroa.Byzantine.Spec.Simple 38 | , Network.Tangaroa.Byzantine.Timer 39 | , Network.Tangaroa.Byzantine.Types 40 | , Network.Tangaroa.Byzantine.Util 41 | build-depends: base < 5 42 | , binary 43 | , bytestring 44 | , containers 45 | , crypto-api 46 | , lens 47 | , lifted-base 48 | , monad-control 49 | , monad-loops 50 | , mtl 51 | , network 52 | , random 53 | , stm 54 | , RSA 55 | , cryptohash-sha256 56 | , transformers 57 | , unagi-chan 58 | hs-source-dirs: src 59 | ghc-options: -Wall -Werror 60 | default-language: Haskell2010 61 | 62 | executable simpleserver 63 | main-is: Server.hs 64 | build-depends: base < 5 65 | , containers 66 | , tangaroa 67 | hs-source-dirs: bin/Simple 68 | ghc-options: -Wall -threaded -rtsopts 69 | default-language: Haskell2010 70 | 71 | executable simpleclient 72 | main-is: Client.hs 73 | build-depends: base < 5 74 | , tangaroa 75 | hs-source-dirs: bin/Simple 76 | ghc-options: -Wall -threaded -rtsopts 77 | default-language: Haskell2010 78 | 79 | executable bftserver 80 | main-is: Server.hs 81 | build-depends: base < 5 82 | , containers 83 | , tangaroa 84 | , binary 85 | hs-source-dirs: bin/Byzantine 86 | ghc-options: -Wall -threaded -rtsopts 87 | default-language: Haskell2010 88 | 89 | executable bftclient 90 | main-is: Client.hs 91 | build-depends: base < 5 92 | , tangaroa 93 | , binary 94 | hs-source-dirs: bin/Byzantine 95 | ghc-options: -Wall -threaded -rtsopts 96 | default-language: Haskell2010 97 | 98 | executable genkeys 99 | main-is: GenerateKeys.hs 100 | build-depends: base < 5 101 | , RSA 102 | , containers 103 | , crypto-api 104 | , network 105 | , directory 106 | , filepath 107 | hs-source-dirs: bin 108 | ghc-options: -Wall -threaded -rtsopts 109 | default-language: Haskell2010 110 | 111 | executable udprecv 112 | main-is: udprecv.hs 113 | build-depends: base < 5 114 | , bytestring 115 | , network 116 | hs-source-dirs: bin 117 | ghc-options: -Wall -threaded -rtsopts 118 | default-language: Haskell2010 119 | 120 | executable udpsend 121 | main-is: udpsend.hs 122 | build-depends: base < 5 123 | , bytestring 124 | , network 125 | hs-source-dirs: bin 126 | ghc-options: -Wall -threaded -rtsopts 127 | default-language: Haskell2010 128 | --------------------------------------------------------------------------------