├── .idea ├── .gitignore ├── 6.824.iml ├── deployment.xml ├── leetcode │ └── statistics.xml ├── modules.xml ├── sshConfigs.xml ├── vcs.xml └── webServers.xml ├── Makefile ├── README.md └── src ├── .gitignore ├── go.mod ├── go.sum ├── kvraft ├── client.go ├── common.go ├── config.go ├── dstest.py ├── server.go └── test_test.go ├── labgob ├── labgob.go └── test_test.go ├── labrpc ├── labrpc.go └── test_test.go ├── main ├── diskvd.go ├── lockc.go ├── lockd.go ├── mrcoordinator.go ├── mrsequential.go ├── mrworker.go ├── pbc.go ├── pbd.go ├── pg-being_ernest.txt ├── pg-dorian_gray.txt ├── pg-frankenstein.txt ├── pg-grimm.txt ├── pg-huckleberry_finn.txt ├── pg-metamorphosis.txt ├── pg-sherlock_holmes.txt ├── pg-tom_sawyer.txt ├── test-mr-many.sh ├── test-mr.sh └── viewd.go ├── models └── kv.go ├── mr ├── coordinator.go ├── rpc.go └── worker.go ├── mrapps ├── crash.go ├── early_exit.go ├── indexer.go ├── jobcount.go ├── mtiming.go ├── nocrash.go ├── rtiming.go └── wc.go ├── porcupine ├── bitset.go ├── checker.go ├── model.go ├── porcupine.go └── visualization.go ├── raft ├── config.go ├── dstest.py ├── persister.go ├── raft.go ├── test_test.go └── util.go ├── shardctrler ├── client.go ├── common.go ├── config.go ├── server.go └── test_test.go └── shardkv ├── client.go ├── common.go ├── config.go ├── server.go └── test_test.go /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/6.824.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 15 | -------------------------------------------------------------------------------- /.idea/leetcode/statistics.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 20 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # This is the Makefile helping you submit the labs. 2 | # Just create 6.824/api.key with your API key in it, 3 | # and submit your lab with the following command: 4 | # $ make [lab1|lab2a|lab2b|lab2c|lab2d|lab3a|lab3b|lab4a|lab4b] 5 | 6 | LABS=" lab1 lab2a lab2b lab2c lab2d lab3a lab3b lab4a lab4b " 7 | 8 | %: check-% 9 | @echo "Preparing $@-handin.tar.gz" 10 | @if echo $(LABS) | grep -q " $@ " ; then \ 11 | echo "Tarring up your submission..." ; \ 12 | COPYFILE_DISABLE=1 tar cvzf $@-handin.tar.gz \ 13 | "--exclude=src/main/pg-*.txt" \ 14 | "--exclude=src/main/diskvd" \ 15 | "--exclude=src/mapreduce/824-mrinput-*.txt" \ 16 | "--exclude=src/main/mr-*" \ 17 | "--exclude=mrtmp.*" \ 18 | "--exclude=src/main/diff.out" \ 19 | "--exclude=src/main/mrcoordinator" \ 20 | "--exclude=src/main/mrsequential" \ 21 | "--exclude=src/main/mrworker" \ 22 | "--exclude=*.so" \ 23 | Makefile src; \ 24 | if ! test -e api.key ; then \ 25 | echo "Missing $(PWD)/api.key. Please create the file with your key in it or submit the $@-handin.tar.gz via the web interface."; \ 26 | else \ 27 | echo "Are you sure you want to submit $@? Enter 'yes' to continue:"; \ 28 | read line; \ 29 | if test "$$line" != "yes" ; then echo "Giving up submission"; exit; fi; \ 30 | if test `stat -c "%s" "$@-handin.tar.gz" 2>/dev/null || stat -f "%z" "$@-handin.tar.gz"` -ge 20971520 ; then echo "File exceeds 20MB."; exit; fi; \ 31 | cat api.key | tr -d '\n' > .api.key.trimmed ; \ 32 | curl --silent --fail --show-error -F file=@$@-handin.tar.gz -F "key=<.api.key.trimmed" \ 33 | https://6824.scripts.mit.edu/2022/handin.py/upload > /dev/null || { \ 34 | echo ; \ 35 | echo "Submit seems to have failed."; \ 36 | echo "Please upload the tarball manually on the submission website."; } \ 37 | fi; \ 38 | else \ 39 | echo "Bad target $@. Usage: make [$(LABS)]"; \ 40 | fi 41 | 42 | .PHONY: check-% 43 | check-%: 44 | @echo "Checking that your submission builds correctly..." 45 | @./.check-build git://g.csail.mit.edu/6.824-golabs-2022 $(patsubst check-%,%,$@) 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIT6.824-2022 2 | 课程链接:https://pdos.csail.mit.edu/6.824/
3 | 4 | ### 完成的Labs:
5 | * Lab 1 : 2022/06/6->2022/06/12 [6天] 6 | * 7 | * Lab2A: 2022/06/12->2022/09/13 [10天] 8 | * Lab2B: 2022/09/13->2022/09/19 [7天] 9 | * Lab2C: 2022/09/20->2022/10/09 [19天] 10 | * Lab2D: 2022/10/11->2022/10/17 [7天] 11 | * 12 | * Lab3A: 2022/10/18->2022/10/22 [5天] 13 | * Lab3B: 2022/10/27->2022/11/24 [10天] -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | *.*/ 2 | main/mr-tmp/ 3 | mrtmp.* 4 | 824-mrinput-*.txt 5 | /main/diff.out 6 | /mapreduce/x.txt 7 | /pbservice/x.txt 8 | /kvpaxos/x.txt 9 | *.so 10 | /main/mrcoordinator 11 | /main/mrsequential 12 | /main/mrworker 13 | -------------------------------------------------------------------------------- /src/go.mod: -------------------------------------------------------------------------------- 1 | module mit6.824 2 | 3 | go 1.15 4 | -------------------------------------------------------------------------------- /src/go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vacant2333/MIT6.824-2022/937eb6e6d400c7ee7cdc6f9e978bb54687e596a4/src/go.sum -------------------------------------------------------------------------------- /src/kvraft/client.go: -------------------------------------------------------------------------------- 1 | package kvraft 2 | 3 | import ( 4 | "mit6.824/labrpc" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type task struct { 10 | index RequestId // 对于当前Client的任务的Index 11 | op string // 任务类型 12 | key string // Get/PutAppend参数 13 | value string // PutAppend参数 14 | resultCh chan string // 传Get的返回值和Block住Get/PutAppend方法 15 | } 16 | 17 | type Clerk struct { 18 | servers []*labrpc.ClientEnd 19 | taskMu sync.Mutex 20 | taskQueue chan task // 任务队列 21 | clientTag ClientId // Client的唯一标识 22 | taskIndex RequestId // 最后一条任务的下标(包括未完成的任务) 23 | leaderIndex int // 上一次成功完成任务的Leader的Index,没有的话为-1 24 | } 25 | 26 | func MakeClerk(servers []*labrpc.ClientEnd) *Clerk { 27 | ck := &Clerk{ 28 | servers: servers, 29 | taskQueue: make(chan task), 30 | clientTag: nRand(), 31 | leaderIndex: -1, 32 | } 33 | go ck.doTasks() 34 | return ck 35 | } 36 | 37 | // 持续通过ck.taskQueue接受新的任务 38 | func (ck *Clerk) doTasks() { 39 | for { 40 | currentTask := <-ck.taskQueue 41 | DPrintf("C[%v] start a task:[%v]\n", ck.clientTag, currentTask) 42 | var args interface{} 43 | // 根据任务类型设置args 44 | if currentTask.op == "Get" { 45 | // Get task 46 | args = &GetArgs{ 47 | Key: currentTask.key, 48 | TaskIndex: currentTask.index, 49 | ClientTag: ck.clientTag, 50 | } 51 | } else { 52 | // Put/Append task 53 | args = &PutAppendArgs{ 54 | Key: currentTask.key, 55 | Value: currentTask.value, 56 | Op: currentTask.op, 57 | TaskIndex: currentTask.index, 58 | ClientTag: ck.clientTag, 59 | } 60 | } 61 | for { 62 | if err, value := ck.startTask(currentTask.op, args); err != ErrNoLeader { 63 | // 任务完成,Err不一定是OK,也可能是ErrNoKey 64 | DPrintf("C[%v] success a task:[%v]\n", ck.clientTag, currentTask) 65 | // 如果是Get会传回value,如果是Put/Append会传回"",让Append请求完成 66 | currentTask.resultCh <- value 67 | break 68 | } 69 | time.Sleep(clientNoLeaderSleepTime) 70 | } 71 | } 72 | } 73 | 74 | // 并行的向所有Servers发送某个Task 75 | func (ck *Clerk) startTask(op string, args interface{}) (Err, string) { 76 | // 所有的Reply发送到该Ch 77 | replyCh := make(chan interface{}, len(ck.servers)) 78 | // 当前Reply的Server 79 | serverCh := make(chan int, len(ck.servers)) 80 | // 初始化Reply 81 | replies := make([]interface{}, len(ck.servers)) 82 | for index := range replies { 83 | if op == "Get" { 84 | replies[index] = &GetReply{} 85 | } else { 86 | replies[index] = &PutAppendReply{} 87 | } 88 | } 89 | // 向某个Server提交Task 90 | askServer := func(server int) { 91 | if op == "Get" { 92 | ck.servers[server].Call("KVServer.Get", args, replies[server]) 93 | } else { 94 | ck.servers[server].Call("KVServer.PutAppend", args, replies[server]) 95 | } 96 | replyCh <- replies[server] 97 | serverCh <- server 98 | } 99 | // 会收到的Reply的数量 100 | replyCount := len(ck.servers) 101 | if ck.leaderIndex != -1 { 102 | // 优先发给上一次保存的Leader 103 | go askServer(ck.leaderIndex) 104 | replyCount = 1 105 | } else { 106 | // 没有保存leaderIndex,从所有服务器拿结果 107 | for server := 0; server < len(ck.servers); server++ { 108 | go askServer(server) 109 | } 110 | } 111 | // 持续检查replyCh,如果有可用的reply则直接返回 112 | timeOut := time.After(clientDoTaskTimeOut) 113 | for ; replyCount > 0; replyCount-- { 114 | var reply interface{} 115 | select { 116 | case reply = <-replyCh: 117 | // 拿到了reply 118 | case <-timeOut: 119 | // 任务超时 120 | DPrintf("C[%v] task[%v] timeout,leaderIndex[%v]\n", ck.clientTag, args, ck.leaderIndex) 121 | ck.leaderIndex = -1 122 | return ErrNoLeader, "" 123 | } 124 | server := <-serverCh 125 | // 如果Reply不为空则返回对应的数据给ch 126 | if op == "Get" && reply != nil { 127 | // Get 128 | getReply := reply.(*GetReply) 129 | if getReply.Err == OK || getReply.Err == ErrNoKey { 130 | ck.leaderIndex = server 131 | return getReply.Err, getReply.Value 132 | } 133 | } else if reply != nil { 134 | // Put/Append 135 | putAppendReply := reply.(*PutAppendReply) 136 | if putAppendReply.Err == OK { 137 | ck.leaderIndex = server 138 | return putAppendReply.Err, "" 139 | } 140 | } 141 | } 142 | // 没有可用的Leader或是保存的leaderIndex失效 143 | ck.leaderIndex = -1 144 | return ErrNoLeader, "" 145 | } 146 | 147 | // 添加任务,返回任务结果的chan 148 | func (ck *Clerk) addTask(op string, key string, value string) chan string { 149 | resultCh := make(chan string) 150 | ck.taskMu.Lock() 151 | ck.taskQueue <- task{ 152 | index: ck.taskIndex + 1, 153 | op: op, 154 | key: key, 155 | value: value, 156 | resultCh: resultCh, 157 | } 158 | ck.taskIndex++ 159 | ck.taskMu.Unlock() 160 | return resultCh 161 | } 162 | 163 | func (ck *Clerk) Get(key string) string { 164 | return <-ck.addTask("Get", key, "") 165 | } 166 | 167 | func (ck *Clerk) PutAppend(key string, value string, op string) { 168 | <-ck.addTask(op, key, value) 169 | } 170 | 171 | func (ck *Clerk) Put(key string, value string) { 172 | ck.PutAppend(key, value, "Put") 173 | } 174 | func (ck *Clerk) Append(key string, value string) { 175 | ck.PutAppend(key, value, "Append") 176 | } 177 | -------------------------------------------------------------------------------- /src/kvraft/common.go: -------------------------------------------------------------------------------- 1 | package kvraft 2 | 3 | import ( 4 | "crypto/rand" 5 | "log" 6 | "math/big" 7 | "time" 8 | ) 9 | 10 | const ( 11 | Debug = false 12 | 13 | OK = "OK" 14 | ErrNoKey = "ErrNoKey" 15 | ErrWrongLeader = "ErrWrongLeader" 16 | ErrNoLeader = "ErrNoLeader" 17 | // Client任务超时 18 | clientDoTaskTimeOut = 800 * time.Millisecond 19 | // Client没找到Leader的等待时间 20 | clientNoLeaderSleepTime = 65 * time.Millisecond 21 | // 当Raft的ReadStateSize大于(该值*maxRaftState)时开始Snapshot 22 | serverSnapshotStatePercent = 0.9 23 | ) 24 | 25 | type Err string 26 | 27 | // ClientId 每个Client的唯一Tag 28 | type ClientId int64 29 | 30 | // RequestId 每个Task的Index 31 | type RequestId int64 32 | 33 | type PutAppendArgs struct { 34 | Key string 35 | Value string 36 | Op string // Put/Append 37 | ClientTag ClientId 38 | TaskIndex RequestId 39 | } 40 | 41 | type PutAppendReply struct { 42 | Err Err 43 | } 44 | 45 | type GetArgs struct { 46 | Key string 47 | ClientTag ClientId 48 | TaskIndex RequestId 49 | } 50 | 51 | type GetReply struct { 52 | Err Err 53 | Value string 54 | } 55 | 56 | func DPrintf(format string, a ...interface{}) { 57 | if Debug { 58 | log.Printf(format, a...) 59 | } 60 | } 61 | 62 | // 获得一个int64的随机数(Client的Tag) 63 | func nRand() ClientId { 64 | max := big.NewInt(int64(1) << 62) 65 | bigX, _ := rand.Int(rand.Reader, max) 66 | return ClientId(bigX.Int64()) 67 | } 68 | -------------------------------------------------------------------------------- /src/kvraft/config.go: -------------------------------------------------------------------------------- 1 | package kvraft 2 | 3 | import "mit6.824/labrpc" 4 | import "testing" 5 | import "os" 6 | 7 | // import "log" 8 | import crand "crypto/rand" 9 | import "math/big" 10 | import "math/rand" 11 | import "encoding/base64" 12 | import "sync" 13 | import "runtime" 14 | import "mit6.824/raft" 15 | import "fmt" 16 | import "time" 17 | import "sync/atomic" 18 | 19 | func randstring(n int) string { 20 | b := make([]byte, 2*n) 21 | crand.Read(b) 22 | s := base64.URLEncoding.EncodeToString(b) 23 | return s[0:n] 24 | } 25 | 26 | func makeSeed() int64 { 27 | max := big.NewInt(int64(1) << 62) 28 | bigx, _ := crand.Int(crand.Reader, max) 29 | x := bigx.Int64() 30 | return x 31 | } 32 | 33 | // Randomize server handles 34 | func random_handles(kvh []*labrpc.ClientEnd) []*labrpc.ClientEnd { 35 | sa := make([]*labrpc.ClientEnd, len(kvh)) 36 | copy(sa, kvh) 37 | for i := range sa { 38 | j := rand.Intn(i + 1) 39 | sa[i], sa[j] = sa[j], sa[i] 40 | } 41 | return sa 42 | } 43 | 44 | type config struct { 45 | mu sync.Mutex 46 | t *testing.T 47 | net *labrpc.Network 48 | n int 49 | kvservers []*KVServer 50 | saved []*raft.Persister 51 | endnames [][]string // names of each server's sending ClientEnds 52 | clerks map[*Clerk][]string 53 | nextClientId int 54 | maxraftstate int 55 | start time.Time // time at which make_config() was called 56 | // begin()/end() statistics 57 | t0 time.Time // time at which test_test.go called cfg.begin() 58 | rpcs0 int // rpcTotal() at start of test 59 | ops int32 // number of clerk get/put/append method calls 60 | } 61 | 62 | func (cfg *config) checkTimeout() { 63 | // enforce a two minute real-time limit on each test 64 | if !cfg.t.Failed() && time.Since(cfg.start) > 120*time.Second { 65 | cfg.t.Fatal("test took longer than 120 seconds") 66 | } 67 | } 68 | 69 | func (cfg *config) cleanup() { 70 | cfg.mu.Lock() 71 | defer cfg.mu.Unlock() 72 | for i := 0; i < len(cfg.kvservers); i++ { 73 | if cfg.kvservers[i] != nil { 74 | cfg.kvservers[i].Kill() 75 | } 76 | } 77 | cfg.net.Cleanup() 78 | cfg.checkTimeout() 79 | } 80 | 81 | // Maximum log size across all servers 82 | func (cfg *config) LogSize() int { 83 | logsize := 0 84 | for i := 0; i < cfg.n; i++ { 85 | n := cfg.saved[i].RaftStateSize() 86 | if n > logsize { 87 | logsize = n 88 | } 89 | } 90 | return logsize 91 | } 92 | 93 | // Maximum snapshot size across all servers 94 | func (cfg *config) SnapshotSize() int { 95 | snapshotsize := 0 96 | for i := 0; i < cfg.n; i++ { 97 | n := cfg.saved[i].SnapshotSize() 98 | if n > snapshotsize { 99 | snapshotsize = n 100 | } 101 | } 102 | return snapshotsize 103 | } 104 | 105 | // attach server i to servers listed in to 106 | // caller must hold cfg.mu 107 | func (cfg *config) connectUnlocked(i int, to []int) { 108 | // log.Printf("connect peer %d to %v\n", i, to) 109 | 110 | // outgoing socket files 111 | for j := 0; j < len(to); j++ { 112 | endname := cfg.endnames[i][to[j]] 113 | cfg.net.Enable(endname, true) 114 | } 115 | 116 | // incoming socket files 117 | for j := 0; j < len(to); j++ { 118 | endname := cfg.endnames[to[j]][i] 119 | cfg.net.Enable(endname, true) 120 | } 121 | } 122 | 123 | func (cfg *config) connect(i int, to []int) { 124 | cfg.mu.Lock() 125 | defer cfg.mu.Unlock() 126 | cfg.connectUnlocked(i, to) 127 | } 128 | 129 | // detach server i from the servers listed in from 130 | // caller must hold cfg.mu 131 | func (cfg *config) disconnectUnlocked(i int, from []int) { 132 | // log.Printf("disconnect peer %d from %v\n", i, from) 133 | 134 | // outgoing socket files 135 | for j := 0; j < len(from); j++ { 136 | if cfg.endnames[i] != nil { 137 | endname := cfg.endnames[i][from[j]] 138 | cfg.net.Enable(endname, false) 139 | } 140 | } 141 | 142 | // incoming socket files 143 | for j := 0; j < len(from); j++ { 144 | if cfg.endnames[j] != nil { 145 | endname := cfg.endnames[from[j]][i] 146 | cfg.net.Enable(endname, false) 147 | } 148 | } 149 | } 150 | 151 | func (cfg *config) disconnect(i int, from []int) { 152 | cfg.mu.Lock() 153 | defer cfg.mu.Unlock() 154 | cfg.disconnectUnlocked(i, from) 155 | } 156 | 157 | func (cfg *config) All() []int { 158 | all := make([]int, cfg.n) 159 | for i := 0; i < cfg.n; i++ { 160 | all[i] = i 161 | } 162 | return all 163 | } 164 | 165 | func (cfg *config) ConnectAll() { 166 | cfg.mu.Lock() 167 | defer cfg.mu.Unlock() 168 | for i := 0; i < cfg.n; i++ { 169 | cfg.connectUnlocked(i, cfg.All()) 170 | } 171 | } 172 | 173 | // Sets up 2 partitions with connectivity between servers in each partition. 174 | func (cfg *config) partition(p1 []int, p2 []int) { 175 | cfg.mu.Lock() 176 | defer cfg.mu.Unlock() 177 | // log.Printf("partition servers into: %v %v\n", p1, p2) 178 | for i := 0; i < len(p1); i++ { 179 | cfg.disconnectUnlocked(p1[i], p2) 180 | cfg.connectUnlocked(p1[i], p1) 181 | } 182 | for i := 0; i < len(p2); i++ { 183 | cfg.disconnectUnlocked(p2[i], p1) 184 | cfg.connectUnlocked(p2[i], p2) 185 | } 186 | } 187 | 188 | // Create a clerk with clerk specific server names. 189 | // Give it connections to all of the servers, but for 190 | // now enable only connections to servers in to[]. 191 | func (cfg *config) makeClient(to []int) *Clerk { 192 | cfg.mu.Lock() 193 | defer cfg.mu.Unlock() 194 | 195 | // a fresh set of ClientEnds. 196 | ends := make([]*labrpc.ClientEnd, cfg.n) 197 | endnames := make([]string, cfg.n) 198 | for j := 0; j < cfg.n; j++ { 199 | endnames[j] = randstring(20) 200 | ends[j] = cfg.net.MakeEnd(endnames[j]) 201 | cfg.net.Connect(endnames[j], j) 202 | } 203 | 204 | ck := MakeClerk(random_handles(ends)) 205 | cfg.clerks[ck] = endnames 206 | cfg.nextClientId++ 207 | cfg.ConnectClientUnlocked(ck, to) 208 | return ck 209 | } 210 | 211 | func (cfg *config) deleteClient(ck *Clerk) { 212 | cfg.mu.Lock() 213 | defer cfg.mu.Unlock() 214 | 215 | v := cfg.clerks[ck] 216 | for i := 0; i < len(v); i++ { 217 | os.Remove(v[i]) 218 | } 219 | delete(cfg.clerks, ck) 220 | } 221 | 222 | // caller should hold cfg.mu 223 | func (cfg *config) ConnectClientUnlocked(ck *Clerk, to []int) { 224 | // log.Printf("ConnectClient %v to %v\n", ck, to) 225 | endnames := cfg.clerks[ck] 226 | for j := 0; j < len(to); j++ { 227 | s := endnames[to[j]] 228 | cfg.net.Enable(s, true) 229 | } 230 | } 231 | 232 | func (cfg *config) ConnectClient(ck *Clerk, to []int) { 233 | cfg.mu.Lock() 234 | defer cfg.mu.Unlock() 235 | cfg.ConnectClientUnlocked(ck, to) 236 | } 237 | 238 | // caller should hold cfg.mu 239 | func (cfg *config) DisconnectClientUnlocked(ck *Clerk, from []int) { 240 | // log.Printf("DisconnectClient %v from %v\n", ck, from) 241 | endnames := cfg.clerks[ck] 242 | for j := 0; j < len(from); j++ { 243 | s := endnames[from[j]] 244 | cfg.net.Enable(s, false) 245 | } 246 | } 247 | 248 | func (cfg *config) DisconnectClient(ck *Clerk, from []int) { 249 | cfg.mu.Lock() 250 | defer cfg.mu.Unlock() 251 | cfg.DisconnectClientUnlocked(ck, from) 252 | } 253 | 254 | // Shutdown a server by isolating it 255 | func (cfg *config) ShutdownServer(i int) { 256 | cfg.mu.Lock() 257 | defer cfg.mu.Unlock() 258 | 259 | cfg.disconnectUnlocked(i, cfg.All()) 260 | 261 | // disable client connections to the server. 262 | // it's important to do this before creating 263 | // the new Persister in saved[i], to avoid 264 | // the possibility of the server returning a 265 | // positive reply to an Append but persisting 266 | // the result in the superseded Persister. 267 | cfg.net.DeleteServer(i) 268 | 269 | // a fresh persister, in case old instance 270 | // continues to update the Persister. 271 | // but copy old persister's content so that we always 272 | // pass Make() the last persisted state. 273 | if cfg.saved[i] != nil { 274 | cfg.saved[i] = cfg.saved[i].Copy() 275 | } 276 | 277 | kv := cfg.kvservers[i] 278 | if kv != nil { 279 | cfg.mu.Unlock() 280 | kv.Kill() 281 | cfg.mu.Lock() 282 | cfg.kvservers[i] = nil 283 | } 284 | } 285 | 286 | // If restart servers, first call ShutdownServer 287 | func (cfg *config) StartServer(i int) { 288 | cfg.mu.Lock() 289 | 290 | // a fresh set of outgoing ClientEnd names. 291 | cfg.endnames[i] = make([]string, cfg.n) 292 | for j := 0; j < cfg.n; j++ { 293 | cfg.endnames[i][j] = randstring(20) 294 | } 295 | 296 | // a fresh set of ClientEnds. 297 | ends := make([]*labrpc.ClientEnd, cfg.n) 298 | for j := 0; j < cfg.n; j++ { 299 | ends[j] = cfg.net.MakeEnd(cfg.endnames[i][j]) 300 | cfg.net.Connect(cfg.endnames[i][j], j) 301 | } 302 | 303 | // a fresh persister, so old instance doesn't overwrite 304 | // new instance's persisted state. 305 | // give the fresh persister a copy of the old persister's 306 | // state, so that the spec is that we pass StartKVServer() 307 | // the last persisted state. 308 | if cfg.saved[i] != nil { 309 | cfg.saved[i] = cfg.saved[i].Copy() 310 | } else { 311 | cfg.saved[i] = raft.MakePersister() 312 | } 313 | cfg.mu.Unlock() 314 | 315 | cfg.kvservers[i] = StartKVServer(ends, i, cfg.saved[i], cfg.maxraftstate) 316 | 317 | kvsvc := labrpc.MakeService(cfg.kvservers[i]) 318 | rfsvc := labrpc.MakeService(cfg.kvservers[i].rf) 319 | srv := labrpc.MakeServer() 320 | srv.AddService(kvsvc) 321 | srv.AddService(rfsvc) 322 | cfg.net.AddServer(i, srv) 323 | } 324 | 325 | func (cfg *config) Leader() (bool, int) { 326 | cfg.mu.Lock() 327 | defer cfg.mu.Unlock() 328 | 329 | for i := 0; i < cfg.n; i++ { 330 | _, is_leader := cfg.kvservers[i].rf.GetState() 331 | if is_leader { 332 | return true, i 333 | } 334 | } 335 | return false, 0 336 | } 337 | 338 | // Partition servers into 2 groups and put current leader in minority 339 | func (cfg *config) make_partition() ([]int, []int) { 340 | _, l := cfg.Leader() 341 | p1 := make([]int, cfg.n/2+1) 342 | p2 := make([]int, cfg.n/2) 343 | j := 0 344 | for i := 0; i < cfg.n; i++ { 345 | if i != l { 346 | if j < len(p1) { 347 | p1[j] = i 348 | } else { 349 | p2[j-len(p1)] = i 350 | } 351 | j++ 352 | } 353 | } 354 | p2[len(p2)-1] = l 355 | return p1, p2 356 | } 357 | 358 | var ncpu_once sync.Once 359 | 360 | func make_config(t *testing.T, n int, unreliable bool, maxraftstate int) *config { 361 | ncpu_once.Do(func() { 362 | if runtime.NumCPU() < 2 { 363 | fmt.Printf("warning: only one CPU, which may conceal locking bugs\n") 364 | } 365 | rand.Seed(makeSeed()) 366 | }) 367 | runtime.GOMAXPROCS(4) 368 | cfg := &config{} 369 | cfg.t = t 370 | cfg.net = labrpc.MakeNetwork() 371 | cfg.n = n 372 | cfg.kvservers = make([]*KVServer, cfg.n) 373 | cfg.saved = make([]*raft.Persister, cfg.n) 374 | cfg.endnames = make([][]string, cfg.n) 375 | cfg.clerks = make(map[*Clerk][]string) 376 | cfg.nextClientId = cfg.n + 1000 // client ids start 1000 above the highest serverid 377 | cfg.maxraftstate = maxraftstate 378 | cfg.start = time.Now() 379 | 380 | // create a full set of KV servers. 381 | for i := 0; i < cfg.n; i++ { 382 | cfg.StartServer(i) 383 | } 384 | 385 | cfg.ConnectAll() 386 | 387 | cfg.net.Reliable(!unreliable) 388 | 389 | return cfg 390 | } 391 | 392 | func (cfg *config) rpcTotal() int { 393 | return cfg.net.GetTotalCount() 394 | } 395 | 396 | // start a Test. 397 | // print the Test message. 398 | // e.g. cfg.begin("Test (2B): RPC counts aren't too high") 399 | func (cfg *config) begin(description string) { 400 | fmt.Printf("%s ...\n", description) 401 | cfg.t0 = time.Now() 402 | cfg.rpcs0 = cfg.rpcTotal() 403 | atomic.StoreInt32(&cfg.ops, 0) 404 | } 405 | 406 | func (cfg *config) op() { 407 | atomic.AddInt32(&cfg.ops, 1) 408 | } 409 | 410 | // end a Test -- the fact that we got here means there 411 | // was no failure. 412 | // print the Passed message, 413 | // and some performance numbers. 414 | func (cfg *config) end() { 415 | cfg.checkTimeout() 416 | if cfg.t.Failed() == false { 417 | t := time.Since(cfg.t0).Seconds() // real time 418 | npeers := cfg.n // number of Raft peers 419 | nrpc := cfg.rpcTotal() - cfg.rpcs0 // number of RPC sends 420 | ops := atomic.LoadInt32(&cfg.ops) // number of clerk get/put/append calls 421 | 422 | fmt.Printf(" ... Passed --") 423 | fmt.Printf(" %4.1f %d %5d %4d\n", t, npeers, nrpc, ops) 424 | } 425 | } 426 | -------------------------------------------------------------------------------- /src/kvraft/dstest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import itertools 4 | import math 5 | import signal 6 | import subprocess 7 | import tempfile 8 | import shutil 9 | import time 10 | import os 11 | import sys 12 | import datetime 13 | from collections import defaultdict 14 | from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED 15 | from dataclasses import dataclass 16 | from pathlib import Path 17 | from typing import List, Optional, Dict, DefaultDict, Tuple 18 | 19 | import typer 20 | import rich 21 | from rich import print 22 | from rich.table import Table 23 | from rich.progress import ( 24 | Progress, 25 | TimeElapsedColumn, 26 | TimeRemainingColumn, 27 | TextColumn, 28 | BarColumn, 29 | SpinnerColumn, 30 | ) 31 | from rich.live import Live 32 | from rich.panel import Panel 33 | from rich.traceback import install 34 | 35 | install(show_locals=True) 36 | 37 | 38 | @dataclass 39 | class StatsMeter: 40 | """ 41 | Auxiliary classs to keep track of online stats including: count, mean, variance 42 | Uses Welford's algorithm to compute sample mean and sample variance incrementally. 43 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm 44 | """ 45 | 46 | n: int = 0 47 | mean: float = 0.0 48 | S: float = 0.0 49 | 50 | def add(self, datum): 51 | self.n += 1 52 | delta = datum - self.mean 53 | # Mk = Mk-1+ (xk – Mk-1)/k 54 | self.mean += delta / self.n 55 | # Sk = Sk-1 + (xk – Mk-1)*(xk – Mk). 56 | self.S += delta * (datum - self.mean) 57 | 58 | @property 59 | def variance(self): 60 | return self.S / self.n 61 | 62 | @property 63 | def std(self): 64 | return math.sqrt(self.variance) 65 | 66 | 67 | def print_results(results: Dict[str, Dict[str, StatsMeter]], timing=False): 68 | table = Table(show_header=True, header_style="bold") 69 | table.add_column("Test") 70 | table.add_column("Failed", justify="right") 71 | table.add_column("Total", justify="right") 72 | if not timing: 73 | table.add_column("Time", justify="right") 74 | else: 75 | table.add_column("Real Time", justify="right") 76 | table.add_column("User Time", justify="right") 77 | table.add_column("System Time", justify="right") 78 | 79 | for test, stats in results.items(): 80 | if stats["completed"].n == 0: 81 | continue 82 | color = "green" if stats["failed"].n == 0 else "red" 83 | row = [ 84 | f"[{color}]{test}[/{color}]", 85 | str(stats["failed"].n), 86 | str(stats["completed"].n), 87 | ] 88 | if not timing: 89 | row.append(f"{stats['time'].mean:.2f} ± {stats['time'].std:.2f}") 90 | else: 91 | row.extend( 92 | [ 93 | f"{stats['real_time'].mean:.2f} ± {stats['real_time'].std:.2f}", 94 | f"{stats['user_time'].mean:.2f} ± {stats['user_time'].std:.2f}", 95 | f"{stats['system_time'].mean:.2f} ± {stats['system_time'].std:.2f}", 96 | ] 97 | ) 98 | table.add_row(*row) 99 | 100 | print(table) 101 | 102 | 103 | def run_test(test: str, race: bool, timing: bool): 104 | test_cmd = ["go", "test", f"-run={test}"] 105 | if race: 106 | test_cmd.append("-race") 107 | if timing: 108 | test_cmd = ["time"] + cmd 109 | f, path = tempfile.mkstemp() 110 | start = time.time() 111 | proc = subprocess.run(test_cmd, stdout=f, stderr=f) 112 | runtime = time.time() - start 113 | os.close(f) 114 | return test, path, proc.returncode, runtime 115 | 116 | 117 | def last_line(file: str) -> str: 118 | with open(file, "rb") as f: 119 | f.seek(-2, os.SEEK_END) 120 | while f.read(1) != b"\n": 121 | f.seek(-2, os.SEEK_CUR) 122 | line = f.readline().decode() 123 | return line 124 | 125 | 126 | # fmt: off 127 | def run_tests( 128 | tests: List[str], 129 | sequential: bool = typer.Option(False, '--sequential', '-s', help='Run all test of each group in order'), 130 | workers: int = typer.Option(1, '--workers', '-p', help='Number of parallel tasks'), 131 | iterations: int = typer.Option(10, '--iter', '-n', help='Number of iterations to run'), 132 | output: Optional[Path] = typer.Option(None, '--output', '-o', help='Output path to use'), 133 | verbose: int = typer.Option(0, '--verbose', '-v', help='Verbosity level', count=True), 134 | archive: bool = typer.Option(False, '--archive', '-a', help='Save all logs intead of only failed ones'), 135 | race: bool = typer.Option(False, '--race/--no-race', '-r/-R', help='Run with race checker'), 136 | loop: bool = typer.Option(False, '--loop', '-l', help='Run continuously'), 137 | growth: int = typer.Option(10, '--growth', '-g', help='Growth ratio of iterations when using --loop'), 138 | timing: bool = typer.Option(False, '--timing', '-t', help='Report timing, only works on macOS'), 139 | # fmt: on 140 | ): 141 | 142 | if output is None: 143 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 144 | output = Path(timestamp) 145 | 146 | if race: 147 | print("[yellow]Running with the race detector\n[/yellow]") 148 | 149 | if verbose > 0: 150 | print(f"[yellow] Verbosity level set to {verbose}[/yellow]") 151 | os.environ['VERBOSE'] = str(verbose) 152 | 153 | while True: 154 | 155 | total = iterations * len(tests) 156 | completed = 0 157 | 158 | results = {test: defaultdict(StatsMeter) for test in tests} 159 | 160 | if sequential: 161 | test_instances = itertools.chain.from_iterable(itertools.repeat(test, iterations) for test in tests) 162 | else: 163 | test_instances = itertools.chain.from_iterable(itertools.repeat(tests, iterations)) 164 | test_instances = iter(test_instances) 165 | 166 | total_progress = Progress( 167 | "[progress.description]{task.description}", 168 | BarColumn(), 169 | TimeRemainingColumn(), 170 | "[progress.percentage]{task.percentage:>3.0f}%", 171 | TimeElapsedColumn(), 172 | ) 173 | total_task = total_progress.add_task("[yellow]Tests[/yellow]", total=total) 174 | 175 | task_progress = Progress( 176 | "[progress.description]{task.description}", 177 | SpinnerColumn(), 178 | BarColumn(), 179 | "{task.completed}/{task.total}", 180 | ) 181 | tasks = {test: task_progress.add_task(test, total=iterations) for test in tests} 182 | 183 | progress_table = Table.grid() 184 | progress_table.add_row(total_progress) 185 | progress_table.add_row(Panel.fit(task_progress)) 186 | 187 | with Live(progress_table, transient=True) as live: 188 | 189 | def handler(_, frame): 190 | live.stop() 191 | print('\n') 192 | print_results(results) 193 | sys.exit(1) 194 | 195 | signal.signal(signal.SIGINT, handler) 196 | 197 | with ThreadPoolExecutor(max_workers=workers) as executor: 198 | 199 | futures = [] 200 | while completed < total: 201 | n = len(futures) 202 | if n < workers: 203 | for test in itertools.islice(test_instances, workers-n): 204 | futures.append(executor.submit(run_test, test, race, timing)) 205 | 206 | done, not_done = wait(futures, return_when=FIRST_COMPLETED) 207 | 208 | for future in done: 209 | test, path, rc, runtime = future.result() 210 | 211 | results[test]['completed'].add(1) 212 | results[test]['time'].add(runtime) 213 | task_progress.update(tasks[test], advance=1) 214 | dest = (output / f"{test}_{completed}.log").as_posix() 215 | if rc != 0: 216 | print(f"Failed test {test} - {dest}") 217 | task_progress.update(tasks[test], description=f"[red]{test}[/red]") 218 | results[test]['failed'].add(1) 219 | else: 220 | if results[test]['completed'].n == iterations and results[test]['failed'].n == 0: 221 | task_progress.update(tasks[test], description=f"[green]{test}[/green]") 222 | 223 | if rc != 0 or archive: 224 | output.mkdir(exist_ok=True, parents=True) 225 | shutil.copy(path, dest) 226 | 227 | if timing: 228 | line = last_line(path) 229 | real, _, user, _, system, _ = line.replace(' '*8, '').split(' ') 230 | results[test]['real_time'].add(float(real)) 231 | results[test]['user_time'].add(float(user)) 232 | results[test]['system_time'].add(float(system)) 233 | 234 | os.remove(path) 235 | 236 | completed += 1 237 | total_progress.update(total_task, advance=1) 238 | 239 | futures = list(not_done) 240 | 241 | print_results(results, timing) 242 | 243 | if loop: 244 | iterations *= growth 245 | print(f"[yellow]Increasing iterations to {iterations}[/yellow]") 246 | else: 247 | break 248 | 249 | 250 | if __name__ == "__main__": 251 | typer.run(run_tests) 252 | -------------------------------------------------------------------------------- /src/kvraft/server.go: -------------------------------------------------------------------------------- 1 | package kvraft 2 | 3 | import ( 4 | "bytes" 5 | "mit6.824/labgob" 6 | "mit6.824/labrpc" 7 | "mit6.824/raft" 8 | "sync" 9 | "sync/atomic" 10 | ) 11 | 12 | type Op struct { 13 | Type string // 任务类型 14 | Key string 15 | Value string 16 | ClientId ClientId // 任务的Client 17 | RequestId RequestId // 对应的Client的这条任务的下标 18 | } 19 | 20 | type KVServer struct { 21 | mu sync.Mutex 22 | me int 23 | rf *raft.Raft // 该状态机对应的Raft 24 | applyCh chan raft.ApplyMsg // Raft apply的Logs 25 | dead int32 // 同Raft的dead 26 | // 3A 27 | kv map[string]string // (持久化)Key/Value数据库 28 | clientLastTaskIndex map[ClientId]RequestId // (持久化)每个客户端已完成的最后一个任务的下标 29 | taskTerm map[int]int // 已完成的任务(用于校验完成的Index对应的任务是不是自己发布的任务) 30 | doneCond map[int]*sync.Cond // Client发送到该Server的任务,任务完成后通知Cond回复Client 31 | // 3B 32 | persister *raft.Persister 33 | maxRaftState int // 当Raft的RaftStateSize接近该值时进行Snapshot 34 | } 35 | 36 | func (kv *KVServer) Get(args *GetArgs, reply *GetReply) { 37 | kv.mu.Lock() 38 | done, isLeader, index, term := kv.startOp("Get", args.Key, "", args.ClientTag, args.TaskIndex) 39 | if done { 40 | // 任务已完成过 41 | reply.Err = OK 42 | reply.Value = kv.kv[args.Key] 43 | } else if !isLeader { 44 | // 不是Leader 45 | reply.Err = ErrWrongLeader 46 | } else { 47 | cond := kv.doneCond[index] 48 | kv.mu.Unlock() 49 | cond.L.Lock() 50 | // 等待任务完成,推送至cond唤醒 51 | cond.Wait() 52 | cond.L.Unlock() 53 | kv.mu.Lock() 54 | // 任务对应的Index已经被Apply,检查完成的任务是否是自己发布的那个 55 | if term == kv.taskTerm[index] { 56 | // 完成的任务和自己发布的任务相同 57 | if value, ok := kv.kv[args.Key]; ok { 58 | reply.Err = OK 59 | reply.Value = value 60 | } else { 61 | reply.Err = ErrNoKey 62 | } 63 | } else { 64 | // 任务被其他Server处理了 65 | reply.Err = ErrWrongLeader 66 | } 67 | delete(kv.taskTerm, index) 68 | delete(kv.doneCond, index) 69 | } 70 | kv.mu.Unlock() 71 | } 72 | 73 | func (kv *KVServer) PutAppend(args *PutAppendArgs, reply *PutAppendReply) { 74 | kv.mu.Lock() 75 | done, isLeader, index, term := kv.startOp(args.Op, args.Key, args.Value, args.ClientTag, args.TaskIndex) 76 | if done { 77 | // 任务已完成过 78 | reply.Err = OK 79 | } else if !isLeader { 80 | // 不是Leader 81 | reply.Err = ErrWrongLeader 82 | } else { 83 | cond := kv.doneCond[index] 84 | kv.mu.Unlock() 85 | cond.L.Lock() 86 | // 等待任务完成,推送至cond唤醒 87 | cond.Wait() 88 | cond.L.Unlock() 89 | kv.mu.Lock() 90 | if term == kv.taskTerm[index] { 91 | // 完成的任务和自己发布的任务相同 92 | reply.Err = OK 93 | } else { 94 | // 任务被其他Server处理了 95 | reply.Err = ErrWrongLeader 96 | } 97 | delete(kv.taskTerm, index) 98 | delete(kv.doneCond, index) 99 | } 100 | kv.mu.Unlock() 101 | } 102 | 103 | // 开始执行一条Log,Lock使用(已完成过该任务,是否为Leader,新任务的index) 104 | func (kv *KVServer) startOp(op string, key string, value string, clientTag ClientId, clientTaskIndex RequestId) (bool, bool, int, int) { 105 | if kv.getClientLastIndex(clientTag) >= clientTaskIndex { 106 | // 这个任务已经完成过,直接返回 107 | return true, false, 0, 0 108 | } 109 | // 要求Raft开始一次提交 110 | index, term, isLeader := kv.rf.Start(Op{ 111 | Type: op, 112 | Key: key, 113 | Value: value, 114 | ClientId: clientTag, 115 | RequestId: clientTaskIndex, 116 | }) 117 | if !isLeader { 118 | // 不是Leader 119 | return false, false, 0, 0 120 | } 121 | if _, ok := kv.doneCond[index]; !ok { 122 | // 没有执行过这条Log,存入一个cond,完成该任务后通过这个cond通知所有goroutine 123 | kv.doneCond[index] = &sync.Cond{L: &sync.Mutex{}} 124 | } 125 | return false, true, index, term 126 | } 127 | 128 | // 持续接受来自Raft的Log 129 | func (kv *KVServer) applier() { 130 | for kv.killed() == false { 131 | // 接受一条被Apply的Log 132 | msg := <-kv.applyCh 133 | kv.mu.Lock() 134 | if msg.CommandValid { 135 | // Command Log,解析Log中的command 136 | command, _ := msg.Command.(Op) 137 | // 检查任务是否已完成过(一个任务/Log可能会发送多次,因为前几次可能因为某种原因没有及时提交) 138 | // 最后一条已完成的任务的Index必须小于当前任务才算没有完成过,因为线性一致性 139 | DPrintf("S[%v] apply %v[%v, %v] index[%v]\n", kv.me, command.Type, command.Key, command.Value, msg.CommandIndex) 140 | if command.Type != "Get" && kv.getClientLastIndex(command.ClientId) < command.RequestId { 141 | // 如果是第一次完成该任务/Log,才保存到KV中 142 | if command.Type == "Put" { 143 | // Put 144 | kv.kv[command.Key] = command.Value 145 | } else { 146 | // Append 147 | if _, ok := kv.kv[command.Key]; ok { 148 | kv.kv[command.Key] += command.Value 149 | } else { 150 | kv.kv[command.Key] = command.Value 151 | } 152 | } 153 | // 该任务的Index比之前存的任务Index大,更新 154 | kv.clientLastTaskIndex[command.ClientId] = command.RequestId 155 | } 156 | if cond, ok := kv.doneCond[msg.CommandIndex]; ok { 157 | // 这个任务被给到过自己,保存它的Term,用来校验 158 | kv.taskTerm[msg.CommandIndex] = msg.CommandTerm 159 | // 通知所有在等待该任务的goroutine 160 | cond.Broadcast() 161 | } 162 | // 检查是否需要Snapshot 163 | if kv.maxRaftState != -1 && float64(kv.persister.RaftStateSize()) > float64(kv.maxRaftState)*serverSnapshotStatePercent { 164 | // Raft状态的大小接近阈值,要求Raft进行Snapshot 165 | kv.saveSnapshot(msg.CommandIndex) 166 | } 167 | } else if msg.SnapshotValid { 168 | // Snapshot Log,只有在Leader发给该Server的InstallSnapshot种才会走到这里,这表明该Server的Logs过于老旧 169 | if kv.rf.CondInstallSnapshot(msg.SnapshotTerm, msg.SnapshotIndex, msg.Snapshot) { 170 | kv.readSnapshot(msg.Snapshot) 171 | } 172 | } 173 | kv.mu.Unlock() 174 | } 175 | } 176 | 177 | // 通过ClientTag获得该Client完成的最后一条任务的下标,0则没有完成 178 | func (kv *KVServer) getClientLastIndex(client ClientId) RequestId { 179 | if last, ok := kv.clientLastTaskIndex[client]; ok { 180 | return last 181 | } 182 | return 0 183 | } 184 | 185 | func StartKVServer(servers []*labrpc.ClientEnd, me int, persister *raft.Persister, maxRaftState int) *KVServer { 186 | labgob.Register(Op{}) 187 | applyCh := make(chan raft.ApplyMsg) 188 | kv := &KVServer{ 189 | me: me, 190 | rf: raft.Make(servers, me, persister, applyCh), 191 | applyCh: applyCh, 192 | maxRaftState: maxRaftState, 193 | kv: make(map[string]string), 194 | clientLastTaskIndex: make(map[ClientId]RequestId), 195 | taskTerm: make(map[int]int), 196 | doneCond: make(map[int]*sync.Cond), 197 | persister: persister, 198 | } 199 | kv.readSnapshot(kv.persister.ReadSnapshot()) 200 | go kv.applier() 201 | return kv 202 | } 203 | 204 | // 保存Snapshot(被动快照) 205 | func (kv *KVServer) saveSnapshot(lastIndex int) { 206 | writer := new(bytes.Buffer) 207 | encoder := labgob.NewEncoder(writer) 208 | if encoder.Encode(kv.kv) == nil && 209 | encoder.Encode(kv.clientLastTaskIndex) == nil { 210 | kv.rf.Snapshot(lastIndex, writer.Bytes()) 211 | DPrintf("S[%v] save snapshot(%v) size[%v]\n", kv.me, lastIndex, len(writer.Bytes())) 212 | } 213 | } 214 | 215 | // 读取Snapshot 216 | func (kv *KVServer) readSnapshot(data []byte) { 217 | if data == nil || len(data) < 1 { 218 | return 219 | } 220 | decoder := labgob.NewDecoder(bytes.NewBuffer(data)) 221 | var kvMap map[string]string 222 | var clientLastTaskIndex map[ClientId]RequestId 223 | if decoder.Decode(&kvMap) == nil && 224 | decoder.Decode(&clientLastTaskIndex) == nil { 225 | kv.kv = kvMap 226 | kv.clientLastTaskIndex = clientLastTaskIndex 227 | DPrintf("S[%v] readSnapshot size[%v]\n", kv.me, len(data)) 228 | } 229 | } 230 | 231 | func (kv *KVServer) Kill() { 232 | atomic.StoreInt32(&kv.dead, 1) 233 | DPrintf("S[%v] killed\n", kv.me) 234 | kv.rf.Kill() 235 | } 236 | 237 | func (kv *KVServer) killed() bool { 238 | z := atomic.LoadInt32(&kv.dead) 239 | return z == 1 240 | } 241 | -------------------------------------------------------------------------------- /src/labgob/labgob.go: -------------------------------------------------------------------------------- 1 | package labgob 2 | 3 | // 4 | // trying to send non-capitalized fields over RPC produces a range of 5 | // misbehavior, including both mysterious incorrect computation and 6 | // outright crashes. so this wrapper around Go's encoding/gob warns 7 | // about non-capitalized field names. 8 | // 9 | 10 | import "encoding/gob" 11 | import "io" 12 | import "reflect" 13 | import "fmt" 14 | import "sync" 15 | import "unicode" 16 | import "unicode/utf8" 17 | 18 | var mu sync.Mutex 19 | var errorCount int // for TestCapital 20 | var checked map[reflect.Type]bool 21 | 22 | type LabEncoder struct { 23 | gob *gob.Encoder 24 | } 25 | 26 | func NewEncoder(w io.Writer) *LabEncoder { 27 | enc := &LabEncoder{} 28 | enc.gob = gob.NewEncoder(w) 29 | return enc 30 | } 31 | 32 | func (enc *LabEncoder) Encode(e interface{}) error { 33 | checkValue(e) 34 | return enc.gob.Encode(e) 35 | } 36 | 37 | func (enc *LabEncoder) EncodeValue(value reflect.Value) error { 38 | checkValue(value.Interface()) 39 | return enc.gob.EncodeValue(value) 40 | } 41 | 42 | type LabDecoder struct { 43 | gob *gob.Decoder 44 | } 45 | 46 | func NewDecoder(r io.Reader) *LabDecoder { 47 | dec := &LabDecoder{} 48 | dec.gob = gob.NewDecoder(r) 49 | return dec 50 | } 51 | 52 | func (dec *LabDecoder) Decode(e interface{}) error { 53 | checkValue(e) 54 | checkDefault(e) 55 | return dec.gob.Decode(e) 56 | } 57 | 58 | func Register(value interface{}) { 59 | checkValue(value) 60 | gob.Register(value) 61 | } 62 | 63 | func RegisterName(name string, value interface{}) { 64 | checkValue(value) 65 | gob.RegisterName(name, value) 66 | } 67 | 68 | func checkValue(value interface{}) { 69 | checkType(reflect.TypeOf(value)) 70 | } 71 | 72 | func checkType(t reflect.Type) { 73 | k := t.Kind() 74 | 75 | mu.Lock() 76 | // only complain once, and avoid recursion. 77 | if checked == nil { 78 | checked = map[reflect.Type]bool{} 79 | } 80 | if checked[t] { 81 | mu.Unlock() 82 | return 83 | } 84 | checked[t] = true 85 | mu.Unlock() 86 | 87 | switch k { 88 | case reflect.Struct: 89 | for i := 0; i < t.NumField(); i++ { 90 | f := t.Field(i) 91 | rune, _ := utf8.DecodeRuneInString(f.Name) 92 | if unicode.IsUpper(rune) == false { 93 | // ta da 94 | fmt.Printf("labgob error: lower-case field %v of %v in RPC or persist/snapshot will break your Raft\n", 95 | f.Name, t.Name()) 96 | mu.Lock() 97 | errorCount += 1 98 | mu.Unlock() 99 | } 100 | checkType(f.Type) 101 | } 102 | return 103 | case reflect.Slice, reflect.Array, reflect.Ptr: 104 | checkType(t.Elem()) 105 | return 106 | case reflect.Map: 107 | checkType(t.Elem()) 108 | checkType(t.Key()) 109 | return 110 | default: 111 | return 112 | } 113 | } 114 | 115 | // 116 | // warn if the value contains non-default values, 117 | // as it would if one sent an RPC but the reply 118 | // struct was already modified. if the RPC reply 119 | // contains default values, GOB won't overwrite 120 | // the non-default value. 121 | // 122 | func checkDefault(value interface{}) { 123 | if value == nil { 124 | return 125 | } 126 | checkDefault1(reflect.ValueOf(value), 1, "") 127 | } 128 | 129 | func checkDefault1(value reflect.Value, depth int, name string) { 130 | if depth > 3 { 131 | return 132 | } 133 | 134 | t := value.Type() 135 | k := t.Kind() 136 | 137 | switch k { 138 | case reflect.Struct: 139 | for i := 0; i < t.NumField(); i++ { 140 | vv := value.Field(i) 141 | name1 := t.Field(i).Name 142 | if name != "" { 143 | name1 = name + "." + name1 144 | } 145 | checkDefault1(vv, depth+1, name1) 146 | } 147 | return 148 | case reflect.Ptr: 149 | if value.IsNil() { 150 | return 151 | } 152 | checkDefault1(value.Elem(), depth+1, name) 153 | return 154 | case reflect.Bool, 155 | reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 156 | reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, 157 | reflect.Uintptr, reflect.Float32, reflect.Float64, 158 | reflect.String: 159 | if reflect.DeepEqual(reflect.Zero(t).Interface(), value.Interface()) == false { 160 | mu.Lock() 161 | if errorCount < 1 { 162 | what := name 163 | if what == "" { 164 | what = t.Name() 165 | } 166 | // this warning typically arises if code re-uses the same RPC reply 167 | // variable for multiple RPC calls, or if code restores persisted 168 | // state into variable that already have non-default values. 169 | fmt.Printf("labgob warning: Decoding into a non-default variable/field %v may not work\n", 170 | what) 171 | } 172 | errorCount += 1 173 | mu.Unlock() 174 | } 175 | return 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /src/labgob/test_test.go: -------------------------------------------------------------------------------- 1 | package labgob 2 | 3 | import "testing" 4 | 5 | import "bytes" 6 | 7 | type T1 struct { 8 | T1int0 int 9 | T1int1 int 10 | T1string0 string 11 | T1string1 string 12 | } 13 | 14 | type T2 struct { 15 | T2slice []T1 16 | T2map map[int]*T1 17 | T2t3 interface{} 18 | } 19 | 20 | type T3 struct { 21 | T3int999 int 22 | } 23 | 24 | // 25 | // test that we didn't break GOB. 26 | // 27 | func TestGOB(t *testing.T) { 28 | e0 := errorCount 29 | 30 | w := new(bytes.Buffer) 31 | 32 | Register(T3{}) 33 | 34 | { 35 | x0 := 0 36 | x1 := 1 37 | t1 := T1{} 38 | t1.T1int1 = 1 39 | t1.T1string1 = "6.824" 40 | t2 := T2{} 41 | t2.T2slice = []T1{T1{}, t1} 42 | t2.T2map = map[int]*T1{} 43 | t2.T2map[99] = &T1{1, 2, "x", "y"} 44 | t2.T2t3 = T3{999} 45 | 46 | e := NewEncoder(w) 47 | e.Encode(x0) 48 | e.Encode(x1) 49 | e.Encode(t1) 50 | e.Encode(t2) 51 | } 52 | data := w.Bytes() 53 | 54 | { 55 | var x0 int 56 | var x1 int 57 | var t1 T1 58 | var t2 T2 59 | 60 | r := bytes.NewBuffer(data) 61 | d := NewDecoder(r) 62 | if d.Decode(&x0) != nil || 63 | d.Decode(&x1) != nil || 64 | d.Decode(&t1) != nil || 65 | d.Decode(&t2) != nil { 66 | t.Fatalf("Decode failed") 67 | } 68 | 69 | if x0 != 0 { 70 | t.Fatalf("wrong x0 %v\n", x0) 71 | } 72 | if x1 != 1 { 73 | t.Fatalf("wrong x1 %v\n", x1) 74 | } 75 | if t1.T1int0 != 0 { 76 | t.Fatalf("wrong t1.T1int0 %v\n", t1.T1int0) 77 | } 78 | if t1.T1int1 != 1 { 79 | t.Fatalf("wrong t1.T1int1 %v\n", t1.T1int1) 80 | } 81 | if t1.T1string0 != "" { 82 | t.Fatalf("wrong t1.T1string0 %v\n", t1.T1string0) 83 | } 84 | if t1.T1string1 != "6.824" { 85 | t.Fatalf("wrong t1.T1string1 %v\n", t1.T1string1) 86 | } 87 | if len(t2.T2slice) != 2 { 88 | t.Fatalf("wrong t2.T2slice len %v\n", len(t2.T2slice)) 89 | } 90 | if t2.T2slice[1].T1int1 != 1 { 91 | t.Fatalf("wrong slice value\n") 92 | } 93 | if len(t2.T2map) != 1 { 94 | t.Fatalf("wrong t2.T2map len %v\n", len(t2.T2map)) 95 | } 96 | if t2.T2map[99].T1string1 != "y" { 97 | t.Fatalf("wrong map value\n") 98 | } 99 | t3 := (t2.T2t3).(T3) 100 | if t3.T3int999 != 999 { 101 | t.Fatalf("wrong t2.T2t3.T3int999\n") 102 | } 103 | } 104 | 105 | if errorCount != e0 { 106 | t.Fatalf("there were errors, but should not have been") 107 | } 108 | } 109 | 110 | type T4 struct { 111 | Yes int 112 | no int 113 | } 114 | 115 | // 116 | // make sure we check capitalization 117 | // labgob prints one warning during this test. 118 | // 119 | func TestCapital(t *testing.T) { 120 | e0 := errorCount 121 | 122 | v := []map[*T4]int{} 123 | 124 | w := new(bytes.Buffer) 125 | e := NewEncoder(w) 126 | e.Encode(v) 127 | data := w.Bytes() 128 | 129 | var v1 []map[T4]int 130 | r := bytes.NewBuffer(data) 131 | d := NewDecoder(r) 132 | d.Decode(&v1) 133 | 134 | if errorCount != e0+1 { 135 | t.Fatalf("failed to warn about lower-case field") 136 | } 137 | } 138 | 139 | // 140 | // check that we warn when someone sends a default value over 141 | // RPC but the target into which we're decoding holds a non-default 142 | // value, which GOB seems not to overwrite as you'd expect. 143 | // 144 | // labgob does not print a warning. 145 | // 146 | func TestDefault(t *testing.T) { 147 | e0 := errorCount 148 | 149 | type DD struct { 150 | X int 151 | } 152 | 153 | // send a default value... 154 | dd1 := DD{} 155 | 156 | w := new(bytes.Buffer) 157 | e := NewEncoder(w) 158 | e.Encode(dd1) 159 | data := w.Bytes() 160 | 161 | // and receive it into memory that already 162 | // holds non-default values. 163 | reply := DD{99} 164 | 165 | r := bytes.NewBuffer(data) 166 | d := NewDecoder(r) 167 | d.Decode(&reply) 168 | 169 | if errorCount != e0+1 { 170 | t.Fatalf("failed to warn about decoding into non-default value") 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/labrpc/labrpc.go: -------------------------------------------------------------------------------- 1 | package labrpc 2 | 3 | // 4 | // channel-based RPC, for 824 labs. 5 | // 6 | // simulates a network that can lose requests, lose replies, 7 | // delay messages, and entirely disconnect particular hosts. 8 | // 9 | // we will use the original labrpc.go to test your code for grading. 10 | // so, while you can modify this code to help you debug, please 11 | // test against the original before submitting. 12 | // 13 | // adapted from Go net/rpc/server.go. 14 | // 15 | // sends labgob-encoded values to ensure that RPCs 16 | // don't include references to program objects. 17 | // 18 | // net := MakeNetwork() -- holds network, clients, servers. 19 | // end := net.MakeEnd(endname) -- create a client end-point, to talk to one server. 20 | // net.AddServer(servername, server) -- adds a named server to network. 21 | // net.DeleteServer(servername) -- eliminate the named server. 22 | // net.Connect(endname, servername) -- connect a client to a server. 23 | // net.Enable(endname, enabled) -- enable/disable a client. 24 | // net.Reliable(bool) -- false means drop/delay messages 25 | // 26 | // end.Call("Raft.AppendEntries", &args, &reply) -- send an RPC, wait for reply. 27 | // the "Raft" is the name of the server struct to be called. 28 | // the "AppendEntries" is the name of the method to be called. 29 | // Call() returns true to indicate that the server executed the request 30 | // and the reply is valid. 31 | // Call() returns false if the network lost the request or reply 32 | // or the server is down. 33 | // It is OK to have multiple Call()s in progress at the same time on the 34 | // same ClientEnd. 35 | // Concurrent calls to Call() may be delivered to the server out of order, 36 | // since the network may re-order messages. 37 | // Call() is guaranteed to return (perhaps after a delay) *except* if the 38 | // handler function on the server side does not return. 39 | // the server RPC handler function must declare its args and reply arguments 40 | // as pointers, so that their types exactly match the types of the arguments 41 | // to Call(). 42 | // 43 | // srv := MakeServer() 44 | // srv.AddService(svc) -- a server can have multiple services, e.g. Raft and k/v 45 | // pass srv to net.AddServer() 46 | // 47 | // svc := MakeService(receiverObject) -- obj's methods will handle RPCs 48 | // much like Go's rpcs.Register() 49 | // pass svc to srv.AddService() 50 | // 51 | 52 | import "mit6.824/labgob" 53 | import "bytes" 54 | import "reflect" 55 | import "sync" 56 | import "log" 57 | import "strings" 58 | import "math/rand" 59 | import "time" 60 | import "sync/atomic" 61 | 62 | type reqMsg struct { 63 | endname interface{} // name of sending ClientEnd 64 | svcMeth string // e.g. "Raft.AppendEntries" 65 | argsType reflect.Type 66 | args []byte 67 | replyCh chan replyMsg 68 | } 69 | 70 | type replyMsg struct { 71 | ok bool 72 | reply []byte 73 | } 74 | 75 | type ClientEnd struct { 76 | endname interface{} // this end-point's name 77 | ch chan reqMsg // copy of Network.endCh 78 | done chan struct{} // closed when Network is cleaned up 79 | } 80 | 81 | // send an RPC, wait for the reply. 82 | // the return value indicates success; false means that 83 | // no reply was received from the server. 84 | func (e *ClientEnd) Call(svcMeth string, args interface{}, reply interface{}) bool { 85 | req := reqMsg{} 86 | req.endname = e.endname 87 | req.svcMeth = svcMeth 88 | req.argsType = reflect.TypeOf(args) 89 | req.replyCh = make(chan replyMsg) 90 | 91 | qb := new(bytes.Buffer) 92 | qe := labgob.NewEncoder(qb) 93 | if err := qe.Encode(args); err != nil { 94 | panic(err) 95 | } 96 | req.args = qb.Bytes() 97 | 98 | // 99 | // send the request. 100 | // 101 | select { 102 | case e.ch <- req: 103 | // the request has been sent. 104 | case <-e.done: 105 | // entire Network has been destroyed. 106 | return false 107 | } 108 | 109 | // 110 | // wait for the reply. 111 | // 112 | rep := <-req.replyCh 113 | if rep.ok { 114 | rb := bytes.NewBuffer(rep.reply) 115 | rd := labgob.NewDecoder(rb) 116 | if err := rd.Decode(reply); err != nil { 117 | log.Fatalf("ClientEnd.Call(): decode reply: %v\n", err) 118 | } 119 | return true 120 | } else { 121 | return false 122 | } 123 | } 124 | 125 | type Network struct { 126 | mu sync.Mutex 127 | reliable bool 128 | longDelays bool // pause a long time on send on disabled connection 129 | longReordering bool // sometimes delay replies a long time 130 | ends map[interface{}]*ClientEnd // ends, by name 131 | enabled map[interface{}]bool // by end name 132 | servers map[interface{}]*Server // servers, by name 133 | connections map[interface{}]interface{} // endname -> servername 134 | endCh chan reqMsg 135 | done chan struct{} // closed when Network is cleaned up 136 | count int32 // total RPC count, for statistics 137 | bytes int64 // total bytes send, for statistics 138 | } 139 | 140 | func MakeNetwork() *Network { 141 | rn := &Network{} 142 | rn.reliable = true 143 | rn.ends = map[interface{}]*ClientEnd{} 144 | rn.enabled = map[interface{}]bool{} 145 | rn.servers = map[interface{}]*Server{} 146 | rn.connections = map[interface{}](interface{}){} 147 | rn.endCh = make(chan reqMsg) 148 | rn.done = make(chan struct{}) 149 | 150 | // single goroutine to handle all ClientEnd.Call()s 151 | go func() { 152 | for { 153 | select { 154 | case xreq := <-rn.endCh: 155 | atomic.AddInt32(&rn.count, 1) 156 | atomic.AddInt64(&rn.bytes, int64(len(xreq.args))) 157 | go rn.processReq(xreq) 158 | case <-rn.done: 159 | return 160 | } 161 | } 162 | }() 163 | 164 | return rn 165 | } 166 | 167 | func (rn *Network) Cleanup() { 168 | close(rn.done) 169 | } 170 | 171 | func (rn *Network) Reliable(yes bool) { 172 | rn.mu.Lock() 173 | defer rn.mu.Unlock() 174 | 175 | rn.reliable = yes 176 | } 177 | 178 | func (rn *Network) LongReordering(yes bool) { 179 | rn.mu.Lock() 180 | defer rn.mu.Unlock() 181 | 182 | rn.longReordering = yes 183 | } 184 | 185 | func (rn *Network) LongDelays(yes bool) { 186 | rn.mu.Lock() 187 | defer rn.mu.Unlock() 188 | 189 | rn.longDelays = yes 190 | } 191 | 192 | func (rn *Network) readEndnameInfo(endname interface{}) (enabled bool, 193 | servername interface{}, server *Server, reliable bool, longreordering bool, 194 | ) { 195 | rn.mu.Lock() 196 | defer rn.mu.Unlock() 197 | 198 | enabled = rn.enabled[endname] 199 | servername = rn.connections[endname] 200 | if servername != nil { 201 | server = rn.servers[servername] 202 | } 203 | reliable = rn.reliable 204 | longreordering = rn.longReordering 205 | return 206 | } 207 | 208 | func (rn *Network) isServerDead(endname interface{}, servername interface{}, server *Server) bool { 209 | rn.mu.Lock() 210 | defer rn.mu.Unlock() 211 | 212 | if rn.enabled[endname] == false || rn.servers[servername] != server { 213 | return true 214 | } 215 | return false 216 | } 217 | 218 | func (rn *Network) processReq(req reqMsg) { 219 | enabled, servername, server, reliable, longreordering := rn.readEndnameInfo(req.endname) 220 | 221 | if enabled && servername != nil && server != nil { 222 | if reliable == false { 223 | // short delay 224 | ms := (rand.Int() % 27) 225 | time.Sleep(time.Duration(ms) * time.Millisecond) 226 | } 227 | 228 | if reliable == false && (rand.Int()%1000) < 100 { 229 | // drop the request, return as if timeout 230 | req.replyCh <- replyMsg{false, nil} 231 | return 232 | } 233 | 234 | // execute the request (call the RPC handler). 235 | // in a separate thread so that we can periodically check 236 | // if the server has been killed and the RPC should get a 237 | // failure reply. 238 | ech := make(chan replyMsg) 239 | go func() { 240 | r := server.dispatch(req) 241 | ech <- r 242 | }() 243 | 244 | // wait for handler to return, 245 | // but stop waiting if DeleteServer() has been called, 246 | // and return an error. 247 | var reply replyMsg 248 | replyOK := false 249 | serverDead := false 250 | for replyOK == false && serverDead == false { 251 | select { 252 | case reply = <-ech: 253 | replyOK = true 254 | case <-time.After(100 * time.Millisecond): 255 | serverDead = rn.isServerDead(req.endname, servername, server) 256 | if serverDead { 257 | go func() { 258 | <-ech // drain channel to let the goroutine created earlier terminate 259 | }() 260 | } 261 | } 262 | } 263 | 264 | // do not reply if DeleteServer() has been called, i.e. 265 | // the server has been killed. this is needed to avoid 266 | // situation in which a client gets a positive reply 267 | // to an Append, but the server persisted the update 268 | // into the old Persister. config.go is careful to call 269 | // DeleteServer() before superseding the Persister. 270 | serverDead = rn.isServerDead(req.endname, servername, server) 271 | 272 | if replyOK == false || serverDead == true { 273 | // server was killed while we were waiting; return error. 274 | req.replyCh <- replyMsg{false, nil} 275 | } else if reliable == false && (rand.Int()%1000) < 100 { 276 | // drop the reply, return as if timeout 277 | req.replyCh <- replyMsg{false, nil} 278 | } else if longreordering == true && rand.Intn(900) < 600 { 279 | // delay the response for a while 280 | ms := 200 + rand.Intn(1+rand.Intn(2000)) 281 | // Russ points out that this timer arrangement will decrease 282 | // the number of goroutines, so that the race 283 | // detector is less likely to get upset. 284 | time.AfterFunc(time.Duration(ms)*time.Millisecond, func() { 285 | atomic.AddInt64(&rn.bytes, int64(len(reply.reply))) 286 | req.replyCh <- reply 287 | }) 288 | } else { 289 | atomic.AddInt64(&rn.bytes, int64(len(reply.reply))) 290 | req.replyCh <- reply 291 | } 292 | } else { 293 | // simulate no reply and eventual timeout. 294 | ms := 0 295 | if rn.longDelays { 296 | // let Raft tests check that leader doesn't send 297 | // RPCs synchronously. 298 | ms = (rand.Int() % 7000) 299 | } else { 300 | // many kv tests require the client to try each 301 | // server in fairly rapid succession. 302 | ms = (rand.Int() % 100) 303 | } 304 | time.AfterFunc(time.Duration(ms)*time.Millisecond, func() { 305 | req.replyCh <- replyMsg{false, nil} 306 | }) 307 | } 308 | 309 | } 310 | 311 | // create a client end-point. 312 | // start the thread that listens and delivers. 313 | func (rn *Network) MakeEnd(endname interface{}) *ClientEnd { 314 | rn.mu.Lock() 315 | defer rn.mu.Unlock() 316 | 317 | if _, ok := rn.ends[endname]; ok { 318 | log.Fatalf("MakeEnd: %v already exists\n", endname) 319 | } 320 | 321 | e := &ClientEnd{} 322 | e.endname = endname 323 | e.ch = rn.endCh 324 | e.done = rn.done 325 | rn.ends[endname] = e 326 | rn.enabled[endname] = false 327 | rn.connections[endname] = nil 328 | 329 | return e 330 | } 331 | 332 | func (rn *Network) AddServer(servername interface{}, rs *Server) { 333 | rn.mu.Lock() 334 | defer rn.mu.Unlock() 335 | 336 | rn.servers[servername] = rs 337 | } 338 | 339 | func (rn *Network) DeleteServer(servername interface{}) { 340 | rn.mu.Lock() 341 | defer rn.mu.Unlock() 342 | 343 | rn.servers[servername] = nil 344 | } 345 | 346 | // connect a ClientEnd to a server. 347 | // a ClientEnd can only be connected once in its lifetime. 348 | func (rn *Network) Connect(endname interface{}, servername interface{}) { 349 | rn.mu.Lock() 350 | defer rn.mu.Unlock() 351 | 352 | rn.connections[endname] = servername 353 | } 354 | 355 | // enable/disable a ClientEnd. 356 | func (rn *Network) Enable(endname interface{}, enabled bool) { 357 | rn.mu.Lock() 358 | defer rn.mu.Unlock() 359 | 360 | rn.enabled[endname] = enabled 361 | } 362 | 363 | // get a server's count of incoming RPCs. 364 | func (rn *Network) GetCount(servername interface{}) int { 365 | rn.mu.Lock() 366 | defer rn.mu.Unlock() 367 | 368 | svr := rn.servers[servername] 369 | return svr.GetCount() 370 | } 371 | 372 | func (rn *Network) GetTotalCount() int { 373 | x := atomic.LoadInt32(&rn.count) 374 | return int(x) 375 | } 376 | 377 | func (rn *Network) GetTotalBytes() int64 { 378 | x := atomic.LoadInt64(&rn.bytes) 379 | return x 380 | } 381 | 382 | // 383 | // a server is a collection of services, all sharing 384 | // the same rpc dispatcher. so that e.g. both a Raft 385 | // and a k/v server can listen to the same rpc endpoint. 386 | // 387 | type Server struct { 388 | mu sync.Mutex 389 | services map[string]*Service 390 | count int // incoming RPCs 391 | } 392 | 393 | func MakeServer() *Server { 394 | rs := &Server{} 395 | rs.services = map[string]*Service{} 396 | return rs 397 | } 398 | 399 | func (rs *Server) AddService(svc *Service) { 400 | rs.mu.Lock() 401 | defer rs.mu.Unlock() 402 | rs.services[svc.name] = svc 403 | } 404 | 405 | func (rs *Server) dispatch(req reqMsg) replyMsg { 406 | rs.mu.Lock() 407 | 408 | rs.count += 1 409 | 410 | // split Raft.AppendEntries into service and method 411 | dot := strings.LastIndex(req.svcMeth, ".") 412 | serviceName := req.svcMeth[:dot] 413 | methodName := req.svcMeth[dot+1:] 414 | 415 | service, ok := rs.services[serviceName] 416 | 417 | rs.mu.Unlock() 418 | 419 | if ok { 420 | return service.dispatch(methodName, req) 421 | } else { 422 | choices := []string{} 423 | for k, _ := range rs.services { 424 | choices = append(choices, k) 425 | } 426 | log.Fatalf("labrpc.Server.dispatch(): unknown service %v in %v.%v; expecting one of %v\n", 427 | serviceName, serviceName, methodName, choices) 428 | return replyMsg{false, nil} 429 | } 430 | } 431 | 432 | func (rs *Server) GetCount() int { 433 | rs.mu.Lock() 434 | defer rs.mu.Unlock() 435 | return rs.count 436 | } 437 | 438 | // an object with methods that can be called via RPC. 439 | // a single server may have more than one Service. 440 | type Service struct { 441 | name string 442 | rcvr reflect.Value 443 | typ reflect.Type 444 | methods map[string]reflect.Method 445 | } 446 | 447 | func MakeService(rcvr interface{}) *Service { 448 | svc := &Service{} 449 | svc.typ = reflect.TypeOf(rcvr) 450 | svc.rcvr = reflect.ValueOf(rcvr) 451 | svc.name = reflect.Indirect(svc.rcvr).Type().Name() 452 | svc.methods = map[string]reflect.Method{} 453 | 454 | for m := 0; m < svc.typ.NumMethod(); m++ { 455 | method := svc.typ.Method(m) 456 | mtype := method.Type 457 | mname := method.Name 458 | 459 | //fmt.Printf("%v pp %v ni %v 1k %v 2k %v no %v\n", 460 | // mname, method.PkgPath, mtype.NumIn(), mtype.In(1).Kind(), mtype.In(2).Kind(), mtype.NumOut()) 461 | 462 | if method.PkgPath != "" || // capitalized? 463 | mtype.NumIn() != 3 || 464 | //mtype.In(1).Kind() != reflect.Ptr || 465 | mtype.In(2).Kind() != reflect.Ptr || 466 | mtype.NumOut() != 0 { 467 | // the method is not suitable for a handler 468 | //fmt.Printf("bad method: %v\n", mname) 469 | } else { 470 | // the method looks like a handler 471 | svc.methods[mname] = method 472 | } 473 | } 474 | 475 | return svc 476 | } 477 | 478 | func (svc *Service) dispatch(methname string, req reqMsg) replyMsg { 479 | if method, ok := svc.methods[methname]; ok { 480 | // prepare space into which to read the argument. 481 | // the Value's type will be a pointer to req.argsType. 482 | args := reflect.New(req.argsType) 483 | 484 | // decode the argument. 485 | ab := bytes.NewBuffer(req.args) 486 | ad := labgob.NewDecoder(ab) 487 | ad.Decode(args.Interface()) 488 | 489 | // allocate space for the reply. 490 | replyType := method.Type.In(2) 491 | replyType = replyType.Elem() 492 | replyv := reflect.New(replyType) 493 | 494 | // call the method. 495 | function := method.Func 496 | function.Call([]reflect.Value{svc.rcvr, args.Elem(), replyv}) 497 | 498 | // encode the reply. 499 | rb := new(bytes.Buffer) 500 | re := labgob.NewEncoder(rb) 501 | re.EncodeValue(replyv) 502 | 503 | return replyMsg{true, rb.Bytes()} 504 | } else { 505 | choices := []string{} 506 | for k, _ := range svc.methods { 507 | choices = append(choices, k) 508 | } 509 | log.Fatalf("labrpc.Service.dispatch(): unknown method %v in %v; expecting one of %v\n", 510 | methname, req.svcMeth, choices) 511 | return replyMsg{false, nil} 512 | } 513 | } 514 | -------------------------------------------------------------------------------- /src/labrpc/test_test.go: -------------------------------------------------------------------------------- 1 | package labrpc 2 | 3 | import "testing" 4 | import "strconv" 5 | import "sync" 6 | import "runtime" 7 | import "time" 8 | import "fmt" 9 | 10 | type JunkArgs struct { 11 | X int 12 | } 13 | type JunkReply struct { 14 | X string 15 | } 16 | 17 | type JunkServer struct { 18 | mu sync.Mutex 19 | log1 []string 20 | log2 []int 21 | } 22 | 23 | func (js *JunkServer) Handler1(args string, reply *int) { 24 | js.mu.Lock() 25 | defer js.mu.Unlock() 26 | js.log1 = append(js.log1, args) 27 | *reply, _ = strconv.Atoi(args) 28 | } 29 | 30 | func (js *JunkServer) Handler2(args int, reply *string) { 31 | js.mu.Lock() 32 | defer js.mu.Unlock() 33 | js.log2 = append(js.log2, args) 34 | *reply = "handler2-" + strconv.Itoa(args) 35 | } 36 | 37 | func (js *JunkServer) Handler3(args int, reply *int) { 38 | js.mu.Lock() 39 | defer js.mu.Unlock() 40 | time.Sleep(20 * time.Second) 41 | *reply = -args 42 | } 43 | 44 | // args is a pointer 45 | func (js *JunkServer) Handler4(args *JunkArgs, reply *JunkReply) { 46 | reply.X = "pointer" 47 | } 48 | 49 | // args is a not pointer 50 | func (js *JunkServer) Handler5(args JunkArgs, reply *JunkReply) { 51 | reply.X = "no pointer" 52 | } 53 | 54 | func (js *JunkServer) Handler6(args string, reply *int) { 55 | js.mu.Lock() 56 | defer js.mu.Unlock() 57 | *reply = len(args) 58 | } 59 | 60 | func (js *JunkServer) Handler7(args int, reply *string) { 61 | js.mu.Lock() 62 | defer js.mu.Unlock() 63 | *reply = "" 64 | for i := 0; i < args; i++ { 65 | *reply = *reply + "y" 66 | } 67 | } 68 | 69 | func TestBasic(t *testing.T) { 70 | runtime.GOMAXPROCS(4) 71 | 72 | rn := MakeNetwork() 73 | defer rn.Cleanup() 74 | 75 | e := rn.MakeEnd("end1-99") 76 | 77 | js := &JunkServer{} 78 | svc := MakeService(js) 79 | 80 | rs := MakeServer() 81 | rs.AddService(svc) 82 | rn.AddServer("server99", rs) 83 | 84 | rn.Connect("end1-99", "server99") 85 | rn.Enable("end1-99", true) 86 | 87 | { 88 | reply := "" 89 | e.Call("JunkServer.Handler2", 111, &reply) 90 | if reply != "handler2-111" { 91 | t.Fatalf("wrong reply from Handler2") 92 | } 93 | } 94 | 95 | { 96 | reply := 0 97 | e.Call("JunkServer.Handler1", "9099", &reply) 98 | if reply != 9099 { 99 | t.Fatalf("wrong reply from Handler1") 100 | } 101 | } 102 | } 103 | 104 | func TestTypes(t *testing.T) { 105 | runtime.GOMAXPROCS(4) 106 | 107 | rn := MakeNetwork() 108 | defer rn.Cleanup() 109 | 110 | e := rn.MakeEnd("end1-99") 111 | 112 | js := &JunkServer{} 113 | svc := MakeService(js) 114 | 115 | rs := MakeServer() 116 | rs.AddService(svc) 117 | rn.AddServer("server99", rs) 118 | 119 | rn.Connect("end1-99", "server99") 120 | rn.Enable("end1-99", true) 121 | 122 | { 123 | var args JunkArgs 124 | var reply JunkReply 125 | // args must match type (pointer or not) of handler. 126 | e.Call("JunkServer.Handler4", &args, &reply) 127 | if reply.X != "pointer" { 128 | t.Fatalf("wrong reply from Handler4") 129 | } 130 | } 131 | 132 | { 133 | var args JunkArgs 134 | var reply JunkReply 135 | // args must match type (pointer or not) of handler. 136 | e.Call("JunkServer.Handler5", args, &reply) 137 | if reply.X != "no pointer" { 138 | t.Fatalf("wrong reply from Handler5") 139 | } 140 | } 141 | } 142 | 143 | // 144 | // does net.Enable(endname, false) really disconnect a client? 145 | // 146 | func TestDisconnect(t *testing.T) { 147 | runtime.GOMAXPROCS(4) 148 | 149 | rn := MakeNetwork() 150 | defer rn.Cleanup() 151 | 152 | e := rn.MakeEnd("end1-99") 153 | 154 | js := &JunkServer{} 155 | svc := MakeService(js) 156 | 157 | rs := MakeServer() 158 | rs.AddService(svc) 159 | rn.AddServer("server99", rs) 160 | 161 | rn.Connect("end1-99", "server99") 162 | 163 | { 164 | reply := "" 165 | e.Call("JunkServer.Handler2", 111, &reply) 166 | if reply != "" { 167 | t.Fatalf("unexpected reply from Handler2") 168 | } 169 | } 170 | 171 | rn.Enable("end1-99", true) 172 | 173 | { 174 | reply := 0 175 | e.Call("JunkServer.Handler1", "9099", &reply) 176 | if reply != 9099 { 177 | t.Fatalf("wrong reply from Handler1") 178 | } 179 | } 180 | } 181 | 182 | // 183 | // test net.GetCount() 184 | // 185 | func TestCounts(t *testing.T) { 186 | runtime.GOMAXPROCS(4) 187 | 188 | rn := MakeNetwork() 189 | defer rn.Cleanup() 190 | 191 | e := rn.MakeEnd("end1-99") 192 | 193 | js := &JunkServer{} 194 | svc := MakeService(js) 195 | 196 | rs := MakeServer() 197 | rs.AddService(svc) 198 | rn.AddServer(99, rs) 199 | 200 | rn.Connect("end1-99", 99) 201 | rn.Enable("end1-99", true) 202 | 203 | for i := 0; i < 17; i++ { 204 | reply := "" 205 | e.Call("JunkServer.Handler2", i, &reply) 206 | wanted := "handler2-" + strconv.Itoa(i) 207 | if reply != wanted { 208 | t.Fatalf("wrong reply %v from Handler1, expecting %v", reply, wanted) 209 | } 210 | } 211 | 212 | n := rn.GetCount(99) 213 | if n != 17 { 214 | t.Fatalf("wrong GetCount() %v, expected 17\n", n) 215 | } 216 | } 217 | 218 | // 219 | // test net.GetTotalBytes() 220 | // 221 | func TestBytes(t *testing.T) { 222 | runtime.GOMAXPROCS(4) 223 | 224 | rn := MakeNetwork() 225 | defer rn.Cleanup() 226 | 227 | e := rn.MakeEnd("end1-99") 228 | 229 | js := &JunkServer{} 230 | svc := MakeService(js) 231 | 232 | rs := MakeServer() 233 | rs.AddService(svc) 234 | rn.AddServer(99, rs) 235 | 236 | rn.Connect("end1-99", 99) 237 | rn.Enable("end1-99", true) 238 | 239 | for i := 0; i < 17; i++ { 240 | args := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" 241 | args = args + args 242 | args = args + args 243 | reply := 0 244 | e.Call("JunkServer.Handler6", args, &reply) 245 | wanted := len(args) 246 | if reply != wanted { 247 | t.Fatalf("wrong reply %v from Handler6, expecting %v", reply, wanted) 248 | } 249 | } 250 | 251 | n := rn.GetTotalBytes() 252 | if n < 4828 || n > 6000 { 253 | t.Fatalf("wrong GetTotalBytes() %v, expected about 5000\n", n) 254 | } 255 | 256 | for i := 0; i < 17; i++ { 257 | args := 107 258 | reply := "" 259 | e.Call("JunkServer.Handler7", args, &reply) 260 | wanted := args 261 | if len(reply) != wanted { 262 | t.Fatalf("wrong reply len=%v from Handler6, expecting %v", len(reply), wanted) 263 | } 264 | } 265 | 266 | nn := rn.GetTotalBytes() - n 267 | if nn < 1800 || nn > 2500 { 268 | t.Fatalf("wrong GetTotalBytes() %v, expected about 2000\n", nn) 269 | } 270 | } 271 | 272 | // 273 | // test RPCs from concurrent ClientEnds 274 | // 275 | func TestConcurrentMany(t *testing.T) { 276 | runtime.GOMAXPROCS(4) 277 | 278 | rn := MakeNetwork() 279 | defer rn.Cleanup() 280 | 281 | js := &JunkServer{} 282 | svc := MakeService(js) 283 | 284 | rs := MakeServer() 285 | rs.AddService(svc) 286 | rn.AddServer(1000, rs) 287 | 288 | ch := make(chan int) 289 | 290 | nclients := 20 291 | nrpcs := 10 292 | for ii := 0; ii < nclients; ii++ { 293 | go func(i int) { 294 | n := 0 295 | defer func() { ch <- n }() 296 | 297 | e := rn.MakeEnd(i) 298 | rn.Connect(i, 1000) 299 | rn.Enable(i, true) 300 | 301 | for j := 0; j < nrpcs; j++ { 302 | arg := i*100 + j 303 | reply := "" 304 | e.Call("JunkServer.Handler2", arg, &reply) 305 | wanted := "handler2-" + strconv.Itoa(arg) 306 | if reply != wanted { 307 | t.Fatalf("wrong reply %v from Handler1, expecting %v", reply, wanted) 308 | } 309 | n += 1 310 | } 311 | }(ii) 312 | } 313 | 314 | total := 0 315 | for ii := 0; ii < nclients; ii++ { 316 | x := <-ch 317 | total += x 318 | } 319 | 320 | if total != nclients*nrpcs { 321 | t.Fatalf("wrong number of RPCs completed, got %v, expected %v", total, nclients*nrpcs) 322 | } 323 | 324 | n := rn.GetCount(1000) 325 | if n != total { 326 | t.Fatalf("wrong GetCount() %v, expected %v\n", n, total) 327 | } 328 | } 329 | 330 | // 331 | // test unreliable 332 | // 333 | func TestUnreliable(t *testing.T) { 334 | runtime.GOMAXPROCS(4) 335 | 336 | rn := MakeNetwork() 337 | defer rn.Cleanup() 338 | rn.Reliable(false) 339 | 340 | js := &JunkServer{} 341 | svc := MakeService(js) 342 | 343 | rs := MakeServer() 344 | rs.AddService(svc) 345 | rn.AddServer(1000, rs) 346 | 347 | ch := make(chan int) 348 | 349 | nclients := 300 350 | for ii := 0; ii < nclients; ii++ { 351 | go func(i int) { 352 | n := 0 353 | defer func() { ch <- n }() 354 | 355 | e := rn.MakeEnd(i) 356 | rn.Connect(i, 1000) 357 | rn.Enable(i, true) 358 | 359 | arg := i * 100 360 | reply := "" 361 | ok := e.Call("JunkServer.Handler2", arg, &reply) 362 | if ok { 363 | wanted := "handler2-" + strconv.Itoa(arg) 364 | if reply != wanted { 365 | t.Fatalf("wrong reply %v from Handler1, expecting %v", reply, wanted) 366 | } 367 | n += 1 368 | } 369 | }(ii) 370 | } 371 | 372 | total := 0 373 | for ii := 0; ii < nclients; ii++ { 374 | x := <-ch 375 | total += x 376 | } 377 | 378 | if total == nclients || total == 0 { 379 | t.Fatalf("all RPCs succeeded despite unreliable") 380 | } 381 | } 382 | 383 | // 384 | // test concurrent RPCs from a single ClientEnd 385 | // 386 | func TestConcurrentOne(t *testing.T) { 387 | runtime.GOMAXPROCS(4) 388 | 389 | rn := MakeNetwork() 390 | defer rn.Cleanup() 391 | 392 | js := &JunkServer{} 393 | svc := MakeService(js) 394 | 395 | rs := MakeServer() 396 | rs.AddService(svc) 397 | rn.AddServer(1000, rs) 398 | 399 | e := rn.MakeEnd("c") 400 | rn.Connect("c", 1000) 401 | rn.Enable("c", true) 402 | 403 | ch := make(chan int) 404 | 405 | nrpcs := 20 406 | for ii := 0; ii < nrpcs; ii++ { 407 | go func(i int) { 408 | n := 0 409 | defer func() { ch <- n }() 410 | 411 | arg := 100 + i 412 | reply := "" 413 | e.Call("JunkServer.Handler2", arg, &reply) 414 | wanted := "handler2-" + strconv.Itoa(arg) 415 | if reply != wanted { 416 | t.Fatalf("wrong reply %v from Handler2, expecting %v", reply, wanted) 417 | } 418 | n += 1 419 | }(ii) 420 | } 421 | 422 | total := 0 423 | for ii := 0; ii < nrpcs; ii++ { 424 | x := <-ch 425 | total += x 426 | } 427 | 428 | if total != nrpcs { 429 | t.Fatalf("wrong number of RPCs completed, got %v, expected %v", total, nrpcs) 430 | } 431 | 432 | js.mu.Lock() 433 | defer js.mu.Unlock() 434 | if len(js.log2) != nrpcs { 435 | t.Fatalf("wrong number of RPCs delivered") 436 | } 437 | 438 | n := rn.GetCount(1000) 439 | if n != total { 440 | t.Fatalf("wrong GetCount() %v, expected %v\n", n, total) 441 | } 442 | } 443 | 444 | // 445 | // regression: an RPC that's delayed during Enabled=false 446 | // should not delay subsequent RPCs (e.g. after Enabled=true). 447 | // 448 | func TestRegression1(t *testing.T) { 449 | runtime.GOMAXPROCS(4) 450 | 451 | rn := MakeNetwork() 452 | defer rn.Cleanup() 453 | 454 | js := &JunkServer{} 455 | svc := MakeService(js) 456 | 457 | rs := MakeServer() 458 | rs.AddService(svc) 459 | rn.AddServer(1000, rs) 460 | 461 | e := rn.MakeEnd("c") 462 | rn.Connect("c", 1000) 463 | 464 | // start some RPCs while the ClientEnd is disabled. 465 | // they'll be delayed. 466 | rn.Enable("c", false) 467 | ch := make(chan bool) 468 | nrpcs := 20 469 | for ii := 0; ii < nrpcs; ii++ { 470 | go func(i int) { 471 | ok := false 472 | defer func() { ch <- ok }() 473 | 474 | arg := 100 + i 475 | reply := "" 476 | // this call ought to return false. 477 | e.Call("JunkServer.Handler2", arg, &reply) 478 | ok = true 479 | }(ii) 480 | } 481 | 482 | time.Sleep(100 * time.Millisecond) 483 | 484 | // now enable the ClientEnd and check that an RPC completes quickly. 485 | t0 := time.Now() 486 | rn.Enable("c", true) 487 | { 488 | arg := 99 489 | reply := "" 490 | e.Call("JunkServer.Handler2", arg, &reply) 491 | wanted := "handler2-" + strconv.Itoa(arg) 492 | if reply != wanted { 493 | t.Fatalf("wrong reply %v from Handler2, expecting %v", reply, wanted) 494 | } 495 | } 496 | dur := time.Since(t0).Seconds() 497 | 498 | if dur > 0.03 { 499 | t.Fatalf("RPC took too long (%v) after Enable", dur) 500 | } 501 | 502 | for ii := 0; ii < nrpcs; ii++ { 503 | <-ch 504 | } 505 | 506 | js.mu.Lock() 507 | defer js.mu.Unlock() 508 | if len(js.log2) != 1 { 509 | t.Fatalf("wrong number (%v) of RPCs delivered, expected 1", len(js.log2)) 510 | } 511 | 512 | n := rn.GetCount(1000) 513 | if n != 1 { 514 | t.Fatalf("wrong GetCount() %v, expected %v\n", n, 1) 515 | } 516 | } 517 | 518 | // 519 | // if an RPC is stuck in a server, and the server 520 | // is killed with DeleteServer(), does the RPC 521 | // get un-stuck? 522 | // 523 | func TestKilled(t *testing.T) { 524 | runtime.GOMAXPROCS(4) 525 | 526 | rn := MakeNetwork() 527 | defer rn.Cleanup() 528 | 529 | e := rn.MakeEnd("end1-99") 530 | 531 | js := &JunkServer{} 532 | svc := MakeService(js) 533 | 534 | rs := MakeServer() 535 | rs.AddService(svc) 536 | rn.AddServer("server99", rs) 537 | 538 | rn.Connect("end1-99", "server99") 539 | rn.Enable("end1-99", true) 540 | 541 | doneCh := make(chan bool) 542 | go func() { 543 | reply := 0 544 | ok := e.Call("JunkServer.Handler3", 99, &reply) 545 | doneCh <- ok 546 | }() 547 | 548 | time.Sleep(1000 * time.Millisecond) 549 | 550 | select { 551 | case <-doneCh: 552 | t.Fatalf("Handler3 should not have returned yet") 553 | case <-time.After(100 * time.Millisecond): 554 | } 555 | 556 | rn.DeleteServer("server99") 557 | 558 | select { 559 | case x := <-doneCh: 560 | if x != false { 561 | t.Fatalf("Handler3 returned successfully despite DeleteServer()") 562 | } 563 | case <-time.After(100 * time.Millisecond): 564 | t.Fatalf("Handler3 should return after DeleteServer()") 565 | } 566 | } 567 | 568 | func TestBenchmark(t *testing.T) { 569 | runtime.GOMAXPROCS(4) 570 | 571 | rn := MakeNetwork() 572 | defer rn.Cleanup() 573 | 574 | e := rn.MakeEnd("end1-99") 575 | 576 | js := &JunkServer{} 577 | svc := MakeService(js) 578 | 579 | rs := MakeServer() 580 | rs.AddService(svc) 581 | rn.AddServer("server99", rs) 582 | 583 | rn.Connect("end1-99", "server99") 584 | rn.Enable("end1-99", true) 585 | 586 | t0 := time.Now() 587 | n := 100000 588 | for iters := 0; iters < n; iters++ { 589 | reply := "" 590 | e.Call("JunkServer.Handler2", 111, &reply) 591 | if reply != "handler2-111" { 592 | t.Fatalf("wrong reply from Handler2") 593 | } 594 | } 595 | fmt.Printf("%v for %v\n", time.Since(t0), n) 596 | // march 2016, rtm laptop, 22 microseconds per RPC 597 | } 598 | -------------------------------------------------------------------------------- /src/main/diskvd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // start a diskvd server. it's a member of some replica 5 | // group, which has other members, and it needs to know 6 | // how to talk to the members of the shardmaster service. 7 | // used by ../diskv/test_test.go 8 | // 9 | // arguments: 10 | // -g groupid 11 | // -m masterport1 -m masterport2 ... 12 | // -s replicaport1 -s replicaport2 ... 13 | // -i my-index-in-server-port-list 14 | // -u unreliable 15 | // -d directory 16 | // -r restart 17 | 18 | import "time" 19 | import "6.824/diskv" 20 | import "os" 21 | import "fmt" 22 | import "strconv" 23 | import "runtime" 24 | 25 | func usage() { 26 | fmt.Printf("Usage: diskvd -g gid -m master... -s server... -i my-index -d dir\n") 27 | os.Exit(1) 28 | } 29 | 30 | func main() { 31 | var gid int64 = -1 // my replica group ID 32 | masters := []string{} // ports of shardmasters 33 | replicas := []string{} // ports of servers in my replica group 34 | me := -1 // my index in replicas[] 35 | unreliable := false 36 | dir := "" // store persistent data here 37 | restart := false 38 | 39 | for i := 1; i+1 < len(os.Args); i += 2 { 40 | a0 := os.Args[i] 41 | a1 := os.Args[i+1] 42 | if a0 == "-g" { 43 | gid, _ = strconv.ParseInt(a1, 10, 64) 44 | } else if a0 == "-m" { 45 | masters = append(masters, a1) 46 | } else if a0 == "-s" { 47 | replicas = append(replicas, a1) 48 | } else if a0 == "-i" { 49 | me, _ = strconv.Atoi(a1) 50 | } else if a0 == "-u" { 51 | unreliable, _ = strconv.ParseBool(a1) 52 | } else if a0 == "-d" { 53 | dir = a1 54 | } else if a0 == "-r" { 55 | restart, _ = strconv.ParseBool(a1) 56 | } else { 57 | usage() 58 | } 59 | } 60 | 61 | if gid < 0 || me < 0 || len(masters) < 1 || me >= len(replicas) || dir == "" { 62 | usage() 63 | } 64 | 65 | runtime.GOMAXPROCS(4) 66 | 67 | srv := diskv.StartServer(gid, masters, replicas, me, dir, restart) 68 | srv.Setunreliable(unreliable) 69 | 70 | // for safety, force quit after 10 minutes. 71 | time.Sleep(10 * 60 * time.Second) 72 | mep, _ := os.FindProcess(os.Getpid()) 73 | mep.Kill() 74 | } 75 | -------------------------------------------------------------------------------- /src/main/lockc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // see comments in lockd.go 5 | // 6 | 7 | import "6.824/lockservice" 8 | import "os" 9 | import "fmt" 10 | 11 | func usage() { 12 | fmt.Printf("Usage: lockc -l|-u primaryport backupport lockname\n") 13 | os.Exit(1) 14 | } 15 | 16 | func main() { 17 | if len(os.Args) == 5 { 18 | ck := lockservice.MakeClerk(os.Args[2], os.Args[3]) 19 | var ok bool 20 | if os.Args[1] == "-l" { 21 | ok = ck.Lock(os.Args[4]) 22 | } else if os.Args[1] == "-u" { 23 | ok = ck.Unlock(os.Args[4]) 24 | } else { 25 | usage() 26 | } 27 | fmt.Printf("reply: %v\n", ok) 28 | } else { 29 | usage() 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/lockd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // export GOPATH=~/6.824 4 | // go build lockd.go 5 | // go build lockc.go 6 | // ./lockd -p a b & 7 | // ./lockd -b a b & 8 | // ./lockc -l a b lx 9 | // ./lockc -u a b lx 10 | // 11 | // on Athena, use /tmp/myname-a and /tmp/myname-b 12 | // instead of a and b. 13 | 14 | import "time" 15 | import "6.824/lockservice" 16 | import "os" 17 | import "fmt" 18 | 19 | func main() { 20 | if len(os.Args) == 4 && os.Args[1] == "-p" { 21 | lockservice.StartServer(os.Args[2], os.Args[3], true) 22 | } else if len(os.Args) == 4 && os.Args[1] == "-b" { 23 | lockservice.StartServer(os.Args[2], os.Args[3], false) 24 | } else { 25 | fmt.Printf("Usage: lockd -p|-b primaryport backupport\n") 26 | os.Exit(1) 27 | } 28 | for { 29 | time.Sleep(100 * time.Second) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/mrcoordinator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // start the coordinator process, which is implemented 5 | // in ../mr/coordinator.go 6 | // 7 | // go run mrcoordinator.go pg*.txt 8 | // 9 | // Please do not change this file. 10 | // 11 | 12 | import "mit6.824/mr" 13 | import "time" 14 | import "os" 15 | import "fmt" 16 | 17 | func main() { 18 | if len(os.Args) < 2 { 19 | fmt.Fprintf(os.Stderr, "Usage: mrcoordinator inputfiles...\n") 20 | os.Exit(1) 21 | } 22 | 23 | m := mr.MakeCoordinator(os.Args[1:], 10) 24 | for m.Done() == false { 25 | time.Sleep(time.Second) 26 | } 27 | 28 | time.Sleep(time.Second) 29 | } 30 | -------------------------------------------------------------------------------- /src/main/mrsequential.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // simple sequential MapReduce. 5 | // 6 | // go run mrsequential.go wc.so pg*.txt 7 | // 8 | 9 | import "fmt" 10 | import "mit6.824/mr" 11 | import "plugin" 12 | import "os" 13 | import "log" 14 | import "io/ioutil" 15 | import "sort" 16 | 17 | // for sorting by key. 18 | type ByKey []mr.KeyValue 19 | 20 | // for sorting by key. 21 | func (a ByKey) Len() int { return len(a) } 22 | func (a ByKey) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 23 | func (a ByKey) Less(i, j int) bool { return a[i].Key < a[j].Key } 24 | 25 | func main() { 26 | if len(os.Args) < 3 { 27 | fmt.Fprintf(os.Stderr, "Usage: mrsequential xxx.so inputfiles...\n") 28 | os.Exit(1) 29 | } 30 | // 读出Map Reduce函数 31 | mapf, reducef := loadPlugin(os.Args[1]) 32 | 33 | // 34 | // read each input file, 35 | // pass it to Map, 36 | // accumulate the intermediate Map output. 37 | // 38 | // 遍历文件名,读取单词后存入intermediate 39 | var intermediate []mr.KeyValue 40 | for _, filename := range os.Args[2:] { 41 | //fmt.Println(filename) 42 | file, err := os.Open(filename) 43 | if err != nil { 44 | log.Fatalf("cannot open %v", filename) 45 | } 46 | content, err := ioutil.ReadAll(file) 47 | if err != nil { 48 | log.Fatalf("cannot read %v", filename) 49 | } 50 | file.Close() 51 | kva := mapf(filename, string(content)) 52 | intermediate = append(intermediate, kva...) 53 | } 54 | 55 | // 56 | // a big difference from real MapReduce is that all the 57 | // intermediate data is in one place, intermediate[], 58 | // rather than being partitioned into NxM buckets. 59 | // 60 | 61 | sort.Sort(ByKey(intermediate)) 62 | 63 | oname := "mr-out-0" 64 | ofile, _ := os.Create(oname) 65 | 66 | // 67 | // call Reduce on each distinct key in intermediate[], 68 | // and print the result to mr-out-0. 69 | // 70 | i := 0 71 | for i < len(intermediate) { 72 | j := i + 1 73 | for j < len(intermediate) && intermediate[j].Key == intermediate[i].Key { 74 | j++ 75 | } 76 | var values []string 77 | for k := i; k < j; k++ { 78 | values = append(values, intermediate[k].Value) 79 | } 80 | output := reducef(intermediate[i].Key, values) 81 | 82 | // this is the correct format for each line of Reduce output. 83 | fmt.Fprintf(ofile, "%v %v\n", intermediate[i].Key, output) 84 | 85 | i = j 86 | } 87 | 88 | ofile.Close() 89 | } 90 | 91 | // 92 | // load the application Map and Reduce functions 93 | // from a plugin file, e.g. ../mrapps/wc.so 94 | // 95 | // 读取插件中的Map Reduce函数,然后返回 96 | func loadPlugin(filename string) (func(string, string) []mr.KeyValue, func(string, []string) string) { 97 | p, err := plugin.Open(filename) 98 | if err != nil { 99 | log.Fatalf("cannot load plugin %v", filename) 100 | } 101 | xmapf, err := p.Lookup("Map") 102 | if err != nil { 103 | log.Fatalf("cannot find Map in %v", filename) 104 | } 105 | mapf := xmapf.(func(string, string) []mr.KeyValue) 106 | xreducef, err := p.Lookup("Reduce") 107 | if err != nil { 108 | log.Fatalf("cannot find Reduce in %v", filename) 109 | } 110 | reducef := xreducef.(func(string, []string) string) 111 | 112 | return mapf, reducef 113 | } 114 | -------------------------------------------------------------------------------- /src/main/mrworker.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // start a worker process, which is implemented 5 | // in ../mr/worker.go. typically there will be 6 | // multiple worker processes, talking to one coordinator. 7 | // 8 | // go run mrworker.go wc.so 9 | // 10 | // Please do not change this file. 11 | // 12 | 13 | import "mit6.824/mr" 14 | import "plugin" 15 | import "os" 16 | import "fmt" 17 | import "log" 18 | 19 | func main() { 20 | if len(os.Args) != 2 { 21 | fmt.Fprintf(os.Stderr, "Usage: mrworker xxx.so\n") 22 | os.Exit(1) 23 | } 24 | 25 | mapf, reducef := loadPlugin(os.Args[1]) 26 | 27 | mr.Worker(mapf, reducef) 28 | } 29 | 30 | // 31 | // load the application Map and Reduce functions 32 | // from a plugin file, e.g. ../mrapps/wc.so 33 | // 34 | func loadPlugin(filename string) (func(string, string) []mr.KeyValue, func(string, []string) string) { 35 | p, err := plugin.Open(filename) 36 | if err != nil { 37 | log.Fatalf("cannot load plugin %v", filename) 38 | } 39 | xmapf, err := p.Lookup("Map") 40 | if err != nil { 41 | log.Fatalf("cannot find Map in %v", filename) 42 | } 43 | mapf := xmapf.(func(string, string) []mr.KeyValue) 44 | xreducef, err := p.Lookup("Reduce") 45 | if err != nil { 46 | log.Fatalf("cannot find Reduce in %v", filename) 47 | } 48 | reducef := xreducef.(func(string, []string) string) 49 | 50 | return mapf, reducef 51 | } 52 | -------------------------------------------------------------------------------- /src/main/pbc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // pbservice client application 5 | // 6 | // export GOPATH=~/6.824 7 | // go build viewd.go 8 | // go build pbd.go 9 | // go build pbc.go 10 | // ./viewd /tmp/rtm-v & 11 | // ./pbd /tmp/rtm-v /tmp/rtm-1 & 12 | // ./pbd /tmp/rtm-v /tmp/rtm-2 & 13 | // ./pbc /tmp/rtm-v key1 value1 14 | // ./pbc /tmp/rtm-v key1 15 | // 16 | // change "rtm" to your user name. 17 | // start the pbd programs in separate windows and kill 18 | // and restart them to exercise fault tolerance. 19 | // 20 | 21 | import "6.824/pbservice" 22 | import "os" 23 | import "fmt" 24 | 25 | func usage() { 26 | fmt.Printf("Usage: pbc viewport key\n") 27 | fmt.Printf(" pbc viewport key value\n") 28 | os.Exit(1) 29 | } 30 | 31 | func main() { 32 | if len(os.Args) == 3 { 33 | // get 34 | ck := pbservice.MakeClerk(os.Args[1], "") 35 | v := ck.Get(os.Args[2]) 36 | fmt.Printf("%v\n", v) 37 | } else if len(os.Args) == 4 { 38 | // put 39 | ck := pbservice.MakeClerk(os.Args[1], "") 40 | ck.Put(os.Args[2], os.Args[3]) 41 | } else { 42 | usage() 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/pbd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // see directions in pbc.go 5 | // 6 | 7 | import "time" 8 | import "6.824/pbservice" 9 | import "os" 10 | import "fmt" 11 | 12 | func main() { 13 | if len(os.Args) != 3 { 14 | fmt.Printf("Usage: pbd viewport myport\n") 15 | os.Exit(1) 16 | } 17 | 18 | pbservice.StartServer(os.Args[1], os.Args[2]) 19 | 20 | for { 21 | time.Sleep(100 * time.Second) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/test-mr-many.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "Usage: $0 numTrials" 5 | exit 1 6 | fi 7 | 8 | trap 'kill -INT -$pid; exit 1' INT 9 | 10 | # Note: because the socketID is based on the current userID, 11 | # ./test-mr.sh cannot be run in parallel 12 | runs=$1 13 | chmod +x test-mr.sh 14 | 15 | for i in $(seq 1 $runs); do 16 | timeout -k 2s 900s ./test-mr.sh & 17 | pid=$! 18 | if ! wait $pid; then 19 | echo '***' FAILED TESTS IN TRIAL $i 20 | exit 1 21 | fi 22 | done 23 | echo '***' PASSED ALL $i TESTING TRIALS 24 | -------------------------------------------------------------------------------- /src/main/test-mr.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # map-reduce tests 5 | # 6 | 7 | # comment this out to run the tests without the Go race detector. 8 | RACE=-race 9 | 10 | if [[ "$OSTYPE" = "darwin"* ]] 11 | then 12 | if go version | grep 'go1.17.[012345]' 13 | then 14 | # -race with plug-ins on x86 MacOS 12 with 15 | # go1.17 before 1.17.6 sometimes crash. 16 | RACE= 17 | echo '*** Turning off -race since it may not work on a Mac' 18 | echo ' with ' `go version` 19 | fi 20 | fi 21 | 22 | TIMEOUT=timeout 23 | if timeout 2s sleep 1 > /dev/null 2>&1 24 | then 25 | : 26 | else 27 | if gtimeout 2s sleep 1 > /dev/null 2>&1 28 | then 29 | TIMEOUT=gtimeout 30 | else 31 | # no timeout command 32 | TIMEOUT= 33 | # echo '*** Cannot find timeout command; proceeding without timeouts.' 34 | fi 35 | fi 36 | if [ "$TIMEOUT" != "" ] 37 | then 38 | TIMEOUT+=" -k 2s 180s " 39 | fi 40 | 41 | # run the test in a fresh sub-directory. 42 | rm -rf mr-tmp 43 | mkdir mr-tmp || exit 1 44 | cd mr-tmp || exit 1 45 | # 没用吧这个指令 :) 46 | # rm -f mr-* 47 | 48 | # make sure software is freshly built. 49 | (cd ../../mrapps && go clean) 50 | (cd .. && go clean) 51 | (cd ../../mrapps && go build $RACE -buildmode=plugin wc.go) || exit 1 52 | (cd ../../mrapps && go build $RACE -buildmode=plugin indexer.go) || exit 1 53 | (cd ../../mrapps && go build $RACE -buildmode=plugin mtiming.go) || exit 1 54 | (cd ../../mrapps && go build $RACE -buildmode=plugin rtiming.go) || exit 1 55 | (cd ../../mrapps && go build $RACE -buildmode=plugin jobcount.go) || exit 1 56 | (cd ../../mrapps && go build $RACE -buildmode=plugin early_exit.go) || exit 1 57 | (cd ../../mrapps && go build $RACE -buildmode=plugin crash.go) || exit 1 58 | (cd ../../mrapps && go build $RACE -buildmode=plugin nocrash.go) || exit 1 59 | (cd .. && go build $RACE mrcoordinator.go) || exit 1 60 | (cd .. && go build $RACE mrworker.go) || exit 1 61 | (cd .. && go build $RACE mrsequential.go) || exit 1 62 | 63 | failed_any=0 64 | 65 | ######################################################### 66 | # first word-count 67 | 68 | # generate the correct output 69 | ../mrsequential ../../mrapps/wc.so ../pg*txt || exit 1 70 | sort mr-out-0 > mr-correct-wc.txt 71 | rm -f mr-out* 72 | 73 | echo '***' Starting wc test. 74 | 75 | $TIMEOUT ../mrcoordinator ../pg*txt & 76 | pid=$! 77 | 78 | # give the coordinator time to create the sockets. 79 | sleep 1 80 | 81 | # start multiple workers. 82 | $TIMEOUT ../mrworker ../../mrapps/wc.so & 83 | $TIMEOUT ../mrworker ../../mrapps/wc.so & 84 | $TIMEOUT ../mrworker ../../mrapps/wc.so & 85 | 86 | # wait for the coordinator to exit. 87 | wait $pid 88 | 89 | # since workers are required to exit when a job is completely finished, 90 | # and not before, that means the job has finished. 91 | sort mr-out* | grep . > mr-wc-all 92 | if cmp mr-wc-all mr-correct-wc.txt 93 | then 94 | echo '---' wc test: PASS 95 | else 96 | echo '---' wc output is not the same as mr-correct-wc.txt 97 | echo '---' wc test: FAIL 98 | failed_any=1 99 | fi 100 | 101 | # wait for remaining workers and coordinator to exit. 102 | wait 103 | 104 | # 断点 只测试wc 105 | # exit 1 106 | 107 | ######################################################### 108 | # now indexer 109 | rm -f mr-* 110 | 111 | # generate the correct output 112 | ../mrsequential ../../mrapps/indexer.so ../pg*txt || exit 1 113 | sort mr-out-0 > mr-correct-indexer.txt 114 | rm -f mr-out* 115 | 116 | echo '***' Starting indexer test. 117 | 118 | $TIMEOUT ../mrcoordinator ../pg*txt & 119 | sleep 1 120 | 121 | # start multiple workers 122 | $TIMEOUT ../mrworker ../../mrapps/indexer.so & 123 | $TIMEOUT ../mrworker ../../mrapps/indexer.so 124 | 125 | sort mr-out* | grep . > mr-indexer-all 126 | if cmp mr-indexer-all mr-correct-indexer.txt 127 | then 128 | echo '---' indexer test: PASS 129 | else 130 | echo '---' indexer output is not the same as mr-correct-indexer.txt 131 | echo '---' indexer test: FAIL 132 | failed_any=1 133 | fi 134 | 135 | wait 136 | 137 | ######################################################### 138 | echo '***' Starting map parallelism test. 139 | 140 | rm -f mr-* 141 | 142 | $TIMEOUT ../mrcoordinator ../pg*txt & 143 | sleep 1 144 | 145 | $TIMEOUT ../mrworker ../../mrapps/mtiming.so & 146 | $TIMEOUT ../mrworker ../../mrapps/mtiming.so 147 | 148 | NT=`cat mr-out* | grep '^times-' | wc -l | sed 's/ //g'` 149 | if [ "$NT" != "2" ] 150 | then 151 | echo '---' saw "$NT" workers rather than 2 152 | echo '---' map parallelism test: FAIL 153 | failed_any=1 154 | fi 155 | 156 | if cat mr-out* | grep '^parallel.* 2' > /dev/null 157 | then 158 | echo '---' map parallelism test: PASS 159 | else 160 | echo '---' map workers did not run in parallel 161 | echo '---' map parallelism test: FAIL 162 | failed_any=1 163 | fi 164 | 165 | wait 166 | 167 | 168 | ######################################################### 169 | echo '***' Starting reduce parallelism test. 170 | 171 | rm -f mr-* 172 | 173 | $TIMEOUT ../mrcoordinator ../pg*txt & 174 | sleep 1 175 | 176 | $TIMEOUT ../mrworker ../../mrapps/rtiming.so & 177 | $TIMEOUT ../mrworker ../../mrapps/rtiming.so 178 | 179 | NT=`cat mr-out* | grep '^[a-z] 2' | wc -l | sed 's/ //g'` 180 | if [ "$NT" -lt "2" ] 181 | then 182 | echo '---' too few parallel reduces. 183 | echo '---' reduce parallelism test: FAIL 184 | failed_any=1 185 | else 186 | echo '---' reduce parallelism test: PASS 187 | fi 188 | 189 | wait 190 | 191 | ######################################################### 192 | echo '***' Starting job count test. 193 | 194 | rm -f mr-* 195 | 196 | $TIMEOUT ../mrcoordinator ../pg*txt & 197 | sleep 1 198 | 199 | $TIMEOUT ../mrworker ../../mrapps/jobcount.so & 200 | $TIMEOUT ../mrworker ../../mrapps/jobcount.so 201 | $TIMEOUT ../mrworker ../../mrapps/jobcount.so & 202 | $TIMEOUT ../mrworker ../../mrapps/jobcount.so 203 | 204 | NT=`cat mr-out* | awk '{print $2}'` 205 | if [ "$NT" -eq "8" ] 206 | then 207 | echo '---' job count test: PASS 208 | else 209 | echo '---' map jobs ran incorrect number of times "($NT != 8)" 210 | echo '---' job count test: FAIL 211 | failed_any=1 212 | fi 213 | 214 | wait 215 | 216 | ######################################################### 217 | # test whether any worker or coordinator exits before the 218 | # task has completed (i.e., all output files have been finalized) 219 | rm -f mr-* 220 | 221 | echo '***' Starting early exit test. 222 | 223 | DF=anydone$$ 224 | rm -f $DF 225 | 226 | ($TIMEOUT ../mrcoordinator ../pg*txt ; touch $DF) & 227 | 228 | # give the coordinator time to create the sockets. 229 | sleep 1 230 | 231 | # start multiple workers. 232 | ($TIMEOUT ../mrworker ../../mrapps/early_exit.so ; touch $DF) & 233 | ($TIMEOUT ../mrworker ../../mrapps/early_exit.so ; touch $DF) & 234 | ($TIMEOUT ../mrworker ../../mrapps/early_exit.so ; touch $DF) & 235 | 236 | # wait for any of the coord or workers to exit. 237 | # `jobs` ensures that any completed old processes from other tests 238 | # are not waited upon. 239 | jobs &> /dev/null 240 | if [[ "$OSTYPE" = "darwin"* ]] 241 | then 242 | # bash on the Mac doesn't have wait -n 243 | while [ ! -e $DF ] 244 | do 245 | sleep 0.2 246 | done 247 | else 248 | # the -n causes wait to wait for just one child process, 249 | # rather than waiting for all to finish. 250 | wait -n 251 | fi 252 | 253 | rm -f $DF 254 | 255 | # a process has exited. this means that the output should be finalized 256 | # otherwise, either a worker or the coordinator exited early 257 | sort mr-out* | grep . > mr-wc-all-initial 258 | 259 | # wait for remaining workers and coordinator to exit. 260 | wait 261 | 262 | # compare initial and final outputs 263 | sort mr-out* | grep . > mr-wc-all-final 264 | if cmp mr-wc-all-final mr-wc-all-initial 265 | then 266 | echo '---' early exit test: PASS 267 | else 268 | echo '---' output changed after first worker exited 269 | echo '---' early exit test: FAIL 270 | failed_any=1 271 | fi 272 | rm -f mr-* 273 | 274 | ######################################################### 275 | echo '***' Starting crash test. 276 | 277 | # generate the correct output 278 | ../mrsequential ../../mrapps/nocrash.so ../pg*txt || exit 1 279 | sort mr-out-0 > mr-correct-crash.txt 280 | rm -f mr-out* 281 | 282 | rm -f mr-done 283 | ($TIMEOUT ../mrcoordinator ../pg*txt ; touch mr-done ) & 284 | sleep 1 285 | 286 | # start multiple workers 287 | $TIMEOUT ../mrworker ../../mrapps/crash.so & 288 | 289 | # mimic rpc.go's coordinatorSock() 290 | SOCKNAME=/var/tmp/824-mr-`id -u` 291 | 292 | ( while [ -e $SOCKNAME -a ! -f mr-done ] 293 | do 294 | $TIMEOUT ../mrworker ../../mrapps/crash.so 295 | sleep 1 296 | done ) & 297 | 298 | ( while [ -e $SOCKNAME -a ! -f mr-done ] 299 | do 300 | $TIMEOUT ../mrworker ../../mrapps/crash.so 301 | sleep 1 302 | done ) & 303 | 304 | while [ -e $SOCKNAME -a ! -f mr-done ] 305 | do 306 | $TIMEOUT ../mrworker ../../mrapps/crash.so 307 | sleep 1 308 | done 309 | 310 | wait 311 | 312 | rm $SOCKNAME 313 | sort mr-out* | grep . > mr-crash-all 314 | if cmp mr-crash-all mr-correct-crash.txt 315 | then 316 | echo '---' crash test: PASS 317 | else 318 | echo '---' crash output is not the same as mr-correct-crash.txt 319 | echo '---' crash test: FAIL 320 | failed_any=1 321 | fi 322 | 323 | ######################################################### 324 | if [ $failed_any -eq 0 ]; then 325 | echo '***' PASSED ALL TESTS 326 | else 327 | echo '***' FAILED SOME TESTS 328 | exit 1 329 | fi 330 | -------------------------------------------------------------------------------- /src/main/viewd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // see directions in pbc.go 5 | // 6 | 7 | import "time" 8 | import "6.824/viewservice" 9 | import "os" 10 | import "fmt" 11 | 12 | func main() { 13 | if len(os.Args) != 2 { 14 | fmt.Printf("Usage: viewd port\n") 15 | os.Exit(1) 16 | } 17 | 18 | viewservice.StartServer(os.Args[1]) 19 | 20 | for { 21 | time.Sleep(100 * time.Second) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/models/kv.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "mit6.824/porcupine" 4 | import "fmt" 5 | import "sort" 6 | 7 | type KvInput struct { 8 | Op uint8 // 0 => get, 1 => put, 2 => append 9 | Key string 10 | Value string 11 | } 12 | 13 | type KvOutput struct { 14 | Value string 15 | } 16 | 17 | var KvModel = porcupine.Model{ 18 | Partition: func(history []porcupine.Operation) [][]porcupine.Operation { 19 | m := make(map[string][]porcupine.Operation) 20 | for _, v := range history { 21 | key := v.Input.(KvInput).Key 22 | m[key] = append(m[key], v) 23 | } 24 | keys := make([]string, 0, len(m)) 25 | for k := range m { 26 | keys = append(keys, k) 27 | } 28 | sort.Strings(keys) 29 | ret := make([][]porcupine.Operation, 0, len(keys)) 30 | for _, k := range keys { 31 | ret = append(ret, m[k]) 32 | } 33 | return ret 34 | }, 35 | Init: func() interface{} { 36 | // note: we are modeling a single key's value here; 37 | // we're partitioning by key, so this is okay 38 | return "" 39 | }, 40 | Step: func(state, input, output interface{}) (bool, interface{}) { 41 | inp := input.(KvInput) 42 | out := output.(KvOutput) 43 | st := state.(string) 44 | if inp.Op == 0 { 45 | // get 46 | return out.Value == st, state 47 | } else if inp.Op == 1 { 48 | // put 49 | return true, inp.Value 50 | } else { 51 | // append 52 | return true, (st + inp.Value) 53 | } 54 | }, 55 | DescribeOperation: func(input, output interface{}) string { 56 | inp := input.(KvInput) 57 | out := output.(KvOutput) 58 | switch inp.Op { 59 | case 0: 60 | return fmt.Sprintf("get('%s') -> '%s'", inp.Key, out.Value) 61 | case 1: 62 | return fmt.Sprintf("put('%s', '%s')", inp.Key, inp.Value) 63 | case 2: 64 | return fmt.Sprintf("append('%s', '%s')", inp.Key, inp.Value) 65 | default: 66 | return "" 67 | } 68 | }, 69 | } 70 | -------------------------------------------------------------------------------- /src/mr/coordinator.go: -------------------------------------------------------------------------------- 1 | package mr 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | "time" 7 | ) 8 | import "net" 9 | import "os" 10 | import "net/rpc" 11 | import "net/http" 12 | 13 | type Coordinator struct { 14 | // Mutex锁 15 | lock sync.Mutex 16 | // Reduce任务数量 17 | reduceCount int 18 | // Worker进程ID 19 | workers []int 20 | // 当前状态,0正在做Map任务,1正在做Reduce任务,2等Worker全部退出 21 | status int 22 | 23 | // Map任务 24 | mapTasks map[string]*mapTask 25 | // Map已完成数量 26 | mapTaskDoneCount int 27 | 28 | // Reduce任务 29 | reduceTasks []*reduceTask 30 | // Reduce任务完成数量 31 | reduceTaskDoneCount int 32 | } 33 | 34 | // Reduce任务结构 35 | type reduceTask struct { 36 | id int 37 | working bool 38 | done bool 39 | workerID int 40 | } 41 | 42 | // Map任务结构 43 | type mapTask struct { 44 | id int 45 | name string 46 | working bool 47 | done bool 48 | workerID int 49 | } 50 | 51 | // Your code here -- RPC handlers for the worker to call. 52 | /* 53 | 1.分Map任务给worker,worker完成之后call一个task ok(10s计时) 54 | 2.全部Map完成后开始reduce,每个key的所有values传给worker 55 | 3.写入文件mr-out-x 先排序好所有的intermediate 56 | 4.关闭所有worker后退出自己 57 | */ 58 | 59 | // Fuck Worker访问此接口拿到一个任务 60 | func (c *Coordinator) Fuck(args *FuckArgs, reply *FuckReply) error { 61 | c.lock.Lock() 62 | defer c.lock.Unlock() 63 | // 检查状态 64 | if c.status == 0 { 65 | // Map任务没做完,找一个没做的任务给worker 66 | mapName := "" 67 | for _, task := range c.mapTasks { 68 | if task.working == false && task.done == false { 69 | // 找到了可以给他的map任务 70 | mapName = task.name 71 | task.workerID = args.WorkerID 72 | task.working = true 73 | break 74 | } 75 | } 76 | if mapName == "" { 77 | // 可能找不到可以做的map任务,就传0让worker等一会 78 | reply.TaskType = 0 79 | return nil 80 | } 81 | // 回传给worker的数据 82 | reply.MapID = c.mapTasks[mapName].id 83 | reply.MapName = mapName 84 | reply.TaskType = 1 85 | reply.ReduceCount = c.reduceCount 86 | go c.checkTaskTimeOut(1, mapName, 0) 87 | } else if c.status == 1 { 88 | // Map做完了,在做Reduce任务 89 | reduceID := -1 90 | for i, task := range c.reduceTasks { 91 | if task.working == false && task.done == false { 92 | // 找到了可以给他的reduce任务 93 | reduceID = i 94 | task.workerID = args.WorkerID 95 | task.working = true 96 | break 97 | } 98 | } 99 | if reduceID == -1 { 100 | // 没找到可以做的reduce任务,也传0 101 | reply.TaskType = 0 102 | return nil 103 | } 104 | reply.TaskType = 2 105 | reply.ReduceID = reduceID 106 | reply.MapTaskCount = c.mapTaskDoneCount 107 | go c.checkTaskTimeOut(2, "", reduceID) 108 | } else if c.status == 2 { 109 | // 发送退出信号,任务都完成了 110 | reply.Exit = true 111 | } 112 | return nil 113 | } 114 | 115 | // WorkerExit Worker退出回传 116 | func (c *Coordinator) WorkerExit(args *WorkerExitArgs, n *None) error { 117 | c.lock.Lock() 118 | log.Printf("Worker[%v] exit!", args.WorkerID) 119 | c.deleteWorker(args.WorkerID) 120 | c.lock.Unlock() 121 | return nil 122 | } 123 | 124 | // 删除worker,需要提前lock 125 | func (c *Coordinator) deleteWorker(workerID int) { 126 | // 从workers删除这个worker 127 | workerKey := -1 128 | for i, worker := range c.workers { 129 | if worker == workerID { 130 | workerKey = i 131 | break 132 | } 133 | } 134 | // 检查是否有这个worker,可能这是以前没死完的worker 135 | if workerKey == -1 { 136 | log.Printf("Worker[%v] exit error! its not my worker!", workerID) 137 | } else { 138 | // 删除这个worker 139 | c.workers = append(c.workers[:workerKey], c.workers[workerKey+1:]...) 140 | } 141 | } 142 | 143 | // checkTaskTimeOut 检查任务超时 144 | func (c *Coordinator) checkTaskTimeOut(taskType int, mapName string, reduceID int) { 145 | time.Sleep(10 * time.Second) 146 | c.lock.Lock() 147 | if taskType == 1 { 148 | // 检查map任务 149 | if c.mapTasks[mapName].done == false { 150 | log.Printf("Map task[%v] dead, worker[%v] dead!:", mapName, c.mapTasks[mapName].workerID) 151 | c.mapTasks[mapName].working = false 152 | c.deleteWorker(c.mapTasks[mapName].workerID) 153 | } 154 | } else if taskType == 2 { 155 | // 检查reduce任务 156 | if c.reduceTasks[reduceID].done == false { 157 | log.Printf("Reduce task[%v] dead, worker[%v] dead!", reduceID, c.reduceTasks[reduceID].workerID) 158 | c.reduceTasks[reduceID].working = false 159 | c.deleteWorker(c.reduceTasks[reduceID].workerID) 160 | } 161 | } 162 | c.lock.Unlock() 163 | } 164 | 165 | // TaskDone worker任务完成回传 166 | func (c *Coordinator) TaskDone(args *TaskDoneArgs, n *None) error { 167 | c.lock.Lock() 168 | if args.TaskType == 1 { 169 | // Map任务完成 170 | c.mapTasks[args.MapName].done = true 171 | c.mapTasks[args.MapName].working = false 172 | c.mapTaskDoneCount++ 173 | log.Printf("Map task[%v] done", args.MapName) 174 | if c.mapTaskDoneCount == len(c.mapTasks) { 175 | // 所有Map任务已完成 176 | c.status = 1 177 | log.Println("All map tasks done!") 178 | } 179 | } else if args.TaskType == 2 { 180 | // Reduce任务完成 181 | c.reduceTasks[args.ReduceID].done = true 182 | c.reduceTasks[args.ReduceID].working = false 183 | c.reduceTaskDoneCount++ 184 | log.Printf("Reduce Task[%v] done", args.ReduceID) 185 | if c.reduceTaskDoneCount == len(c.reduceTasks) { 186 | // 所有Reduce任务已完成 187 | c.status = 2 188 | log.Printf("All reduce tasks done!") 189 | } 190 | } 191 | c.lock.Unlock() 192 | return nil 193 | } 194 | 195 | // RegisterWorker Worker访问此接口来注册到Master,传回一个id 196 | func (c *Coordinator) RegisterWorker(n *None, workerID *int) error { 197 | c.lock.Lock() 198 | *workerID = len(c.workers) 199 | c.workers = append(c.workers, *workerID) 200 | log.Printf("Worker[%v] register to master now!", *workerID) 201 | c.lock.Unlock() 202 | return nil 203 | } 204 | 205 | // code end 206 | 207 | // 208 | // start a thread that listens for RPCs from worker.go 209 | // 210 | func (c *Coordinator) server() { 211 | rpc.Register(c) 212 | rpc.HandleHTTP() 213 | sockname := coordinatorSock() 214 | os.Remove(sockname) 215 | l, e := net.Listen("unix", sockname) 216 | if e != nil { 217 | log.Fatal("listen error:", e) 218 | } 219 | go http.Serve(l, nil) 220 | } 221 | 222 | // Done 223 | // main/mrcoordinator.go calls Done() periodically to find out 224 | // if the entire job has finished. 225 | // 226 | // mrcoordinator通过这个来检查是否所有任务已完成 227 | func (c *Coordinator) Done() bool { 228 | ret := false 229 | // Your code here. 230 | c.lock.Lock() 231 | if c.status == 2 && len(c.workers) == 0 { 232 | log.Printf("Master done now!") 233 | ret = true 234 | } 235 | c.lock.Unlock() 236 | // code end 237 | return ret 238 | } 239 | 240 | // MakeCoordinator 241 | // create a Coordinator. 242 | // main/mrcoordinator.go calls this function. 243 | // nReduce is the number of reduce tasks to use. 244 | // 245 | func MakeCoordinator(files []string, nReduce int) *Coordinator { 246 | c := Coordinator{} 247 | // Your code here. 248 | // 初始化coordinator 249 | c.reduceCount = nReduce 250 | c.workers = make([]int, 0) 251 | c.status = 0 252 | // 初始化Map任务 253 | c.mapTasks = make(map[string]*mapTask) 254 | for i, fileName := range files { 255 | c.mapTasks[fileName] = &mapTask{i, fileName, false, false, 0} 256 | } 257 | // 初始化Reduce任务 258 | c.reduceTasks = make([]*reduceTask, nReduce) 259 | for i := 0; i < nReduce; i++ { 260 | c.reduceTasks[i] = &reduceTask{i, false, false, 0} 261 | } 262 | // code end 263 | c.server() 264 | return &c 265 | } 266 | -------------------------------------------------------------------------------- /src/mr/rpc.go: -------------------------------------------------------------------------------- 1 | package mr 2 | 3 | // 4 | // RPC definitions. 5 | // 6 | // remember to capitalize all names. 7 | // 8 | 9 | import ( 10 | "os" 11 | ) 12 | import "strconv" 13 | 14 | // 15 | // example to show how to declare the arguments 16 | // and reply for an RPC. 17 | // 18 | 19 | // Args Add your RPC definitions here. 20 | 21 | type FuckArgs struct { 22 | WorkerID int 23 | } 24 | 25 | type FuckReply struct { 26 | // 任务类型:[1] Map任务, [2] Reduce任务 27 | TaskType int 28 | // Reduce任务数量 29 | ReduceCount int 30 | // Map任务文件名 31 | MapName string 32 | // Map任务ID 33 | MapID int 34 | // Reduce任务ID 35 | ReduceID int 36 | // Reduce任务用的,map任务数量 37 | MapTaskCount int 38 | // 是否退出Worker 39 | Exit bool 40 | } 41 | 42 | // TaskDoneArgs 任务完成结构 43 | type TaskDoneArgs struct { 44 | // 任务类型:[1] Map任务, [2] Reduce任务 45 | TaskType int 46 | // Map任务文件名 47 | MapName string 48 | // Reduce任务ID 49 | ReduceID int 50 | } 51 | 52 | // WorkerExitArgs worker退出 请求 53 | type WorkerExitArgs struct { 54 | WorkerID int 55 | } 56 | 57 | // None 空结构,用来占位 58 | type None struct{} 59 | 60 | // 中间文件格式 mr-MapID-ReduceID 61 | var interFileName = "mr-%v-%v" 62 | 63 | // 输出文件格式 mr-out-ReduceID 64 | var outFileName = "mr-out-%v" 65 | 66 | // code end 67 | 68 | // Cook up a unique-ish UNIX-domain socket name 69 | // in /var/tmp, for the coordinator. 70 | // Can't use the current directory since 71 | // Athena AFS doesn't support UNIX-domain sockets. 72 | func coordinatorSock() string { 73 | s := "/var/tmp/824-mr-" 74 | s += strconv.Itoa(os.Getuid()) 75 | return s 76 | } 77 | -------------------------------------------------------------------------------- /src/mr/worker.go: -------------------------------------------------------------------------------- 1 | package mr 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "os" 8 | "sort" 9 | ) 10 | import "log" 11 | import "net/rpc" 12 | import "hash/fnv" 13 | 14 | // KeyValue 15 | // Map functions return a slice of KeyValue. 16 | // 17 | type KeyValue struct { 18 | Key string 19 | Value string 20 | } 21 | 22 | // ByKey for sorting by key. 23 | type ByKey []KeyValue 24 | 25 | // Len for sorting by key. 26 | func (a ByKey) Len() int { return len(a) } 27 | func (a ByKey) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 28 | func (a ByKey) Less(i, j int) bool { return a[i].Key < a[j].Key } 29 | 30 | // 31 | // use ihash(key) % NReduce to choose the reduce 32 | // task number for each KeyValue emitted by Map. 33 | // 34 | func ihash(key string) int { 35 | h := fnv.New32a() 36 | h.Write([]byte(key)) 37 | return int(h.Sum32() & 0x7fffffff) 38 | } 39 | 40 | // Worker 41 | // main/mrworker.go calls this function. 42 | // 43 | func Worker(mapf func(string, string) []KeyValue, reducef func(string, []string) string) { 44 | workerID := register() 45 | args := FuckArgs{workerID} 46 | for { 47 | reply := FuckReply{} 48 | ok := call("Coordinator.Fuck", &args, &reply) 49 | if !ok { 50 | log.Fatalf("Worker get task fail!") 51 | } 52 | if reply.Exit == false { 53 | if reply.TaskType == 0 { 54 | // 没拿到任务 休息一会 55 | // time.Sleep(time.Millisecond) 56 | } else if reply.TaskType == 1 { 57 | // Map任务 58 | mapResult := mapf(reply.MapName, readFile(reply.MapName)) 59 | // 根据ReduceCount拆分 60 | reduceContent := make([][]KeyValue, reply.ReduceCount) 61 | // 存进content 62 | for _, kv := range mapResult { 63 | key := ihash(kv.Key) % reply.ReduceCount 64 | reduceContent[key] = append(reduceContent[key], kv) 65 | } 66 | // 写Inter文件,nReduce有多少就写多少 67 | for i, content := range reduceContent { 68 | fileName := fmt.Sprintf(interFileName, reply.MapID, i) 69 | f, _ := os.Create(fileName) 70 | enc := json.NewEncoder(f) 71 | for _, line := range content { 72 | enc.Encode(&line) 73 | } 74 | f.Close() 75 | } 76 | // 回传Map任务完成 77 | taskDone(1, reply.MapName, 0) 78 | } else if reply.TaskType == 2 { 79 | // reduce任务,把所有key相同的values传给reduce函数,然后写入文件 80 | // 要先读入同一个reduceID的文件,然后排序,整理 81 | inter := make([]KeyValue, 0) 82 | for i := 0; i < reply.MapTaskCount; i++ { 83 | // 读取所有这个reduceID的文件 84 | fileName := fmt.Sprintf(interFileName, i, reply.ReduceID) 85 | reduceF, _ := os.Open(fileName) 86 | dec := json.NewDecoder(reduceF) 87 | for { 88 | var kv KeyValue 89 | if err := dec.Decode(&kv); err != nil { 90 | break 91 | } 92 | inter = append(inter, kv) 93 | } 94 | reduceF.Close() 95 | } 96 | // 读入了所有的kv 排序 97 | sort.Sort(ByKey(inter)) 98 | // 合并同类kv且写入文件 99 | fileName := fmt.Sprintf(outFileName, reply.ReduceID) 100 | outF, _ := os.Create(fileName) 101 | i := 0 102 | for i < len(inter) { 103 | j := i + 1 104 | for j < len(inter) && inter[j].Key == inter[i].Key { 105 | j++ 106 | } 107 | var values []string 108 | for k := i; k < j; k++ { 109 | values = append(values, inter[k].Value) 110 | } 111 | output := reducef(inter[i].Key, values) 112 | fmt.Fprintf(outF, "%v %v\n", inter[i].Key, output) 113 | i = j 114 | } 115 | outF.Close() 116 | // 回传Task任务完成 117 | taskDone(2, "", reply.ReduceID) 118 | } 119 | } else { 120 | workerExit(workerID) 121 | os.Exit(1) 122 | } 123 | } 124 | } 125 | 126 | // worker退出 回传给master 127 | func workerExit(workerID int) { 128 | args := WorkerExitArgs{workerID} 129 | ok := call("Coordinator.WorkerExit", &args, &None{}) 130 | if !ok { 131 | log.Fatalf("Worker[%v] exit fail!", workerID) 132 | } 133 | } 134 | 135 | // 任务完成回传 [1]:map [2]:reduce 136 | func taskDone(taskType int, mapName string, reduceID int) { 137 | args := TaskDoneArgs{taskType, mapName, reduceID} 138 | ok := call("Coordinator.TaskDone", &args, &None{}) 139 | if !ok { 140 | log.Fatalf("Worker matTaskDone fail!") 141 | } 142 | } 143 | 144 | // 注册当前worker到master,返回master给的id 145 | func register() int { 146 | var workerID int 147 | ok := call("Coordinator.RegisterWorker", &None{}, &workerID) 148 | if !ok { 149 | log.Fatalf("Worker register to master fail!") 150 | } 151 | return workerID 152 | } 153 | 154 | // 读取文件,返回内容 155 | func readFile(fileName string) string { 156 | file, err := os.Open(fileName) 157 | if err != nil { 158 | log.Fatalf("Master cannot open %v", fileName) 159 | } 160 | content, err := ioutil.ReadAll(file) 161 | if err != nil { 162 | log.Fatalf("Master cannot read %v", fileName) 163 | } 164 | err = file.Close() 165 | if err != nil { 166 | log.Fatalf("Master cannot close %v", fileName) 167 | } 168 | return string(content) 169 | } 170 | 171 | // 172 | // send an RPC request to the coordinator, wait for the response. 173 | // usually returns true. 174 | // returns false if something goes wrong. 175 | // 176 | func call(rpcname string, args interface{}, reply interface{}) bool { 177 | sockname := coordinatorSock() 178 | c, err := rpc.DialHTTP("unix", sockname) 179 | if err != nil { 180 | log.Fatal("dialing:", err) 181 | } 182 | defer c.Close() 183 | err = c.Call(rpcname, args, reply) 184 | if err == nil { 185 | return true 186 | } 187 | fmt.Println(err) 188 | return false 189 | } 190 | -------------------------------------------------------------------------------- /src/mrapps/crash.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // a MapReduce pseudo-application that sometimes crashes, 5 | // and sometimes takes a long time, 6 | // to test MapReduce's ability to recover. 7 | // 8 | // go build -buildmode=plugin crash.go 9 | // 10 | 11 | import "mit6.824/mr" 12 | import crand "crypto/rand" 13 | import "math/big" 14 | import "strings" 15 | import "os" 16 | import "sort" 17 | import "strconv" 18 | import "time" 19 | 20 | func maybeCrash() { 21 | max := big.NewInt(1000) 22 | rr, _ := crand.Int(crand.Reader, max) 23 | if rr.Int64() < 330 { 24 | // crash! 25 | os.Exit(1) 26 | } else if rr.Int64() < 660 { 27 | // delay for a while. 28 | maxms := big.NewInt(10 * 1000) 29 | ms, _ := crand.Int(crand.Reader, maxms) 30 | time.Sleep(time.Duration(ms.Int64()) * time.Millisecond) 31 | } 32 | } 33 | 34 | func Map(filename string, contents string) []mr.KeyValue { 35 | maybeCrash() 36 | 37 | kva := []mr.KeyValue{} 38 | kva = append(kva, mr.KeyValue{"a", filename}) 39 | kva = append(kva, mr.KeyValue{"b", strconv.Itoa(len(filename))}) 40 | kva = append(kva, mr.KeyValue{"c", strconv.Itoa(len(contents))}) 41 | kva = append(kva, mr.KeyValue{"d", "xyzzy"}) 42 | return kva 43 | } 44 | 45 | func Reduce(key string, values []string) string { 46 | maybeCrash() 47 | 48 | // sort values to ensure deterministic output. 49 | vv := make([]string, len(values)) 50 | copy(vv, values) 51 | sort.Strings(vv) 52 | 53 | val := strings.Join(vv, " ") 54 | return val 55 | } 56 | -------------------------------------------------------------------------------- /src/mrapps/early_exit.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // a word-count application "plugin" for MapReduce. 5 | // 6 | // go build -buildmode=plugin wc_long.go 7 | // 8 | 9 | import ( 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "mit6.824/mr" 15 | ) 16 | 17 | // 18 | // The map function is called once for each file of input. 19 | // This map function just returns 1 for each file 20 | // 21 | func Map(filename string, contents string) []mr.KeyValue { 22 | kva := []mr.KeyValue{} 23 | kva = append(kva, mr.KeyValue{filename, "1"}) 24 | return kva 25 | } 26 | 27 | // 28 | // The reduce function is called once for each key generated by the 29 | // map tasks, with a list of all the values created for that key by 30 | // any map task. 31 | // 32 | func Reduce(key string, values []string) string { 33 | // some reduce tasks sleep for a long time; potentially seeing if 34 | // a worker will accidentally exit early 35 | if strings.Contains(key, "sherlock") || strings.Contains(key, "tom") { 36 | time.Sleep(time.Duration(3 * time.Second)) 37 | } 38 | // return the number of occurrences of this file. 39 | return strconv.Itoa(len(values)) 40 | } 41 | -------------------------------------------------------------------------------- /src/mrapps/indexer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // an indexing application "plugin" for MapReduce. 5 | // 6 | // go build -buildmode=plugin indexer.go 7 | // 8 | 9 | import "fmt" 10 | import "mit6.824/mr" 11 | 12 | import "strings" 13 | import "unicode" 14 | import "sort" 15 | 16 | // Map The mapping function is called once for each piece of the input. 17 | // In this framework, the key is the name of the file that is being processed, 18 | // and the value is the file's contents. The return value should be a slice of 19 | // key/value pairs, each represented by a mr.KeyValue. 20 | func Map(document string, value string) (res []mr.KeyValue) { 21 | m := make(map[string]bool) 22 | words := strings.FieldsFunc(value, func(x rune) bool { return !unicode.IsLetter(x) }) 23 | for _, w := range words { 24 | m[w] = true 25 | } 26 | for w := range m { 27 | kv := mr.KeyValue{w, document} 28 | res = append(res, kv) 29 | } 30 | return 31 | } 32 | 33 | // Reduce The reduce function is called once for each key generated by Map, with a 34 | // list of that key's string value (merged across all inputs). The return value 35 | // should be a single output value for that key. 36 | func Reduce(key string, values []string) string { 37 | sort.Strings(values) 38 | return fmt.Sprintf("%d %s", len(values), strings.Join(values, ",")) 39 | } 40 | -------------------------------------------------------------------------------- /src/mrapps/jobcount.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // a MapReduce pseudo-application that counts the number of times map/reduce 5 | // tasks are run, to test whether jobs are assigned multiple times even when 6 | // there is no failure. 7 | // 8 | // go build -buildmode=plugin crash.go 9 | // 10 | 11 | import "mit6.824/mr" 12 | import "math/rand" 13 | import "strings" 14 | import "strconv" 15 | import "time" 16 | import "fmt" 17 | import "os" 18 | import "io/ioutil" 19 | 20 | var count int 21 | 22 | func Map(filename string, contents string) []mr.KeyValue { 23 | me := os.Getpid() 24 | f := fmt.Sprintf("mr-worker-jobcount-%d-%d", me, count) 25 | count++ 26 | err := ioutil.WriteFile(f, []byte("x"), 0666) 27 | if err != nil { 28 | panic(err) 29 | } 30 | time.Sleep(time.Duration(2000+rand.Intn(3000)) * time.Millisecond) 31 | return []mr.KeyValue{mr.KeyValue{"a", "x"}} 32 | } 33 | 34 | func Reduce(key string, values []string) string { 35 | files, err := ioutil.ReadDir(".") 36 | if err != nil { 37 | panic(err) 38 | } 39 | invocations := 0 40 | for _, f := range files { 41 | if strings.HasPrefix(f.Name(), "mr-worker-jobcount") { 42 | invocations++ 43 | } 44 | } 45 | return strconv.Itoa(invocations) 46 | } 47 | -------------------------------------------------------------------------------- /src/mrapps/mtiming.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // a MapReduce pseudo-application to test that workers 5 | // execute map tasks in parallel. 6 | // 7 | // go build -buildmode=plugin mtiming.go 8 | // 9 | 10 | import "mit6.824/mr" 11 | import "strings" 12 | import "fmt" 13 | import "os" 14 | import "syscall" 15 | import "time" 16 | import "sort" 17 | import "io/ioutil" 18 | 19 | func nparallel(phase string) int { 20 | // create a file so that other workers will see that 21 | // we're running at the same time as them. 22 | pid := os.Getpid() 23 | myfilename := fmt.Sprintf("mr-worker-%s-%d", phase, pid) 24 | err := ioutil.WriteFile(myfilename, []byte("x"), 0666) 25 | if err != nil { 26 | panic(err) 27 | } 28 | 29 | // are any other workers running? 30 | // find their PIDs by scanning directory for mr-worker-XXX files. 31 | dd, err := os.Open(".") 32 | if err != nil { 33 | panic(err) 34 | } 35 | names, err := dd.Readdirnames(1000000) 36 | if err != nil { 37 | panic(err) 38 | } 39 | ret := 0 40 | for _, name := range names { 41 | var xpid int 42 | pat := fmt.Sprintf("mr-worker-%s-%%d", phase) 43 | n, err := fmt.Sscanf(name, pat, &xpid) 44 | if n == 1 && err == nil { 45 | err := syscall.Kill(xpid, 0) 46 | if err == nil { 47 | // if err == nil, xpid is alive. 48 | ret += 1 49 | } 50 | } 51 | } 52 | dd.Close() 53 | 54 | time.Sleep(1 * time.Second) 55 | 56 | err = os.Remove(myfilename) 57 | if err != nil { 58 | panic(err) 59 | } 60 | 61 | return ret 62 | } 63 | 64 | func Map(filename string, contents string) []mr.KeyValue { 65 | t0 := time.Now() 66 | ts := float64(t0.Unix()) + (float64(t0.Nanosecond()) / 1000000000.0) 67 | pid := os.Getpid() 68 | 69 | n := nparallel("map") 70 | 71 | kva := []mr.KeyValue{} 72 | kva = append(kva, mr.KeyValue{ 73 | fmt.Sprintf("times-%v", pid), 74 | fmt.Sprintf("%.1f", ts)}) 75 | kva = append(kva, mr.KeyValue{ 76 | fmt.Sprintf("parallel-%v", pid), 77 | fmt.Sprintf("%d", n)}) 78 | return kva 79 | } 80 | 81 | func Reduce(key string, values []string) string { 82 | //n := nparallel("reduce") 83 | 84 | // sort values to ensure deterministic output. 85 | vv := make([]string, len(values)) 86 | copy(vv, values) 87 | sort.Strings(vv) 88 | 89 | val := strings.Join(vv, " ") 90 | return val 91 | } 92 | -------------------------------------------------------------------------------- /src/mrapps/nocrash.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // same as crash.go but doesn't actually crash. 5 | // 6 | // go build -buildmode=plugin nocrash.go 7 | // 8 | 9 | import "mit6.824/mr" 10 | import crand "crypto/rand" 11 | import "math/big" 12 | import "strings" 13 | import "os" 14 | import "sort" 15 | import "strconv" 16 | 17 | func maybeCrash() { 18 | max := big.NewInt(1000) 19 | rr, _ := crand.Int(crand.Reader, max) 20 | if false && rr.Int64() < 500 { 21 | // crash! 22 | os.Exit(1) 23 | } 24 | } 25 | 26 | func Map(filename string, contents string) []mr.KeyValue { 27 | maybeCrash() 28 | 29 | kva := []mr.KeyValue{} 30 | kva = append(kva, mr.KeyValue{"a", filename}) 31 | kva = append(kva, mr.KeyValue{"b", strconv.Itoa(len(filename))}) 32 | kva = append(kva, mr.KeyValue{"c", strconv.Itoa(len(contents))}) 33 | kva = append(kva, mr.KeyValue{"d", "xyzzy"}) 34 | return kva 35 | } 36 | 37 | func Reduce(key string, values []string) string { 38 | maybeCrash() 39 | 40 | // sort values to ensure deterministic output. 41 | vv := make([]string, len(values)) 42 | copy(vv, values) 43 | sort.Strings(vv) 44 | 45 | val := strings.Join(vv, " ") 46 | return val 47 | } 48 | -------------------------------------------------------------------------------- /src/mrapps/rtiming.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // a MapReduce pseudo-application to test that workers 5 | // execute reduce tasks in parallel. 6 | // 7 | // go build -buildmode=plugin rtiming.go 8 | // 9 | 10 | import "mit6.824/mr" 11 | import "fmt" 12 | import "os" 13 | import "syscall" 14 | import "time" 15 | import "io/ioutil" 16 | 17 | func nparallel(phase string) int { 18 | // create a file so that other workers will see that 19 | // we're running at the same time as them. 20 | pid := os.Getpid() 21 | myfilename := fmt.Sprintf("mr-worker-%s-%d", phase, pid) 22 | err := ioutil.WriteFile(myfilename, []byte("x"), 0666) 23 | if err != nil { 24 | panic(err) 25 | } 26 | 27 | // are any other workers running? 28 | // find their PIDs by scanning directory for mr-worker-XXX files. 29 | dd, err := os.Open(".") 30 | if err != nil { 31 | panic(err) 32 | } 33 | names, err := dd.Readdirnames(1000000) 34 | if err != nil { 35 | panic(err) 36 | } 37 | ret := 0 38 | for _, name := range names { 39 | var xpid int 40 | pat := fmt.Sprintf("mr-worker-%s-%%d", phase) 41 | n, err := fmt.Sscanf(name, pat, &xpid) 42 | if n == 1 && err == nil { 43 | err := syscall.Kill(xpid, 0) 44 | if err == nil { 45 | // if err == nil, xpid is alive. 46 | ret += 1 47 | } 48 | } 49 | } 50 | dd.Close() 51 | 52 | time.Sleep(1 * time.Second) 53 | 54 | err = os.Remove(myfilename) 55 | if err != nil { 56 | panic(err) 57 | } 58 | 59 | return ret 60 | } 61 | 62 | func Map(filename string, contents string) []mr.KeyValue { 63 | 64 | kva := []mr.KeyValue{} 65 | kva = append(kva, mr.KeyValue{"a", "1"}) 66 | kva = append(kva, mr.KeyValue{"b", "1"}) 67 | kva = append(kva, mr.KeyValue{"c", "1"}) 68 | kva = append(kva, mr.KeyValue{"d", "1"}) 69 | kva = append(kva, mr.KeyValue{"e", "1"}) 70 | kva = append(kva, mr.KeyValue{"f", "1"}) 71 | kva = append(kva, mr.KeyValue{"g", "1"}) 72 | kva = append(kva, mr.KeyValue{"h", "1"}) 73 | kva = append(kva, mr.KeyValue{"i", "1"}) 74 | kva = append(kva, mr.KeyValue{"j", "1"}) 75 | return kva 76 | } 77 | 78 | func Reduce(key string, values []string) string { 79 | n := nparallel("reduce") 80 | 81 | val := fmt.Sprintf("%d", n) 82 | 83 | return val 84 | } 85 | -------------------------------------------------------------------------------- /src/mrapps/wc.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 4 | // a word-count application "plugin" for MapReduce. 5 | // 6 | // go build -buildmode=plugin wc.go 7 | // 8 | 9 | import "mit6.824/mr" 10 | import "unicode" 11 | import "strings" 12 | import "strconv" 13 | 14 | // Map 15 | // The map function is called once for each file of input. The first 16 | // argument is the name of the input file, and the second is the 17 | // file's complete contents. You should ignore the input file name, 18 | // and look only at the contents argument. The return value is a slice 19 | // of key/value pairs. 20 | // 21 | // Map处理原始数据 22 | func Map(filename string, contents string) []mr.KeyValue { 23 | // function to detect word separators. 24 | ff := func(r rune) bool { return !unicode.IsLetter(r) } 25 | 26 | // split contents into an array of words. 27 | words := strings.FieldsFunc(contents, ff) 28 | 29 | kva := []mr.KeyValue{} 30 | for _, w := range words { 31 | kv := mr.KeyValue{w, "1"} 32 | kva = append(kva, kv) 33 | } 34 | return kva 35 | } 36 | 37 | // Reduce 38 | // The reduce function is called once for each key generated by the 39 | // map tasks, with a list of all the values created for that key by 40 | // any map task. 41 | // 42 | // reduce处理整理完的数据 43 | func Reduce(key string, values []string) string { 44 | // return the number of occurrences of this word. 45 | return strconv.Itoa(len(values)) 46 | } 47 | -------------------------------------------------------------------------------- /src/porcupine/bitset.go: -------------------------------------------------------------------------------- 1 | package porcupine 2 | 3 | import "math/bits" 4 | 5 | type bitset []uint64 6 | 7 | // data layout: 8 | // bits 0-63 are in data[0], the next are in data[1], etc. 9 | 10 | func newBitset(bits uint) bitset { 11 | extra := uint(0) 12 | if bits%64 != 0 { 13 | extra = 1 14 | } 15 | chunks := bits/64 + extra 16 | return bitset(make([]uint64, chunks)) 17 | } 18 | 19 | func (b bitset) clone() bitset { 20 | dataCopy := make([]uint64, len(b)) 21 | copy(dataCopy, b) 22 | return bitset(dataCopy) 23 | } 24 | 25 | func bitsetIndex(pos uint) (uint, uint) { 26 | return pos / 64, pos % 64 27 | } 28 | 29 | func (b bitset) set(pos uint) bitset { 30 | major, minor := bitsetIndex(pos) 31 | b[major] |= (1 << minor) 32 | return b 33 | } 34 | 35 | func (b bitset) clear(pos uint) bitset { 36 | major, minor := bitsetIndex(pos) 37 | b[major] &^= (1 << minor) 38 | return b 39 | } 40 | 41 | func (b bitset) get(pos uint) bool { 42 | major, minor := bitsetIndex(pos) 43 | return b[major]&(1<= 0; i-- { 125 | elem := entries[i] 126 | if elem.kind == returnEntry { 127 | entry := &node{value: elem.value, match: nil, id: elem.id} 128 | match[elem.id] = entry 129 | insertBefore(entry, root) 130 | root = entry 131 | } else { 132 | entry := &node{value: elem.value, match: match[elem.id], id: elem.id} 133 | insertBefore(entry, root) 134 | root = entry 135 | } 136 | } 137 | return root 138 | } 139 | 140 | type cacheEntry struct { 141 | linearized bitset 142 | state interface{} 143 | } 144 | 145 | func cacheContains(model Model, cache map[uint64][]cacheEntry, entry cacheEntry) bool { 146 | for _, elem := range cache[entry.linearized.hash()] { 147 | if entry.linearized.equals(elem.linearized) && model.Equal(entry.state, elem.state) { 148 | return true 149 | } 150 | } 151 | return false 152 | } 153 | 154 | type callsEntry struct { 155 | entry *node 156 | state interface{} 157 | } 158 | 159 | func lift(entry *node) { 160 | entry.prev.next = entry.next 161 | entry.next.prev = entry.prev 162 | match := entry.match 163 | match.prev.next = match.next 164 | if match.next != nil { 165 | match.next.prev = match.prev 166 | } 167 | } 168 | 169 | func unlift(entry *node) { 170 | match := entry.match 171 | match.prev.next = match 172 | if match.next != nil { 173 | match.next.prev = match 174 | } 175 | entry.prev.next = entry 176 | entry.next.prev = entry 177 | } 178 | 179 | func checkSingle(model Model, history []entry, computePartial bool, kill *int32) (bool, []*[]int) { 180 | entry := makeLinkedEntries(history) 181 | n := length(entry) / 2 182 | linearized := newBitset(uint(n)) 183 | cache := make(map[uint64][]cacheEntry) // map from hash to cache entry 184 | var calls []callsEntry 185 | // longest linearizable prefix that includes the given entry 186 | longest := make([]*[]int, n) 187 | 188 | state := model.Init() 189 | headEntry := insertBefore(&node{value: nil, match: nil, id: -1}, entry) 190 | for headEntry.next != nil { 191 | if atomic.LoadInt32(kill) != 0 { 192 | return false, longest 193 | } 194 | if entry.match != nil { 195 | matching := entry.match // the return entry 196 | ok, newState := model.Step(state, entry.value, matching.value) 197 | if ok { 198 | newLinearized := linearized.clone().set(uint(entry.id)) 199 | newCacheEntry := cacheEntry{newLinearized, newState} 200 | if !cacheContains(model, cache, newCacheEntry) { 201 | hash := newLinearized.hash() 202 | cache[hash] = append(cache[hash], newCacheEntry) 203 | calls = append(calls, callsEntry{entry, state}) 204 | state = newState 205 | linearized.set(uint(entry.id)) 206 | lift(entry) 207 | entry = headEntry.next 208 | } else { 209 | entry = entry.next 210 | } 211 | } else { 212 | entry = entry.next 213 | } 214 | } else { 215 | if len(calls) == 0 { 216 | return false, longest 217 | } 218 | // longest 219 | if computePartial { 220 | callsLen := len(calls) 221 | var seq []int = nil 222 | for _, v := range calls { 223 | if longest[v.entry.id] == nil || callsLen > len(*longest[v.entry.id]) { 224 | // create seq lazily 225 | if seq == nil { 226 | seq = make([]int, len(calls)) 227 | for i, v := range calls { 228 | seq[i] = v.entry.id 229 | } 230 | } 231 | longest[v.entry.id] = &seq 232 | } 233 | } 234 | } 235 | callsTop := calls[len(calls)-1] 236 | entry = callsTop.entry 237 | state = callsTop.state 238 | linearized.clear(uint(entry.id)) 239 | calls = calls[:len(calls)-1] 240 | unlift(entry) 241 | entry = entry.next 242 | } 243 | } 244 | // longest linearization is the complete linearization, which is calls 245 | seq := make([]int, len(calls)) 246 | for i, v := range calls { 247 | seq[i] = v.entry.id 248 | } 249 | for i := 0; i < n; i++ { 250 | longest[i] = &seq 251 | } 252 | return true, longest 253 | } 254 | 255 | func fillDefault(model Model) Model { 256 | if model.Partition == nil { 257 | model.Partition = NoPartition 258 | } 259 | if model.PartitionEvent == nil { 260 | model.PartitionEvent = NoPartitionEvent 261 | } 262 | if model.Equal == nil { 263 | model.Equal = ShallowEqual 264 | } 265 | if model.DescribeOperation == nil { 266 | model.DescribeOperation = DefaultDescribeOperation 267 | } 268 | if model.DescribeState == nil { 269 | model.DescribeState = DefaultDescribeState 270 | } 271 | return model 272 | } 273 | 274 | func checkParallel(model Model, history [][]entry, computeInfo bool, timeout time.Duration) (CheckResult, linearizationInfo) { 275 | ok := true 276 | timedOut := false 277 | results := make(chan bool, len(history)) 278 | longest := make([][]*[]int, len(history)) 279 | kill := int32(0) 280 | for i, subhistory := range history { 281 | go func(i int, subhistory []entry) { 282 | ok, l := checkSingle(model, subhistory, computeInfo, &kill) 283 | longest[i] = l 284 | results <- ok 285 | }(i, subhistory) 286 | } 287 | var timeoutChan <-chan time.Time 288 | if timeout > 0 { 289 | timeoutChan = time.After(timeout) 290 | } 291 | count := 0 292 | loop: 293 | for { 294 | select { 295 | case result := <-results: 296 | count++ 297 | ok = ok && result 298 | if !ok && !computeInfo { 299 | atomic.StoreInt32(&kill, 1) 300 | break loop 301 | } 302 | if count >= len(history) { 303 | break loop 304 | } 305 | case <-timeoutChan: 306 | timedOut = true 307 | atomic.StoreInt32(&kill, 1) 308 | break loop // if we time out, we might get a false positive 309 | } 310 | } 311 | var info linearizationInfo 312 | if computeInfo { 313 | // make sure we've waited for all goroutines to finish, 314 | // otherwise we might race on access to longest[] 315 | for count < len(history) { 316 | <-results 317 | count++ 318 | } 319 | // return longest linearizable prefixes that include each history element 320 | partialLinearizations := make([][][]int, len(history)) 321 | for i := 0; i < len(history); i++ { 322 | var partials [][]int 323 | // turn longest into a set of unique linearizations 324 | set := make(map[*[]int]struct{}) 325 | for _, v := range longest[i] { 326 | if v != nil { 327 | set[v] = struct{}{} 328 | } 329 | } 330 | for k := range set { 331 | arr := make([]int, len(*k)) 332 | for i, v := range *k { 333 | arr[i] = v 334 | } 335 | partials = append(partials, arr) 336 | } 337 | partialLinearizations[i] = partials 338 | } 339 | info.history = history 340 | info.partialLinearizations = partialLinearizations 341 | } 342 | var result CheckResult 343 | if !ok { 344 | result = Illegal 345 | } else { 346 | if timedOut { 347 | result = Unknown 348 | } else { 349 | result = Ok 350 | } 351 | } 352 | return result, info 353 | } 354 | 355 | func checkEvents(model Model, history []Event, verbose bool, timeout time.Duration) (CheckResult, linearizationInfo) { 356 | model = fillDefault(model) 357 | partitions := model.PartitionEvent(history) 358 | l := make([][]entry, len(partitions)) 359 | for i, subhistory := range partitions { 360 | l[i] = convertEntries(renumber(subhistory)) 361 | } 362 | return checkParallel(model, l, verbose, timeout) 363 | } 364 | 365 | func checkOperations(model Model, history []Operation, verbose bool, timeout time.Duration) (CheckResult, linearizationInfo) { 366 | model = fillDefault(model) 367 | partitions := model.Partition(history) 368 | l := make([][]entry, len(partitions)) 369 | for i, subhistory := range partitions { 370 | l[i] = makeEntries(subhistory) 371 | } 372 | return checkParallel(model, l, verbose, timeout) 373 | } 374 | -------------------------------------------------------------------------------- /src/porcupine/model.go: -------------------------------------------------------------------------------- 1 | package porcupine 2 | 3 | import "fmt" 4 | 5 | type Operation struct { 6 | ClientId int // optional, unless you want a visualization; zero-indexed 7 | Input interface{} 8 | Call int64 // invocation time 9 | Output interface{} 10 | Return int64 // response time 11 | } 12 | 13 | type EventKind bool 14 | 15 | const ( 16 | CallEvent EventKind = false 17 | ReturnEvent EventKind = true 18 | ) 19 | 20 | type Event struct { 21 | ClientId int // optional, unless you want a visualization; zero-indexed 22 | Kind EventKind 23 | Value interface{} 24 | Id int 25 | } 26 | 27 | type Model struct { 28 | // Partition functions, such that a history is linearizable if and only 29 | // if each partition is linearizable. If you don't want to implement 30 | // this, you can always use the `NoPartition` functions implemented 31 | // below. 32 | Partition func(history []Operation) [][]Operation 33 | PartitionEvent func(history []Event) [][]Event 34 | // Initial state of the system. 35 | Init func() interface{} 36 | // Step function for the system. Returns whether or not the system 37 | // could take this step with the given inputs and outputs and also 38 | // returns the new state. This should not mutate the existing state. 39 | Step func(state interface{}, input interface{}, output interface{}) (bool, interface{}) 40 | // Equality on states. If you are using a simple data type for states, 41 | // you can use the `ShallowEqual` function implemented below. 42 | Equal func(state1, state2 interface{}) bool 43 | // For visualization, describe an operation as a string. 44 | // For example, "Get('x') -> 'y'". 45 | DescribeOperation func(input interface{}, output interface{}) string 46 | // For visualization purposes, describe a state as a string. 47 | // For example, "{'x' -> 'y', 'z' -> 'w'}" 48 | DescribeState func(state interface{}) string 49 | } 50 | 51 | func NoPartition(history []Operation) [][]Operation { 52 | return [][]Operation{history} 53 | } 54 | 55 | func NoPartitionEvent(history []Event) [][]Event { 56 | return [][]Event{history} 57 | } 58 | 59 | func ShallowEqual(state1, state2 interface{}) bool { 60 | return state1 == state2 61 | } 62 | 63 | func DefaultDescribeOperation(input interface{}, output interface{}) string { 64 | return fmt.Sprintf("%v -> %v", input, output) 65 | } 66 | 67 | func DefaultDescribeState(state interface{}) string { 68 | return fmt.Sprintf("%v", state) 69 | } 70 | 71 | type CheckResult string 72 | 73 | const ( 74 | Unknown CheckResult = "Unknown" // timed out 75 | Ok = "Ok" 76 | Illegal = "Illegal" 77 | ) 78 | -------------------------------------------------------------------------------- /src/porcupine/porcupine.go: -------------------------------------------------------------------------------- 1 | package porcupine 2 | 3 | import "time" 4 | 5 | func CheckOperations(model Model, history []Operation) bool { 6 | res, _ := checkOperations(model, history, false, 0) 7 | return res == Ok 8 | } 9 | 10 | // timeout = 0 means no timeout 11 | // if this operation times out, then a false positive is possible 12 | func CheckOperationsTimeout(model Model, history []Operation, timeout time.Duration) CheckResult { 13 | res, _ := checkOperations(model, history, false, timeout) 14 | return res 15 | } 16 | 17 | // timeout = 0 means no timeout 18 | // if this operation times out, then a false positive is possible 19 | func CheckOperationsVerbose(model Model, history []Operation, timeout time.Duration) (CheckResult, linearizationInfo) { 20 | return checkOperations(model, history, true, timeout) 21 | } 22 | 23 | func CheckEvents(model Model, history []Event) bool { 24 | res, _ := checkEvents(model, history, false, 0) 25 | return res == Ok 26 | } 27 | 28 | // timeout = 0 means no timeout 29 | // if this operation times out, then a false positive is possible 30 | func CheckEventsTimeout(model Model, history []Event, timeout time.Duration) CheckResult { 31 | res, _ := checkEvents(model, history, false, timeout) 32 | return res 33 | } 34 | 35 | // timeout = 0 means no timeout 36 | // if this operation times out, then a false positive is possible 37 | func CheckEventsVerbose(model Model, history []Event, timeout time.Duration) (CheckResult, linearizationInfo) { 38 | return checkEvents(model, history, true, timeout) 39 | } 40 | -------------------------------------------------------------------------------- /src/raft/config.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | // 4 | // support for Raft tester. 5 | // 6 | // we will use the original config.go to test your code for grading. 7 | // so, while you can modify this code to help you debug, please 8 | // test with the original before submitting. 9 | // 10 | 11 | import "mit6.824/labgob" 12 | import "mit6.824/labrpc" 13 | import "bytes" 14 | import "log" 15 | import "sync" 16 | import "sync/atomic" 17 | import "testing" 18 | import "runtime" 19 | import "math/rand" 20 | import crand "crypto/rand" 21 | import "math/big" 22 | import "encoding/base64" 23 | import "time" 24 | import "fmt" 25 | 26 | func randstring(n int) string { 27 | b := make([]byte, 2*n) 28 | crand.Read(b) 29 | s := base64.URLEncoding.EncodeToString(b) 30 | return s[0:n] 31 | } 32 | 33 | func makeSeed() int64 { 34 | max := big.NewInt(int64(1) << 62) 35 | bigx, _ := crand.Int(crand.Reader, max) 36 | x := bigx.Int64() 37 | return x 38 | } 39 | 40 | type config struct { 41 | mu sync.Mutex 42 | t *testing.T 43 | finished int32 44 | net *labrpc.Network 45 | n int 46 | rafts []*Raft 47 | applyErr []string // from apply channel readers 48 | connected []bool // whether each server is on the net 49 | saved []*Persister 50 | endnames [][]string // the port file names each sends to 51 | logs []map[int]interface{} // copy of each server's committed entries 52 | lastApplied []int 53 | start time.Time // time at which make_config() was called 54 | // begin()/end() statistics 55 | t0 time.Time // time at which test_test.go called cfg.begin() 56 | rpcs0 int // rpcTotal() at start of test 57 | cmds0 int // number of agreements 58 | bytes0 int64 59 | maxIndex int 60 | maxIndex0 int 61 | } 62 | 63 | var ncpu_once sync.Once 64 | 65 | func make_config(t *testing.T, n int, unreliable bool, snapshot bool) *config { 66 | ncpu_once.Do(func() { 67 | if runtime.NumCPU() < 2 { 68 | fmt.Printf("warning: only one CPU, which may conceal locking bugs\n") 69 | } 70 | rand.Seed(makeSeed()) 71 | }) 72 | runtime.GOMAXPROCS(4) 73 | cfg := &config{} 74 | cfg.t = t 75 | cfg.net = labrpc.MakeNetwork() 76 | cfg.n = n 77 | cfg.applyErr = make([]string, cfg.n) 78 | cfg.rafts = make([]*Raft, cfg.n) 79 | cfg.connected = make([]bool, cfg.n) 80 | cfg.saved = make([]*Persister, cfg.n) 81 | cfg.endnames = make([][]string, cfg.n) 82 | cfg.logs = make([]map[int]interface{}, cfg.n) 83 | cfg.lastApplied = make([]int, cfg.n) 84 | cfg.start = time.Now() 85 | 86 | cfg.setunreliable(unreliable) 87 | 88 | cfg.net.LongDelays(true) 89 | 90 | applier := cfg.applier 91 | if snapshot { 92 | applier = cfg.applierSnap 93 | } 94 | // create a full set of Rafts. 95 | for i := 0; i < cfg.n; i++ { 96 | cfg.logs[i] = map[int]interface{}{} 97 | cfg.start1(i, applier) 98 | } 99 | 100 | // connect everyone 101 | for i := 0; i < cfg.n; i++ { 102 | cfg.connect(i) 103 | } 104 | 105 | return cfg 106 | } 107 | 108 | // shut down a Raft server but save its persistent state. 109 | func (cfg *config) crash1(i int) { 110 | cfg.disconnect(i) 111 | cfg.net.DeleteServer(i) // disable client connections to the server. 112 | 113 | cfg.mu.Lock() 114 | defer cfg.mu.Unlock() 115 | 116 | // a fresh persister, in case old instance 117 | // continues to update the Persister. 118 | // but copy old persister's content so that we always 119 | // pass Make() the last persisted state. 120 | if cfg.saved[i] != nil { 121 | cfg.saved[i] = cfg.saved[i].Copy() 122 | } 123 | 124 | rf := cfg.rafts[i] 125 | if rf != nil { 126 | cfg.mu.Unlock() 127 | rf.Kill() 128 | cfg.mu.Lock() 129 | cfg.rafts[i] = nil 130 | } 131 | 132 | if cfg.saved[i] != nil { 133 | raftlog := cfg.saved[i].ReadRaftState() 134 | snapshot := cfg.saved[i].ReadSnapshot() 135 | cfg.saved[i] = &Persister{} 136 | cfg.saved[i].SaveStateAndSnapshot(raftlog, snapshot) 137 | } 138 | } 139 | 140 | func (cfg *config) checkLogs(i int, m ApplyMsg) (string, bool) { 141 | err_msg := "" 142 | v := m.Command 143 | for j := 0; j < len(cfg.logs); j++ { 144 | if old, oldok := cfg.logs[j][m.CommandIndex]; oldok && old != v { 145 | log.Printf("%v: log %v; server %v\n", i, cfg.logs[i], cfg.logs[j]) 146 | // some server has already committed a different value for this entry! 147 | err_msg = fmt.Sprintf("commit index=%v server=%v %v != server=%v %v", 148 | m.CommandIndex, i, m.Command, j, old) 149 | } 150 | } 151 | _, prevok := cfg.logs[i][m.CommandIndex-1] 152 | cfg.logs[i][m.CommandIndex] = v 153 | if m.CommandIndex > cfg.maxIndex { 154 | cfg.maxIndex = m.CommandIndex 155 | } 156 | return err_msg, prevok 157 | } 158 | 159 | // applier reads message from apply ch and checks that they match the log 160 | // contents 161 | func (cfg *config) applier(i int, applyCh chan ApplyMsg) { 162 | for m := range applyCh { 163 | if m.CommandValid == false { 164 | // ignore other types of ApplyMsg 165 | } else { 166 | cfg.mu.Lock() 167 | err_msg, prevok := cfg.checkLogs(i, m) 168 | cfg.mu.Unlock() 169 | if m.CommandIndex > 1 && prevok == false { 170 | err_msg = fmt.Sprintf("server %v apply out of order %v", i, m.CommandIndex) 171 | } 172 | if err_msg != "" { 173 | log.Fatalf("apply error: %v", err_msg) 174 | cfg.applyErr[i] = err_msg 175 | // keep reading after error so that Raft doesn't block 176 | // holding locks... 177 | } 178 | } 179 | } 180 | } 181 | 182 | // returns "" or error string 183 | func (cfg *config) ingestSnap(i int, snapshot []byte, index int) string { 184 | if snapshot == nil { 185 | log.Fatalf("nil snapshot") 186 | return "nil snapshot" 187 | } 188 | r := bytes.NewBuffer(snapshot) 189 | d := labgob.NewDecoder(r) 190 | var lastIncludedIndex int 191 | var xlog []interface{} 192 | if d.Decode(&lastIncludedIndex) != nil || 193 | d.Decode(&xlog) != nil { 194 | log.Fatalf("snapshot decode error") 195 | return "snapshot Decode() error" 196 | } 197 | if index != -1 && index != lastIncludedIndex { 198 | err := fmt.Sprintf("server %v snapshot doesn't match m.SnapshotIndex", i) 199 | return err 200 | } 201 | cfg.logs[i] = map[int]interface{}{} 202 | for j := 0; j < len(xlog); j++ { 203 | cfg.logs[i][j] = xlog[j] 204 | } 205 | cfg.lastApplied[i] = lastIncludedIndex 206 | return "" 207 | } 208 | 209 | const SnapShotInterval = 10 210 | 211 | // periodically snapshot raft state 212 | func (cfg *config) applierSnap(i int, applyCh chan ApplyMsg) { 213 | cfg.mu.Lock() 214 | rf := cfg.rafts[i] 215 | cfg.mu.Unlock() 216 | if rf == nil { 217 | return // ??? 218 | } 219 | 220 | for m := range applyCh { 221 | err_msg := "" 222 | if m.SnapshotValid { 223 | if rf.CondInstallSnapshot(m.SnapshotTerm, m.SnapshotIndex, m.Snapshot) { 224 | cfg.mu.Lock() 225 | err_msg = cfg.ingestSnap(i, m.Snapshot, m.SnapshotIndex) 226 | cfg.mu.Unlock() 227 | } 228 | } else if m.CommandValid { 229 | if m.CommandIndex != cfg.lastApplied[i]+1 { 230 | err_msg = fmt.Sprintf("server %v apply out of order, expected index %v, got %v", i, cfg.lastApplied[i]+1, m.CommandIndex) 231 | } 232 | 233 | if err_msg == "" { 234 | cfg.mu.Lock() 235 | var prevok bool 236 | err_msg, prevok = cfg.checkLogs(i, m) 237 | cfg.mu.Unlock() 238 | if m.CommandIndex > 1 && prevok == false { 239 | err_msg = fmt.Sprintf("server %v apply out of order %v", i, m.CommandIndex) 240 | } 241 | } 242 | 243 | cfg.mu.Lock() 244 | cfg.lastApplied[i] = m.CommandIndex 245 | cfg.mu.Unlock() 246 | 247 | if (m.CommandIndex+1)%SnapShotInterval == 0 { 248 | w := new(bytes.Buffer) 249 | e := labgob.NewEncoder(w) 250 | e.Encode(m.CommandIndex) 251 | var xlog []interface{} 252 | for j := 0; j <= m.CommandIndex; j++ { 253 | xlog = append(xlog, cfg.logs[i][j]) 254 | } 255 | e.Encode(xlog) 256 | rf.Snapshot(m.CommandIndex, w.Bytes()) 257 | } 258 | } else { 259 | // Ignore other types of ApplyMsg. 260 | } 261 | if err_msg != "" { 262 | log.Fatalf("apply error: %v", err_msg) 263 | cfg.applyErr[i] = err_msg 264 | // keep reading after error so that Raft doesn't block 265 | // holding locks... 266 | } 267 | } 268 | } 269 | 270 | // 271 | // start or re-start a Raft. 272 | // if one already exists, "kill" it first. 273 | // allocate new outgoing port file names, and a new 274 | // state persister, to isolate previous instance of 275 | // this server. since we cannot really kill it. 276 | // 277 | func (cfg *config) start1(i int, applier func(int, chan ApplyMsg)) { 278 | cfg.crash1(i) 279 | 280 | // a fresh set of outgoing ClientEnd names. 281 | // so that old crashed instance's ClientEnds can't send. 282 | cfg.endnames[i] = make([]string, cfg.n) 283 | for j := 0; j < cfg.n; j++ { 284 | cfg.endnames[i][j] = randstring(20) 285 | } 286 | 287 | // a fresh set of ClientEnds. 288 | ends := make([]*labrpc.ClientEnd, cfg.n) 289 | for j := 0; j < cfg.n; j++ { 290 | ends[j] = cfg.net.MakeEnd(cfg.endnames[i][j]) 291 | cfg.net.Connect(cfg.endnames[i][j], j) 292 | } 293 | 294 | cfg.mu.Lock() 295 | 296 | cfg.lastApplied[i] = 0 297 | 298 | // a fresh persister, so old instance doesn't overwrite 299 | // new instance's persisted state. 300 | // but copy old persister's content so that we always 301 | // pass Make() the last persisted state. 302 | if cfg.saved[i] != nil { 303 | cfg.saved[i] = cfg.saved[i].Copy() 304 | 305 | snapshot := cfg.saved[i].ReadSnapshot() 306 | if snapshot != nil && len(snapshot) > 0 { 307 | // mimic KV server and process snapshot now. 308 | // ideally Raft should send it up on applyCh... 309 | err := cfg.ingestSnap(i, snapshot, -1) 310 | if err != "" { 311 | cfg.t.Fatal(err) 312 | } 313 | } 314 | } else { 315 | cfg.saved[i] = MakePersister() 316 | } 317 | 318 | cfg.mu.Unlock() 319 | 320 | applyCh := make(chan ApplyMsg) 321 | 322 | rf := Make(ends, i, cfg.saved[i], applyCh) 323 | 324 | cfg.mu.Lock() 325 | cfg.rafts[i] = rf 326 | cfg.mu.Unlock() 327 | 328 | go applier(i, applyCh) 329 | 330 | svc := labrpc.MakeService(rf) 331 | srv := labrpc.MakeServer() 332 | srv.AddService(svc) 333 | cfg.net.AddServer(i, srv) 334 | } 335 | 336 | func (cfg *config) checkTimeout() { 337 | // enforce a two minute real-time limit on each test 338 | if !cfg.t.Failed() && time.Since(cfg.start) > 120*time.Second { 339 | cfg.t.Fatal("test took longer than 120 seconds") 340 | } 341 | } 342 | 343 | func (cfg *config) checkFinished() bool { 344 | z := atomic.LoadInt32(&cfg.finished) 345 | return z != 0 346 | } 347 | 348 | func (cfg *config) cleanup() { 349 | atomic.StoreInt32(&cfg.finished, 1) 350 | for i := 0; i < len(cfg.rafts); i++ { 351 | if cfg.rafts[i] != nil { 352 | cfg.rafts[i].Kill() 353 | } 354 | } 355 | cfg.net.Cleanup() 356 | cfg.checkTimeout() 357 | } 358 | 359 | // attach server i to the net. 360 | func (cfg *config) connect(i int) { 361 | DPrintf("Test: connect(%d)\n", i) 362 | 363 | cfg.connected[i] = true 364 | 365 | // outgoing ClientEnds 366 | for j := 0; j < cfg.n; j++ { 367 | if cfg.connected[j] { 368 | endname := cfg.endnames[i][j] 369 | cfg.net.Enable(endname, true) 370 | } 371 | } 372 | 373 | // incoming ClientEnds 374 | for j := 0; j < cfg.n; j++ { 375 | if cfg.connected[j] { 376 | endname := cfg.endnames[j][i] 377 | cfg.net.Enable(endname, true) 378 | } 379 | } 380 | } 381 | 382 | // detach server i from the net. 383 | func (cfg *config) disconnect(i int) { 384 | DPrintf("Test: disconnect(%d)\n", i) 385 | 386 | cfg.connected[i] = false 387 | 388 | // outgoing ClientEnds 389 | for j := 0; j < cfg.n; j++ { 390 | if cfg.endnames[i] != nil { 391 | endname := cfg.endnames[i][j] 392 | cfg.net.Enable(endname, false) 393 | } 394 | } 395 | 396 | // incoming ClientEnds 397 | for j := 0; j < cfg.n; j++ { 398 | if cfg.endnames[j] != nil { 399 | endname := cfg.endnames[j][i] 400 | cfg.net.Enable(endname, false) 401 | } 402 | } 403 | } 404 | 405 | func (cfg *config) rpcCount(server int) int { 406 | return cfg.net.GetCount(server) 407 | } 408 | 409 | func (cfg *config) rpcTotal() int { 410 | return cfg.net.GetTotalCount() 411 | } 412 | 413 | func (cfg *config) setunreliable(unrel bool) { 414 | cfg.net.Reliable(!unrel) 415 | } 416 | 417 | func (cfg *config) bytesTotal() int64 { 418 | return cfg.net.GetTotalBytes() 419 | } 420 | 421 | func (cfg *config) setlongreordering(longrel bool) { 422 | cfg.net.LongReordering(longrel) 423 | } 424 | 425 | // 426 | // check that one of the connected servers thinks 427 | // it is the leader, and that no other connected 428 | // server thinks otherwise. 429 | // 430 | // try a few times in case re-elections are needed. 431 | // 432 | func (cfg *config) checkOneLeader() int { 433 | for iters := 0; iters < 10; iters++ { 434 | ms := 450 + (rand.Int63() % 100) 435 | time.Sleep(time.Duration(ms) * time.Millisecond) 436 | 437 | leaders := make(map[int][]int) 438 | for i := 0; i < cfg.n; i++ { 439 | if cfg.connected[i] { 440 | if term, leader := cfg.rafts[i].GetState(); leader { 441 | leaders[term] = append(leaders[term], i) 442 | } 443 | } 444 | } 445 | 446 | lastTermWithLeader := -1 447 | for term, leaders := range leaders { 448 | if len(leaders) > 1 { 449 | cfg.t.Fatalf("term %d has %d (>1) leaders", term, len(leaders)) 450 | } 451 | if term > lastTermWithLeader { 452 | lastTermWithLeader = term 453 | } 454 | } 455 | 456 | if len(leaders) != 0 { 457 | return leaders[lastTermWithLeader][0] 458 | } 459 | } 460 | cfg.t.Fatalf("expected one leader, got none") 461 | return -1 462 | } 463 | 464 | // check that everyone agrees on the term. 465 | func (cfg *config) checkTerms() int { 466 | term := -1 467 | for i := 0; i < cfg.n; i++ { 468 | if cfg.connected[i] { 469 | xterm, _ := cfg.rafts[i].GetState() 470 | if term == -1 { 471 | term = xterm 472 | } else if term != xterm { 473 | cfg.t.Fatalf("servers disagree on term") 474 | } 475 | } 476 | } 477 | return term 478 | } 479 | 480 | // 481 | // check that none of the connected servers 482 | // thinks it is the leader. 483 | // 484 | func (cfg *config) checkNoLeader() { 485 | for i := 0; i < cfg.n; i++ { 486 | if cfg.connected[i] { 487 | _, is_leader := cfg.rafts[i].GetState() 488 | if is_leader { 489 | cfg.t.Fatalf("expected no leader among connected servers, but %v claims to be leader", i) 490 | } 491 | } 492 | } 493 | } 494 | 495 | // how many servers think a log entry is committed? 496 | func (cfg *config) nCommitted(index int) (int, interface{}) { 497 | count := 0 498 | var cmd interface{} = nil 499 | for i := 0; i < len(cfg.rafts); i++ { 500 | if cfg.applyErr[i] != "" { 501 | cfg.t.Fatal(cfg.applyErr[i]) 502 | } 503 | 504 | cfg.mu.Lock() 505 | cmd1, ok := cfg.logs[i][index] 506 | cfg.mu.Unlock() 507 | 508 | if ok { 509 | if count > 0 && cmd != cmd1 { 510 | cfg.t.Fatalf("committed values do not match: index %v, %v, %v", 511 | index, cmd, cmd1) 512 | } 513 | count += 1 514 | cmd = cmd1 515 | } 516 | } 517 | return count, cmd 518 | } 519 | 520 | // wait for at least n servers to commit. 521 | // but don't wait forever. 522 | func (cfg *config) wait(index int, n int, startTerm int) interface{} { 523 | to := 10 * time.Millisecond 524 | for iters := 0; iters < 30; iters++ { 525 | nd, _ := cfg.nCommitted(index) 526 | if nd >= n { 527 | break 528 | } 529 | time.Sleep(to) 530 | if to < time.Second { 531 | to *= 2 532 | } 533 | if startTerm > -1 { 534 | for _, r := range cfg.rafts { 535 | if t, _ := r.GetState(); t > startTerm { 536 | // someone has moved on 537 | // can no longer guarantee that we'll "win" 538 | return -1 539 | } 540 | } 541 | } 542 | } 543 | nd, cmd := cfg.nCommitted(index) 544 | if nd < n { 545 | cfg.t.Fatalf("only %d decided for index %d; wanted %d", 546 | nd, index, n) 547 | } 548 | return cmd 549 | } 550 | 551 | // do a complete agreement. 552 | // it might choose the wrong leader initially, 553 | // and have to re-submit after giving up. 554 | // entirely gives up after about 10 seconds. 555 | // indirectly checks that the servers agree on the 556 | // same value, since nCommitted() checks this, 557 | // as do the threads that read from applyCh. 558 | // returns index. 559 | // if retry==true, may submit the command multiple 560 | // times, in case a leader fails just after Start(). 561 | // if retry==false, calls Start() only once, in order 562 | // to simplify the early Lab 2B tests. 563 | func (cfg *config) one(cmd interface{}, expectedServers int, retry bool) int { 564 | t0 := time.Now() 565 | starts := 0 566 | for time.Since(t0).Seconds() < 10 && cfg.checkFinished() == false { 567 | // try all the servers, maybe one is the leader. 568 | index := -1 569 | for si := 0; si < cfg.n; si++ { 570 | starts = (starts + 1) % cfg.n 571 | var rf *Raft 572 | cfg.mu.Lock() 573 | if cfg.connected[starts] { 574 | rf = cfg.rafts[starts] 575 | } 576 | cfg.mu.Unlock() 577 | if rf != nil { 578 | index1, _, ok := rf.Start(cmd) 579 | if ok { 580 | index = index1 581 | break 582 | } 583 | } 584 | } 585 | 586 | if index != -1 { 587 | // somebody claimed to be the leader and to have 588 | // submitted our command; wait a while for agreement. 589 | t1 := time.Now() 590 | for time.Since(t1).Seconds() < 2 { 591 | nd, cmd1 := cfg.nCommitted(index) 592 | if nd > 0 && nd >= expectedServers { 593 | // committed 594 | if cmd1 == cmd { 595 | // and it was the command we submitted. 596 | return index 597 | } 598 | } 599 | time.Sleep(20 * time.Millisecond) 600 | } 601 | if retry == false { 602 | cfg.t.Fatalf("one(%v) failed to reach agreement", cmd) 603 | } 604 | } else { 605 | time.Sleep(50 * time.Millisecond) 606 | } 607 | } 608 | if cfg.checkFinished() == false { 609 | cfg.t.Fatalf("one(%v) failed to reach agreement", cmd) 610 | } 611 | return -1 612 | } 613 | 614 | // start a Test. 615 | // print the Test message. 616 | // e.g. cfg.begin("Test (2B): RPC counts aren't too high") 617 | func (cfg *config) begin(description string) { 618 | fmt.Printf("%s ...\n", description) 619 | cfg.t0 = time.Now() 620 | cfg.rpcs0 = cfg.rpcTotal() 621 | cfg.bytes0 = cfg.bytesTotal() 622 | cfg.cmds0 = 0 623 | cfg.maxIndex0 = cfg.maxIndex 624 | } 625 | 626 | // end a Test -- the fact that we got here means there 627 | // was no failure. 628 | // print the Passed message, 629 | // and some performance numbers. 630 | func (cfg *config) end() { 631 | cfg.checkTimeout() 632 | if cfg.t.Failed() == false { 633 | cfg.mu.Lock() 634 | t := time.Since(cfg.t0).Seconds() // real time 635 | npeers := cfg.n // number of Raft peers 636 | nrpc := cfg.rpcTotal() - cfg.rpcs0 // number of RPC sends 637 | nbytes := cfg.bytesTotal() - cfg.bytes0 // number of bytes 638 | ncmds := cfg.maxIndex - cfg.maxIndex0 // number of Raft agreements reported 639 | cfg.mu.Unlock() 640 | 641 | fmt.Printf(" ... Passed --") 642 | fmt.Printf(" %4.1f %d %4d %7d %4d\n", t, npeers, nrpc, nbytes, ncmds) 643 | } 644 | } 645 | 646 | // Maximum log size across all servers 647 | func (cfg *config) LogSize() int { 648 | logsize := 0 649 | for i := 0; i < cfg.n; i++ { 650 | n := cfg.saved[i].RaftStateSize() 651 | if n > logsize { 652 | logsize = n 653 | } 654 | } 655 | return logsize 656 | } 657 | -------------------------------------------------------------------------------- /src/raft/dstest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import itertools 4 | import math 5 | import signal 6 | import subprocess 7 | import tempfile 8 | import shutil 9 | import time 10 | import os 11 | import sys 12 | import datetime 13 | from collections import defaultdict 14 | from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED 15 | from dataclasses import dataclass 16 | from pathlib import Path 17 | from typing import List, Optional, Dict, DefaultDict, Tuple 18 | 19 | import typer 20 | import rich 21 | from rich import print 22 | from rich.table import Table 23 | from rich.progress import ( 24 | Progress, 25 | TimeElapsedColumn, 26 | TimeRemainingColumn, 27 | TextColumn, 28 | BarColumn, 29 | SpinnerColumn, 30 | ) 31 | from rich.live import Live 32 | from rich.panel import Panel 33 | from rich.traceback import install 34 | 35 | install(show_locals=True) 36 | 37 | 38 | @dataclass 39 | class StatsMeter: 40 | """ 41 | Auxiliary classs to keep track of online stats including: count, mean, variance 42 | Uses Welford's algorithm to compute sample mean and sample variance incrementally. 43 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm 44 | """ 45 | 46 | n: int = 0 47 | mean: float = 0.0 48 | S: float = 0.0 49 | 50 | def add(self, datum): 51 | self.n += 1 52 | delta = datum - self.mean 53 | # Mk = Mk-1+ (xk – Mk-1)/k 54 | self.mean += delta / self.n 55 | # Sk = Sk-1 + (xk – Mk-1)*(xk – Mk). 56 | self.S += delta * (datum - self.mean) 57 | 58 | @property 59 | def variance(self): 60 | return self.S / self.n 61 | 62 | @property 63 | def std(self): 64 | return math.sqrt(self.variance) 65 | 66 | 67 | def print_results(results: Dict[str, Dict[str, StatsMeter]], timing=False): 68 | table = Table(show_header=True, header_style="bold") 69 | table.add_column("Test") 70 | table.add_column("Failed", justify="right") 71 | table.add_column("Total", justify="right") 72 | if not timing: 73 | table.add_column("Time", justify="right") 74 | else: 75 | table.add_column("Real Time", justify="right") 76 | table.add_column("User Time", justify="right") 77 | table.add_column("System Time", justify="right") 78 | 79 | for test, stats in results.items(): 80 | if stats["completed"].n == 0: 81 | continue 82 | color = "green" if stats["failed"].n == 0 else "red" 83 | row = [ 84 | f"[{color}]{test}[/{color}]", 85 | str(stats["failed"].n), 86 | str(stats["completed"].n), 87 | ] 88 | if not timing: 89 | row.append(f"{stats['time'].mean:.2f} ± {stats['time'].std:.2f}") 90 | else: 91 | row.extend( 92 | [ 93 | f"{stats['real_time'].mean:.2f} ± {stats['real_time'].std:.2f}", 94 | f"{stats['user_time'].mean:.2f} ± {stats['user_time'].std:.2f}", 95 | f"{stats['system_time'].mean:.2f} ± {stats['system_time'].std:.2f}", 96 | ] 97 | ) 98 | table.add_row(*row) 99 | 100 | print(table) 101 | 102 | 103 | def run_test(test: str, race: bool, timing: bool): 104 | test_cmd = ["go", "test", f"-run={test}"] 105 | if race: 106 | test_cmd.append("-race") 107 | if timing: 108 | test_cmd = ["time"] + cmd 109 | f, path = tempfile.mkstemp() 110 | start = time.time() 111 | proc = subprocess.run(test_cmd, stdout=f, stderr=f) 112 | runtime = time.time() - start 113 | os.close(f) 114 | return test, path, proc.returncode, runtime 115 | 116 | 117 | def last_line(file: str) -> str: 118 | with open(file, "rb") as f: 119 | f.seek(-2, os.SEEK_END) 120 | while f.read(1) != b"\n": 121 | f.seek(-2, os.SEEK_CUR) 122 | line = f.readline().decode() 123 | return line 124 | 125 | 126 | # fmt: off 127 | def run_tests( 128 | tests: List[str], 129 | sequential: bool = typer.Option(False, '--sequential', '-s', help='Run all test of each group in order'), 130 | workers: int = typer.Option(1, '--workers', '-p', help='Number of parallel tasks'), 131 | iterations: int = typer.Option(10, '--iter', '-n', help='Number of iterations to run'), 132 | output: Optional[Path] = typer.Option(None, '--output', '-o', help='Output path to use'), 133 | verbose: int = typer.Option(0, '--verbose', '-v', help='Verbosity level', count=True), 134 | archive: bool = typer.Option(False, '--archive', '-a', help='Save all logs intead of only failed ones'), 135 | race: bool = typer.Option(False, '--race/--no-race', '-r/-R', help='Run with race checker'), 136 | loop: bool = typer.Option(False, '--loop', '-l', help='Run continuously'), 137 | growth: int = typer.Option(10, '--growth', '-g', help='Growth ratio of iterations when using --loop'), 138 | timing: bool = typer.Option(False, '--timing', '-t', help='Report timing, only works on macOS'), 139 | # fmt: on 140 | ): 141 | 142 | if output is None: 143 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 144 | output = Path(timestamp) 145 | 146 | if race: 147 | print("[yellow]Running with the race detector\n[/yellow]") 148 | 149 | if verbose > 0: 150 | print(f"[yellow] Verbosity level set to {verbose}[/yellow]") 151 | os.environ['VERBOSE'] = str(verbose) 152 | 153 | while True: 154 | 155 | total = iterations * len(tests) 156 | completed = 0 157 | 158 | results = {test: defaultdict(StatsMeter) for test in tests} 159 | 160 | if sequential: 161 | test_instances = itertools.chain.from_iterable(itertools.repeat(test, iterations) for test in tests) 162 | else: 163 | test_instances = itertools.chain.from_iterable(itertools.repeat(tests, iterations)) 164 | test_instances = iter(test_instances) 165 | 166 | total_progress = Progress( 167 | "[progress.description]{task.description}", 168 | BarColumn(), 169 | TimeRemainingColumn(), 170 | "[progress.percentage]{task.percentage:>3.0f}%", 171 | TimeElapsedColumn(), 172 | ) 173 | total_task = total_progress.add_task("[yellow]Tests[/yellow]", total=total) 174 | 175 | task_progress = Progress( 176 | "[progress.description]{task.description}", 177 | SpinnerColumn(), 178 | BarColumn(), 179 | "{task.completed}/{task.total}", 180 | ) 181 | tasks = {test: task_progress.add_task(test, total=iterations) for test in tests} 182 | 183 | progress_table = Table.grid() 184 | progress_table.add_row(total_progress) 185 | progress_table.add_row(Panel.fit(task_progress)) 186 | 187 | with Live(progress_table, transient=True) as live: 188 | 189 | def handler(_, frame): 190 | live.stop() 191 | print('\n') 192 | print_results(results) 193 | sys.exit(1) 194 | 195 | signal.signal(signal.SIGINT, handler) 196 | 197 | with ThreadPoolExecutor(max_workers=workers) as executor: 198 | 199 | futures = [] 200 | while completed < total: 201 | n = len(futures) 202 | if n < workers: 203 | for test in itertools.islice(test_instances, workers-n): 204 | futures.append(executor.submit(run_test, test, race, timing)) 205 | 206 | done, not_done = wait(futures, return_when=FIRST_COMPLETED) 207 | 208 | for future in done: 209 | test, path, rc, runtime = future.result() 210 | 211 | results[test]['completed'].add(1) 212 | results[test]['time'].add(runtime) 213 | task_progress.update(tasks[test], advance=1) 214 | dest = (output / f"{test}_{completed}.log").as_posix() 215 | if rc != 0: 216 | print(f"Failed test {test} - {dest}") 217 | task_progress.update(tasks[test], description=f"[red]{test}[/red]") 218 | results[test]['failed'].add(1) 219 | else: 220 | if results[test]['completed'].n == iterations and results[test]['failed'].n == 0: 221 | task_progress.update(tasks[test], description=f"[green]{test}[/green]") 222 | 223 | if rc != 0 or archive: 224 | output.mkdir(exist_ok=True, parents=True) 225 | shutil.copy(path, dest) 226 | 227 | if timing: 228 | line = last_line(path) 229 | real, _, user, _, system, _ = line.replace(' '*8, '').split(' ') 230 | results[test]['real_time'].add(float(real)) 231 | results[test]['user_time'].add(float(user)) 232 | results[test]['system_time'].add(float(system)) 233 | 234 | os.remove(path) 235 | 236 | completed += 1 237 | total_progress.update(total_task, advance=1) 238 | 239 | futures = list(not_done) 240 | 241 | print_results(results, timing) 242 | 243 | if loop: 244 | iterations *= growth 245 | print(f"[yellow]Increasing iterations to {iterations}[/yellow]") 246 | else: 247 | break 248 | 249 | 250 | if __name__ == "__main__": 251 | typer.run(run_tests) 252 | -------------------------------------------------------------------------------- /src/raft/persister.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | // 4 | // support for Raft and kvraft to save persistent 5 | // Raft state (log &c) and k/v server snapshots. 6 | // 7 | // we will use the original persister.go to test your code for grading. 8 | // so, while you can modify this code to help you debug, please 9 | // test with the original before submitting. 10 | // 11 | 12 | import "sync" 13 | 14 | type Persister struct { 15 | mu sync.Mutex 16 | raftstate []byte 17 | snapshot []byte 18 | } 19 | 20 | func MakePersister() *Persister { 21 | return &Persister{} 22 | } 23 | 24 | func clone(orig []byte) []byte { 25 | x := make([]byte, len(orig)) 26 | copy(x, orig) 27 | return x 28 | } 29 | 30 | func (ps *Persister) Copy() *Persister { 31 | ps.mu.Lock() 32 | defer ps.mu.Unlock() 33 | np := MakePersister() 34 | np.raftstate = ps.raftstate 35 | np.snapshot = ps.snapshot 36 | return np 37 | } 38 | 39 | func (ps *Persister) SaveRaftState(state []byte) { 40 | ps.mu.Lock() 41 | defer ps.mu.Unlock() 42 | ps.raftstate = clone(state) 43 | } 44 | 45 | func (ps *Persister) ReadRaftState() []byte { 46 | ps.mu.Lock() 47 | defer ps.mu.Unlock() 48 | return clone(ps.raftstate) 49 | } 50 | 51 | func (ps *Persister) RaftStateSize() int { 52 | ps.mu.Lock() 53 | defer ps.mu.Unlock() 54 | return len(ps.raftstate) 55 | } 56 | 57 | // Save both Raft state and K/V snapshot as a single atomic action, 58 | // to help avoid them getting out of sync. 59 | func (ps *Persister) SaveStateAndSnapshot(state []byte, snapshot []byte) { 60 | ps.mu.Lock() 61 | defer ps.mu.Unlock() 62 | ps.raftstate = clone(state) 63 | ps.snapshot = clone(snapshot) 64 | } 65 | 66 | func (ps *Persister) ReadSnapshot() []byte { 67 | ps.mu.Lock() 68 | defer ps.mu.Unlock() 69 | return clone(ps.snapshot) 70 | } 71 | 72 | func (ps *Persister) SnapshotSize() int { 73 | ps.mu.Lock() 74 | defer ps.mu.Unlock() 75 | return len(ps.snapshot) 76 | } 77 | -------------------------------------------------------------------------------- /src/raft/util.go: -------------------------------------------------------------------------------- 1 | package raft 2 | 3 | import ( 4 | "log" 5 | "math/rand" 6 | "time" 7 | ) 8 | 9 | func DPrintf(format string, a ...interface{}) { 10 | if Debug { 11 | log.Printf(format, a...) 12 | } 13 | } 14 | 15 | func min(a int, b int) int { 16 | if a <= b { 17 | return a 18 | } 19 | return b 20 | } 21 | 22 | func max(a int, b int) int { 23 | if a >= b { 24 | return a 25 | } 26 | return b 27 | } 28 | 29 | // 获得一个随机选举超时时间 30 | func getRandElectionTimeOut() time.Duration { 31 | return time.Duration((rand.Int()%(electionTimeOutMax-electionTimeOutMin))+electionTimeOutMin) * time.Millisecond 32 | } 33 | -------------------------------------------------------------------------------- /src/shardctrler/client.go: -------------------------------------------------------------------------------- 1 | package shardctrler 2 | 3 | // 4 | // Shardctrler clerk. 5 | // 6 | 7 | import "mit6.824/labrpc" 8 | import "time" 9 | import "crypto/rand" 10 | import "math/big" 11 | 12 | type Clerk struct { 13 | servers []*labrpc.ClientEnd 14 | // Your data here. 15 | } 16 | 17 | func nrand() int64 { 18 | max := big.NewInt(int64(1) << 62) 19 | bigx, _ := rand.Int(rand.Reader, max) 20 | x := bigx.Int64() 21 | return x 22 | } 23 | 24 | func MakeClerk(servers []*labrpc.ClientEnd) *Clerk { 25 | ck := new(Clerk) 26 | ck.servers = servers 27 | // Your code here. 28 | return ck 29 | } 30 | 31 | func (ck *Clerk) Query(num int) Config { 32 | args := &QueryArgs{} 33 | // Your code here. 34 | args.Num = num 35 | for { 36 | // try each known server. 37 | for _, srv := range ck.servers { 38 | var reply QueryReply 39 | ok := srv.Call("ShardCtrler.Query", args, &reply) 40 | if ok && reply.WrongLeader == false { 41 | return reply.Config 42 | } 43 | } 44 | time.Sleep(100 * time.Millisecond) 45 | } 46 | } 47 | 48 | func (ck *Clerk) Join(servers map[int][]string) { 49 | args := &JoinArgs{} 50 | // Your code here. 51 | args.Servers = servers 52 | 53 | for { 54 | // try each known server. 55 | for _, srv := range ck.servers { 56 | var reply JoinReply 57 | ok := srv.Call("ShardCtrler.Join", args, &reply) 58 | if ok && reply.WrongLeader == false { 59 | return 60 | } 61 | } 62 | time.Sleep(100 * time.Millisecond) 63 | } 64 | } 65 | 66 | func (ck *Clerk) Leave(gids []int) { 67 | args := &LeaveArgs{} 68 | // Your code here. 69 | args.GIDs = gids 70 | 71 | for { 72 | // try each known server. 73 | for _, srv := range ck.servers { 74 | var reply LeaveReply 75 | ok := srv.Call("ShardCtrler.Leave", args, &reply) 76 | if ok && reply.WrongLeader == false { 77 | return 78 | } 79 | } 80 | time.Sleep(100 * time.Millisecond) 81 | } 82 | } 83 | 84 | func (ck *Clerk) Move(shard int, gid int) { 85 | args := &MoveArgs{} 86 | // Your code here. 87 | args.Shard = shard 88 | args.GID = gid 89 | 90 | for { 91 | // try each known server. 92 | for _, srv := range ck.servers { 93 | var reply MoveReply 94 | ok := srv.Call("ShardCtrler.Move", args, &reply) 95 | if ok && reply.WrongLeader == false { 96 | return 97 | } 98 | } 99 | time.Sleep(100 * time.Millisecond) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/shardctrler/common.go: -------------------------------------------------------------------------------- 1 | package shardctrler 2 | 3 | // 4 | // Shard controler: assigns shards to replication groups. 5 | // 6 | // RPC interface: 7 | // Join(servers) -- add a set of groups (gid -> server-list mapping). 8 | // Leave(gids) -- delete a set of groups. 9 | // Move(shard, gid) -- hand off one shard from current owner to gid. 10 | // Query(num) -> fetch Config # num, or latest config if num==-1. 11 | // 12 | // A Config (configuration) describes a set of replica groups, and the 13 | // replica group responsible for each shard. Configs are numbered. Config 14 | // #0 is the initial configuration, with no groups and all shards 15 | // assigned to group 0 (the invalid group). 16 | // 17 | // You will need to add fields to the RPC argument structs. 18 | // 19 | 20 | // The number of shards. 21 | const NShards = 10 22 | 23 | // A configuration -- an assignment of shards to groups. 24 | // Please don't change this. 25 | type Config struct { 26 | Num int // config number 27 | Shards [NShards]int // shard -> gid 28 | Groups map[int][]string // gid -> servers[] 29 | } 30 | 31 | const ( 32 | OK = "OK" 33 | ) 34 | 35 | type Err string 36 | 37 | type JoinArgs struct { 38 | Servers map[int][]string // new GID -> servers mappings 39 | } 40 | 41 | type JoinReply struct { 42 | WrongLeader bool 43 | Err Err 44 | } 45 | 46 | type LeaveArgs struct { 47 | GIDs []int 48 | } 49 | 50 | type LeaveReply struct { 51 | WrongLeader bool 52 | Err Err 53 | } 54 | 55 | type MoveArgs struct { 56 | Shard int 57 | GID int 58 | } 59 | 60 | type MoveReply struct { 61 | WrongLeader bool 62 | Err Err 63 | } 64 | 65 | type QueryArgs struct { 66 | Num int // desired config number 67 | } 68 | 69 | type QueryReply struct { 70 | WrongLeader bool 71 | Err Err 72 | Config Config 73 | } 74 | -------------------------------------------------------------------------------- /src/shardctrler/config.go: -------------------------------------------------------------------------------- 1 | package shardctrler 2 | 3 | import "mit6.824/labrpc" 4 | import "mit6.824/raft" 5 | import "testing" 6 | import "os" 7 | 8 | // import "log" 9 | import crand "crypto/rand" 10 | import "math/rand" 11 | import "encoding/base64" 12 | import "sync" 13 | import "runtime" 14 | import "time" 15 | 16 | func randstring(n int) string { 17 | b := make([]byte, 2*n) 18 | crand.Read(b) 19 | s := base64.URLEncoding.EncodeToString(b) 20 | return s[0:n] 21 | } 22 | 23 | // Randomize server handles 24 | func random_handles(kvh []*labrpc.ClientEnd) []*labrpc.ClientEnd { 25 | sa := make([]*labrpc.ClientEnd, len(kvh)) 26 | copy(sa, kvh) 27 | for i := range sa { 28 | j := rand.Intn(i + 1) 29 | sa[i], sa[j] = sa[j], sa[i] 30 | } 31 | return sa 32 | } 33 | 34 | type config struct { 35 | mu sync.Mutex 36 | t *testing.T 37 | net *labrpc.Network 38 | n int 39 | servers []*ShardCtrler 40 | saved []*raft.Persister 41 | endnames [][]string // names of each server's sending ClientEnds 42 | clerks map[*Clerk][]string 43 | nextClientId int 44 | start time.Time // time at which make_config() was called 45 | } 46 | 47 | func (cfg *config) checkTimeout() { 48 | // enforce a two minute real-time limit on each test 49 | if !cfg.t.Failed() && time.Since(cfg.start) > 120*time.Second { 50 | cfg.t.Fatal("test took longer than 120 seconds") 51 | } 52 | } 53 | 54 | func (cfg *config) cleanup() { 55 | cfg.mu.Lock() 56 | defer cfg.mu.Unlock() 57 | for i := 0; i < len(cfg.servers); i++ { 58 | if cfg.servers[i] != nil { 59 | cfg.servers[i].Kill() 60 | } 61 | } 62 | cfg.net.Cleanup() 63 | cfg.checkTimeout() 64 | } 65 | 66 | // Maximum log size across all servers 67 | func (cfg *config) LogSize() int { 68 | logsize := 0 69 | for i := 0; i < cfg.n; i++ { 70 | n := cfg.saved[i].RaftStateSize() 71 | if n > logsize { 72 | logsize = n 73 | } 74 | } 75 | return logsize 76 | } 77 | 78 | // attach server i to servers listed in to 79 | // caller must hold cfg.mu 80 | func (cfg *config) connectUnlocked(i int, to []int) { 81 | // log.Printf("connect peer %d to %v\n", i, to) 82 | 83 | // outgoing socket files 84 | for j := 0; j < len(to); j++ { 85 | endname := cfg.endnames[i][to[j]] 86 | cfg.net.Enable(endname, true) 87 | } 88 | 89 | // incoming socket files 90 | for j := 0; j < len(to); j++ { 91 | endname := cfg.endnames[to[j]][i] 92 | cfg.net.Enable(endname, true) 93 | } 94 | } 95 | 96 | func (cfg *config) connect(i int, to []int) { 97 | cfg.mu.Lock() 98 | defer cfg.mu.Unlock() 99 | cfg.connectUnlocked(i, to) 100 | } 101 | 102 | // detach server i from the servers listed in from 103 | // caller must hold cfg.mu 104 | func (cfg *config) disconnectUnlocked(i int, from []int) { 105 | // log.Printf("disconnect peer %d from %v\n", i, from) 106 | 107 | // outgoing socket files 108 | for j := 0; j < len(from); j++ { 109 | if cfg.endnames[i] != nil { 110 | endname := cfg.endnames[i][from[j]] 111 | cfg.net.Enable(endname, false) 112 | } 113 | } 114 | 115 | // incoming socket files 116 | for j := 0; j < len(from); j++ { 117 | if cfg.endnames[j] != nil { 118 | endname := cfg.endnames[from[j]][i] 119 | cfg.net.Enable(endname, false) 120 | } 121 | } 122 | } 123 | 124 | func (cfg *config) disconnect(i int, from []int) { 125 | cfg.mu.Lock() 126 | defer cfg.mu.Unlock() 127 | cfg.disconnectUnlocked(i, from) 128 | } 129 | 130 | func (cfg *config) All() []int { 131 | all := make([]int, cfg.n) 132 | for i := 0; i < cfg.n; i++ { 133 | all[i] = i 134 | } 135 | return all 136 | } 137 | 138 | func (cfg *config) ConnectAll() { 139 | cfg.mu.Lock() 140 | defer cfg.mu.Unlock() 141 | for i := 0; i < cfg.n; i++ { 142 | cfg.connectUnlocked(i, cfg.All()) 143 | } 144 | } 145 | 146 | // Sets up 2 partitions with connectivity between servers in each partition. 147 | func (cfg *config) partition(p1 []int, p2 []int) { 148 | cfg.mu.Lock() 149 | defer cfg.mu.Unlock() 150 | // log.Printf("partition servers into: %v %v\n", p1, p2) 151 | for i := 0; i < len(p1); i++ { 152 | cfg.disconnectUnlocked(p1[i], p2) 153 | cfg.connectUnlocked(p1[i], p1) 154 | } 155 | for i := 0; i < len(p2); i++ { 156 | cfg.disconnectUnlocked(p2[i], p1) 157 | cfg.connectUnlocked(p2[i], p2) 158 | } 159 | } 160 | 161 | // Create a clerk with clerk specific server names. 162 | // Give it connections to all of the servers, but for 163 | // now enable only connections to servers in to[]. 164 | func (cfg *config) makeClient(to []int) *Clerk { 165 | cfg.mu.Lock() 166 | defer cfg.mu.Unlock() 167 | 168 | // a fresh set of ClientEnds. 169 | ends := make([]*labrpc.ClientEnd, cfg.n) 170 | endnames := make([]string, cfg.n) 171 | for j := 0; j < cfg.n; j++ { 172 | endnames[j] = randstring(20) 173 | ends[j] = cfg.net.MakeEnd(endnames[j]) 174 | cfg.net.Connect(endnames[j], j) 175 | } 176 | 177 | ck := MakeClerk(random_handles(ends)) 178 | cfg.clerks[ck] = endnames 179 | cfg.nextClientId++ 180 | cfg.ConnectClientUnlocked(ck, to) 181 | return ck 182 | } 183 | 184 | func (cfg *config) deleteClient(ck *Clerk) { 185 | cfg.mu.Lock() 186 | defer cfg.mu.Unlock() 187 | 188 | v := cfg.clerks[ck] 189 | for i := 0; i < len(v); i++ { 190 | os.Remove(v[i]) 191 | } 192 | delete(cfg.clerks, ck) 193 | } 194 | 195 | // caller should hold cfg.mu 196 | func (cfg *config) ConnectClientUnlocked(ck *Clerk, to []int) { 197 | // log.Printf("ConnectClient %v to %v\n", ck, to) 198 | endnames := cfg.clerks[ck] 199 | for j := 0; j < len(to); j++ { 200 | s := endnames[to[j]] 201 | cfg.net.Enable(s, true) 202 | } 203 | } 204 | 205 | func (cfg *config) ConnectClient(ck *Clerk, to []int) { 206 | cfg.mu.Lock() 207 | defer cfg.mu.Unlock() 208 | cfg.ConnectClientUnlocked(ck, to) 209 | } 210 | 211 | // caller should hold cfg.mu 212 | func (cfg *config) DisconnectClientUnlocked(ck *Clerk, from []int) { 213 | // log.Printf("DisconnectClient %v from %v\n", ck, from) 214 | endnames := cfg.clerks[ck] 215 | for j := 0; j < len(from); j++ { 216 | s := endnames[from[j]] 217 | cfg.net.Enable(s, false) 218 | } 219 | } 220 | 221 | func (cfg *config) DisconnectClient(ck *Clerk, from []int) { 222 | cfg.mu.Lock() 223 | defer cfg.mu.Unlock() 224 | cfg.DisconnectClientUnlocked(ck, from) 225 | } 226 | 227 | // Shutdown a server by isolating it 228 | func (cfg *config) ShutdownServer(i int) { 229 | cfg.mu.Lock() 230 | defer cfg.mu.Unlock() 231 | 232 | cfg.disconnectUnlocked(i, cfg.All()) 233 | 234 | // disable client connections to the server. 235 | // it's important to do this before creating 236 | // the new Persister in saved[i], to avoid 237 | // the possibility of the server returning a 238 | // positive reply to an Append but persisting 239 | // the result in the superseded Persister. 240 | cfg.net.DeleteServer(i) 241 | 242 | // a fresh persister, in case old instance 243 | // continues to update the Persister. 244 | // but copy old persister's content so that we always 245 | // pass Make() the last persisted state. 246 | if cfg.saved[i] != nil { 247 | cfg.saved[i] = cfg.saved[i].Copy() 248 | } 249 | 250 | kv := cfg.servers[i] 251 | if kv != nil { 252 | cfg.mu.Unlock() 253 | kv.Kill() 254 | cfg.mu.Lock() 255 | cfg.servers[i] = nil 256 | } 257 | } 258 | 259 | // If restart servers, first call ShutdownServer 260 | func (cfg *config) StartServer(i int) { 261 | cfg.mu.Lock() 262 | 263 | // a fresh set of outgoing ClientEnd names. 264 | cfg.endnames[i] = make([]string, cfg.n) 265 | for j := 0; j < cfg.n; j++ { 266 | cfg.endnames[i][j] = randstring(20) 267 | } 268 | 269 | // a fresh set of ClientEnds. 270 | ends := make([]*labrpc.ClientEnd, cfg.n) 271 | for j := 0; j < cfg.n; j++ { 272 | ends[j] = cfg.net.MakeEnd(cfg.endnames[i][j]) 273 | cfg.net.Connect(cfg.endnames[i][j], j) 274 | } 275 | 276 | // a fresh persister, so old instance doesn't overwrite 277 | // new instance's persisted state. 278 | // give the fresh persister a copy of the old persister's 279 | // state, so that the spec is that we pass StartKVServer() 280 | // the last persisted state. 281 | if cfg.saved[i] != nil { 282 | cfg.saved[i] = cfg.saved[i].Copy() 283 | } else { 284 | cfg.saved[i] = raft.MakePersister() 285 | } 286 | 287 | cfg.mu.Unlock() 288 | 289 | cfg.servers[i] = StartServer(ends, i, cfg.saved[i]) 290 | 291 | kvsvc := labrpc.MakeService(cfg.servers[i]) 292 | rfsvc := labrpc.MakeService(cfg.servers[i].rf) 293 | srv := labrpc.MakeServer() 294 | srv.AddService(kvsvc) 295 | srv.AddService(rfsvc) 296 | cfg.net.AddServer(i, srv) 297 | } 298 | 299 | func (cfg *config) Leader() (bool, int) { 300 | cfg.mu.Lock() 301 | defer cfg.mu.Unlock() 302 | 303 | for i := 0; i < cfg.n; i++ { 304 | if cfg.servers[i] != nil { 305 | _, is_leader := cfg.servers[i].rf.GetState() 306 | if is_leader { 307 | return true, i 308 | } 309 | } 310 | } 311 | return false, 0 312 | } 313 | 314 | // Partition servers into 2 groups and put current leader in minority 315 | func (cfg *config) make_partition() ([]int, []int) { 316 | _, l := cfg.Leader() 317 | p1 := make([]int, cfg.n/2+1) 318 | p2 := make([]int, cfg.n/2) 319 | j := 0 320 | for i := 0; i < cfg.n; i++ { 321 | if i != l { 322 | if j < len(p1) { 323 | p1[j] = i 324 | } else { 325 | p2[j-len(p1)] = i 326 | } 327 | j++ 328 | } 329 | } 330 | p2[len(p2)-1] = l 331 | return p1, p2 332 | } 333 | 334 | func make_config(t *testing.T, n int, unreliable bool) *config { 335 | runtime.GOMAXPROCS(4) 336 | cfg := &config{} 337 | cfg.t = t 338 | cfg.net = labrpc.MakeNetwork() 339 | cfg.n = n 340 | cfg.servers = make([]*ShardCtrler, cfg.n) 341 | cfg.saved = make([]*raft.Persister, cfg.n) 342 | cfg.endnames = make([][]string, cfg.n) 343 | cfg.clerks = make(map[*Clerk][]string) 344 | cfg.nextClientId = cfg.n + 1000 // client ids start 1000 above the highest serverid 345 | cfg.start = time.Now() 346 | 347 | // create a full set of KV servers. 348 | for i := 0; i < cfg.n; i++ { 349 | cfg.StartServer(i) 350 | } 351 | 352 | cfg.ConnectAll() 353 | 354 | cfg.net.Reliable(!unreliable) 355 | 356 | return cfg 357 | } 358 | -------------------------------------------------------------------------------- /src/shardctrler/server.go: -------------------------------------------------------------------------------- 1 | package shardctrler 2 | 3 | import "mit6.824/raft" 4 | import "mit6.824/labrpc" 5 | import "sync" 6 | import "mit6.824/labgob" 7 | 8 | type ShardCtrler struct { 9 | mu sync.Mutex 10 | me int 11 | rf *raft.Raft 12 | applyCh chan raft.ApplyMsg 13 | 14 | // Your data here. 15 | 16 | configs []Config // indexed by config num 17 | } 18 | 19 | type Op struct { 20 | // Your data here. 21 | } 22 | 23 | func (sc *ShardCtrler) Join(args *JoinArgs, reply *JoinReply) { 24 | // Your code here. 25 | } 26 | 27 | func (sc *ShardCtrler) Leave(args *LeaveArgs, reply *LeaveReply) { 28 | // Your code here. 29 | } 30 | 31 | func (sc *ShardCtrler) Move(args *MoveArgs, reply *MoveReply) { 32 | // Your code here. 33 | } 34 | 35 | func (sc *ShardCtrler) Query(args *QueryArgs, reply *QueryReply) { 36 | // Your code here. 37 | } 38 | 39 | // the tester calls Kill() when a ShardCtrler instance won't 40 | // be needed again. you are not required to do anything 41 | // in Kill(), but it might be convenient to (for example) 42 | // turn off debug output from this instance. 43 | func (sc *ShardCtrler) Kill() { 44 | sc.rf.Kill() 45 | // Your code here, if desired. 46 | } 47 | 48 | // needed by shardkv tester 49 | func (sc *ShardCtrler) Raft() *raft.Raft { 50 | return sc.rf 51 | } 52 | 53 | // servers[] contains the ports of the set of 54 | // servers that will cooperate via Raft to 55 | // form the fault-tolerant shardctrler service. 56 | // me is the index of the current server in servers[]. 57 | func StartServer(servers []*labrpc.ClientEnd, me int, persister *raft.Persister) *ShardCtrler { 58 | sc := new(ShardCtrler) 59 | sc.me = me 60 | 61 | sc.configs = make([]Config, 1) 62 | sc.configs[0].Groups = map[int][]string{} 63 | 64 | labgob.Register(Op{}) 65 | sc.applyCh = make(chan raft.ApplyMsg) 66 | sc.rf = raft.Make(servers, me, persister, sc.applyCh) 67 | 68 | // Your code here. 69 | 70 | return sc 71 | } 72 | -------------------------------------------------------------------------------- /src/shardctrler/test_test.go: -------------------------------------------------------------------------------- 1 | package shardctrler 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | // import "time" 11 | 12 | func check(t *testing.T, groups []int, ck *Clerk) { 13 | c := ck.Query(-1) 14 | if len(c.Groups) != len(groups) { 15 | t.Fatalf("wanted %v groups, got %v", len(groups), len(c.Groups)) 16 | } 17 | 18 | // are the groups as expected? 19 | for _, g := range groups { 20 | _, ok := c.Groups[g] 21 | if ok != true { 22 | t.Fatalf("missing group %v", g) 23 | } 24 | } 25 | 26 | // any un-allocated shards? 27 | if len(groups) > 0 { 28 | for s, g := range c.Shards { 29 | _, ok := c.Groups[g] 30 | if ok == false { 31 | t.Fatalf("shard %v -> invalid group %v", s, g) 32 | } 33 | } 34 | } 35 | 36 | // more or less balanced sharding? 37 | counts := map[int]int{} 38 | for _, g := range c.Shards { 39 | counts[g] += 1 40 | } 41 | min := 257 42 | max := 0 43 | for g, _ := range c.Groups { 44 | if counts[g] > max { 45 | max = counts[g] 46 | } 47 | if counts[g] < min { 48 | min = counts[g] 49 | } 50 | } 51 | if max > min+1 { 52 | t.Fatalf("max %v too much larger than min %v", max, min) 53 | } 54 | } 55 | 56 | func check_same_config(t *testing.T, c1 Config, c2 Config) { 57 | if c1.Num != c2.Num { 58 | t.Fatalf("Num wrong") 59 | } 60 | if c1.Shards != c2.Shards { 61 | t.Fatalf("Shards wrong") 62 | } 63 | if len(c1.Groups) != len(c2.Groups) { 64 | t.Fatalf("number of Groups is wrong") 65 | } 66 | for gid, sa := range c1.Groups { 67 | sa1, ok := c2.Groups[gid] 68 | if ok == false || len(sa1) != len(sa) { 69 | t.Fatalf("len(Groups) wrong") 70 | } 71 | if ok && len(sa1) == len(sa) { 72 | for j := 0; j < len(sa); j++ { 73 | if sa[j] != sa1[j] { 74 | t.Fatalf("Groups wrong") 75 | } 76 | } 77 | } 78 | } 79 | } 80 | 81 | func TestBasic(t *testing.T) { 82 | const nservers = 3 83 | cfg := make_config(t, nservers, false) 84 | defer cfg.cleanup() 85 | 86 | ck := cfg.makeClient(cfg.All()) 87 | 88 | fmt.Printf("Test: Basic leave/join ...\n") 89 | 90 | cfa := make([]Config, 6) 91 | cfa[0] = ck.Query(-1) 92 | 93 | check(t, []int{}, ck) 94 | 95 | var gid1 int = 1 96 | ck.Join(map[int][]string{gid1: []string{"x", "y", "z"}}) 97 | check(t, []int{gid1}, ck) 98 | cfa[1] = ck.Query(-1) 99 | 100 | var gid2 int = 2 101 | ck.Join(map[int][]string{gid2: []string{"a", "b", "c"}}) 102 | check(t, []int{gid1, gid2}, ck) 103 | cfa[2] = ck.Query(-1) 104 | 105 | cfx := ck.Query(-1) 106 | sa1 := cfx.Groups[gid1] 107 | if len(sa1) != 3 || sa1[0] != "x" || sa1[1] != "y" || sa1[2] != "z" { 108 | t.Fatalf("wrong servers for gid %v: %v\n", gid1, sa1) 109 | } 110 | sa2 := cfx.Groups[gid2] 111 | if len(sa2) != 3 || sa2[0] != "a" || sa2[1] != "b" || sa2[2] != "c" { 112 | t.Fatalf("wrong servers for gid %v: %v\n", gid2, sa2) 113 | } 114 | 115 | ck.Leave([]int{gid1}) 116 | check(t, []int{gid2}, ck) 117 | cfa[4] = ck.Query(-1) 118 | 119 | ck.Leave([]int{gid2}) 120 | cfa[5] = ck.Query(-1) 121 | 122 | fmt.Printf(" ... Passed\n") 123 | 124 | fmt.Printf("Test: Historical queries ...\n") 125 | 126 | for s := 0; s < nservers; s++ { 127 | cfg.ShutdownServer(s) 128 | for i := 0; i < len(cfa); i++ { 129 | c := ck.Query(cfa[i].Num) 130 | check_same_config(t, c, cfa[i]) 131 | } 132 | cfg.StartServer(s) 133 | cfg.ConnectAll() 134 | } 135 | 136 | fmt.Printf(" ... Passed\n") 137 | 138 | fmt.Printf("Test: Move ...\n") 139 | { 140 | var gid3 int = 503 141 | ck.Join(map[int][]string{gid3: []string{"3a", "3b", "3c"}}) 142 | var gid4 int = 504 143 | ck.Join(map[int][]string{gid4: []string{"4a", "4b", "4c"}}) 144 | for i := 0; i < NShards; i++ { 145 | cf := ck.Query(-1) 146 | if i < NShards/2 { 147 | ck.Move(i, gid3) 148 | if cf.Shards[i] != gid3 { 149 | cf1 := ck.Query(-1) 150 | if cf1.Num <= cf.Num { 151 | t.Fatalf("Move should increase Config.Num") 152 | } 153 | } 154 | } else { 155 | ck.Move(i, gid4) 156 | if cf.Shards[i] != gid4 { 157 | cf1 := ck.Query(-1) 158 | if cf1.Num <= cf.Num { 159 | t.Fatalf("Move should increase Config.Num") 160 | } 161 | } 162 | } 163 | } 164 | cf2 := ck.Query(-1) 165 | for i := 0; i < NShards; i++ { 166 | if i < NShards/2 { 167 | if cf2.Shards[i] != gid3 { 168 | t.Fatalf("expected shard %v on gid %v actually %v", 169 | i, gid3, cf2.Shards[i]) 170 | } 171 | } else { 172 | if cf2.Shards[i] != gid4 { 173 | t.Fatalf("expected shard %v on gid %v actually %v", 174 | i, gid4, cf2.Shards[i]) 175 | } 176 | } 177 | } 178 | ck.Leave([]int{gid3}) 179 | ck.Leave([]int{gid4}) 180 | } 181 | fmt.Printf(" ... Passed\n") 182 | 183 | fmt.Printf("Test: Concurrent leave/join ...\n") 184 | 185 | const npara = 10 186 | var cka [npara]*Clerk 187 | for i := 0; i < len(cka); i++ { 188 | cka[i] = cfg.makeClient(cfg.All()) 189 | } 190 | gids := make([]int, npara) 191 | ch := make(chan bool) 192 | for xi := 0; xi < npara; xi++ { 193 | gids[xi] = int((xi * 10) + 100) 194 | go func(i int) { 195 | defer func() { ch <- true }() 196 | var gid int = gids[i] 197 | var sid1 = fmt.Sprintf("s%da", gid) 198 | var sid2 = fmt.Sprintf("s%db", gid) 199 | cka[i].Join(map[int][]string{gid + 1000: []string{sid1}}) 200 | cka[i].Join(map[int][]string{gid: []string{sid2}}) 201 | cka[i].Leave([]int{gid + 1000}) 202 | }(xi) 203 | } 204 | for i := 0; i < npara; i++ { 205 | <-ch 206 | } 207 | check(t, gids, ck) 208 | 209 | fmt.Printf(" ... Passed\n") 210 | 211 | fmt.Printf("Test: Minimal transfers after joins ...\n") 212 | 213 | c1 := ck.Query(-1) 214 | for i := 0; i < 5; i++ { 215 | var gid = int(npara + 1 + i) 216 | ck.Join(map[int][]string{gid: []string{ 217 | fmt.Sprintf("%da", gid), 218 | fmt.Sprintf("%db", gid), 219 | fmt.Sprintf("%db", gid)}}) 220 | } 221 | c2 := ck.Query(-1) 222 | for i := int(1); i <= npara; i++ { 223 | for j := 0; j < len(c1.Shards); j++ { 224 | if c2.Shards[j] == i { 225 | if c1.Shards[j] != i { 226 | t.Fatalf("non-minimal transfer after Join()s") 227 | } 228 | } 229 | } 230 | } 231 | 232 | fmt.Printf(" ... Passed\n") 233 | 234 | fmt.Printf("Test: Minimal transfers after leaves ...\n") 235 | 236 | for i := 0; i < 5; i++ { 237 | ck.Leave([]int{int(npara + 1 + i)}) 238 | } 239 | c3 := ck.Query(-1) 240 | for i := int(1); i <= npara; i++ { 241 | for j := 0; j < len(c1.Shards); j++ { 242 | if c2.Shards[j] == i { 243 | if c3.Shards[j] != i { 244 | t.Fatalf("non-minimal transfer after Leave()s") 245 | } 246 | } 247 | } 248 | } 249 | 250 | fmt.Printf(" ... Passed\n") 251 | } 252 | 253 | func TestMulti(t *testing.T) { 254 | const nservers = 3 255 | cfg := make_config(t, nservers, false) 256 | defer cfg.cleanup() 257 | 258 | ck := cfg.makeClient(cfg.All()) 259 | 260 | fmt.Printf("Test: Multi-group join/leave ...\n") 261 | 262 | cfa := make([]Config, 6) 263 | cfa[0] = ck.Query(-1) 264 | 265 | check(t, []int{}, ck) 266 | 267 | var gid1 int = 1 268 | var gid2 int = 2 269 | ck.Join(map[int][]string{ 270 | gid1: []string{"x", "y", "z"}, 271 | gid2: []string{"a", "b", "c"}, 272 | }) 273 | check(t, []int{gid1, gid2}, ck) 274 | cfa[1] = ck.Query(-1) 275 | 276 | var gid3 int = 3 277 | ck.Join(map[int][]string{gid3: []string{"j", "k", "l"}}) 278 | check(t, []int{gid1, gid2, gid3}, ck) 279 | cfa[2] = ck.Query(-1) 280 | 281 | cfx := ck.Query(-1) 282 | sa1 := cfx.Groups[gid1] 283 | if len(sa1) != 3 || sa1[0] != "x" || sa1[1] != "y" || sa1[2] != "z" { 284 | t.Fatalf("wrong servers for gid %v: %v\n", gid1, sa1) 285 | } 286 | sa2 := cfx.Groups[gid2] 287 | if len(sa2) != 3 || sa2[0] != "a" || sa2[1] != "b" || sa2[2] != "c" { 288 | t.Fatalf("wrong servers for gid %v: %v\n", gid2, sa2) 289 | } 290 | sa3 := cfx.Groups[gid3] 291 | if len(sa3) != 3 || sa3[0] != "j" || sa3[1] != "k" || sa3[2] != "l" { 292 | t.Fatalf("wrong servers for gid %v: %v\n", gid3, sa3) 293 | } 294 | 295 | ck.Leave([]int{gid1, gid3}) 296 | check(t, []int{gid2}, ck) 297 | cfa[3] = ck.Query(-1) 298 | 299 | cfx = ck.Query(-1) 300 | sa2 = cfx.Groups[gid2] 301 | if len(sa2) != 3 || sa2[0] != "a" || sa2[1] != "b" || sa2[2] != "c" { 302 | t.Fatalf("wrong servers for gid %v: %v\n", gid2, sa2) 303 | } 304 | 305 | ck.Leave([]int{gid2}) 306 | 307 | fmt.Printf(" ... Passed\n") 308 | 309 | fmt.Printf("Test: Concurrent multi leave/join ...\n") 310 | 311 | const npara = 10 312 | var cka [npara]*Clerk 313 | for i := 0; i < len(cka); i++ { 314 | cka[i] = cfg.makeClient(cfg.All()) 315 | } 316 | gids := make([]int, npara) 317 | var wg sync.WaitGroup 318 | for xi := 0; xi < npara; xi++ { 319 | wg.Add(1) 320 | gids[xi] = int(xi + 1000) 321 | go func(i int) { 322 | defer wg.Done() 323 | var gid int = gids[i] 324 | cka[i].Join(map[int][]string{ 325 | gid: []string{ 326 | fmt.Sprintf("%da", gid), 327 | fmt.Sprintf("%db", gid), 328 | fmt.Sprintf("%dc", gid)}, 329 | gid + 1000: []string{fmt.Sprintf("%da", gid+1000)}, 330 | gid + 2000: []string{fmt.Sprintf("%da", gid+2000)}, 331 | }) 332 | cka[i].Leave([]int{gid + 1000, gid + 2000}) 333 | }(xi) 334 | } 335 | wg.Wait() 336 | check(t, gids, ck) 337 | 338 | fmt.Printf(" ... Passed\n") 339 | 340 | fmt.Printf("Test: Minimal transfers after multijoins ...\n") 341 | 342 | c1 := ck.Query(-1) 343 | m := make(map[int][]string) 344 | for i := 0; i < 5; i++ { 345 | var gid = npara + 1 + i 346 | m[gid] = []string{fmt.Sprintf("%da", gid), fmt.Sprintf("%db", gid)} 347 | } 348 | ck.Join(m) 349 | c2 := ck.Query(-1) 350 | for i := int(1); i <= npara; i++ { 351 | for j := 0; j < len(c1.Shards); j++ { 352 | if c2.Shards[j] == i { 353 | if c1.Shards[j] != i { 354 | t.Fatalf("non-minimal transfer after Join()s") 355 | } 356 | } 357 | } 358 | } 359 | 360 | fmt.Printf(" ... Passed\n") 361 | 362 | fmt.Printf("Test: Minimal transfers after multileaves ...\n") 363 | 364 | var l []int 365 | for i := 0; i < 5; i++ { 366 | l = append(l, npara+1+i) 367 | } 368 | ck.Leave(l) 369 | c3 := ck.Query(-1) 370 | for i := int(1); i <= npara; i++ { 371 | for j := 0; j < len(c1.Shards); j++ { 372 | if c2.Shards[j] == i { 373 | if c3.Shards[j] != i { 374 | t.Fatalf("non-minimal transfer after Leave()s") 375 | } 376 | } 377 | } 378 | } 379 | 380 | fmt.Printf(" ... Passed\n") 381 | 382 | fmt.Printf("Test: Check Same config on servers ...\n") 383 | 384 | isLeader, leader := cfg.Leader() 385 | if !isLeader { 386 | t.Fatalf("Leader not found") 387 | } 388 | c := ck.Query(-1) // Config leader claims 389 | 390 | cfg.ShutdownServer(leader) 391 | 392 | attempts := 0 393 | for isLeader, leader = cfg.Leader(); isLeader; time.Sleep(1 * time.Second) { 394 | if attempts++; attempts >= 3 { 395 | t.Fatalf("Leader not found") 396 | } 397 | } 398 | 399 | c1 = ck.Query(-1) 400 | check_same_config(t, c, c1) 401 | 402 | fmt.Printf(" ... Passed\n") 403 | } 404 | -------------------------------------------------------------------------------- /src/shardkv/client.go: -------------------------------------------------------------------------------- 1 | package shardkv 2 | 3 | // 4 | // client code to talk to a sharded key/value service. 5 | // 6 | // the client first talks to the shardctrler to find out 7 | // the assignment of shards (keys) to groups, and then 8 | // talks to the group that holds the key's shard. 9 | // 10 | 11 | import "mit6.824/labrpc" 12 | import "crypto/rand" 13 | import "math/big" 14 | import "mit6.824/shardctrler" 15 | import "time" 16 | 17 | // which shard is a key in? 18 | // please use this function, 19 | // and please do not change it. 20 | func key2shard(key string) int { 21 | shard := 0 22 | if len(key) > 0 { 23 | shard = int(key[0]) 24 | } 25 | shard %= shardctrler.NShards 26 | return shard 27 | } 28 | 29 | func nrand() int64 { 30 | max := big.NewInt(int64(1) << 62) 31 | bigx, _ := rand.Int(rand.Reader, max) 32 | x := bigx.Int64() 33 | return x 34 | } 35 | 36 | type Clerk struct { 37 | sm *shardctrler.Clerk 38 | config shardctrler.Config 39 | make_end func(string) *labrpc.ClientEnd 40 | // You will have to modify this struct. 41 | } 42 | 43 | // the tester calls MakeClerk. 44 | // 45 | // ctrlers[] is needed to call shardctrler.MakeClerk(). 46 | // 47 | // make_end(servername) turns a server name from a 48 | // Config.Groups[gid][i] into a labrpc.ClientEnd on which you can 49 | // send RPCs. 50 | func MakeClerk(ctrlers []*labrpc.ClientEnd, make_end func(string) *labrpc.ClientEnd) *Clerk { 51 | ck := new(Clerk) 52 | ck.sm = shardctrler.MakeClerk(ctrlers) 53 | ck.make_end = make_end 54 | // You'll have to add code here. 55 | return ck 56 | } 57 | 58 | // fetch the current value for a key. 59 | // returns "" if the key does not exist. 60 | // keeps trying forever in the face of all other errors. 61 | // You will have to modify this function. 62 | func (ck *Clerk) Get(key string) string { 63 | args := GetArgs{} 64 | args.Key = key 65 | 66 | for { 67 | shard := key2shard(key) 68 | gid := ck.config.Shards[shard] 69 | if servers, ok := ck.config.Groups[gid]; ok { 70 | // try each server for the shard. 71 | for si := 0; si < len(servers); si++ { 72 | srv := ck.make_end(servers[si]) 73 | var reply GetReply 74 | ok := srv.Call("ShardKV.Get", &args, &reply) 75 | if ok && (reply.Err == OK || reply.Err == ErrNoKey) { 76 | return reply.Value 77 | } 78 | if ok && (reply.Err == ErrWrongGroup) { 79 | break 80 | } 81 | // ... not ok, or ErrWrongLeader 82 | } 83 | } 84 | time.Sleep(100 * time.Millisecond) 85 | // ask controler for the latest configuration. 86 | ck.config = ck.sm.Query(-1) 87 | } 88 | 89 | return "" 90 | } 91 | 92 | // shared by Put and Append. 93 | // You will have to modify this function. 94 | func (ck *Clerk) PutAppend(key string, value string, op string) { 95 | args := PutAppendArgs{} 96 | args.Key = key 97 | args.Value = value 98 | args.Op = op 99 | 100 | for { 101 | shard := key2shard(key) 102 | gid := ck.config.Shards[shard] 103 | if servers, ok := ck.config.Groups[gid]; ok { 104 | for si := 0; si < len(servers); si++ { 105 | srv := ck.make_end(servers[si]) 106 | var reply PutAppendReply 107 | ok := srv.Call("ShardKV.PutAppend", &args, &reply) 108 | if ok && reply.Err == OK { 109 | return 110 | } 111 | if ok && reply.Err == ErrWrongGroup { 112 | break 113 | } 114 | // ... not ok, or ErrWrongLeader 115 | } 116 | } 117 | time.Sleep(100 * time.Millisecond) 118 | // ask controler for the latest configuration. 119 | ck.config = ck.sm.Query(-1) 120 | } 121 | } 122 | 123 | func (ck *Clerk) Put(key string, value string) { 124 | ck.PutAppend(key, value, "Put") 125 | } 126 | func (ck *Clerk) Append(key string, value string) { 127 | ck.PutAppend(key, value, "Append") 128 | } 129 | -------------------------------------------------------------------------------- /src/shardkv/common.go: -------------------------------------------------------------------------------- 1 | package shardkv 2 | 3 | // 4 | // Sharded key/value server. 5 | // Lots of replica groups, each running Raft. 6 | // Shardctrler decides which group serves each shard. 7 | // Shardctrler may change shard assignment from time to time. 8 | // 9 | // You will have to modify these definitions. 10 | // 11 | 12 | const ( 13 | OK = "OK" 14 | ErrNoKey = "ErrNoKey" 15 | ErrWrongGroup = "ErrWrongGroup" 16 | ErrWrongLeader = "ErrWrongLeader" 17 | ) 18 | 19 | type Err string 20 | 21 | // Put or Append 22 | type PutAppendArgs struct { 23 | // You'll have to add definitions here. 24 | Key string 25 | Value string 26 | Op string // "Put" or "Append" 27 | // You'll have to add definitions here. 28 | // Field names must start with capital letters, 29 | // otherwise RPC will break. 30 | } 31 | 32 | type PutAppendReply struct { 33 | Err Err 34 | } 35 | 36 | type GetArgs struct { 37 | Key string 38 | // You'll have to add definitions here. 39 | } 40 | 41 | type GetReply struct { 42 | Err Err 43 | Value string 44 | } 45 | -------------------------------------------------------------------------------- /src/shardkv/config.go: -------------------------------------------------------------------------------- 1 | package shardkv 2 | 3 | import "mit6.824/shardctrler" 4 | import "mit6.824/labrpc" 5 | import "testing" 6 | import "os" 7 | 8 | // import "log" 9 | import crand "crypto/rand" 10 | import "math/big" 11 | import "math/rand" 12 | import "encoding/base64" 13 | import "sync" 14 | import "runtime" 15 | import "mit6.824/raft" 16 | import "strconv" 17 | import "fmt" 18 | import "time" 19 | 20 | func randstring(n int) string { 21 | b := make([]byte, 2*n) 22 | crand.Read(b) 23 | s := base64.URLEncoding.EncodeToString(b) 24 | return s[0:n] 25 | } 26 | 27 | func makeSeed() int64 { 28 | max := big.NewInt(int64(1) << 62) 29 | bigx, _ := crand.Int(crand.Reader, max) 30 | x := bigx.Int64() 31 | return x 32 | } 33 | 34 | // Randomize server handles 35 | func random_handles(kvh []*labrpc.ClientEnd) []*labrpc.ClientEnd { 36 | sa := make([]*labrpc.ClientEnd, len(kvh)) 37 | copy(sa, kvh) 38 | for i := range sa { 39 | j := rand.Intn(i + 1) 40 | sa[i], sa[j] = sa[j], sa[i] 41 | } 42 | return sa 43 | } 44 | 45 | type group struct { 46 | gid int 47 | servers []*ShardKV 48 | saved []*raft.Persister 49 | endnames [][]string 50 | mendnames [][]string 51 | } 52 | 53 | type config struct { 54 | mu sync.Mutex 55 | t *testing.T 56 | net *labrpc.Network 57 | start time.Time // time at which make_config() was called 58 | 59 | nctrlers int 60 | ctrlerservers []*shardctrler.ShardCtrler 61 | mck *shardctrler.Clerk 62 | 63 | ngroups int 64 | n int // servers per k/v group 65 | groups []*group 66 | 67 | clerks map[*Clerk][]string 68 | nextClientId int 69 | maxraftstate int 70 | } 71 | 72 | func (cfg *config) checkTimeout() { 73 | // enforce a two minute real-time limit on each test 74 | if !cfg.t.Failed() && time.Since(cfg.start) > 120*time.Second { 75 | cfg.t.Fatal("test took longer than 120 seconds") 76 | } 77 | } 78 | 79 | func (cfg *config) cleanup() { 80 | for gi := 0; gi < cfg.ngroups; gi++ { 81 | cfg.ShutdownGroup(gi) 82 | } 83 | for i := 0; i < cfg.nctrlers; i++ { 84 | cfg.ctrlerservers[i].Kill() 85 | } 86 | cfg.net.Cleanup() 87 | cfg.checkTimeout() 88 | } 89 | 90 | // check that no server's log is too big. 91 | func (cfg *config) checklogs() { 92 | for gi := 0; gi < cfg.ngroups; gi++ { 93 | for i := 0; i < cfg.n; i++ { 94 | raft := cfg.groups[gi].saved[i].RaftStateSize() 95 | snap := len(cfg.groups[gi].saved[i].ReadSnapshot()) 96 | if cfg.maxraftstate >= 0 && raft > 8*cfg.maxraftstate { 97 | cfg.t.Fatalf("persister.RaftStateSize() %v, but maxraftstate %v", 98 | raft, cfg.maxraftstate) 99 | } 100 | if cfg.maxraftstate < 0 && snap > 0 { 101 | cfg.t.Fatalf("maxraftstate is -1, but snapshot is non-empty!") 102 | } 103 | } 104 | } 105 | } 106 | 107 | // controler server name for labrpc. 108 | func (cfg *config) ctrlername(i int) string { 109 | return "ctrler" + strconv.Itoa(i) 110 | } 111 | 112 | // shard server name for labrpc. 113 | // i'th server of group gid. 114 | func (cfg *config) servername(gid int, i int) string { 115 | return "server-" + strconv.Itoa(gid) + "-" + strconv.Itoa(i) 116 | } 117 | 118 | func (cfg *config) makeClient() *Clerk { 119 | cfg.mu.Lock() 120 | defer cfg.mu.Unlock() 121 | 122 | // ClientEnds to talk to controler service. 123 | ends := make([]*labrpc.ClientEnd, cfg.nctrlers) 124 | endnames := make([]string, cfg.n) 125 | for j := 0; j < cfg.nctrlers; j++ { 126 | endnames[j] = randstring(20) 127 | ends[j] = cfg.net.MakeEnd(endnames[j]) 128 | cfg.net.Connect(endnames[j], cfg.ctrlername(j)) 129 | cfg.net.Enable(endnames[j], true) 130 | } 131 | 132 | ck := MakeClerk(ends, func(servername string) *labrpc.ClientEnd { 133 | name := randstring(20) 134 | end := cfg.net.MakeEnd(name) 135 | cfg.net.Connect(name, servername) 136 | cfg.net.Enable(name, true) 137 | return end 138 | }) 139 | cfg.clerks[ck] = endnames 140 | cfg.nextClientId++ 141 | return ck 142 | } 143 | 144 | func (cfg *config) deleteClient(ck *Clerk) { 145 | cfg.mu.Lock() 146 | defer cfg.mu.Unlock() 147 | 148 | v := cfg.clerks[ck] 149 | for i := 0; i < len(v); i++ { 150 | os.Remove(v[i]) 151 | } 152 | delete(cfg.clerks, ck) 153 | } 154 | 155 | // Shutdown i'th server of gi'th group, by isolating it 156 | func (cfg *config) ShutdownServer(gi int, i int) { 157 | cfg.mu.Lock() 158 | defer cfg.mu.Unlock() 159 | 160 | gg := cfg.groups[gi] 161 | 162 | // prevent this server from sending 163 | for j := 0; j < len(gg.servers); j++ { 164 | name := gg.endnames[i][j] 165 | cfg.net.Enable(name, false) 166 | } 167 | for j := 0; j < len(gg.mendnames[i]); j++ { 168 | name := gg.mendnames[i][j] 169 | cfg.net.Enable(name, false) 170 | } 171 | 172 | // disable client connections to the server. 173 | // it's important to do this before creating 174 | // the new Persister in saved[i], to avoid 175 | // the possibility of the server returning a 176 | // positive reply to an Append but persisting 177 | // the result in the superseded Persister. 178 | cfg.net.DeleteServer(cfg.servername(gg.gid, i)) 179 | 180 | // a fresh persister, in case old instance 181 | // continues to update the Persister. 182 | // but copy old persister's content so that we always 183 | // pass Make() the last persisted state. 184 | if gg.saved[i] != nil { 185 | gg.saved[i] = gg.saved[i].Copy() 186 | } 187 | 188 | kv := gg.servers[i] 189 | if kv != nil { 190 | cfg.mu.Unlock() 191 | kv.Kill() 192 | cfg.mu.Lock() 193 | gg.servers[i] = nil 194 | } 195 | } 196 | 197 | func (cfg *config) ShutdownGroup(gi int) { 198 | for i := 0; i < cfg.n; i++ { 199 | cfg.ShutdownServer(gi, i) 200 | } 201 | } 202 | 203 | // start i'th server in gi'th group 204 | func (cfg *config) StartServer(gi int, i int) { 205 | cfg.mu.Lock() 206 | 207 | gg := cfg.groups[gi] 208 | 209 | // a fresh set of outgoing ClientEnd names 210 | // to talk to other servers in this group. 211 | gg.endnames[i] = make([]string, cfg.n) 212 | for j := 0; j < cfg.n; j++ { 213 | gg.endnames[i][j] = randstring(20) 214 | } 215 | 216 | // and the connections to other servers in this group. 217 | ends := make([]*labrpc.ClientEnd, cfg.n) 218 | for j := 0; j < cfg.n; j++ { 219 | ends[j] = cfg.net.MakeEnd(gg.endnames[i][j]) 220 | cfg.net.Connect(gg.endnames[i][j], cfg.servername(gg.gid, j)) 221 | cfg.net.Enable(gg.endnames[i][j], true) 222 | } 223 | 224 | // ends to talk to shardctrler service 225 | mends := make([]*labrpc.ClientEnd, cfg.nctrlers) 226 | gg.mendnames[i] = make([]string, cfg.nctrlers) 227 | for j := 0; j < cfg.nctrlers; j++ { 228 | gg.mendnames[i][j] = randstring(20) 229 | mends[j] = cfg.net.MakeEnd(gg.mendnames[i][j]) 230 | cfg.net.Connect(gg.mendnames[i][j], cfg.ctrlername(j)) 231 | cfg.net.Enable(gg.mendnames[i][j], true) 232 | } 233 | 234 | // a fresh persister, so old instance doesn't overwrite 235 | // new instance's persisted state. 236 | // give the fresh persister a copy of the old persister's 237 | // state, so that the spec is that we pass StartKVServer() 238 | // the last persisted state. 239 | if gg.saved[i] != nil { 240 | gg.saved[i] = gg.saved[i].Copy() 241 | } else { 242 | gg.saved[i] = raft.MakePersister() 243 | } 244 | cfg.mu.Unlock() 245 | 246 | gg.servers[i] = StartServer(ends, i, gg.saved[i], cfg.maxraftstate, 247 | gg.gid, mends, 248 | func(servername string) *labrpc.ClientEnd { 249 | name := randstring(20) 250 | end := cfg.net.MakeEnd(name) 251 | cfg.net.Connect(name, servername) 252 | cfg.net.Enable(name, true) 253 | return end 254 | }) 255 | 256 | kvsvc := labrpc.MakeService(gg.servers[i]) 257 | rfsvc := labrpc.MakeService(gg.servers[i].rf) 258 | srv := labrpc.MakeServer() 259 | srv.AddService(kvsvc) 260 | srv.AddService(rfsvc) 261 | cfg.net.AddServer(cfg.servername(gg.gid, i), srv) 262 | } 263 | 264 | func (cfg *config) StartGroup(gi int) { 265 | for i := 0; i < cfg.n; i++ { 266 | cfg.StartServer(gi, i) 267 | } 268 | } 269 | 270 | func (cfg *config) StartCtrlerserver(i int) { 271 | // ClientEnds to talk to other controler replicas. 272 | ends := make([]*labrpc.ClientEnd, cfg.nctrlers) 273 | for j := 0; j < cfg.nctrlers; j++ { 274 | endname := randstring(20) 275 | ends[j] = cfg.net.MakeEnd(endname) 276 | cfg.net.Connect(endname, cfg.ctrlername(j)) 277 | cfg.net.Enable(endname, true) 278 | } 279 | 280 | p := raft.MakePersister() 281 | 282 | cfg.ctrlerservers[i] = shardctrler.StartServer(ends, i, p) 283 | 284 | msvc := labrpc.MakeService(cfg.ctrlerservers[i]) 285 | rfsvc := labrpc.MakeService(cfg.ctrlerservers[i].Raft()) 286 | srv := labrpc.MakeServer() 287 | srv.AddService(msvc) 288 | srv.AddService(rfsvc) 289 | cfg.net.AddServer(cfg.ctrlername(i), srv) 290 | } 291 | 292 | func (cfg *config) shardclerk() *shardctrler.Clerk { 293 | // ClientEnds to talk to ctrler service. 294 | ends := make([]*labrpc.ClientEnd, cfg.nctrlers) 295 | for j := 0; j < cfg.nctrlers; j++ { 296 | name := randstring(20) 297 | ends[j] = cfg.net.MakeEnd(name) 298 | cfg.net.Connect(name, cfg.ctrlername(j)) 299 | cfg.net.Enable(name, true) 300 | } 301 | 302 | return shardctrler.MakeClerk(ends) 303 | } 304 | 305 | // tell the shardctrler that a group is joining. 306 | func (cfg *config) join(gi int) { 307 | cfg.joinm([]int{gi}) 308 | } 309 | 310 | func (cfg *config) joinm(gis []int) { 311 | m := make(map[int][]string, len(gis)) 312 | for _, g := range gis { 313 | gid := cfg.groups[g].gid 314 | servernames := make([]string, cfg.n) 315 | for i := 0; i < cfg.n; i++ { 316 | servernames[i] = cfg.servername(gid, i) 317 | } 318 | m[gid] = servernames 319 | } 320 | cfg.mck.Join(m) 321 | } 322 | 323 | // tell the shardctrler that a group is leaving. 324 | func (cfg *config) leave(gi int) { 325 | cfg.leavem([]int{gi}) 326 | } 327 | 328 | func (cfg *config) leavem(gis []int) { 329 | gids := make([]int, 0, len(gis)) 330 | for _, g := range gis { 331 | gids = append(gids, cfg.groups[g].gid) 332 | } 333 | cfg.mck.Leave(gids) 334 | } 335 | 336 | var ncpu_once sync.Once 337 | 338 | func make_config(t *testing.T, n int, unreliable bool, maxraftstate int) *config { 339 | ncpu_once.Do(func() { 340 | if runtime.NumCPU() < 2 { 341 | fmt.Printf("warning: only one CPU, which may conceal locking bugs\n") 342 | } 343 | rand.Seed(makeSeed()) 344 | }) 345 | runtime.GOMAXPROCS(4) 346 | cfg := &config{} 347 | cfg.t = t 348 | cfg.maxraftstate = maxraftstate 349 | cfg.net = labrpc.MakeNetwork() 350 | cfg.start = time.Now() 351 | 352 | // controler 353 | cfg.nctrlers = 3 354 | cfg.ctrlerservers = make([]*shardctrler.ShardCtrler, cfg.nctrlers) 355 | for i := 0; i < cfg.nctrlers; i++ { 356 | cfg.StartCtrlerserver(i) 357 | } 358 | cfg.mck = cfg.shardclerk() 359 | 360 | cfg.ngroups = 3 361 | cfg.groups = make([]*group, cfg.ngroups) 362 | cfg.n = n 363 | for gi := 0; gi < cfg.ngroups; gi++ { 364 | gg := &group{} 365 | cfg.groups[gi] = gg 366 | gg.gid = 100 + gi 367 | gg.servers = make([]*ShardKV, cfg.n) 368 | gg.saved = make([]*raft.Persister, cfg.n) 369 | gg.endnames = make([][]string, cfg.n) 370 | gg.mendnames = make([][]string, cfg.nctrlers) 371 | for i := 0; i < cfg.n; i++ { 372 | cfg.StartServer(gi, i) 373 | } 374 | } 375 | 376 | cfg.clerks = make(map[*Clerk][]string) 377 | cfg.nextClientId = cfg.n + 1000 // client ids start 1000 above the highest serverid 378 | 379 | cfg.net.Reliable(!unreliable) 380 | 381 | return cfg 382 | } 383 | -------------------------------------------------------------------------------- /src/shardkv/server.go: -------------------------------------------------------------------------------- 1 | package shardkv 2 | 3 | import "mit6.824/labrpc" 4 | import "mit6.824/raft" 5 | import "sync" 6 | import "mit6.824/labgob" 7 | 8 | type Op struct { 9 | // Your definitions here. 10 | // Field names must start with capital letters, 11 | // otherwise RPC will break. 12 | } 13 | 14 | type ShardKV struct { 15 | mu sync.Mutex 16 | me int 17 | rf *raft.Raft 18 | applyCh chan raft.ApplyMsg 19 | make_end func(string) *labrpc.ClientEnd 20 | gid int 21 | ctrlers []*labrpc.ClientEnd 22 | maxraftstate int // snapshot if log grows this big 23 | 24 | // Your definitions here. 25 | } 26 | 27 | func (kv *ShardKV) Get(args *GetArgs, reply *GetReply) { 28 | // Your code here. 29 | } 30 | 31 | func (kv *ShardKV) PutAppend(args *PutAppendArgs, reply *PutAppendReply) { 32 | // Your code here. 33 | } 34 | 35 | // the tester calls Kill() when a ShardKV instance won't 36 | // be needed again. you are not required to do anything 37 | // in Kill(), but it might be convenient to (for example) 38 | // turn off debug output from this instance. 39 | func (kv *ShardKV) Kill() { 40 | kv.rf.Kill() 41 | // Your code here, if desired. 42 | } 43 | 44 | // servers[] contains the ports of the servers in this group. 45 | // 46 | // me is the index of the current server in servers[]. 47 | // 48 | // the k/v server should store snapshots through the underlying Raft 49 | // implementation, which should call persister.SaveStateAndSnapshot() to 50 | // atomically save the Raft state along with the snapshot. 51 | // 52 | // the k/v server should snapshot when Raft's saved state exceeds 53 | // maxraftstate bytes, in order to allow Raft to garbage-collect its 54 | // log. if maxraftstate is -1, you don't need to snapshot. 55 | // 56 | // gid is this group's GID, for interacting with the shardctrler. 57 | // 58 | // pass ctrlers[] to shardctrler.MakeClerk() so you can send 59 | // RPCs to the shardctrler. 60 | // 61 | // make_end(servername) turns a server name from a 62 | // Config.Groups[gid][i] into a labrpc.ClientEnd on which you can 63 | // send RPCs. You'll need this to send RPCs to other groups. 64 | // 65 | // look at client.go for examples of how to use ctrlers[] 66 | // and make_end() to send RPCs to the group owning a specific shard. 67 | // 68 | // StartServer() must return quickly, so it should start goroutines 69 | // for any long-running work. 70 | func StartServer(servers []*labrpc.ClientEnd, me int, persister *raft.Persister, maxraftstate int, gid int, ctrlers []*labrpc.ClientEnd, make_end func(string) *labrpc.ClientEnd) *ShardKV { 71 | // call labgob.Register on structures you want 72 | // Go's RPC library to marshall/unmarshall. 73 | labgob.Register(Op{}) 74 | 75 | kv := new(ShardKV) 76 | kv.me = me 77 | kv.maxraftstate = maxraftstate 78 | kv.make_end = make_end 79 | kv.gid = gid 80 | kv.ctrlers = ctrlers 81 | 82 | // Your initialization code here. 83 | 84 | // Use something like this to talk to the shardctrler: 85 | // kv.mck = shardctrler.MakeClerk(kv.ctrlers) 86 | 87 | kv.applyCh = make(chan raft.ApplyMsg) 88 | kv.rf = raft.Make(servers, me, persister, kv.applyCh) 89 | 90 | return kv 91 | } 92 | --------------------------------------------------------------------------------