├── .gitignore ├── init.py ├── raft ├── status.go ├── timeouts.go ├── commit.go ├── clientRequests.go ├── README.md ├── server.go ├── handler.go ├── voting.go └── appending.go ├── db ├── handler.go ├── README.md ├── log.go ├── db.go └── db_test.go ├── LICENSE ├── main.go └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | db*.json 2 | log*.json 3 | config.json 4 | mmapd 5 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | ports = ['3001', '3002', '3003'] 2 | for port in ports: 3 | db = open('db:'+port+'.json', 'w') 4 | log = open('log:'+port+'.json', 'w') 5 | db.write('{}') 6 | log.write('[]') 7 | db.close() 8 | log.close() 9 | 10 | -------------------------------------------------------------------------------- /raft/status.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | type status struct { 4 | ID string `json:"id"` 5 | State string `json:"state"` 6 | Term int `json:"term"` 7 | VotedFor string `json:"voted for"` 8 | CommitIndex int `json:"commit index"` 9 | LastApplied int `json:"last applied"` 10 | NextIndex []int `json:"next index"` 11 | MatchIndex []int `json:"match index"` 12 | } 13 | -------------------------------------------------------------------------------- /db/handler.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "github.com/julienschmidt/httprouter" 6 | "net/http" 7 | ) 8 | 9 | func (db *DB) handler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { 10 | switch r.Method { 11 | case "GET": 12 | key := ps.ByName("key") 13 | c := make(chan ReturnChanMessage) 14 | m := ReadChanMessage{key, c} 15 | db.ReadChan <- m 16 | resp := <-c 17 | close(c) 18 | if resp.Err != nil { 19 | http.NotFound(w, r) 20 | } else { 21 | fmt.Fprint(w, resp.Json) 22 | } 23 | case "POST": 24 | key := ps.ByName("key") 25 | value := ps.ByName("value") 26 | c := make(chan ReturnChanMessage) 27 | m := WriteChanMessage{key, value, c} 28 | db.WriteChan <- m 29 | fmt.Fprint(w, r.FormValue("value")) 30 | } 31 | } 32 | 33 | func NewHandler(db *DB) func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { 34 | return db.handler 35 | } 36 | -------------------------------------------------------------------------------- /raft/timeouts.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | "math/big" 7 | "time" 8 | ) 9 | 10 | type timeout struct { 11 | period time.Duration 12 | ticker time.Ticker 13 | } 14 | 15 | func (t *timeout) reset() { 16 | t.ticker.Stop() 17 | t.ticker = *time.NewTicker(t.period) 18 | } 19 | 20 | func createTimeout(period time.Duration) *timeout { 21 | return &timeout{period, *time.NewTicker(period)} 22 | } 23 | 24 | func generateRandomInt(lower int, upper int) int { 25 | l := int64(lower) 26 | u := int64(upper) 27 | max := big.NewInt(u - l) 28 | r, err := rand.Int(rand.Reader, max) 29 | if err != nil { 30 | fmt.Println("Couldn't generate random int!") 31 | } 32 | return int(l + r.Int64()) 33 | } 34 | 35 | func createRandomTimeout(lower int, upper int, period time.Duration) *timeout { 36 | randomInt := generateRandomInt(lower, upper) 37 | period = time.Duration(randomInt) * period 38 | return createTimeout(period) 39 | } 40 | -------------------------------------------------------------------------------- /raft/commit.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "fmt" 5 | "github.com/arpith/mmapd/db" 6 | ) 7 | 8 | func (s *server) commitEntries(leaderCommit int) { 9 | fmt.Println("GOING TO COMMIT STUFFFFFFF!!!!") 10 | fmt.Println(s.commitIndex, leaderCommit) 11 | for i := s.commitIndex; i <= leaderCommit; i++ { 12 | fmt.Println(i, s.db.Log.Entries) 13 | if i > len(s.db.Log.Entries)-1 { 14 | fmt.Println("Can't commit an entry that's not in the log :)") 15 | return 16 | } 17 | if i == -1 { 18 | // s.commitIndex is initialised to -1 19 | // Will commit on the next iteration 20 | continue 21 | } 22 | entry := s.db.Log.Entries[i] 23 | c := make(chan db.ReturnChanMessage) 24 | m := db.WriteChanMessage{entry.Key, entry.Value, c} 25 | s.db.WriteChan <- m 26 | r := <-c 27 | if r.Err != nil { 28 | fmt.Println("Error committing entry: ", entry) 29 | fmt.Println(r.Err) 30 | break 31 | } else { 32 | fmt.Println("COMMITTED SOMETHING!!!!") 33 | s.commitIndex = i 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /raft/clientRequests.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/arpith/mmapd/db" 7 | ) 8 | 9 | type returnChanMessage struct { 10 | err error 11 | json string 12 | } 13 | 14 | type readRequest struct { 15 | key string 16 | returnChan chan db.ReturnChanMessage 17 | } 18 | 19 | type writeRequest struct { 20 | key string 21 | value string 22 | returnChan chan db.ReturnChanMessage 23 | } 24 | 25 | func (s *server) handleReadRequest(req readRequest) { 26 | key := req.key 27 | c := req.returnChan 28 | m := db.ReadChanMessage{key, c} 29 | s.db.ReadChan <- m 30 | } 31 | 32 | func (s *server) handleWriteRequest(req writeRequest) { 33 | key := req.key 34 | value := req.value 35 | command := "SET" 36 | c := make(chan bool) 37 | fmt.Println("GOT WRITE REQUEST!!!!") 38 | go s.appendEntry(command, key, value, c) 39 | isCommitted := <-c 40 | close(c) 41 | if isCommitted { 42 | m := db.WriteChanMessage{key, value, req.returnChan} 43 | s.db.WriteChan <- m 44 | } else { 45 | m := &db.ReturnChanMessage{errors.New("Couldn't Commit"), ""} 46 | req.returnChan <- *m 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Arpith Siromoney 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "github.com/arpith/mmapd/db" 6 | "github.com/arpith/mmapd/raft" 7 | "github.com/julienschmidt/httprouter" 8 | "net/http" 9 | "os" 10 | "strings" 11 | ) 12 | 13 | func main() { 14 | dbFilenamePtr := flag.String("db", "db.json", "database filename") 15 | logFilenamePtr := flag.String("log", "log.json", "log filename") 16 | configFilenamePtr := flag.String("config", "config.json", "config file name") 17 | portPtr := flag.String("port", "3001", "port to listen on") 18 | ipPtr := flag.String("ip", "localhost", "ip that the server is running on") 19 | flag.Parse() 20 | 21 | port := *portPtr 22 | if port == "3001" { 23 | envPort := strings.TrimSpace(os.Getenv("PORT")) 24 | if envPort != "" { 25 | port = envPort 26 | } 27 | } 28 | 29 | id := *ipPtr + ":" + port 30 | 31 | DB := db.Init(*dbFilenamePtr, *logFilenamePtr) 32 | server := raft.Init(id, *configFilenamePtr, DB) 33 | appendEntryHandler := raft.NewHandler(server, "Append Entry") 34 | requestForVoteHandler := raft.NewHandler(server, "Request For Vote") 35 | clientRequestHandler := raft.NewHandler(server, "Client Request") 36 | statusRequestHandler := raft.NewHandler(server, "Status Request") 37 | 38 | router := httprouter.New() 39 | router.POST("/append", appendEntryHandler) 40 | router.POST("/votes", requestForVoteHandler) 41 | router.GET("/get/:key", clientRequestHandler) 42 | router.POST("/set/:key", clientRequestHandler) 43 | router.GET("/status", statusRequestHandler) 44 | 45 | http.ListenAndServe(":"+port, router) 46 | } 47 | -------------------------------------------------------------------------------- /db/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | Primarily, data is stored in-memory and persisted to disk by the OS. This is done by memory mapping the log and database files into Go byte slices. The byte slices are marshalled/unmarshalled into structures using json. When a request is made a lookup is done. The database is a map of strings - key to values; the log is a slice of structs with 3 | where the lookup is typically based on index. 4 | 5 | Since these structures are not concurrency safe, a single listener goroutine takes ownership of the writes/reads and all other goroutines communicate via channels to send their requests to the listener. The messages contain a response channel over which the listener will send the response. This lets the goroutine that makes the read/write request wait on a single channel for the response. 6 | 7 | ## Tell me more! 8 | ### mmap 9 | Mmap is a system call that reads the content of a file into a byte slice in Go. Changes made to the byte slice are synced to disk by the OS. In future, this will be forced by the database. Stay tuned! 10 | 11 | ### ftruncate 12 | Ftruncate is a system call that resizes a file - this allows the database to grow the file used for persistence. Currently, the strategy is to resize and remap the file everytime it needs to grow, but this can be optimized by doubling each time instead. 13 | 14 | ### Data structures 15 | #### Database 16 | The database is primarily a map of strings, keys to values. There is also data stored like the file descriptor, etc, for resizing, and the actual byte slice that contains the data. 17 | 18 | #### Log 19 | The log is a slice of entries where each entry has a command (for now, that is just "SET"), a key, value and term (used by Raft). Again, the file descriptor etc is also stored, and the byte slice. 20 | 21 | ### JSON 22 | When the database starts up, the db/log files are read into byte slices which are then unmarshalled into the appropriate structs. When writes are made to the structs they are marshalled into the memory-mapped byte slices. This is then eventually synced to disk. 23 | 24 | ## Concurrency 25 | The strategy is to have a single goroutine that is responsible for the data structures which then listens on channels for read/write requests. These requests come from multiple goroutines that the Raft server sets up, so they will be coming in concurrently. Each request (which is a channel message) also contains a return channel on which the requester goroutine will be waiting on. The listener processes the request and sends the response over the return channel. The requester goroutine can then close the channel. 26 | -------------------------------------------------------------------------------- /db/log.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | type Entry struct { 11 | Command string 12 | Key string 13 | Value string 14 | Term int 15 | } 16 | 17 | type Log struct { 18 | data []byte 19 | Entries []Entry 20 | fd int 21 | filename string 22 | file *os.File 23 | } 24 | 25 | func (log *Log) load() { 26 | err := json.Unmarshal(log.data, &log.Entries) 27 | if err != nil { 28 | fmt.Println("Error unmarshalling initial data into map: ", err) 29 | } 30 | } 31 | 32 | func (log *Log) mmap(size int) { 33 | fmt.Println("mmapping log file: ", size) 34 | data, err := syscall.Mmap(log.fd, 0, size, syscall.PROT_WRITE|syscall.PROT_READ, syscall.MAP_SHARED) 35 | if err != nil { 36 | fmt.Println("Error mmapping: ", err) 37 | } 38 | log.data = data 39 | } 40 | 41 | func (log *Log) resize(size int) { 42 | fmt.Println("Resizing log file: ", size) 43 | err := syscall.Ftruncate(log.fd, int64(size)) 44 | if err != nil { 45 | fmt.Println("Error resizing log file: ", err) 46 | } 47 | } 48 | 49 | func (log *Log) open() { 50 | fmt.Println("Getting log file descriptor") 51 | f, err := os.OpenFile(log.filename, os.O_CREATE|os.O_RDWR, 0) 52 | if err != nil { 53 | fmt.Println("Could not open log file: ", err) 54 | } 55 | log.fd = int(f.Fd()) 56 | log.file = f 57 | } 58 | 59 | func (log *Log) extend(size int) { 60 | log.file.Close() 61 | log.open() 62 | log.resize(size) 63 | log.mmap(size) 64 | } 65 | 66 | func (log *Log) SetEntries(entries []Entry) { 67 | log.Entries = entries 68 | b, err := json.Marshal(log.Entries) 69 | if err != nil { 70 | fmt.Println("Error marshalling log: ", err) 71 | } 72 | if len(b) > len(log.data) { 73 | log.extend(len(b)) 74 | } 75 | copy(log.data, b) 76 | } 77 | 78 | func (log *Log) AppendEntry(entry Entry) { 79 | log.Entries = append(log.Entries, entry) 80 | b, err := json.Marshal(log.Entries) 81 | if err != nil { 82 | fmt.Println("Error marshalling log: ", err) 83 | } 84 | if len(b) > len(log.data) { 85 | log.extend(len(b)) 86 | } 87 | copy(log.data, b) 88 | } 89 | 90 | func initLog(filename string) *Log { 91 | var data []byte 92 | var entries []Entry 93 | var fd int 94 | var file *os.File 95 | log := &Log{data, entries, fd, filename, file} 96 | log.open() 97 | f, err := os.Stat(filename) 98 | if err != nil { 99 | fmt.Println("Could not stat file: ", err) 100 | } 101 | size := int(f.Size()) 102 | if size == 0 { 103 | size = 10 104 | log.resize(size) 105 | } 106 | log.mmap(size) 107 | log.load() 108 | return log 109 | } 110 | -------------------------------------------------------------------------------- /raft/README.md: -------------------------------------------------------------------------------- 1 | # Raft 2 | Raft is a distributed consensus algorithm/protocol. This is my implementation in-progress. 3 | 4 | ## Issues 5 | ### Stable initial election 6 | For a cluster of three nodes, stable election happens only when you kill a node a couple of times. My guess is that this is because a leader steps down if it gets a heartbeat with the same term. 7 | 8 | ### Bullying 9 | When a leader dies, and comes up after sometime, the desired behaviour is that it respects the state of the cluster and accepts missing entries from the current leader. This is not what happens here! Old leaders regain leadership when they come up! This could be related to the previous issue. 10 | 11 | ## How it works 12 | ### Timeouts 13 | There are two randomized timeouts, typically between 150 and 300ms. The heartbeat timeout is used by the leader to inform the followers that it is still up. When a follower gets a request from the leader it resets its election timeout. 14 | 15 | ### Leadership 16 | If a follower's election timeout goes off before receiving a request from the leader, promotes itself to candidate, starts an election, votes for itself, and requests votes from all other nodes. Each election is in a new term - an integer that is incremented each time. 17 | 18 | The other nodes vote for the first request they receive, and reject subsequent requests. If the candidate receives votes from a majority of the nodes, it promotes itself to leader and begins replicating its log to the other nodes. 19 | This is also the protocol to elect the leader when the nodes first start up. 20 | 21 | ### RPCs 22 | #### Request for votes 23 | When a candidate requests votes from the other nodes, it sends its id, the index of the last entry in its log and the term it was appended in. It also sends the current term. 24 | 25 | Nodes respond with the term they have seen last, and whether they are granting the vote or not. 26 | 27 | #### Append Entry 28 | When a leader gets a client request to set a key/value pair it appends it to its log and sends a request to all the followers to append the entry to their logs. The request also includes (apart from the entry itself) the current term, the leader's id, the index of the last log entry before this new one, and the term of that entry. Finally, the request also includes the index of the entry that was last committed to the leader's database. 29 | 30 | When a follower receives an append entry request, it responds with its term and success or failure. If the receiver's term is greater than the leader's, it responds false. It also responds false if the entry it stored at the leader's previous log index has a term that is different from the previous log term sent in the request. 31 | -------------------------------------------------------------------------------- /raft/server.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/arpith/mmapd/db" 7 | "io/ioutil" 8 | "time" 9 | ) 10 | 11 | type server struct { 12 | id string 13 | state string 14 | term int 15 | votedFor string 16 | db db.DB 17 | electionTimeout timeout 18 | heartbeatTimeout timeout 19 | config []string 20 | commitIndex int 21 | lastApplied int 22 | nextIndex []int 23 | matchIndex []int 24 | voteRequests chan voteRequest 25 | appendRequests chan appendRequest 26 | writeRequests chan writeRequest 27 | readRequests chan readRequest 28 | } 29 | 30 | func (s *server) listener() { 31 | for { 32 | select { 33 | case v := <-s.voteRequests: 34 | fmt.Println("Got vote request") 35 | s.handleRequestForVote(v) 36 | case e := <-s.appendRequests: 37 | s.handleAppendEntryRequest(e) 38 | case r := <-s.readRequests: 39 | s.handleReadRequest(r) 40 | case w := <-s.writeRequests: 41 | s.handleWriteRequest(w) 42 | case <-s.heartbeatTimeout.ticker.C: 43 | if s.state == "leader" { 44 | fmt.Println("Going to send heartbeats") 45 | c := make(chan bool) 46 | go s.appendEntry("", "", "", c) 47 | } 48 | case <-s.electionTimeout.ticker.C: 49 | if s.state != "leader" { 50 | fmt.Println("Going to start election") 51 | go s.startElection() 52 | } 53 | } 54 | } 55 | } 56 | 57 | func readConfig(filename string) []string { 58 | var config []string 59 | content, err := ioutil.ReadFile(filename) 60 | if err != nil { 61 | fmt.Println("Couldn't read config file") 62 | } 63 | err = json.Unmarshal(content, &config) 64 | if err != nil { 65 | fmt.Println("Error unmarshalling config file: ", err) 66 | } 67 | return config 68 | } 69 | 70 | func Init(id string, configFilename string, db *db.DB) *server { 71 | config := readConfig(configFilename) 72 | server := &server{ 73 | id: id, 74 | state: "follower", 75 | term: 0, 76 | votedFor: "", 77 | db: *db, 78 | electionTimeout: *createRandomTimeout(150, 300, time.Millisecond), 79 | heartbeatTimeout: *createRandomTimeout(150, 300, time.Millisecond), 80 | config: config, 81 | commitIndex: -1, 82 | lastApplied: -1, 83 | nextIndex: make([]int, len(config)), 84 | matchIndex: make([]int, len(config)), 85 | voteRequests: make(chan voteRequest), 86 | appendRequests: make(chan appendRequest), 87 | writeRequests: make(chan writeRequest), 88 | readRequests: make(chan readRequest), 89 | } 90 | go server.listener() 91 | return server 92 | } 93 | -------------------------------------------------------------------------------- /raft/handler.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/arpith/mmapd/db" 7 | "github.com/julienschmidt/httprouter" 8 | "net/http" 9 | ) 10 | 11 | func (s *server) appendEntryHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { 12 | var a appendEntryRequest 13 | decoder := json.NewDecoder(r.Body) 14 | err := decoder.Decode(&a) 15 | if err != nil { 16 | fmt.Println("Couldn't decode append entry request as json", err) 17 | } 18 | returnChan := make(chan appendEntryResponse) 19 | req := &appendRequest{ 20 | Req: a, 21 | ReturnChan: returnChan, 22 | } 23 | s.appendRequests <- *req 24 | resp := <-req.ReturnChan 25 | defer close(req.ReturnChan) 26 | json.NewEncoder(w).Encode(resp) 27 | } 28 | 29 | func (s *server) requestForVoteHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { 30 | var v requestForVote 31 | decoder := json.NewDecoder(r.Body) 32 | err := decoder.Decode(&v) 33 | if err != nil { 34 | fmt.Println("Couldn't decode vote request as json", err) 35 | } 36 | returnChan := make(chan requestForVoteResponse) 37 | req := &voteRequest{ 38 | Req: v, 39 | ReturnChan: returnChan, 40 | } 41 | s.voteRequests <- *req 42 | resp := <-req.ReturnChan 43 | defer close(req.ReturnChan) 44 | json.NewEncoder(w).Encode(resp) 45 | } 46 | 47 | func (s *server) clientRequestHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { 48 | switch r.Method { 49 | case "GET": 50 | key := ps.ByName("key") 51 | c := make(chan db.ReturnChanMessage) 52 | m := readRequest{key, c} 53 | s.readRequests <- m 54 | resp := <-c 55 | close(c) 56 | if resp.Err != nil { 57 | http.NotFound(w, r) 58 | } else { 59 | fmt.Fprint(w, resp.Json) 60 | } 61 | case "POST": 62 | key := ps.ByName("key") 63 | value := r.FormValue("value") 64 | c := make(chan db.ReturnChanMessage) 65 | m := writeRequest{key, value, c} 66 | s.writeRequests <- m 67 | resp := <-c 68 | close(c) 69 | if resp.Err != nil { 70 | fmt.Fprint(w, resp.Err) 71 | } else { 72 | fmt.Fprint(w, resp.Json) 73 | } 74 | } 75 | } 76 | 77 | func (s *server) statusRequestHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { 78 | status := &status{ 79 | s.id, 80 | s.state, 81 | s.term, 82 | s.votedFor, 83 | s.commitIndex, 84 | s.lastApplied, 85 | s.nextIndex, 86 | s.matchIndex, 87 | } 88 | fmt.Println(status) 89 | json.NewEncoder(w).Encode(*status) 90 | } 91 | 92 | func NewHandler(s *server, handlerType string) func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { 93 | switch handlerType { 94 | case "Append Entry": 95 | return s.appendEntryHandler 96 | case "Request For Vote": 97 | return s.requestForVoteHandler 98 | case "Status Request": 99 | return s.statusRequestHandler 100 | default: 101 | return s.clientRequestHandler 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mmapd 2 | Distributed key value datastore written in Go. Uses the Raft distributed consensus algorithm to replicate log, and uses mmap for persistence. 3 | 4 | ## Should I use it? 5 | Not yet! This project is under development, and you WILL lose data. 6 | 7 | ### What works? 8 | Currently, you can get three nodes running (tested on a single machine) and set values on the leader - they will be replicated (and persisted) in the followers' logs and dbs. You can also crash the leader and continue using the followers. 9 | 10 | ### What doesn't work? 11 | Lots of stuff! But primarily, if you re-start the former leader, it WILL try to regain leadership, and won't accept the entries it missed! 12 | 13 | ### Why is this interesting? 14 | My primary goal is to explore the Raft consensus algorithm and mmap. They are both quite cool! Follow my explorations on [Medium](medium.com/@arpith)! 15 | 16 | ## Under the hood 17 | ### Database 18 | `/db` has the files that handle persistence. The basic idea is that the database and log files are memory mapped into a byte slice, which is then parsed as JSON into a map of keys -> values. When the files need to grow, `ftruncate` is used to extend the underlying file, which is then re-mapped into memory. Finally the new byte slice is copied into the mmaped byte slice, which the OS then copies onto the file eventually. Going forward, this is going to be forced using `msync`. 19 | 20 | ### Consensus 21 | `/raft` has my implementation of the Raft consensus algorithm. When a leader receives a request from the client, it sends an append entry request to the followers. The followers then add this to their log and reply if successful. If a majority of the followers successfully replicate this entry onto their log, the leader commits the entry (to the database) and responds to the client. On the next heartbeat, the leader includes the new commit index and the followers commit the latest entry. 22 | 23 | If a follower doesn't receive a heartbeat before its election timeout goes off, it sends vote requests to all the other nodes. Nodes respond with success to the first vote request they receive (and reject the others). If a candidate receives successful responses from a majority of the nodes, it promotes itself to leader. This is also the protocol for the initial election. 24 | 25 | ## Usage 26 | ### Set a value 27 | Make a `POST` request to `/set/key` with the value as a parameter. 28 | 29 | ### Get a value 30 | Make a `GET` request to `/get/key` 31 | 32 | ### Get node status 33 | Make a `GET` request to `/status` 34 | 35 | ## Install 36 | `go get github.com/arpith/mmapd` 37 | 38 | ## Run 39 | `mmapd -port PORT -db DATABASE_FILENAME -log LOG_FILENAME` 40 | 41 | ### -port 42 | Set the port you want this node to be listening on. The default value is `3001` 43 | 44 | ### -db 45 | The filename used to store the database (for persistence). The default value is `db.json` 46 | 47 | ### -log 48 | The filename used to store the log (for persistence). The default value is `log.json` 49 | 50 | ## Example 51 | ``` 52 | $ mmapd -port 3001 -db db3001.json -log log3001.json \ 53 | && mmapd -port 3002 -db db3002.json -log log3002.json \ 54 | && mmapd -port 3003 -db db3003.json -log log3003.json 55 | ``` 56 | -------------------------------------------------------------------------------- /db/db.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "syscall" 9 | ) 10 | 11 | type ReturnChanMessage struct { 12 | Err error 13 | Json string 14 | } 15 | 16 | type ReadChanMessage struct { 17 | Key string 18 | ReturnChan chan ReturnChanMessage 19 | } 20 | 21 | type WriteChanMessage struct { 22 | Key string 23 | Value string 24 | ReturnChan chan ReturnChanMessage 25 | } 26 | 27 | type DB struct { 28 | data []byte 29 | dataMap map[string]string 30 | Log *Log 31 | fd int 32 | filename string 33 | file *os.File 34 | WriteChan chan WriteChanMessage 35 | ReadChan chan ReadChanMessage 36 | } 37 | 38 | func (db *DB) load() { 39 | err := json.Unmarshal(db.data, &db.dataMap) 40 | if err != nil { 41 | fmt.Println("Error unmarshalling initial data into map: ", err) 42 | } 43 | fmt.Println(db.dataMap) 44 | } 45 | 46 | func (db *DB) mmap(size int) { 47 | fmt.Println("mmapping db file: ", size) 48 | data, err := syscall.Mmap(db.fd, 0, size, syscall.PROT_WRITE|syscall.PROT_READ, syscall.MAP_SHARED) 49 | if err != nil { 50 | fmt.Println("Error mmapping: ", err) 51 | } 52 | db.data = data 53 | } 54 | 55 | func (db *DB) resize(size int) { 56 | fmt.Println("Resizing db file: ", size) 57 | err := syscall.Ftruncate(db.fd, int64(size)) 58 | if err != nil { 59 | fmt.Println("Error resizing: ", err) 60 | } 61 | } 62 | 63 | func (db *DB) open() { 64 | fmt.Println("Getting db file descriptor") 65 | f, err := os.OpenFile(db.filename, os.O_CREATE|os.O_RDWR, 0) 66 | if err != nil { 67 | fmt.Println("Could not open file: ", err) 68 | } 69 | db.fd = int(f.Fd()) 70 | db.file = f 71 | } 72 | 73 | func (db *DB) extend(size int) { 74 | db.file.Close() 75 | db.open() 76 | db.resize(size) 77 | db.mmap(size) 78 | } 79 | 80 | func (db *DB) write(key string, value string, returnChan chan ReturnChanMessage) { 81 | db.dataMap[key] = value 82 | b, err := json.Marshal(db.dataMap) 83 | if err != nil { 84 | fmt.Println("Error marshalling db: ", err) 85 | } 86 | if len(b) > len(db.data) { 87 | db.extend(len(b)) 88 | } 89 | copy(db.data, b) 90 | m := &ReturnChanMessage{Err: nil, Json: value} 91 | returnChan <- *m 92 | } 93 | 94 | func (db *DB) listener() { 95 | for { 96 | select { 97 | case writeReq := <-db.WriteChan: 98 | key := writeReq.Key 99 | value := writeReq.Value 100 | returnChan := writeReq.ReturnChan 101 | db.write(key, value, returnChan) 102 | 103 | case readReq := <-db.ReadChan: 104 | key := readReq.Key 105 | returnChan := readReq.ReturnChan 106 | if value, ok := db.dataMap[key]; ok { 107 | m := &ReturnChanMessage{nil, value} 108 | returnChan <- *m 109 | } else { 110 | m := &ReturnChanMessage{errors.New("Invalid Key"), ""} 111 | returnChan <- *m 112 | } 113 | } 114 | } 115 | } 116 | 117 | func Init(dbFilename string, logFilename string) *DB { 118 | log := initLog(logFilename) 119 | writeChan := make(chan WriteChanMessage) 120 | readChan := make(chan ReadChanMessage) 121 | dataMap := make(map[string]string) 122 | var data []byte 123 | var fd int 124 | var file *os.File 125 | db := &DB{data, dataMap, log, fd, dbFilename, file, writeChan, readChan} 126 | db.open() 127 | f, err := os.Stat(dbFilename) 128 | if err != nil { 129 | fmt.Println("Could not stat file: ", err) 130 | } 131 | 132 | db.mmap(int(f.Size())) 133 | db.load() 134 | go db.listener() 135 | return db 136 | } 137 | -------------------------------------------------------------------------------- /db/db_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/base64" 6 | "fmt" 7 | "io/ioutil" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "strconv" 12 | "testing" 13 | ) 14 | 15 | func TestGet(t *testing.T) { 16 | fmt.Println("Testing GET") 17 | resp, err := http.Get("http://localhost:3001/get/testKey") 18 | if err != nil { 19 | t.Error("Couldn't get testKey: ", err) 20 | } 21 | defer resp.Body.Close() 22 | body, err := ioutil.ReadAll(resp.Body) 23 | if err != nil { 24 | t.Error("Expected testValue got ", err) 25 | } else if string(body) != "testValue" { 26 | t.Error("Expected testValue got ", string(body)) 27 | } else { 28 | fmt.Println("Test Passed!") 29 | } 30 | } 31 | 32 | func TestSet(t *testing.T) { 33 | fmt.Println("Testing SET") 34 | c := 9 35 | b := make([]byte, c) 36 | _, err := rand.Read(b) 37 | if err != nil { 38 | t.Error("Couldn't generate random string: ", err) 39 | } 40 | s := base64.URLEncoding.EncodeToString(b) 41 | resp, err := http.PostForm("http://localhost:3001/set/"+s, url.Values{"value": {s}}) 42 | if err != nil { 43 | t.Error("Expected ", s, " got ", err) 44 | } 45 | defer resp.Body.Close() 46 | body, err := ioutil.ReadAll(resp.Body) 47 | if err != nil { 48 | t.Error("Expected ", s, " got ", err) 49 | } else if string(body) != s { 50 | t.Error("Expected ", s, " got ", string(body)) 51 | } else { 52 | fmt.Println("Test passed!") 53 | } 54 | } 55 | 56 | func TestSetAndGet(t *testing.T) { 57 | fmt.Println("Testing SET & GET") 58 | c := 9 59 | b := make([]byte, c) 60 | _, err := rand.Read(b) 61 | if err != nil { 62 | t.Error("Couldn't generate random string: ", err) 63 | } 64 | s := base64.URLEncoding.EncodeToString(b) 65 | resp, err := http.PostForm("http://localhost:3001/set/"+s, url.Values{"value": {s}}) 66 | if err != nil { 67 | t.Error("Couldn't set ", s, ": ", err) 68 | } 69 | defer resp.Body.Close() 70 | body, err := ioutil.ReadAll(resp.Body) 71 | if err != nil { 72 | t.Error("Expected ", s, " got ", err) 73 | } else if string(body) != s { 74 | t.Error("Expected ", s, " got ", string(body)) 75 | } else { 76 | fmt.Println("Set ", s) 77 | } 78 | resp, err = http.Get("http://localhost:3001/get/" + s) 79 | if err != nil { 80 | t.Error("Couldn't get ", s, ": ", err) 81 | } 82 | defer resp.Body.Close() 83 | body, err = ioutil.ReadAll(resp.Body) 84 | if err != nil { 85 | t.Error("Expected ", s, " got ", err) 86 | } else if string(body) != s { 87 | t.Error("Expected ", s, " got ", string(body)) 88 | } else { 89 | fmt.Println("Test passed!") 90 | } 91 | } 92 | 93 | func TestConcurrency(t *testing.T) { 94 | fmt.Println("Testing Concurrency") 95 | maxRequests := os.Getenv("MAX_REQUESTS") 96 | max, err := strconv.Atoi(maxRequests) 97 | if err != nil { 98 | max = 250 99 | } 100 | for i := 0; i < max; i++ { 101 | s := strconv.Itoa(i) 102 | resp, err := http.PostForm("http://localhost:3001/set/"+s, url.Values{"value": {s}}) 103 | if err != nil { 104 | t.Error("Expected ", s, " got ", err) 105 | } 106 | defer resp.Body.Close() 107 | } 108 | for i := 0; i < max; i++ { 109 | s := strconv.Itoa(i) 110 | resp, err := http.Get("http://localhost:3001/get/" + s) 111 | if err != nil { 112 | t.Error("Expected ", s, " got ", err) 113 | } 114 | body, err := ioutil.ReadAll(resp.Body) 115 | if err != nil || string(body) != s { 116 | t.Error("Expected ", s, " got ", err) 117 | } 118 | defer resp.Body.Close() 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /raft/voting.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | ) 10 | 11 | type requestForVote struct { 12 | Term int 13 | CandidateID string 14 | LastLogIndex int 15 | LastLogTerm int 16 | } 17 | 18 | type voteRequest struct { 19 | Req requestForVote 20 | ReturnChan chan requestForVoteResponse 21 | } 22 | 23 | type requestForVoteResponse struct { 24 | Term int 25 | HasGrantedVote bool 26 | } 27 | 28 | type voteResponse struct { 29 | ServerIndex int 30 | Resp requestForVoteResponse 31 | } 32 | 33 | func (s *server) stepDown(reason string) { 34 | t := time.Now() 35 | fmt.Println(t.Format("15:04:05.999999"), "SETTING FOLLOWER & RESETTING ELECTION TIMEOUT: ", reason) 36 | s.state = "follower" 37 | s.electionTimeout.reset() 38 | } 39 | 40 | func (s *server) becomeLeader() { 41 | //Initialize nextIndex values to the index after the last one in leader's log 42 | for follower, _ := range s.nextIndex { 43 | s.nextIndex[follower] = len(s.db.Log.Entries) 44 | } 45 | fmt.Println(s.nextIndex) 46 | fmt.Println("SETTING LEADER: got majority vote") 47 | s.state = "leader" 48 | fmt.Println("IM THE LEADER :D :D :D :D :D ") 49 | } 50 | 51 | func (s *server) handleRequestForVote(v voteRequest) { 52 | fmt.Println("Got vote request:", v.Req) 53 | req := v.Req 54 | returnChan := v.ReturnChan 55 | if req.Term < s.term { 56 | resp := &requestForVoteResponse{s.term, false} 57 | returnChan <- *resp 58 | } else { 59 | if req.Term > s.term { 60 | s.votedFor = "" 61 | s.stepDown("Got vote request with term > current term") 62 | } 63 | cond1 := s.votedFor == "" 64 | cond2 := s.votedFor == req.CandidateID 65 | cond3 := req.LastLogIndex >= len(s.db.Log.Entries)-1 66 | if (cond1 || cond2) && cond3 { 67 | s.votedFor = req.CandidateID 68 | s.term = req.Term 69 | resp := &requestForVoteResponse{s.term, true} 70 | returnChan <- *resp 71 | } 72 | } 73 | } 74 | 75 | func (s *server) sendRequestForVote(receiverIndex int, respChan chan voteResponse) { 76 | receiver := s.config[receiverIndex] 77 | lastLogIndex := len(s.db.Log.Entries) - 1 78 | lastLogTerm := 0 79 | if lastLogIndex > 0 { 80 | lastLogTerm = s.db.Log.Entries[lastLogIndex].Term 81 | } 82 | v := &requestForVote{s.term, s.id, lastLogIndex, lastLogTerm} 83 | b := new(bytes.Buffer) 84 | json.NewEncoder(b).Encode(v) 85 | resp, err := http.Post("http://"+receiver+"/votes", "application/json", b) 86 | if err != nil { 87 | r := &requestForVoteResponse{0, false} 88 | v := &voteResponse{receiverIndex, *r} 89 | fmt.Println("Couldn't send request for votes to " + receiver) 90 | fmt.Println(err) 91 | respChan <- *v 92 | } else { 93 | r := &requestForVoteResponse{} 94 | err := json.NewDecoder(resp.Body).Decode(r) 95 | if err != nil { 96 | r := &requestForVoteResponse{0, false} 97 | v := &voteResponse{receiverIndex, *r} 98 | fmt.Println("Couldn't decode request for vote response from ", s.config[receiverIndex]) 99 | respChan <- *v 100 | return 101 | } 102 | resp.Body.Close() 103 | voteResp := &voteResponse{receiverIndex, *r} 104 | respChan <- *voteResp 105 | } 106 | } 107 | 108 | func (s *server) startElection() { 109 | t := time.Now() 110 | fmt.Println(t.Format("15:04:05.999999"), "SETTING CANDIDATE: Going to start election") 111 | s.state = "candidate" 112 | s.term += 1 113 | s.votedFor = s.id 114 | voteCount := 1 115 | respChan := make(chan voteResponse) 116 | for receiverIndex, receiverId := range s.config { 117 | if receiverId != s.id { 118 | go s.sendRequestForVote(receiverIndex, respChan) 119 | } 120 | } 121 | responseCount := 0 122 | if len(s.config) > 1 { 123 | for { 124 | vote := <-respChan 125 | responseCount++ 126 | fmt.Println("Got vote response", vote) 127 | if vote.Resp.Term > s.term { 128 | s.stepDown("Got vote response with term greater than current term") 129 | break 130 | } 131 | if vote.Resp.HasGrantedVote { 132 | voteCount++ 133 | } 134 | if voteCount > (len(s.config)-1)/2 { 135 | s.becomeLeader() 136 | break 137 | } 138 | if responseCount == len(s.config)-2 { 139 | fmt.Println("Got all the responses") 140 | break 141 | } 142 | } 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /raft/appending.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "github.com/arpith/mmapd/db" 8 | "net/http" 9 | ) 10 | 11 | type appendEntryRequest struct { 12 | Term int 13 | LeaderID string 14 | PrevLogIndex int 15 | PrevLogTerm int 16 | Entry db.Entry 17 | LeaderCommit int 18 | } 19 | 20 | type appendRequest struct { 21 | Req appendEntryRequest 22 | ReturnChan chan appendEntryResponse 23 | } 24 | 25 | type appendEntryResponse struct { 26 | Term int 27 | Success bool 28 | } 29 | 30 | type followerResponse struct { 31 | ServerIndex int 32 | Resp appendEntryResponse 33 | } 34 | 35 | func (s *server) appendEntry(command string, key string, value string, isCommitted chan bool) { 36 | entry := &db.Entry{command, key, value, s.term} 37 | index := -1 38 | if command != "" { 39 | s.db.Log.AppendEntry(*entry) 40 | index = len(s.db.Log.Entries) - 1 41 | } 42 | respChan := make(chan followerResponse) 43 | for i := 0; i < len(s.config); i++ { 44 | if s.config[i] != s.id { 45 | go s.sendAppendEntryRequest(i, index, respChan) 46 | } else { 47 | if command != "" { 48 | //Update nextIndex and matchIndex for self 49 | s.nextIndex[i]++ 50 | s.matchIndex[i]++ 51 | } 52 | } 53 | } 54 | responseCount := 0 55 | if len(s.config) > 1 { 56 | for { 57 | _ = <-respChan 58 | responseCount++ 59 | for N := s.commitIndex + 1; N < len(s.db.Log.Entries); N++ { 60 | //Check if there exists an N > commitIndex 61 | count := 0 62 | for i := 0; i < len(s.matchIndex); i++ { 63 | if s.matchIndex[i] >= N { 64 | count++ 65 | } 66 | } 67 | // Check if a majority of matchIndex[i] >= N 68 | cond1 := count > len(s.matchIndex)/2 69 | // Check if log[N].term == currentTerm 70 | cond2 := s.db.Log.Entries[N].Term == s.term 71 | if cond1 && cond2 { 72 | // Set commitIndex to N 73 | s.commitIndex = N 74 | } else { 75 | break 76 | } 77 | } 78 | if s.commitIndex == index { 79 | isCommitted <- true 80 | return 81 | } 82 | if responseCount == len(s.config)-1 { 83 | fmt.Println("Got all responses!") 84 | /* 85 | isCommitted <- false 86 | return 87 | */ 88 | } 89 | } 90 | } else { 91 | if command != "" { 92 | s.commitIndex++ 93 | isCommitted <- true 94 | } 95 | } 96 | } 97 | 98 | func (s *server) sendAppendEntryRequest(followerIndex int, entryIndex int, respChan chan followerResponse) { 99 | entryP := &db.Entry{"", "", "", s.term} 100 | entry := *entryP 101 | follower := s.config[followerIndex] 102 | prevLogIndex := -1 103 | prevLogTerm := 0 104 | if entryIndex == -1 { 105 | fmt.Println(s.nextIndex) 106 | } 107 | if entryIndex == -1 && len(s.db.Log.Entries)-1 > s.nextIndex[followerIndex] { 108 | // When sending a heartbeat, if nextIndex < last log index, send the missing entry! 109 | entryIndex = s.nextIndex[followerIndex] 110 | } 111 | if entryIndex > -1 { 112 | if entryIndex < len(s.db.Log.Entries) { 113 | entry = s.db.Log.Entries[entryIndex] 114 | } 115 | prevLogIndex = entryIndex - 1 116 | } 117 | if prevLogIndex >= 0 && len(s.db.Log.Entries) > prevLogIndex { 118 | prevLogTerm = s.db.Log.Entries[prevLogIndex].Term 119 | } 120 | fmt.Println("Going to send Append Entry RPC to ", follower, " for entry ", entry, " (prevLogIndex: ", prevLogIndex, " )") 121 | a := &appendEntryRequest{s.term, s.id, prevLogIndex, prevLogTerm, entry, s.commitIndex} 122 | b := new(bytes.Buffer) 123 | json.NewEncoder(b).Encode(a) 124 | resp, err := http.Post("http://"+follower+"/append", "application/json", b) 125 | if err != nil { 126 | fmt.Println("Couldn't send append entry request to " + follower) 127 | fmt.Println(err) 128 | r := &appendEntryResponse{Success: false} 129 | f := &followerResponse{followerIndex, *r} 130 | respChan <- *f 131 | //go s.sendAppendEntryRequest(followerIndex, entryIndex, respChan) 132 | } else { 133 | r := &appendEntryResponse{} 134 | err := json.NewDecoder(resp.Body).Decode(r) 135 | if err != nil { 136 | fmt.Println("Couldn't decode append entries response from " + s.config[followerIndex]) 137 | return 138 | } 139 | resp.Body.Close() 140 | if r.Term > s.term { 141 | s.term = r.Term 142 | s.stepDown("Got append entry RPC response with term > current term") 143 | } 144 | if r.Success { 145 | s.term = r.Term 146 | if entryIndex != -1 { 147 | s.nextIndex[followerIndex]++ 148 | s.matchIndex[followerIndex]++ 149 | } 150 | followerResp := &followerResponse{followerIndex, *r} 151 | respChan <- *followerResp 152 | } else { 153 | if s.nextIndex[followerIndex] > -1 { 154 | s.nextIndex[followerIndex]-- 155 | } 156 | fmt.Println(s.nextIndex[followerIndex], len(s.db.Log.Entries)) 157 | go s.sendAppendEntryRequest(followerIndex, s.nextIndex[followerIndex], respChan) 158 | } 159 | } 160 | } 161 | 162 | func (s *server) handleAppendEntryRequest(a appendRequest) { 163 | fmt.Println("Append entry: ", a.Req) 164 | returnChan := a.ReturnChan 165 | req := a.Req 166 | if req.Term < s.term { 167 | fmt.Println("append entry RPC has term < current term, responding false") 168 | resp := &appendEntryResponse{s.term, false} 169 | returnChan <- *resp 170 | } else if req.PrevLogIndex > -1 && 171 | len(s.db.Log.Entries) > req.PrevLogIndex && 172 | s.db.Log.Entries[req.PrevLogIndex].Term != req.PrevLogTerm { 173 | fmt.Println("responding false to append entry RPC: prev log term clash") 174 | resp := &appendEntryResponse{s.term, false} 175 | returnChan <- *resp 176 | } else { 177 | if req.Term > s.term { 178 | s.term = req.Term 179 | s.stepDown("append entries RPC has term greater than current term") 180 | } 181 | if req.Entry.Command == "" { 182 | s.term = req.Term 183 | s.stepDown("append entries RPC has term >= current term AND is a heartbeat") 184 | fmt.Println("got a heartbeat, responding true: term >= current term") 185 | } else { 186 | if len(s.db.Log.Entries) > req.PrevLogIndex+2 { 187 | if s.db.Log.Entries[req.PrevLogIndex+1].Term != req.Term { 188 | // If existing entry conflicts with new entry 189 | // Delete entry and all that follow it 190 | s.db.Log.SetEntries(s.db.Log.Entries[:req.PrevLogIndex+1]) 191 | } 192 | } 193 | s.db.Log.AppendEntry(req.Entry) 194 | } 195 | if req.LeaderCommit > s.commitIndex { 196 | // Set commit index to the min of the leader's commit index and index of last new entry 197 | if req.Entry.Command == "" || req.LeaderCommit < req.PrevLogIndex+1 { 198 | fmt.Println("Going to commit entries with leader commit: ", req.LeaderCommit) 199 | s.commitEntries(req.LeaderCommit) 200 | } else { 201 | fmt.Println("Going to commit entries with prevLogIndex: ", req.PrevLogIndex) 202 | s.commitEntries(req.PrevLogIndex + 1) 203 | } 204 | } 205 | resp := &appendEntryResponse{s.term, true} 206 | returnChan <- *resp 207 | } 208 | } 209 | --------------------------------------------------------------------------------