├── README.md ├── atom_test.go ├── atomrpc ├── atomrpc.go ├── tls_helper.go └── types.go ├── client └── client.go ├── cmd ├── client │ └── client.go ├── db │ └── db.go ├── directory │ └── directory.go ├── keygen │ └── keygen.go ├── server │ └── server.go └── trustee │ └── trustee.go ├── common ├── group_gen.go ├── group_gen_test.go ├── lib.go ├── types.go └── xor.go ├── crypto ├── cca2.go ├── crypto.go ├── crypto_helper.go ├── crypto_test.go ├── encoding.go ├── encoding_test.go ├── keys.go ├── nizk.go ├── nizk_test.go ├── rand.go ├── shuffle.go ├── threshold.go ├── threshold_test.go └── types.go ├── db ├── db.go └── db_test.go ├── directory ├── directory.go └── helper.go ├── scripts └── run.py ├── server ├── helper.go ├── member.go ├── server.go └── server_test.go └── trustee ├── trustee.go └── trustee_test.go /README.md: -------------------------------------------------------------------------------- 1 | # Atom 2 | 3 | Atom is an anonymous broadcasting system that allows users to send short 4 | messages while preserving their anonymity. This is particularly useful for 5 | things like anonymous whistleblowing and protest organization, where the sender 6 | may fear retaliation from powerful adversaries for sending certain messages. 7 | Atom is the first anonymous communication system that scales horiztonally while 8 | protecting against traffic analysis by global adversaries. Our 9 | [SOSP'17 paper](http://people.csail.mit.edu/devadas/pubs/atom.pdf) 10 | explains the system in detail. 11 | 12 | The code posted here is a *research prototype*. While the code performs all the 13 | necessary crypto operations and should be fairly accurate in terms of 14 | performance, it is likely full of security bugs and security-criticial TODOs 15 | that hasn't been addressed. Pleae be careful if any part of this code is reused 16 | for other projects. 17 | 18 | ## Requirements 19 | 20 | The code requires Go 1.7 or later. The scripts are written in python. Most of 21 | the crypto operation relies on the [DeDiS kyber 22 | library](https://github.com/dedis/kyber). 23 | 24 | ## Components 25 | 26 | * crypto: This implements all crypto operations used by Atom. There is a level 27 | of indirection from here to the kyber library, so that we can easily replace 28 | the crypto section of Atom without impacting rest of the code base. 29 | 30 | * server: This implements both a physical server, and a logical server (member) 31 | which can be part of many groups. This part of the code actually carries out 32 | the protocol. 33 | 34 | * client: Client program handles sending of the messages. Currently, each client 35 | program is responsible for sending many messages. From user's perspective, only 36 | Submit function should be relavant. 37 | 38 | * directory: This is a very simple directory that keeps track of all 39 | participants and their keys. 40 | 41 | * trustee: Trustees are only used in one variant of our protocol, and they serve 42 | as the final line of protection for users. 43 | 44 | * db: This is a very simple database that stores all messages published in a 45 | given round so that users can download them. 46 | 47 | ## Running the code 48 | 49 | The code currently builds on the experimental main branch of the Kyber library. 50 | There is a chance future updates may break the code base. If you catch this, 51 | please file an issue and let us know! 52 | 53 | To create all the executable, run 54 | 55 | $ go install -tags experimental ./... 56 | 57 | in the root folder. 58 | 59 | There is an integration test availble in `atom_test.go`. This serves as both an 60 | example of the overall flow, and a test function. You can run this simply by 61 | doing 62 | 63 | $ go test -v -tags experimental 64 | 65 | in the root atom folder. 66 | 67 | We also provide a way, `run.py`, to start individual processes, both local and 68 | remote, to test the code, rather than using the go test which just uses go 69 | routines. To run this, you first have to generate enough keys by running 70 | something like 71 | 72 | $ mkdir $GOPATH/src/github.com/kwonalbert/atom/keys 73 | $ $GOPATH/bin/keygen -numServers 1024 -numTrustees 32 -serverKeys $GOPATH/src/github.com/kwonalbert/atom/keys/server_keys.json -trusteeKeys $GOPATH/src/github.com/kwonalbert/atom/keys/trustee_keys.json 74 | 75 | The same keys can be used for all experiments afterwards. Once the keys are set 76 | up, you are ready to run `run.py`. Running 77 | 78 | $ run.py --help 79 | 80 | will give you all available options. For example, the following command 81 | 82 | $ /run.py --port 8000 --servers 8 --gsize 4 --groups 4 --clients 4 --trustees 4 --msgs 16 --msize 160 --type 1 --mode 1 83 | 84 | runs a local Atom experiment with 85 | 86 | * 16 servers 87 | * 4 groups 88 | * 4 trustees 89 | * 16 messages per group 90 | * 160 byte messages 91 | * square network 92 | * trap based protection. 93 | 94 | ## Known problems and limitations 95 | 96 | The current implementation just runs one round. There is some work that needs 97 | to be done to extend this code to do multiple rounds of communication. 98 | 99 | This code was *very* recently migrated to the kyber library (from the old DeDiS 100 | library), and I've caught some weird bugs that arose as a result. I tried to 101 | squash as many as I could, but I think there are some more. If you run into 102 | them, please let me know. The current code uses ed25519 instead of nist-p256, 103 | since that portion of the code is not compiled by default in the new kyber library. 104 | 105 | The script provided is very simple, and does not do fancy things for AWS; 106 | currently you need to have a separate script to set up your AWS network, and 107 | give in the instance description (from `aws ec2 describe-instances`) to the 108 | script to run it on AWS. 109 | 110 | ## Contacts 111 | 112 | If you have any problems with the code or want to learn more about it, don't 113 | hesitate to contact me at `kwonal at mit.edu`, or file an issue here. 114 | -------------------------------------------------------------------------------- /atom_test.go: -------------------------------------------------------------------------------- 1 | package atom 2 | 3 | import ( 4 | "crypto/rand" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "runtime/pprof" 10 | "sync" 11 | "testing" 12 | 13 | "github.com/kwonalbert/atom/client" 14 | "github.com/kwonalbert/atom/db" 15 | "github.com/kwonalbert/atom/directory" 16 | "github.com/kwonalbert/atom/server" 17 | "github.com/kwonalbert/atom/trustee" 18 | 19 | . "github.com/kwonalbert/atom/common" 20 | . "github.com/kwonalbert/atom/crypto" 21 | ) 22 | 23 | var addr = "127.0.0.1:%d" 24 | var dirPort = 8000 25 | var port = 8001 26 | var trusteePort = 9001 27 | var dbPort = 10001 28 | 29 | var testNet = SQUARE 30 | 31 | var numServers = 9 32 | var numGroups = 4 33 | var perGroup = 6 34 | var numTrustees = 2 35 | var faultTolerence = 1 36 | 37 | var numMsgs = 16 38 | var msgSize = 10 // in bytes 39 | var threshold = perGroup - faultTolerence 40 | var numClients = numGroups 41 | 42 | var cpuprofile = "cpuprofile" 43 | 44 | func memberMessage(msg Message, msgs []Message) bool { 45 | for m := range msgs { 46 | if msg.Equal(msgs[m]) { 47 | return true 48 | } 49 | } 50 | return false 51 | } 52 | 53 | // Should call defer pprof.StopCPUProfile() for the test as well 54 | func profile() { 55 | flag.Parse() 56 | if cpuprofile != "" { 57 | f, err := os.Create(cpuprofile) 58 | if err != nil { 59 | log.Fatal(err) 60 | } 61 | pprof.StartCPUProfile(f) 62 | } 63 | } 64 | 65 | func TestNIZKMixing(t *testing.T) { 66 | dir, _, servers, clients, db := setup(VER_MODE) 67 | 68 | plaintextss := make([][][]byte, len(clients)) 69 | for c := range clients { 70 | plaintextss[c] = make([][]byte, numMsgs) 71 | for p := range plaintextss[c] { 72 | plaintextss[c][p] = make([]byte, msgSize) 73 | rand.Read(plaintextss[c][p]) 74 | } 75 | } 76 | 77 | results := make(chan [][]byte, len(clients)) 78 | for c := range clients { 79 | go func(c int) { 80 | clients[c].Submit(c, 0, plaintextss[c]) 81 | res, err := clients[c].DownloadMsgs(0) 82 | if err != nil { 83 | t.Error(err) 84 | } 85 | results <- res 86 | }(c) 87 | } 88 | 89 | var exp [][]byte 90 | for _, plaintexts := range plaintextss { 91 | exp = append(exp, plaintexts...) 92 | } 93 | 94 | for _ = range clients { 95 | res := <-results 96 | for r := range res { 97 | if !MemberByteSlice(res[r], exp) { 98 | t.Error("Missing plaintexts") 99 | } 100 | } 101 | } 102 | 103 | dir.Close() 104 | db.Close() 105 | for _, server := range servers { 106 | server.Close() 107 | } 108 | } 109 | 110 | func TestTrapMixing(t *testing.T) { 111 | //profile() 112 | //defer pprof.StopCPUProfile() 113 | dir, trustees, servers, clients, db := setup(TRAP_MODE) 114 | 115 | plaintextss := make([][][]byte, len(clients)) 116 | for c := range clients { 117 | plaintextss[c] = make([][]byte, numMsgs) 118 | for p := range plaintextss[c] { 119 | plaintextss[c][p] = make([]byte, msgSize) 120 | rand.Read(plaintextss[c][p]) 121 | } 122 | } 123 | 124 | results := make(chan [][]byte, len(clients)) 125 | for c := range clients { 126 | go func(c int) { 127 | clients[c].Submit(c, 0, plaintextss[c]) 128 | res, err := clients[c].DownloadMsgs(0) 129 | if err != nil { 130 | t.Error(err) 131 | } 132 | results <- res 133 | }(c) 134 | } 135 | 136 | var exp [][]byte 137 | for _, plaintexts := range plaintextss { 138 | exp = append(exp, plaintexts...) 139 | } 140 | 141 | for _ = range clients { 142 | res := <-results 143 | for r := range res { 144 | if !MemberByteSlice(res[r], exp) { 145 | t.Error("Missing plaintexts") 146 | } 147 | } 148 | } 149 | 150 | dir.Close() 151 | db.Close() 152 | for _, trustee := range trustees { 153 | trustee.Close() 154 | } 155 | for _, server := range servers { 156 | server.Close() 157 | } 158 | } 159 | 160 | func setup(testMode int) (*directory.Directory, []*trustee.Trustee, []*server.Server, []*client.Client, *db.DB) { 161 | numTrustees_ := numTrustees 162 | if testMode == VER_MODE { 163 | numTrustees_ = 0 164 | } 165 | 166 | wg := new(sync.WaitGroup) 167 | 168 | dir, err := directory.NewDirectory(0, dirPort, testMode, testNet, 169 | numServers, numGroups, perGroup, numTrustees_, 170 | numMsgs, msgSize, threshold, 171 | numClients) 172 | if err != nil { 173 | log.Fatal("Directory creation err:", err) 174 | } 175 | 176 | db, err := db.NewDB(dbPort) 177 | if err != nil { 178 | log.Fatal("DB creation err:", err) 179 | } 180 | 181 | trustees := make([]*trustee.Trustee, numTrustees_) 182 | servers := make([]*server.Server, numServers) 183 | clients := make([]*client.Client, numGroups) 184 | 185 | dirAddrs := []string{fmt.Sprintf(addr, dirPort)} 186 | dbAddr := fmt.Sprintf(addr, dbPort) 187 | 188 | // start the servers 189 | for i := range servers { 190 | servers[i], err = server.NewServer(fmt.Sprintf(addr, port+i), i, 191 | "", dirAddrs, dbAddr) 192 | if err != nil { 193 | log.Fatal("Server creation err:", err) 194 | } 195 | } 196 | 197 | for i := range trustees { 198 | trustees[i], err = trustee.NewTrustee(fmt.Sprintf(addr, trusteePort+i), i, 199 | "", dirAddrs) 200 | if err != nil { 201 | log.Fatal("Trustee creation err:", err) 202 | } 203 | } 204 | 205 | // connect the servers and setup group keys 206 | for i := range servers { 207 | wg.Add(1) 208 | go func(i int) { 209 | defer wg.Done() 210 | servers[i].Setup() 211 | }(i) 212 | } 213 | 214 | // start the trustees 215 | for i := range trustees { 216 | wg.Add(1) 217 | go func(i int) { 218 | defer wg.Done() 219 | trustees[i].Setup() 220 | trustees[i].RegisterRound() 221 | }(i) 222 | } 223 | 224 | wg.Wait() 225 | 226 | for i := range clients { 227 | clients[i], _ = client.NewClient(i, dirAddrs, dbAddr) 228 | clients[i].Setup() 229 | } 230 | 231 | return dir, trustees, servers, clients, db 232 | } 233 | -------------------------------------------------------------------------------- /atomrpc/atomrpc.go: -------------------------------------------------------------------------------- 1 | package atomrpc 2 | 3 | import ( 4 | "net/rpc" 5 | "time" 6 | ) 7 | 8 | type AtomRPCError struct { 9 | error 10 | err string 11 | timeout bool 12 | } 13 | 14 | func (e *AtomRPCError) Error() string { 15 | return e.err 16 | } 17 | 18 | func (e *AtomRPCError) Timeout() bool { 19 | return e.timeout 20 | } 21 | 22 | // RPC with timeout 23 | func AtomRPC(client *rpc.Client, method string, args interface{}, reply interface{}, timeout time.Duration) error { 24 | done := make(chan *rpc.Call, 1) 25 | client.Go(method, args, reply, done) 26 | select { 27 | case res := <-done: 28 | if res.Error != nil { 29 | return res.Error 30 | } else { 31 | return nil 32 | } 33 | case <-time.After(timeout): 34 | return &AtomRPCError{err: "Timeout", timeout: true} 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /atomrpc/tls_helper.go: -------------------------------------------------------------------------------- 1 | package atomrpc 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | "crypto/rand" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "crypto/x509/pkix" 10 | "encoding/pem" 11 | "log" 12 | "math/big" 13 | "time" 14 | ) 15 | 16 | // NOTE:TLS USED THIS WAY IS INSECURE. ONLY TO MEASURE PERFORMANCE 17 | func AtomTLSConfig() (*tls.Certificate, *tls.Config) { 18 | _, certB, keyB, err := GenCert() 19 | if err != nil { 20 | log.Fatal("Couldn't generate TLS cert", err) 21 | } 22 | cert, err := tls.X509KeyPair(certB, keyB) 23 | if err != nil { 24 | log.Fatal("Couldn't load TLS cert:", err) 25 | } 26 | var config tls.Config 27 | config.InsecureSkipVerify = true 28 | config.Certificates = []tls.Certificate{cert} 29 | config.ClientAuth = tls.NoClientCert 30 | return &cert, &config 31 | } 32 | 33 | func GenCert() (*x509.Certificate, []byte, []byte, error) { 34 | // generate a random serial number (a real cert authority would have some logic behind this) 35 | tmpl := &x509.Certificate{ 36 | SignatureAlgorithm: x509.ECDSAWithSHA256, 37 | PublicKeyAlgorithm: x509.ECDSA, 38 | Version: 2, // x509v3 39 | SerialNumber: new(big.Int).SetInt64(1), 40 | Subject: pkix.Name{Organization: []string{"Atom"}}, 41 | NotBefore: time.Now(), 42 | NotAfter: time.Now().AddDate(1 /* years */, 0 /* months */, 0 /* days */), 43 | KeyUsage: x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, 44 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, 45 | } 46 | 47 | priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 48 | pub := &priv.PublicKey 49 | certB, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv) 50 | if err != nil { 51 | return nil, nil, nil, err 52 | } 53 | privB, err := x509.MarshalECPrivateKey(priv) 54 | 55 | cert, err := x509.ParseCertificate(certB) 56 | if err != nil { 57 | return nil, nil, nil, err 58 | } 59 | // PEM encode the certificate (this is a standard TLS encoding) 60 | b := pem.Block{Type: "CERTIFICATE", Bytes: certB} 61 | certB = pem.EncodeToMemory(&b) 62 | b = pem.Block{Type: "EC PRIVATE KEY", Bytes: privB} 63 | privB = pem.EncodeToMemory(&b) 64 | 65 | return cert, certB, privB, err 66 | } 67 | -------------------------------------------------------------------------------- /atomrpc/types.go: -------------------------------------------------------------------------------- 1 | package atomrpc 2 | 3 | import . "github.com/kwonalbert/atom/crypto" 4 | 5 | type DealArgs struct { 6 | Uid int // the unique id of group 7 | Idx int 8 | Deal *ThresholdDeal 9 | } 10 | 11 | type DealReply struct { 12 | } 13 | 14 | type ResponseArgs struct { 15 | Uid int // the unique id of group 16 | Resp *ThresholdResponse 17 | } 18 | 19 | type ResponseReply struct { 20 | } 21 | 22 | // basic required info for most rpc calls 23 | type ArgInfo struct { 24 | Round int 25 | Level int 26 | Gid int 27 | Cur int // an index in the group, NOT server id 28 | Group []int // list of indices, NOT server ids 29 | } 30 | 31 | type SubmitArgs struct { 32 | Id int // client id 33 | Ciphertexts []Ciphertext 34 | EncProofs []EncProof 35 | ArgInfo 36 | } 37 | 38 | type SubmitReply struct { 39 | } 40 | 41 | type CommitArgs struct { 42 | Id int // client id 43 | Comms []Commitment 44 | ArgInfo 45 | } 46 | 47 | type CommitReply struct { 48 | } 49 | 50 | type CollectArgs struct { 51 | Id int 52 | Ciphertexts []Ciphertext 53 | ArgInfo 54 | } 55 | 56 | type CollectReply struct { 57 | } 58 | 59 | type ShuffleArgs struct { 60 | Ciphertexts []Ciphertext 61 | ArgInfo 62 | } 63 | 64 | type VerifyShuffleArgs struct { 65 | Old []Ciphertext 66 | New []Ciphertext 67 | Proof ShufProof 68 | ArgInfo 69 | } 70 | 71 | type VerifyShuffleReply struct { 72 | } 73 | 74 | type ShuffleReply struct { 75 | } 76 | 77 | type ReencryptArgs struct { 78 | Batches [][]Ciphertext 79 | ArgInfo 80 | } 81 | 82 | type ReencryptReply struct { 83 | } 84 | 85 | type VerifyReencryptArgs struct { 86 | Old [][]Ciphertext 87 | New [][]Ciphertext 88 | Proofs [][]ReencProof 89 | ArgInfo 90 | } 91 | 92 | type VerifyReencryptReply struct { 93 | } 94 | 95 | type ProofOKArgs struct { 96 | OK bool 97 | ArgInfo 98 | } 99 | 100 | type ProofOKReply struct { 101 | } 102 | 103 | type FinalizeArgs struct { 104 | Plaintexts [][]byte // used only for verifiable mode 105 | Inners []InnerCiphertext // used only for trap mode 106 | Traps []Trap 107 | ArgInfo 108 | } 109 | 110 | type FinalizeReply struct { 111 | } 112 | 113 | type ReportArgs struct { 114 | Round int 115 | Sid int 116 | Uid int 117 | CorrectHash bool 118 | CorrectTraps bool 119 | NoDups bool 120 | NumTraps int 121 | NumMsgs int 122 | } 123 | 124 | type ReportReply struct { 125 | Priv *PrivateKey 126 | } 127 | 128 | type DBArgs struct { 129 | Round int 130 | NumGroups int 131 | Msgs [][]byte 132 | } 133 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "crypto/tls" 7 | "encoding/binary" 8 | "log" 9 | "net/rpc" 10 | "time" 11 | 12 | . "github.com/kwonalbert/atom/atomrpc" 13 | . "github.com/kwonalbert/atom/common" 14 | . "github.com/kwonalbert/atom/crypto" 15 | 16 | "github.com/kwonalbert/atom/directory" 17 | ) 18 | 19 | type Client struct { 20 | id int 21 | 22 | params SystemParameter 23 | network [][]*Group 24 | 25 | dirAddrs []string 26 | dirServers []*rpc.Client 27 | dbServer *rpc.Client 28 | directory *directory.Directory 29 | publicKeys []*PublicKey 30 | 31 | start time.Time 32 | 33 | tlsConfig *tls.Config 34 | } 35 | 36 | func NewClient(id int, dirAddrs []string, dbAddr string) (*Client, error) { 37 | _, tlsConfig := AtomTLSConfig() 38 | 39 | dirServers := make([]*rpc.Client, len(dirAddrs)) 40 | for d, dirAddr := range dirAddrs { 41 | conn, err := tls.Dial("tcp", dirAddr, tlsConfig) 42 | if err != nil { 43 | return nil, err 44 | } 45 | dirServers[d] = rpc.NewClient(conn) 46 | } 47 | 48 | conn, err := tls.Dial("tcp", dbAddr, tlsConfig) 49 | if err != nil { 50 | return nil, err 51 | } 52 | dbServer := rpc.NewClient(conn) 53 | 54 | c := &Client{ 55 | id: id, 56 | dirAddrs: dirAddrs, 57 | dirServers: dirServers, 58 | dbServer: dbServer, 59 | 60 | tlsConfig: tlsConfig, 61 | } 62 | 63 | return c, nil 64 | } 65 | 66 | // primary function used by clients 67 | func (c *Client) Submit(gid, round int, plaintexts [][]byte) { 68 | msgs := c.generateMessages(round, plaintexts) 69 | 70 | // commit the traps first 71 | if c.params.Mode == TRAP_MODE { 72 | traps := c.generateTraps(c.id) 73 | trapMsgs := c.generateTrapMsgs(traps, len(msgs[0])) 74 | msgs = append(msgs, trapMsgs...) 75 | 76 | if c.id == 0 { 77 | log.Println("Committing traps") 78 | } 79 | cargs := c.generateCommitArgs(c.id, traps) 80 | c.commit(c.id, cargs) 81 | } 82 | 83 | submitArgs := c.generateSubmitArgs(gid, round, msgs) 84 | c.submit(gid, submitArgs) 85 | c.start = time.Now() 86 | } 87 | 88 | func (c *Client) Setup() { 89 | var keys [][]*PublicKey 90 | c.directory, c.params, c.publicKeys, keys = directory.GetGroupKeys(c.dirServers) 91 | 92 | var seed [SEED_LEN]byte 93 | for _, dirServer := range c.dirServers { 94 | var val [SEED_LEN]byte 95 | err := dirServer.Call("DirectoryRPC.Randomness", 0, &val) 96 | if err != nil { 97 | log.Fatal("Randomness err:", err) 98 | } 99 | Xor(val[:], seed[:]) 100 | } 101 | network := GenerateGroups(seed, c.params.NetType, c.params.NumServers, 102 | c.params.NumGroups, c.params.PerGroup, 103 | c.params.NumLevels, c.publicKeys) 104 | c.network = network 105 | 106 | for level := range keys { 107 | for gid := range keys[level] { 108 | c.network[level][gid].GroupKey = keys[level][gid] 109 | } 110 | } 111 | 112 | } 113 | 114 | func (c *Client) GenRandPlaintexts() [][]byte { 115 | plaintexts := make([][]byte, c.params.NumMsgs) 116 | for p := range plaintexts { 117 | plaintexts[p] = make([]byte, c.params.MsgSize) 118 | rand.Read(plaintexts[p]) 119 | } 120 | return plaintexts 121 | } 122 | 123 | func (c *Client) DownloadMsgs(round int) ([][]byte, error) { 124 | args := DBArgs{ 125 | Round: round, 126 | NumGroups: c.params.NumGroups, 127 | } 128 | var res [][]byte 129 | err := c.dbServer.Call("DB.Read", &args, &res) 130 | if err != nil { 131 | return nil, err 132 | } 133 | log.Println("Done with client", c.id, ".", time.Since(c.start), ". #msgs: ", c.params.NumMsgs) 134 | return res, nil 135 | } 136 | 137 | func (c *Client) generateRandomMsgs() []Message { 138 | numPts := c.params.MsgSize / PickLen() 139 | if c.params.MsgSize%PickLen() != 0 { 140 | numPts += 1 141 | } 142 | return GenRandMsgs(c.params.NumMsgs, numPts) 143 | } 144 | 145 | func (c *Client) generateMessages(round int, plaintexts [][]byte) []Message { 146 | if c.params.Mode == TRAP_MODE { 147 | buf := new(bytes.Buffer) 148 | err := binary.Write(buf, binary.LittleEndian, uint32(round)) 149 | if err != nil { 150 | log.Fatal("Could not write round") 151 | } 152 | 153 | trusteeKey := LoadPubKey(c.directory.RoundKeys[round]) 154 | 155 | inners := make([]InnerCiphertext, len(plaintexts)) 156 | for i := range inners { 157 | inners[i] = CCA2Encrypt(plaintexts[i], buf.Bytes(), trusteeKey) 158 | } 159 | msgs := make([]Message, len(inners)) 160 | for i := range inners { 161 | cMessage := GenMsg(inners[i].C) 162 | msgs[i] = append([]*Point{inners[i].R}, cMessage...) 163 | } 164 | return msgs 165 | } else { 166 | return GenMsgs(plaintexts) 167 | } 168 | } 169 | 170 | func (c *Client) generateTraps(gid int) []Trap { 171 | traps := make([]Trap, c.params.NumMsgs) 172 | for t := range traps { 173 | traps[t] = GenTrap(gid) 174 | } 175 | return traps 176 | } 177 | 178 | func (c *Client) generateTrapMsgs(traps []Trap, msgSize int) []Message { 179 | numPts := c.params.MsgSize / PickLen() 180 | if c.params.MsgSize%PickLen() != 0 { 181 | numPts += 1 182 | } 183 | 184 | msgs := make([]Message, c.params.NumMsgs) 185 | var err error 186 | for t := range traps { 187 | msgs[t], err = TrapToMessage(traps[t], numPts) 188 | if err != nil { 189 | log.Fatal("trap err:", err) 190 | } 191 | } 192 | return msgs 193 | } 194 | 195 | func (c *Client) generateSubmitArgs(gid, round int, msgs []Message) *SubmitArgs { 196 | group := c.network[0][gid] 197 | info := ArgInfo{ 198 | Round: round, 199 | Level: 0, 200 | Gid: gid, 201 | Cur: 0, 202 | Group: Xrange(c.params.Threshold), 203 | } 204 | 205 | ciphertexts := make([]Ciphertext, len(msgs)) 206 | proofs := make([]EncProof, len(msgs)) 207 | for c := range ciphertexts { 208 | ciphertexts[c], proofs[c] = ProveEncrypt(group.GroupKey, msgs[c]) 209 | } 210 | 211 | args := SubmitArgs{ 212 | Id: c.id, 213 | Ciphertexts: ciphertexts, 214 | EncProofs: proofs, 215 | ArgInfo: info, 216 | } 217 | return &args 218 | } 219 | 220 | func (c *Client) generateCommitArgs(gid int, traps []Trap) *CommitArgs { 221 | info := ArgInfo{ 222 | Round: 0, 223 | Level: 0, 224 | Gid: gid, 225 | Cur: 0, 226 | Group: Xrange(c.params.Threshold), 227 | } 228 | 229 | comms := make([]Commitment, len(traps)) 230 | for t := range traps { 231 | comms[t] = Commit(traps[t]) 232 | } 233 | 234 | args := CommitArgs{ 235 | Id: c.id, 236 | Comms: comms, 237 | ArgInfo: info, 238 | } 239 | return &args 240 | } 241 | 242 | func (c *Client) submit(gid int, args *SubmitArgs) { 243 | group := c.network[0][gid] 244 | for _, idx := range args.Group { 245 | addr := c.directory.Servers[group.Members[idx]] 246 | conn, err := tls.Dial("tcp", addr, c.tlsConfig) 247 | if err != nil { 248 | log.Fatal(err) 249 | } 250 | server := rpc.NewClient(conn) 251 | if err != nil { 252 | log.Fatal(err) 253 | } 254 | err = AtomRPC(server, "ServerRPC.Submit", args, nil, DEFAULT_TIMEOUT) 255 | if err != nil { 256 | log.Fatal(err) 257 | } 258 | server.Close() 259 | } 260 | } 261 | 262 | func (c *Client) commit(gid int, args *CommitArgs) { 263 | group := c.network[0][gid] 264 | 265 | for _, idx := range args.Group { 266 | addr := c.directory.Servers[group.Members[idx]] 267 | conn, err := tls.Dial("tcp", addr, c.tlsConfig) 268 | if err != nil { 269 | log.Fatal(err) 270 | } 271 | server := rpc.NewClient(conn) 272 | if err != nil { 273 | log.Fatal(err) 274 | } 275 | err = AtomRPC(server, "ServerRPC.Commit", args, nil, DEFAULT_TIMEOUT) 276 | if err != nil { 277 | log.Fatal(err) 278 | } 279 | server.Close() 280 | } 281 | } 282 | -------------------------------------------------------------------------------- /cmd/client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | 7 | "github.com/kwonalbert/atom/client" 8 | ) 9 | 10 | var ( 11 | dirAddr = flag.String("dirAddr", "127.0.0.1:8000", "Directory address") 12 | dbAddr = flag.String("dbAddr", "127.0.0.1:10001", "Database address") 13 | id = flag.Int("id", 0, "Public ID of the client") 14 | ) 15 | 16 | func main() { 17 | log.SetFlags(log.LstdFlags | log.Lshortfile) 18 | flag.Parse() 19 | 20 | c, err := client.NewClient(*id, []string{*dirAddr}, *dbAddr) 21 | if err != nil { 22 | log.Fatal("Could not start client:", err) 23 | } 24 | 25 | if *id == 0 { 26 | log.Println("Setting up clients..") 27 | } 28 | c.Setup() 29 | 30 | if *id == 0 { 31 | log.Println("Sending msg") 32 | } 33 | 34 | c.Submit(*id, 0, c.GenRandPlaintexts()) 35 | 36 | if *id == 0 { 37 | log.Println("Done sending") 38 | } 39 | 40 | c.DownloadMsgs(0) 41 | } 42 | -------------------------------------------------------------------------------- /cmd/db/db.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "os" 7 | "os/signal" 8 | "strconv" 9 | "strings" 10 | "syscall" 11 | 12 | "github.com/kwonalbert/atom/db" 13 | ) 14 | 15 | var ( 16 | addr = flag.String("dbAddr", "127.0.0.1:10001", "Database address") 17 | ) 18 | 19 | func main() { 20 | log.SetFlags(log.LstdFlags | log.Lshortfile) 21 | flag.Parse() 22 | 23 | port, err := strconv.Atoi(strings.Split(*addr, ":")[1]) 24 | if err != nil { 25 | log.Fatal(err) 26 | } 27 | 28 | _, err = db.NewDB(port) 29 | if err != nil { 30 | log.Fatal("Could not start db:", err) 31 | } 32 | 33 | kill := make(chan os.Signal) 34 | signal.Notify(kill, syscall.SIGINT, syscall.SIGTERM) 35 | <-kill 36 | } 37 | -------------------------------------------------------------------------------- /cmd/directory/directory.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "os" 7 | "os/signal" 8 | "strconv" 9 | "strings" 10 | "syscall" 11 | 12 | . "github.com/kwonalbert/atom/common" 13 | "github.com/kwonalbert/atom/directory" 14 | ) 15 | 16 | var ( 17 | id = flag.Int("id", 0, "unique id") 18 | addr = flag.String("dirAddr", "127.0.0.1:8000", "Directory address") 19 | perGroup = flag.Int("perGroup", 1, "# of servers per group") 20 | numServers = flag.Int("numServers", 0, "# of servers") 21 | numClients = flag.Int("numClients", 0, "# of clients") 22 | numGroups = flag.Int("numGroups", 0, "# of groups") 23 | numTrustees = flag.Int("numTrustees", 0, "# of trustees") 24 | numMsgs = flag.Int("numMsgs", 1, "# of msgs per group") 25 | msgSize = flag.Int("msgSize", 1, "size of the message in group elements") 26 | mode = flag.Int("mode", TRAP_MODE, "Operation mode") 27 | net = flag.Int("net", BUTTERFLY, "Network topology") 28 | branch = flag.Int("branch", 2, "Branching factor for padding network") 29 | ) 30 | 31 | func main() { 32 | log.SetFlags(log.LstdFlags | log.Lshortfile) 33 | flag.Parse() 34 | 35 | port, err := strconv.Atoi(strings.Split(*addr, ":")[1]) 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | _, err = directory.NewDirectory(*id, port, *mode, *net, 40 | *numServers, *numGroups, *perGroup, *numTrustees, 41 | *numMsgs, *msgSize, *perGroup-1, 42 | *numClients) 43 | if err != nil { 44 | log.Fatal("Directory err:", err) 45 | } 46 | 47 | kill := make(chan os.Signal) 48 | signal.Notify(kill, syscall.SIGINT, syscall.SIGTERM) 49 | <-kill 50 | } 51 | -------------------------------------------------------------------------------- /cmd/keygen/keygen.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "log" 7 | "os" 8 | 9 | "github.com/kwonalbert/atom/crypto" 10 | ) 11 | 12 | var ( 13 | serverKeys = flag.String("serverKeys", "server_keys.json", "Key file") 14 | trusteeKeys = flag.String("trusteeKeys", "trustee_keys.json", "Key file") 15 | numServers = flag.Int("numServers", 0, "# of servers") 16 | numTrustees = flag.Int("numTrustees", 0, "# of trustees") 17 | ) 18 | 19 | func main() { 20 | flag.Parse() 21 | 22 | serverFile, err := os.Create(*serverKeys) 23 | if err != nil { 24 | log.Fatal("file err:", err) 25 | } 26 | trusteeFile, err := os.Create(*trusteeKeys) 27 | if err != nil { 28 | log.Fatal("file err:", err) 29 | } 30 | 31 | sks := make([]crypto.HexKeyPair, *numServers) 32 | for s := 0; s < *numServers; s++ { 33 | key := crypto.GenKey() 34 | sks[s] = crypto.DumpKey(key) 35 | } 36 | 37 | tks := make([]crypto.HexKeyPair, *numTrustees) 38 | for t := 0; t < *numTrustees; t++ { 39 | key := crypto.GenKey() 40 | tks[t] = crypto.DumpKey(key) 41 | } 42 | 43 | sb, err := json.MarshalIndent(sks, "", " ") 44 | if err != nil { 45 | log.Fatal("failed marshaling keys:", err) 46 | } 47 | tb, err := json.MarshalIndent(tks, "", " ") 48 | if err != nil { 49 | log.Fatal("failed marshaling keys:", err) 50 | } 51 | 52 | serverFile.Write(sb) 53 | trusteeFile.Write(tb) 54 | 55 | serverFile.Close() 56 | trusteeFile.Close() 57 | } 58 | -------------------------------------------------------------------------------- /cmd/server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "os" 7 | "os/signal" 8 | "syscall" 9 | 10 | "github.com/kwonalbert/atom/server" 11 | ) 12 | 13 | var ( 14 | keyFile = flag.String("keyFile", "keys/server_keys.json", "Server key file") 15 | dirAddr = flag.String("dirAddr", "127.0.0.1:8000", "Directory address") 16 | dbAddr = flag.String("dbAddr", "127.0.0.1:10001", "Database address") 17 | addr = flag.String("addr", "127.0.0.1:8001", "Public address of server") 18 | id = flag.Int("id", 0, "Public ID of the server") 19 | ) 20 | 21 | func main() { 22 | log.SetFlags(log.LstdFlags | log.Lshortfile) 23 | flag.Parse() 24 | 25 | kill := make(chan os.Signal) 26 | 27 | s, err := server.NewServer(*addr, *id, *keyFile, []string{*dirAddr}, *dbAddr) 28 | if err != nil { 29 | log.Fatal("Could not start server:", err) 30 | } 31 | 32 | signal.Notify(kill, syscall.SIGINT, syscall.SIGTERM) 33 | 34 | s.Setup() 35 | 36 | for { 37 | select { 38 | case <-kill: 39 | s.Close() 40 | os.Exit(0) 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /cmd/trustee/trustee.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "log" 6 | "os" 7 | 8 | "github.com/kwonalbert/atom/trustee" 9 | ) 10 | 11 | var ( 12 | keyFile = flag.String("keyFile", "keys/server_keys.json", "Server key file") 13 | dirAddr = flag.String("dirAddr", "127.0.0.1:8000", "Directory address") 14 | addr = flag.String("addr", "127.0.0.1:8001", "Public address of server") 15 | id = flag.Int("id", 0, "Public ID of the server") 16 | ) 17 | 18 | func main() { 19 | log.SetFlags(log.LstdFlags | log.Lshortfile) 20 | flag.Parse() 21 | 22 | t, err := trustee.NewTrustee(*addr, *id, *keyFile, []string{*dirAddr}) 23 | if err != nil { 24 | log.Fatal("Trustee err:", err) 25 | } 26 | 27 | t.Setup() 28 | t.RegisterRound() 29 | 30 | kill := make(chan os.Signal) 31 | <-kill 32 | } 33 | -------------------------------------------------------------------------------- /common/group_gen.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "sort" 5 | 6 | . "github.com/kwonalbert/atom/crypto" 7 | ) 8 | 9 | func GenRandomGroup(numServers, numGroups, perGroup int, rand *Reader) []int { 10 | group := make([]int, perGroup) 11 | for s := range group { 12 | group[s] = -1 13 | } 14 | 15 | for s := range group { 16 | for { 17 | // ensure no duplicate servers in the group 18 | idx := rand.UInt() % numServers 19 | if !IsMember(idx, group) { 20 | group[s] = idx 21 | break 22 | } 23 | } 24 | } 25 | return group 26 | } 27 | 28 | func GenerateGroups(seed [SEED_LEN]byte, netType, 29 | numServers, numGroups, perGroup, numLevels int, 30 | publicKeys []*PublicKey) [][]*Group { 31 | rand := NewRandReader(seed[:]) 32 | 33 | // replicate the groups across levels 34 | // NOTE: wouldn't replicate it for the throughput maximized version 35 | baseGroups := make([][]int, numGroups) 36 | for gid := range baseGroups { 37 | baseGroups[gid] = GenRandomGroup(numServers, numGroups, perGroup, rand) 38 | sort.Ints(baseGroups[gid]) 39 | baseGroups[gid] = append(baseGroups[gid][gid%perGroup:], 40 | baseGroups[gid][:gid%perGroup]...) 41 | } 42 | 43 | usedUids := make(map[int]bool) 44 | 45 | groupss := make([][]*Group, numLevels) 46 | for level := range groupss { 47 | groupss[level] = make([]*Group, numGroups) 48 | for gid := range groupss[level] { 49 | members := baseGroups[gid] 50 | memberKeys := make([]*PublicKey, len(members)) 51 | for m, member := range members { 52 | memberKeys[m] = publicKeys[member] 53 | } 54 | 55 | // make sure no duplicate uids 56 | uid := -1 57 | for { 58 | uid = rand.UInt() 59 | if _, ok := usedUids[uid]; ok { 60 | continue 61 | } else { 62 | usedUids[uid] = true 63 | break 64 | } 65 | } 66 | 67 | groupss[level][gid] = &Group{ 68 | Members: members, 69 | MemberKeys: memberKeys, 70 | Level: level, 71 | Gid: gid, 72 | Uid: uid, 73 | AdjList: nil, 74 | } 75 | } 76 | } 77 | 78 | if netType == BUTTERFLY { 79 | //one full butterfly is log of # of groups 80 | oneButterfly := Log2(numGroups) 81 | for gid := 0; gid < numGroups; gid++ { 82 | for level := 0; level < numLevels; level++ { 83 | // Establish the cross-connections between Shuffle groups 84 | nextGroup := 0 85 | // Butterfly connection 86 | shift := uint(level % oneButterfly) 87 | if (uint(gid)>>shift)&1 == 0 { 88 | nextGroup = gid + (1 << shift) 89 | } else { 90 | nextGroup = gid - (1 << shift) 91 | } 92 | if level < numLevels-1 { 93 | groupss[level][gid].AdjList = 94 | []*Group{groupss[level+1][gid], 95 | groupss[level+1][nextGroup]} 96 | } else { 97 | groupss[level][gid].AdjList = 98 | []*Group{nil, nil} 99 | } 100 | } 101 | } 102 | } else if netType == SQUARE { 103 | // set to 10 for production 104 | for level := 0; level < numLevels; level++ { 105 | for gid := 0; gid < numGroups; gid++ { 106 | var adjList []*Group = nil 107 | if level < numLevels-1 { 108 | adjList = make([]*Group, numGroups) 109 | for neighbor := range adjList { 110 | adjList[neighbor] = groupss[level+1][neighbor] 111 | } 112 | } 113 | groupss[level][gid].AdjList = adjList 114 | } 115 | } 116 | } 117 | 118 | return groupss 119 | } 120 | -------------------------------------------------------------------------------- /common/group_gen_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | . "github.com/kwonalbert/atom/crypto" 8 | ) 9 | 10 | func TestGenerateGroups(t *testing.T) { 11 | numServers := 32 12 | numGroups := 32 13 | perGroup := 32 14 | 15 | _, pubs, _ := GenKeys(numServers) 16 | groupss := GenerateGroups(SEED, SQUARE, numServers, numGroups, 17 | perGroup, 10, pubs) 18 | for level := range groupss { 19 | for gid := range groupss[level] { 20 | // add gid above and uncomment below to see the result 21 | fmt.Println(groupss[level][gid].Members) 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /common/lib.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | // implementes utility and helper functions for atom 8 | 9 | func MemberByteSlice(b []byte, bs [][]byte) bool { 10 | for i := range bs { 11 | if ByteSliceEqual(b, bs[i]) { 12 | return true 13 | } 14 | } 15 | return false 16 | } 17 | 18 | func ByteSliceEqual(s1, s2 []byte) bool { 19 | if len(s1) != len(s2) { 20 | return false 21 | } 22 | for s := range s1 { 23 | if s1[s] != s2[s] { 24 | return false 25 | } 26 | } 27 | return true 28 | } 29 | 30 | func IsMember(val int, set []int) bool { 31 | for _, v := range set { 32 | if val == v { 33 | return true 34 | } 35 | } 36 | return false 37 | } 38 | 39 | // log base 2 40 | func Log2(val int) int { 41 | return int(math.Log2(float64(val))) 42 | } 43 | 44 | // Create a range slice 45 | func Xrange(extent int) []int { 46 | result := make([]int, extent) 47 | for i := range result { 48 | result[i] = i 49 | } 50 | return result 51 | } 52 | -------------------------------------------------------------------------------- /common/types.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/kwonalbert/atom/crypto" 7 | ) 8 | 9 | // "public randomness" 10 | const SEED_LEN = 16 11 | 12 | var SEED = [SEED_LEN]byte{2} 13 | 14 | const DEFAULT_TIMEOUT = 5 * time.Second 15 | 16 | type SystemParameter struct { 17 | Mode int // ver_mode or trap_mode 18 | NetType int // butterlfy or squareroot 19 | 20 | NumServers int // total number of servers available 21 | NumGroups int // number of groups per level 22 | PerGroup int // number of servers per group 23 | NumTrustees int // number of trsutees in trap mode 24 | NumLevels int // number of levels 25 | 26 | NumMsgs int // number of msgs per group 27 | MsgSize int // number of bytes of plaintext msg 28 | 29 | Threshold int // threshold, if it's used 30 | } 31 | 32 | // network nodes 33 | type Group struct { 34 | Members []int // members of the current group (server ids) 35 | MemberKeys []*crypto.PublicKey // members' long term keys 36 | GroupKey *crypto.PublicKey // final group key 37 | Level int // current level 38 | Gid int // gid in the current level 39 | Uid int // a unique id for this group 40 | AdjList []*Group // adjacency list 41 | } 42 | 43 | const ( 44 | VER_MODE = 0 45 | TRAP_MODE = 1 46 | ) 47 | 48 | const ( 49 | BUTTERFLY = 0 50 | SQUARE = 1 51 | ) 52 | -------------------------------------------------------------------------------- /common/xor.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package common 6 | 7 | import ( 8 | "runtime" 9 | "unsafe" 10 | ) 11 | 12 | const wordSize = int(unsafe.Sizeof(uintptr(0))) 13 | const supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "amd64" 14 | 15 | // fastXORBytes xors in bulk. It only works on architectures that 16 | // support unaligned read/writes. 17 | func fastXORBytes(dst, a, b []byte) int { 18 | n := len(a) 19 | if len(b) < n { 20 | n = len(b) 21 | } 22 | 23 | w := n / wordSize 24 | if w > 0 { 25 | dw := *(*[]uintptr)(unsafe.Pointer(&dst)) 26 | aw := *(*[]uintptr)(unsafe.Pointer(&a)) 27 | bw := *(*[]uintptr)(unsafe.Pointer(&b)) 28 | for i := 0; i < w; i++ { 29 | dw[i] = aw[i] ^ bw[i] 30 | } 31 | } 32 | 33 | for i := (n - n%wordSize); i < n; i++ { 34 | dst[i] = a[i] ^ b[i] 35 | } 36 | 37 | return n 38 | } 39 | 40 | func safeXORBytes(dst, a, b []byte) int { 41 | n := len(a) 42 | if len(b) < n { 43 | n = len(b) 44 | } 45 | for i := 0; i < n; i++ { 46 | dst[i] = a[i] ^ b[i] 47 | } 48 | return n 49 | } 50 | 51 | // xorBytes xors the bytes in a and b. The destination is assumed to have enough 52 | // space. Returns the number of bytes xor'd. 53 | func xorBytes(dst, a, b []byte) int { 54 | if supportsUnaligned { 55 | return fastXORBytes(dst, a, b) 56 | } else { 57 | // TODO(hanwen): if (dst, a, b) have common alignment 58 | // we could still try fastXORBytes. It is not clear 59 | // how often this happens, and it's only worth it if 60 | // the block encryption itself is hardware 61 | // accelerated. 62 | return safeXORBytes(dst, a, b) 63 | } 64 | } 65 | 66 | // fastXORWords XORs multiples of 4 or 8 bytes (depending on architecture.) 67 | // The arguments are assumed to be of equal length. 68 | func fastXORWords(dst, a, b []byte) { 69 | dw := *(*[]uintptr)(unsafe.Pointer(&dst)) 70 | aw := *(*[]uintptr)(unsafe.Pointer(&a)) 71 | bw := *(*[]uintptr)(unsafe.Pointer(&b)) 72 | n := len(b) / wordSize 73 | for i := 0; i < n; i++ { 74 | dw[i] = aw[i] ^ bw[i] 75 | } 76 | } 77 | 78 | func XorWords(dst, a, b []byte) { 79 | if supportsUnaligned { 80 | fastXORWords(dst, a, b) 81 | } else { 82 | safeXORBytes(dst, a, b) 83 | } 84 | } 85 | 86 | func Xor(a, dst []byte) { 87 | XorWords(dst, a, dst) 88 | } 89 | -------------------------------------------------------------------------------- /crypto/cca2.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | 7 | "github.com/dedis/kyber/util/random" 8 | 9 | "golang.org/x/crypto/nacl/secretbox" 10 | "golang.org/x/crypto/sha3" 11 | ) 12 | 13 | func CCA2Encrypt(plaintext []byte, nonce []byte, X *PublicKey) InnerCiphertext { 14 | rnd := random.New() 15 | 16 | r := SUITE.Scalar().Pick(rnd) 17 | R := SUITE.Point().Mul(r, nil) 18 | shared := SUITE.Point().Mul(r, X.p) 19 | 20 | sharedBytes, err := shared.MarshalBinary() 21 | if err != nil { 22 | log.Fatal("Could not marshal rand") 23 | } 24 | pubBytes, err := X.p.MarshalBinary() 25 | if err != nil { 26 | log.Fatal("Could not marshal trustee key") 27 | } 28 | key_nonce := append(pubBytes, sharedBytes...) 29 | 30 | var key [32]byte 31 | sha3.ShakeSum128(key[:], key_nonce) 32 | var nonce24 [24]byte 33 | copy(nonce24[:], nonce) 34 | 35 | ciphertext := secretbox.Seal(nil, plaintext, &nonce24, &key) 36 | return InnerCiphertext{ 37 | R: &Point{R}, 38 | C: ciphertext, 39 | } 40 | } 41 | 42 | func CCA2Decrypt(inner InnerCiphertext, nonce []byte, 43 | x *PrivateKey, X *PublicKey) ([]byte, error) { 44 | 45 | shared := SUITE.Point().Mul(x.s, inner.R.p) 46 | sharedBytes, err := shared.MarshalBinary() 47 | if err != nil { 48 | log.Fatal("Could not marshal rand") 49 | } 50 | pubBytes, err := X.p.MarshalBinary() 51 | if err != nil { 52 | log.Fatal("Could not marshal trustee key") 53 | } 54 | key_nonce := append(pubBytes, sharedBytes...) 55 | 56 | var key [32]byte 57 | sha3.ShakeSum128(key[:], key_nonce) 58 | var nonce24 [24]byte 59 | copy(nonce24[:], nonce) 60 | 61 | msg, auth := secretbox.Open(nil, inner.C, &nonce24, &key) 62 | if !auth { 63 | return nil, errors.New("Misauthenticated msg") 64 | } 65 | return msg, nil 66 | } 67 | -------------------------------------------------------------------------------- /crypto/crypto.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "encoding/binary" 5 | "runtime" 6 | "sync" 7 | 8 | "github.com/dedis/kyber/group/edwards25519" 9 | "github.com/dedis/kyber/util/random" 10 | "github.com/dedis/kyber/xof/blake2xb" 11 | 12 | "golang.org/x/crypto/sha3" 13 | ) 14 | 15 | var SUITE = edwards25519.NewBlakeSHA256Ed25519WithRand(blake2xb.New(nil)) 16 | var refPt = SUITE.Point() 17 | 18 | // Basic ElGamal encryption 19 | func Encrypt(X *PublicKey, msg Message) Ciphertext { 20 | rnd := random.New() 21 | R := make([]*Point, len(msg)) 22 | C := make([]*Point, len(msg)) 23 | for idx := range msg { 24 | r := SUITE.Scalar().Pick(rnd) 25 | R[idx] = &Point{SUITE.Point().Mul(r, nil)} 26 | C[idx] = &Point{SUITE.Point().Add(msg[idx].p, SUITE.Point().Mul(r, X.p))} 27 | } 28 | return Ciphertext{ 29 | R: R, 30 | C: C, 31 | Y: nil, 32 | } 33 | } 34 | 35 | // Basic ElGamal decryption 36 | func Decrypt(x *PrivateKey, c Ciphertext) Message { 37 | msg := make([]*Point, len(c.C)) 38 | for idx := range c.C { 39 | blind := SUITE.Point().Mul(x.s, c.R[idx].p) 40 | msg[idx] = &Point{blind.Sub(c.C[idx].p, blind)} 41 | } 42 | return msg 43 | } 44 | 45 | // Reblind ciphertext c using publickey X 46 | func Reblind(X *PublicKey, c Ciphertext) Ciphertext { 47 | rnd := random.New() 48 | for idx := range c.C { 49 | r := SUITE.Scalar().Pick(rnd) 50 | newR := SUITE.Point().Mul(r, nil) 51 | newBlind := SUITE.Point().Mul(r, X.p) 52 | 53 | c.R[idx].p = c.R[idx].p.Add(c.R[idx].p, newR) 54 | c.C[idx].p = c.C[idx].p.Add(c.C[idx].p, newBlind) 55 | } 56 | return c 57 | } 58 | 59 | // Decrypt ciphertext c using x and reencrypt using XBar 60 | func Reencrypt(x *PrivateKey, XBar *PublicKey, c Ciphertext) Ciphertext { 61 | if c.Y == nil { 62 | c.Y = c.R 63 | c.R = make([]*Point, len(c.Y)) 64 | for idx := range c.R { 65 | c.R[idx] = &Point{SUITE.Point().Null()} 66 | } 67 | } 68 | rnd := random.New() 69 | 70 | ciphertext := Ciphertext{ 71 | R: make([]*Point, len(c.R)), 72 | C: make([]*Point, len(c.C)), 73 | Y: make([]*Point, len(c.Y)), 74 | } 75 | for idx := range c.C { 76 | blind := SUITE.Point().Mul(x.s, c.Y[idx].p) 77 | ctmp := blind.Sub(c.C[idx].p, blind) 78 | 79 | rBar := SUITE.Scalar().Pick(rnd) 80 | newR := SUITE.Point().Mul(rBar, nil) 81 | newR = newR.Add(c.R[idx].p, newR) 82 | 83 | newBlind := SUITE.Point().Mul(rBar, XBar.p) 84 | newC := ctmp.Add(ctmp, newBlind) 85 | 86 | ciphertext.R[idx] = &Point{newR} 87 | ciphertext.C[idx] = &Point{newC} 88 | ciphertext.Y[idx] = c.Y[idx] 89 | } 90 | return ciphertext 91 | } 92 | 93 | func ReencryptBatches(priv *PrivateKey, pubKeys []*PublicKey, batches [][]Ciphertext) [][]Ciphertext { 94 | numBatches := len(batches) 95 | batchSize := len(batches[0]) 96 | k := numBatches * batchSize 97 | ciphertexts := make([]Ciphertext, k) 98 | pubs := make([]*PublicKey, k) 99 | idx := 0 100 | for b := range batches { 101 | for i := range batches[b] { 102 | ciphertexts[idx] = batches[b][i] 103 | pubs[idx] = pubKeys[b] 104 | idx++ 105 | } 106 | } 107 | 108 | chunks := runtime.NumCPU() 109 | div := k / chunks 110 | if k < chunks { 111 | div = 1 112 | chunks = k 113 | } else if k%chunks != 0 { 114 | div++ 115 | } 116 | 117 | wg := new(sync.WaitGroup) 118 | wg.Add(chunks) 119 | for d := 0; d < chunks; d++ { 120 | start := d * div 121 | end := (d + 1) * div 122 | if end > k { 123 | end = k 124 | } 125 | go func(start, end int) { 126 | defer wg.Done() 127 | for i := start; i < end; i++ { 128 | ciphertexts[i] = Reencrypt(priv, pubs[i], ciphertexts[i]) 129 | } 130 | }(start, end) 131 | } 132 | wg.Wait() 133 | 134 | result := make([][]Ciphertext, len(batches)) 135 | idx = 0 136 | for b := range batches { 137 | result[b] = make([]Ciphertext, len(batches[b])) 138 | for i := range batches[b] { 139 | result[b][i] = ciphertexts[idx] 140 | idx++ 141 | 142 | } 143 | } 144 | 145 | return result 146 | } 147 | 148 | func Commit(trap Trap) Commitment { 149 | buf := make([]byte, 8) 150 | gid := uint64(trap.Gid) 151 | binary.PutUvarint(buf, gid) 152 | b := append(buf, trap.Nonce...) 153 | return sha3.Sum256(b) 154 | } 155 | 156 | func VerifyCommitment(trap Trap, comm Commitment) bool { 157 | buf := make([]byte, 8) 158 | gid := uint64(trap.Gid) 159 | binary.PutUvarint(buf, gid) 160 | b := append(buf, trap.Nonce...) 161 | res := sha3.Sum256(b) 162 | if res == comm { 163 | return true 164 | } else { 165 | return false 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /crypto/crypto_helper.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto/cipher" 5 | "crypto/rand" 6 | "encoding/binary" 7 | "fmt" 8 | "log" 9 | 10 | "github.com/dedis/kyber" 11 | "github.com/dedis/kyber/util/random" 12 | ) 13 | 14 | func PickLen() int { 15 | return refPt.EmbedLen() 16 | } 17 | 18 | func compareArray(arr1, arr2 []byte) bool { 19 | if len(arr1) != len(arr2) { 20 | return false 21 | } 22 | for i := range arr1 { 23 | if arr1[i] != arr2[i] { 24 | return false 25 | } 26 | } 27 | return true 28 | } 29 | 30 | func GenMsg(msg []byte) Message { 31 | plaintext := append(msg, byte(MSG)) 32 | 33 | for { 34 | var pts []*Point 35 | var pt kyber.Point 36 | done := false 37 | i := 0 38 | l := PickLen() 39 | for !done { 40 | start := i * l 41 | end := (i + 1) * l 42 | if end > len(plaintext) { 43 | end = len(plaintext) 44 | } 45 | pt = SUITE.Point().Embed(plaintext[start:end], random.New()) 46 | pts = append(pts, &Point{pt}) 47 | i++ 48 | done = end == len(plaintext) 49 | } 50 | res, ty, err := ExtractPlaintext(pts) 51 | if err == nil && ty == MSG && compareArray(msg, res) { 52 | return pts 53 | } 54 | fmt.Println("fail msg!") 55 | } 56 | } 57 | 58 | func GenMsgs(plaintexts [][]byte) []Message { 59 | msgs := make([]Message, len(plaintexts)) 60 | for m := range msgs { 61 | msgs[m] = GenMsg(plaintexts[m]) 62 | } 63 | return msgs 64 | } 65 | 66 | func GenRandMsg(numPts int) Message { 67 | msg := make([]*Point, numPts) 68 | rnd := random.New() 69 | for m := range msg { 70 | tmp := SUITE.Scalar().Pick(rnd) 71 | msg[m] = &Point{SUITE.Point().Mul(tmp, nil)} 72 | } 73 | return msg 74 | } 75 | 76 | func GenPoints(numPts int) []*Point { 77 | msg := make([]*Point, numPts) 78 | rnd := random.New() 79 | for m := range msg { 80 | tmp := SUITE.Scalar().Pick(rnd) 81 | msg[m] = &Point{SUITE.Point().Mul(tmp, nil)} 82 | } 83 | return msg 84 | } 85 | 86 | func GenRandMsgs(numMsg, numPts int) []Message { 87 | msgs := make([]Message, numMsg) 88 | for m := range msgs { 89 | msgs[m] = GenRandMsg(numPts) 90 | } 91 | return msgs 92 | } 93 | 94 | func GenTrap(gid int) Trap { 95 | buf := make([]byte, NONCE_LEN) 96 | n, err := rand.Read(buf) 97 | if n != NONCE_LEN { 98 | log.Fatal("Could not read enough rand bytes") 99 | } else if err != nil { 100 | log.Fatal("Could not read rand bytes:", err) 101 | } 102 | return Trap{Gid: gid, Nonce: buf} 103 | } 104 | 105 | func TrapToMessage(trap Trap, numPts int) (Message, error) { 106 | buf, err := (&trap).MarshalBinary() 107 | if err != nil { 108 | return nil, err 109 | } 110 | buf = append(buf, byte(TRAP)) 111 | for { 112 | pt := SUITE.Point().Embed(buf, random.New()) 113 | res, err := pt.Data() 114 | if err != nil || !compareArray(buf, res) { 115 | fmt.Println("fail trap!") 116 | continue 117 | } 118 | msg := make([]*Point, numPts) 119 | for m := range msg { 120 | msg[m] = &Point{pt} 121 | } 122 | return msg, nil 123 | } 124 | } 125 | 126 | func ExtractMessages(ciphertexts []Ciphertext) []Message { 127 | msgs := make([]Message, len(ciphertexts)) 128 | for c, ciphertext := range ciphertexts { 129 | msgs[c] = ciphertext.C 130 | } 131 | return msgs 132 | } 133 | 134 | func ExtractPlaintext(msg Message) ([]byte, MsgType, error) { 135 | var plaintext []byte 136 | for m := range msg { 137 | p, err := msg[m].p.Data() 138 | if err != nil { 139 | return nil, MsgType(byte(OTHER)), err 140 | } 141 | plaintext = append(plaintext, p...) 142 | } 143 | msgType := plaintext[len(plaintext)-1] 144 | if msgType == TRAP { 145 | plaintext = plaintext[:4+NONCE_LEN] 146 | } else if msgType == MSG { 147 | plaintext = plaintext[:len(plaintext)-1] 148 | } 149 | return plaintext, MsgType(msgType), nil 150 | } 151 | 152 | func ExtractPlaintexts(msgs []Message) ([][]byte, []MsgType, error) { 153 | plaintexts := make([][]byte, len(msgs)) 154 | msgType := make([]MsgType, len(msgs)) 155 | var err error 156 | for m := range msgs { 157 | plaintexts[m], msgType[m], err = ExtractPlaintext(msgs[m]) 158 | if err != nil { 159 | return nil, nil, err 160 | } 161 | } 162 | return plaintexts, msgType, nil 163 | } 164 | 165 | func ExtractInnerAndTraps(msgs []Message) ([]InnerCiphertext, []Trap, error) { 166 | var inners []InnerCiphertext 167 | var traps []Trap 168 | for m := range msgs { 169 | // check the last byte of the last msg for the type 170 | p, err := msgs[m][len(msgs[m])-1].p.Data() 171 | if err != nil { 172 | return nil, nil, err 173 | } 174 | msgType := p[len(p)-1] 175 | 176 | if msgType == TRAP { 177 | trap := new(Trap) 178 | err = trap.UnmarshalBinary(p) 179 | if err != nil { 180 | return nil, nil, err 181 | } 182 | traps = append(traps, *trap) 183 | } else { 184 | c, _, err := ExtractPlaintext(msgs[m][1:]) 185 | if err != nil { 186 | return nil, nil, err 187 | } 188 | inner := InnerCiphertext{ 189 | R: msgs[m][0], 190 | C: c, 191 | } 192 | inners = append(inners, inner) 193 | } 194 | } 195 | return inners, traps, nil 196 | } 197 | 198 | func CopyPubs(pubs []*PublicKey) []*PublicKey { 199 | cp := make([]*PublicKey, len(pubs)) 200 | for i := range pubs { 201 | b, _ := pubs[i].MarshalBinary() 202 | cp[i] = new(PublicKey) 203 | cp[i].UnmarshalBinary(b) 204 | } 205 | return cp 206 | } 207 | 208 | // randUint64 chooses a uniform random uint64 209 | func randUint64(rand cipher.Stream) uint64 { 210 | b := random.Bits(64, false, rand) 211 | return binary.BigEndian.Uint64(b) 212 | } 213 | -------------------------------------------------------------------------------- /crypto/crypto_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import "testing" 4 | 5 | var size = 1 6 | var num = 5 7 | 8 | func isEqualMessage(m1, m2 Message) bool { 9 | same := true 10 | for i := range m1 { 11 | same = same && m1[i].Equal(m2[i]) 12 | } 13 | return same 14 | } 15 | 16 | func isMemberMessage(msg Message, msgs []Message) bool { 17 | for m := range msgs { 18 | if isEqualMessage(msg, msgs[m]) { 19 | return true 20 | } 21 | } 22 | return false 23 | } 24 | 25 | func TestComebineKeys(t *testing.T) { 26 | _, pubs, secs := GenKeys(N) 27 | cSec := CombinePrivateKeys(secs) 28 | cPub := CombinePublicKeys(pubs) 29 | if !PubFromPriv(cSec).Equal(cPub) { 30 | t.Error("Mismatched combined keys") 31 | } 32 | } 33 | 34 | func TestEncryptDecrypt(t *testing.T) { 35 | key := GenKey() 36 | x, X := key.Priv, key.Pub 37 | msg := GenRandMsg(size) 38 | 39 | ciphertext := Encrypt(X, msg) 40 | res := Decrypt(x, ciphertext) 41 | 42 | for m := range msg { 43 | if !msg[m].Equal(res[m]) { 44 | t.Error("Mismatched plaintext message.") 45 | } 46 | } 47 | } 48 | 49 | func TestReblind(t *testing.T) { 50 | key := GenKey() 51 | x, X := key.Priv, key.Pub 52 | msg := GenRandMsg(size) 53 | 54 | ciphertext := Encrypt(X, msg) 55 | 56 | reblinded := Reblind(X, ciphertext) 57 | 58 | exp := Decrypt(x, ciphertext) 59 | res := Decrypt(x, reblinded) 60 | 61 | for m := range exp { 62 | if !msg[m].Equal(res[m]) { 63 | t.Error("Mismatched plaintext message.") 64 | } 65 | } 66 | } 67 | 68 | func TestReencrypt(t *testing.T) { 69 | key1 := GenKey() 70 | x1, X1 := key1.Priv, key1.Pub 71 | key2 := GenKey() 72 | x2, X2 := key2.Priv, key2.Pub 73 | Xcombined := CombinePublicKeys([]*PublicKey{X1, X2}) 74 | 75 | key3 := GenKey() 76 | y, Y := key3.Priv, key3.Pub 77 | 78 | msg := GenRandMsg(size) 79 | ciphertext := Encrypt(Xcombined, msg) 80 | 81 | reenc := Reencrypt(x1, Y, ciphertext) 82 | reenc = Reencrypt(x2, Y, reenc) 83 | 84 | res := Decrypt(y, reenc) 85 | for m := range msg { 86 | if !msg[m].Equal(res[m]) { 87 | t.Error("Mismatched plaintext message.") 88 | } 89 | } 90 | } 91 | 92 | func TestShuffle(t *testing.T) { 93 | key := GenKey() 94 | x, X := key.Priv, key.Pub 95 | 96 | msgs := GenRandMsgs(num, size) 97 | ciphertexts := make([]Ciphertext, num) 98 | for m := range msgs { 99 | ciphertexts[m] = Encrypt(X, msgs[m]) 100 | } 101 | 102 | shuffled := Shuffle(X, ciphertexts) 103 | 104 | results := make([]Message, num) 105 | for r := range results { 106 | results[r] = Decrypt(x, shuffled[r]) 107 | } 108 | 109 | for r := range results { 110 | if !isMemberMessage(results[r], msgs) { 111 | t.Error("Missing message.") 112 | } 113 | } 114 | } 115 | 116 | func TestCCA2(t *testing.T) { 117 | nonce := make([]byte, 24) 118 | key := GenKey() 119 | x, X := key.Priv, key.Pub 120 | 121 | msg := make([]byte, 160) 122 | copy(msg, []byte("Hello World")) 123 | inner := CCA2Encrypt(msg, nonce, X) 124 | 125 | decMsg, err := CCA2Decrypt(inner, nonce, x, X) 126 | if err != nil { 127 | t.Error(err) 128 | } else { 129 | if len(msg) != len(decMsg) { 130 | t.Error("Msg length mismatch") 131 | } 132 | for m := range msg { 133 | if msg[m] != decMsg[m] { 134 | t.Error("Msg mismatch") 135 | } 136 | } 137 | } 138 | 139 | } 140 | 141 | func BenchmarkEncrypt(b *testing.B) { 142 | key := GenKey() 143 | _, X := key.Priv, key.Pub 144 | msg := GenRandMsg(size) 145 | 146 | b.ResetTimer() 147 | for i := 0; i < b.N; i++ { 148 | Encrypt(X, msg) 149 | } 150 | } 151 | 152 | func BenchmarkReencrypt(b *testing.B) { 153 | key1 := GenKey() 154 | x1, X1 := key1.Priv, key1.Pub 155 | key2 := GenKey() 156 | _, X2 := key2.Priv, key2.Pub 157 | Xcombined := CombinePublicKeys([]*PublicKey{X1, X2}) 158 | 159 | key3 := GenKey() 160 | _, Y := key3.Priv, key3.Pub 161 | 162 | msg := GenRandMsg(size) 163 | ciphertext := Encrypt(Xcombined, msg) 164 | 165 | b.ResetTimer() 166 | for i := 0; i < b.N; i++ { 167 | Reencrypt(x1, Y, ciphertext) 168 | } 169 | } 170 | 171 | func BenchmarkShuffle1024(b *testing.B) { 172 | numPts := 1024 173 | key := GenKey() 174 | _, X := key.Priv, key.Pub 175 | 176 | msgs := GenRandMsgs(numPts, 1) 177 | ciphertexts := make([]Ciphertext, numPts) 178 | for m := range msgs { 179 | ciphertexts[m] = Encrypt(X, msgs[m]) 180 | } 181 | 182 | b.ResetTimer() 183 | for i := 0; i < b.N; i++ { 184 | Shuffle(X, ciphertexts) 185 | } 186 | } 187 | 188 | func BenchmarkShuffle2048(b *testing.B) { 189 | numPts := 2048 190 | key := GenKey() 191 | _, X := key.Priv, key.Pub 192 | 193 | msgs := GenRandMsgs(numPts, 1) 194 | ciphertexts := make([]Ciphertext, numPts) 195 | for m := range msgs { 196 | ciphertexts[m] = Encrypt(X, msgs[m]) 197 | } 198 | 199 | b.ResetTimer() 200 | for i := 0; i < b.N; i++ { 201 | Shuffle(X, ciphertexts) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /crypto/encoding.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "reflect" 7 | 8 | "github.com/dedis/kyber" 9 | dkg "github.com/dedis/kyber/share/dkg/pedersen" 10 | vss "github.com/dedis/kyber/share/vss/pedersen" 11 | "github.com/dedis/protobuf" 12 | ) 13 | 14 | func (p *Point) MarshalBinary() ([]byte, error) { 15 | return p.p.MarshalBinary() 16 | } 17 | 18 | func (p *Point) UnmarshalBinary(data []byte) error { 19 | p.p = SUITE.Point() 20 | return p.p.UnmarshalBinary(data) 21 | } 22 | 23 | func (s *Scalar) MarshalBinary() ([]byte, error) { 24 | return s.s.MarshalBinary() 25 | } 26 | 27 | func (s *Scalar) UnmarshalBinary(data []byte) error { 28 | s.s = SUITE.Scalar() 29 | return s.s.UnmarshalBinary(data) 30 | } 31 | 32 | func (k *PrivateKey) MarshalBinary() ([]byte, error) { 33 | return k.s.MarshalBinary() 34 | } 35 | 36 | func (k *PrivateKey) UnmarshalBinary(data []byte) error { 37 | k.s = SUITE.Scalar() 38 | return k.s.UnmarshalBinary(data) 39 | } 40 | 41 | func (k *PublicKey) MarshalBinary() ([]byte, error) { 42 | return k.p.MarshalBinary() 43 | } 44 | 45 | func (k *PublicKey) UnmarshalBinary(data []byte) error { 46 | k.p = SUITE.Point() 47 | return k.p.UnmarshalBinary(data) 48 | } 49 | 50 | func writeUint32(buf *bytes.Buffer, val uint32) error { 51 | err := binary.Write(buf, binary.LittleEndian, val) 52 | if err != nil { 53 | return err 54 | } 55 | return nil 56 | } 57 | 58 | func readUint32(buf *bytes.Buffer) (uint32, error) { 59 | var tmp uint32 60 | err := binary.Read(buf, binary.LittleEndian, &tmp) 61 | if err != nil { 62 | return 0, err 63 | } 64 | return tmp, nil 65 | } 66 | 67 | func writeBytes(buf *bytes.Buffer, msg []byte) error { 68 | err := writeUint32(buf, uint32(len(msg))) 69 | if err != nil { 70 | return err 71 | } 72 | _, err = buf.Write(msg) 73 | if err != nil { 74 | return err 75 | } 76 | return nil 77 | } 78 | 79 | func readBytes(buf *bytes.Buffer) ([]byte, error) { 80 | size, err := readUint32(buf) 81 | res := make([]byte, size) 82 | _, err = buf.Read(res) 83 | if err != nil { 84 | return nil, err 85 | } 86 | return res, nil 87 | } 88 | 89 | func (d *ThresholdDeal) MarshalBinary() ([]byte, error) { 90 | buf := new(bytes.Buffer) 91 | deal := d.D 92 | writeUint32(buf, deal.Index) 93 | b := deal.Deal.DHKey 94 | err := writeBytes(buf, b) 95 | if err != nil { 96 | return nil, err 97 | } 98 | err = writeBytes(buf, deal.Deal.Signature) 99 | if err != nil { 100 | return nil, err 101 | } 102 | err = writeBytes(buf, deal.Deal.Nonce) 103 | if err != nil { 104 | return nil, err 105 | } 106 | err = writeBytes(buf, deal.Deal.Cipher) 107 | if err != nil { 108 | return nil, err 109 | } 110 | return buf.Bytes(), nil 111 | } 112 | 113 | func (d *ThresholdDeal) UnmarshalBinary(data []byte) error { 114 | buf := bytes.NewBuffer(data) 115 | deal := new(dkg.Deal) 116 | var err error 117 | deal.Index, err = readUint32(buf) 118 | if err != nil { 119 | return err 120 | } 121 | deal.Deal = new(vss.EncryptedDeal) 122 | b, err := readBytes(buf) 123 | if err != nil { 124 | return err 125 | } 126 | deal.Deal.DHKey = b 127 | 128 | deal.Deal.Signature, err = readBytes(buf) 129 | if err != nil { 130 | return err 131 | } 132 | deal.Deal.Nonce, err = readBytes(buf) 133 | if err != nil { 134 | return err 135 | } 136 | deal.Deal.Cipher, err = readBytes(buf) 137 | d.D = deal 138 | return err 139 | } 140 | 141 | func (r *ThresholdResponse) MarshalBinary() ([]byte, error) { 142 | buf, err := protobuf.Encode(r.R) 143 | cp := make([]byte, len(buf)) 144 | copy(cp, buf) 145 | return cp, err 146 | } 147 | 148 | func (r *ThresholdResponse) UnmarshalBinary(data []byte) error { 149 | cp := make([]byte, len(data)) 150 | copy(cp, data) 151 | r.R = new(dkg.Response) 152 | constructors := make(protobuf.Constructors) 153 | var point kyber.Point 154 | var secret kyber.Scalar 155 | constructors[reflect.TypeOf(&point).Elem()] = func() interface{} { return SUITE.Point() } 156 | constructors[reflect.TypeOf(&secret).Elem()] = func() interface{} { return SUITE.Scalar() } 157 | return protobuf.DecodeWithConstructors(cp, r.R, constructors) 158 | } 159 | 160 | func (t *Trap) MarshalBinary() ([]byte, error) { 161 | buf := new(bytes.Buffer) 162 | err := binary.Write(buf, binary.LittleEndian, uint32(t.Gid)) 163 | if err != nil { 164 | return nil, err 165 | } 166 | _, err = buf.Write(t.Nonce) 167 | if err != nil { 168 | return nil, err 169 | } 170 | return buf.Bytes(), nil 171 | } 172 | 173 | func (t *Trap) UnmarshalBinary(data []byte) error { 174 | buf := bytes.NewBuffer(data) 175 | var tmp uint32 176 | err := binary.Read(buf, binary.LittleEndian, &tmp) 177 | if err != nil { 178 | return err 179 | } 180 | t.Gid = int(tmp) 181 | 182 | nonce := make([]byte, NONCE_LEN) 183 | _, err = buf.Read(nonce) 184 | if err != nil { 185 | return err 186 | } 187 | t.Nonce = nonce 188 | return nil 189 | } 190 | 191 | func (i *InnerCiphertext) MarshalBinary() ([]byte, error) { 192 | r, err := i.R.p.MarshalBinary() 193 | if err != nil { 194 | return nil, err 195 | } 196 | return append(r, i.C...), nil 197 | } 198 | 199 | func (i *InnerCiphertext) UnmarshalBinary(data []byte) error { 200 | R := SUITE.Point() 201 | err := R.UnmarshalBinary(data[:R.MarshalSize()]) 202 | if err != nil { 203 | return err 204 | } 205 | i.R = &Point{R} 206 | c := make([]byte, len(data)-R.MarshalSize()) 207 | copy(c, data[R.MarshalSize():]) 208 | i.C = c 209 | return nil 210 | } 211 | -------------------------------------------------------------------------------- /crypto/encoding_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "log" 7 | "testing" 8 | ) 9 | 10 | func TestEncodeKey(t *testing.T) { 11 | keyPair := GenKey() 12 | 13 | var network bytes.Buffer // Stand-in for the network. 14 | // Create an encoder and send a value. 15 | enc := gob.NewEncoder(&network) 16 | err := enc.Encode(keyPair) 17 | if err != nil { 18 | t.Error(err) 19 | } 20 | 21 | // Create a decoder and receive a value. 22 | dec := gob.NewDecoder(&network) 23 | var res KeyPair 24 | err = dec.Decode(&res) 25 | if err != nil { 26 | log.Fatal("decode err:", err) 27 | } 28 | if !keyPair.Priv.s.Equal(res.Priv.s) || 29 | !keyPair.Pub.p.Equal(res.Pub.p) { 30 | log.Fatal("Failed to encode key") 31 | } 32 | } 33 | 34 | func TestEncodeMsg(t *testing.T) { 35 | msg := GenPoints(size) 36 | 37 | var network bytes.Buffer // Stand-in for the network. 38 | // Create an encoder and send a value. 39 | enc := gob.NewEncoder(&network) 40 | err := enc.Encode(msg) 41 | if err != nil { 42 | t.Error(err) 43 | } 44 | 45 | // Create a decoder and receive a value. 46 | dec := gob.NewDecoder(&network) 47 | var res []*Point 48 | err = dec.Decode(&res) 49 | if err != nil { 50 | log.Fatal("decode err:", err) 51 | } 52 | 53 | for m := range msg { 54 | if !msg[m].Equal(res[m]) { 55 | t.Error("Failed to encode msg") 56 | } 57 | } 58 | } 59 | 60 | func TestEncodeThreshold(t *testing.T) { 61 | keys, pubs, _ := GenKeys(N) 62 | ts := make([]*Threshold, N) 63 | sendss := make([][]*ThresholdDeal, N) 64 | recvss := make([][]*ThresholdResponse, N) 65 | for i := range ts { 66 | cp := CopyPubs(pubs) // done to avoid weird race condition 67 | ts[i] = NewThreshold(i, T, keys[i], cp) 68 | sendss[i] = make([]*ThresholdDeal, N) 69 | recvss[i] = make([]*ThresholdResponse, N) 70 | } 71 | 72 | deal := ts[0].GetDeal(1) 73 | b, _ := deal.MarshalBinary() 74 | dcp := new(ThresholdDeal) 75 | dcp.UnmarshalBinary(b) 76 | //fmt.Println(deal.D.Deal, dcp.D.Deal) 77 | } 78 | -------------------------------------------------------------------------------- /crypto/keys.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "encoding/hex" 5 | "encoding/json" 6 | "io/ioutil" 7 | "log" 8 | "os" 9 | 10 | "github.com/dedis/kyber/util/random" 11 | ) 12 | 13 | func GenKey() *KeyPair { 14 | x := SUITE.Scalar().Pick(random.New()) 15 | X := SUITE.Point().Mul(x, nil) 16 | return &KeyPair{ 17 | Priv: &PrivateKey{x}, 18 | Pub: &PublicKey{X}, 19 | } 20 | } 21 | 22 | func NullKey() *PublicKey { 23 | return &PublicKey{SUITE.Point().Null()} 24 | } 25 | 26 | func PubFromPriv(priv *PrivateKey) *PublicKey { 27 | return &PublicKey{SUITE.Point().Mul(priv.s, nil)} 28 | } 29 | 30 | func GenKeys(N int) ([]*KeyPair, []*PublicKey, []*PrivateKey) { 31 | keys := make([]*KeyPair, N) 32 | pubs := make([]*PublicKey, N) 33 | privs := make([]*PrivateKey, N) 34 | for k := range keys { 35 | keys[k] = GenKey() 36 | pubs[k] = keys[k].Pub 37 | privs[k] = keys[k].Priv 38 | } 39 | return keys, pubs, privs 40 | } 41 | 42 | // The combined private key for a bunch of nodes 43 | func CombinePrivateKeys(privs []*PrivateKey) *PrivateKey { 44 | t := SUITE.Scalar().Zero() 45 | for i := range privs { 46 | t = t.Add(t, privs[i].s) 47 | } 48 | return &PrivateKey{t} 49 | } 50 | 51 | // The combined public key for a bunch of nodes 52 | func CombinePublicKeys(pubs []*PublicKey) *PublicKey { 53 | h := SUITE.Point().Null() 54 | for i := range pubs { 55 | h = h.Add(h, pubs[i].p) 56 | } 57 | return &PublicKey{h} 58 | } 59 | 60 | func ReadKeys(fn string) ([]HexKeyPair, error) { 61 | file, err := os.Open(fn) 62 | if err != nil { 63 | return nil, err 64 | } 65 | bs, err := ioutil.ReadAll(file) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | var keys []HexKeyPair 71 | err = json.Unmarshal(bs, &keys) 72 | if err != nil { 73 | return nil, err 74 | } 75 | return keys, nil 76 | } 77 | 78 | func DumpPrivKey(priv *PrivateKey) string { 79 | b, err := priv.s.MarshalBinary() 80 | if err != nil { 81 | log.Fatal("secret key err:", err) 82 | } 83 | return hex.EncodeToString(b) 84 | } 85 | 86 | func DumpPubKey(pub *PublicKey) string { 87 | b, err := pub.p.MarshalBinary() 88 | if err != nil { 89 | log.Fatal("public key err:", err) 90 | } 91 | return hex.EncodeToString(b) 92 | } 93 | 94 | func DumpKey(keyPair *KeyPair) HexKeyPair { 95 | return HexKeyPair{ 96 | Priv: DumpPrivKey(keyPair.Priv), 97 | Pub: DumpPubKey(keyPair.Pub), 98 | } 99 | } 100 | 101 | func LoadPrivKey(priv string) *PrivateKey { 102 | pb, err := hex.DecodeString(priv) 103 | if err != nil { 104 | log.Fatal("Loading malformed keys", err) 105 | } 106 | privKey := SUITE.Scalar() 107 | err = privKey.UnmarshalBinary(pb) 108 | if err != nil { 109 | log.Fatal("Loading malformed keys", err) 110 | } 111 | return &PrivateKey{privKey} 112 | } 113 | 114 | func LoadPubKey(pub string) *PublicKey { 115 | pb, err := hex.DecodeString(pub) 116 | if err != nil { 117 | log.Fatal("Loading malformed keys", err) 118 | } 119 | pubKey := SUITE.Point() 120 | err = pubKey.UnmarshalBinary(pb) 121 | if err != nil { 122 | log.Fatal("Loading malformed keys", err) 123 | } 124 | return &PublicKey{pubKey} 125 | } 126 | 127 | func LoadKey(key HexKeyPair) *KeyPair { 128 | return &KeyPair{ 129 | Priv: LoadPrivKey(key.Priv), 130 | Pub: LoadPubKey(key.Pub), 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /crypto/nizk.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "runtime" 7 | "sync" 8 | 9 | "golang.org/x/crypto/sha3" 10 | 11 | "github.com/dedis/kyber" 12 | "github.com/dedis/kyber/proof" 13 | "github.com/dedis/kyber/shuffle" 14 | "github.com/dedis/kyber/util/random" 15 | ) 16 | 17 | func ProveEncrypt(X *PublicKey, msg Message) (Ciphertext, EncProof) { 18 | rnd := random.New() 19 | R := make([]*Point, len(msg)) 20 | C := make([]*Point, len(msg)) 21 | proof := EncProof{ 22 | S: make([]*Point, len(msg)), 23 | U: make([]*Scalar, len(msg)), 24 | } 25 | Xbin, _ := X.MarshalBinary() 26 | for idx := range msg { 27 | r := SUITE.Scalar().Pick(rnd) 28 | R[idx] = &Point{SUITE.Point().Mul(r, nil)} 29 | C[idx] = &Point{SUITE.Point().Add(msg[idx].p, SUITE.Point().Mul(r, X.p))} 30 | 31 | s := SUITE.Scalar().Pick(rnd) 32 | S := SUITE.Point().Mul(s, nil) 33 | 34 | Cbin, _ := C[idx].MarshalBinary() 35 | sbin, _ := S.MarshalBinary() 36 | inp := append(Cbin, sbin...) 37 | inp = append(inp, Xbin...) 38 | tbin := sha3.Sum256(inp) 39 | t := SUITE.Scalar().SetBytes(tbin[:]) 40 | u := s.Add(s, t.Mul(t, r)) 41 | proof.S[idx] = &Point{S} 42 | proof.U[idx] = &Scalar{u} 43 | } 44 | return Ciphertext{ 45 | R: R, 46 | C: C, 47 | Y: nil, 48 | }, proof 49 | } 50 | 51 | func VerifyEncrypt(X *PublicKey, c Ciphertext, proof EncProof) error { 52 | Xbin, _ := X.MarshalBinary() 53 | for idx := range c.C { 54 | U := SUITE.Point().Mul(proof.U[idx].s, nil) 55 | S := proof.S[idx].p 56 | 57 | Cbin, _ := c.C[idx].MarshalBinary() 58 | sbin, _ := S.MarshalBinary() 59 | inp := append(Cbin, sbin...) 60 | inp = append(inp, Xbin...) 61 | tbin := sha3.Sum256(inp) 62 | t := SUITE.Scalar().SetBytes(tbin[:]) 63 | R := SUITE.Point().Mul(t, c.R[idx].p) 64 | R = R.Add(S, R) 65 | if !U.Equal(R) { 66 | return errors.New("Encproof verify failed") 67 | } 68 | } 69 | return nil 70 | } 71 | 72 | func ProveReencrypt(x *PrivateKey, XBar *PublicKey, c Ciphertext) (Ciphertext, ReencProof) { 73 | if c.Y == nil { 74 | c.Y = c.R 75 | c.R = make([]*Point, len(c.Y)) 76 | for idx := range c.R { 77 | c.R[idx] = &Point{SUITE.Point().Null()} 78 | } 79 | } 80 | rnd := random.New() 81 | 82 | proofs := make([][]byte, len(c.C)) 83 | 84 | p := proof.Rep("Y'-Y", "-h", "X", "r", "B") 85 | 86 | ciphertext := Ciphertext{ 87 | R: make([]*Point, len(c.R)), 88 | C: make([]*Point, len(c.C)), 89 | Y: make([]*Point, len(c.Y)), 90 | } 91 | 92 | negx := SUITE.Scalar().Neg(x.s) 93 | for idx := range c.C { 94 | pub := map[string]kyber.Point{"B": XBar.p} 95 | sec := map[string]kyber.Scalar{"-h": negx} 96 | 97 | blind := SUITE.Point().Mul(x.s, c.Y[idx].p) 98 | ctmp := blind.Sub(c.C[idx].p, blind) 99 | 100 | rBar := SUITE.Scalar().Pick(rnd) 101 | newR := SUITE.Point().Mul(rBar, nil) 102 | newR = newR.Add(c.R[idx].p, newR) 103 | 104 | newBlind := SUITE.Point().Mul(rBar, XBar.p) 105 | newC := ctmp.Add(ctmp, newBlind) 106 | 107 | ciphertext.R[idx] = &Point{newR} 108 | ciphertext.C[idx] = &Point{newC} 109 | ciphertext.Y[idx] = c.Y[idx] 110 | 111 | pub["X"] = c.Y[idx].p 112 | sec["r"] = rBar 113 | pub["Y'-Y"] = SUITE.Point().Sub(ctmp, c.C[idx].p) 114 | prover := p.Prover(SUITE, sec, pub, nil) 115 | proof, err := proof.HashProve(SUITE, "Decrypt", prover) 116 | if err != nil { 117 | log.Fatal("Proof gen err:", err) 118 | } 119 | proofs[idx] = proof 120 | } 121 | return ciphertext, proofs 122 | } 123 | 124 | func ProveReencryptBatches(priv *PrivateKey, neighborKeys []*PublicKey, batches [][]Ciphertext) ([][]Ciphertext, [][]ReencProof) { 125 | numBatches := len(batches) 126 | batchSize := len(batches[0]) 127 | k := numBatches * batchSize 128 | ciphertexts := make([]Ciphertext, k) 129 | proofs := make([]ReencProof, k) 130 | pubs := make([]*PublicKey, k) 131 | idx := 0 132 | for b := range batches { 133 | for i := range batches[b] { 134 | ciphertexts[idx] = batches[b][i] 135 | pubs[idx] = neighborKeys[b] 136 | idx++ 137 | } 138 | } 139 | 140 | chunks := runtime.NumCPU() 141 | div := k / chunks 142 | if k < chunks { 143 | div = 1 144 | chunks = k 145 | } else if k%chunks != 0 { 146 | div++ 147 | } 148 | 149 | wg := new(sync.WaitGroup) 150 | wg.Add(chunks) 151 | for d := 0; d < chunks; d++ { 152 | start := d * div 153 | end := (d + 1) * div 154 | if end > k { 155 | end = k 156 | } 157 | go func(start, end int) { 158 | defer wg.Done() 159 | for i := start; i < end; i++ { 160 | ciphertexts[i], proofs[i] = ProveReencrypt(priv, pubs[i], ciphertexts[i]) 161 | } 162 | }(start, end) 163 | } 164 | wg.Wait() 165 | 166 | resultc := make([][]Ciphertext, len(batches)) 167 | resultp := make([][]ReencProof, len(batches)) 168 | 169 | idx = 0 170 | for b := range batches { 171 | resultc[b] = make([]Ciphertext, len(batches[b])) 172 | resultp[b] = make([]ReencProof, len(batches[b])) 173 | for i := range batches[b] { 174 | resultc[b][i] = ciphertexts[idx] 175 | resultp[b][i] = proofs[idx] 176 | idx++ 177 | } 178 | } 179 | 180 | return resultc, resultp 181 | } 182 | 183 | func VerifyReencrypt(X *PublicKey, old, new Ciphertext, proofs ReencProof) error { 184 | p := proof.Rep("Y'-Y", "-h", "X", "r", "B") 185 | for idx := range new.C { 186 | pub := map[string]kyber.Point{"B": X.p} 187 | pub["X"] = new.Y[idx].p 188 | pub["Y'-Y"] = SUITE.Point().Sub(new.C[idx].p, old.C[idx].p) 189 | verifier := p.Verifier(SUITE, pub) 190 | err := proof.HashVerify(SUITE, "Decrypt", verifier, proofs[idx]) 191 | if err != nil { 192 | return err 193 | } 194 | } 195 | return nil 196 | } 197 | 198 | func VerifyReencryptBatches(ob, nb [][]Ciphertext, proofs [][]ReencProof, neighborKeys []*PublicKey) bool { 199 | for b := range nb { 200 | for c := range nb[b] { 201 | err := VerifyReencrypt(neighborKeys[b], ob[b][c], nb[b][c], proofs[b][c]) 202 | if err != nil { 203 | log.Println("Incorrect reencrypt proof:", err) 204 | return false 205 | } 206 | } 207 | } 208 | return true 209 | } 210 | 211 | // Reblind and also return reblinding factors for prove shuffle 212 | func reblind(X *PublicKey, c Ciphertext) (Ciphertext, []*Scalar) { 213 | rnd := random.New() 214 | blinds := make([]*Scalar, len(c.C)) 215 | nc := Ciphertext{ 216 | R: make([]*Point, len(c.R)), 217 | C: make([]*Point, len(c.C)), 218 | } 219 | for idx := range c.C { 220 | r := SUITE.Scalar().Pick(rnd) 221 | blinds[idx] = &Scalar{r} 222 | newR := SUITE.Point().Mul(r, nil) 223 | newBlind := SUITE.Point().Mul(r, X.p) 224 | 225 | nc.R[idx] = &Point{SUITE.Point().Add(c.R[idx].p, newR)} 226 | nc.C[idx] = &Point{SUITE.Point().Add(c.C[idx].p, newBlind)} 227 | } 228 | return nc, blinds 229 | } 230 | 231 | func ProveShuffle(X *PublicKey, cs []Ciphertext) ([]Ciphertext, ShufProof) { 232 | rnd := random.New() 233 | k := len(cs) 234 | 235 | ciphertexts := make([]Ciphertext, k) 236 | tmp := make([]Ciphertext, k) 237 | blinds := make([][]*Scalar, k) 238 | 239 | chunks := runtime.NumCPU() 240 | div := k / chunks 241 | if k < chunks { 242 | div = 1 243 | chunks = k 244 | } else if k%chunks != 0 { 245 | div += 1 246 | } 247 | 248 | wg := new(sync.WaitGroup) 249 | wg.Add(chunks) 250 | 251 | for d := 0; d < chunks; d++ { 252 | start := d * div 253 | end := (d + 1) * div 254 | if end > k { 255 | end = k 256 | } 257 | go func(start, end int) { 258 | defer wg.Done() 259 | for i := start; i < end; i++ { 260 | tmp[i], blinds[i] = reblind(X, cs[i]) 261 | } 262 | }(start, end) 263 | } 264 | wg.Wait() 265 | 266 | pi := make([]int, k) 267 | for i := 0; i < k; i++ { // Initialize a trivial permutation 268 | pi[i] = i 269 | } 270 | for i := k - 1; i > 0; i-- { // Shuffle by random swaps 271 | j := int(randUint64(rnd) % uint64(i+1)) 272 | if j != i { 273 | t := pi[j] 274 | pi[j] = pi[i] 275 | pi[i] = t 276 | } 277 | } 278 | 279 | for i := range cs { 280 | ciphertexts[i] = tmp[pi[i]] 281 | } 282 | 283 | proofs := make([][]byte, len(cs[0].C)) 284 | wg.Add(len(cs[0].C)) 285 | for idx := range cs[0].C { 286 | go func(idx int) { 287 | defer wg.Done() 288 | ps := shuffle.PairShuffle{} 289 | ps.Init(SUITE, k) 290 | 291 | R := make([]kyber.Point, k) 292 | C := make([]kyber.Point, k) 293 | r := make([]kyber.Scalar, k) 294 | for c := range cs { 295 | R[c] = cs[c].R[idx].p 296 | C[c] = cs[c].C[idx].p 297 | r[c] = blinds[c][idx].s 298 | } 299 | 300 | prover := func(ctx proof.ProverContext) error { 301 | return ps.Prove(pi, nil, X.p, r, R, C, rnd, ctx) 302 | } 303 | proof, err := proof.HashProve(SUITE, "PairShuffle", prover) 304 | if err != nil { 305 | log.Fatal("Error creating proof:", err) 306 | } 307 | proofs[idx] = proof 308 | }(idx) 309 | } 310 | wg.Wait() 311 | return ciphertexts, proofs 312 | } 313 | 314 | func VerifyShuffle(X *PublicKey, oc, nc []Ciphertext, proofs ShufProof) error { 315 | if len(oc) != len(nc) { 316 | return errors.New("Mismatching length") 317 | } 318 | k := len(nc) 319 | 320 | errChan := make(chan error) 321 | for idx := range oc[0].C { 322 | go func(idx int) { 323 | ps := shuffle.PairShuffle{} 324 | ps.Init(SUITE, len(nc)) 325 | 326 | R := make([]kyber.Point, k) 327 | C := make([]kyber.Point, k) 328 | Rbar := make([]kyber.Point, k) 329 | Cbar := make([]kyber.Point, k) 330 | for c := range oc { 331 | R[c] = oc[c].R[idx].p 332 | C[c] = oc[c].C[idx].p 333 | Rbar[c] = nc[c].R[idx].p 334 | Cbar[c] = nc[c].C[idx].p 335 | } 336 | 337 | verifier := func(ctx proof.VerifierContext) error { 338 | return ps.Verify(nil, X.p, R, C, Rbar, Cbar, ctx) 339 | } 340 | errChan <- proof.HashVerify(SUITE, "PairShuffle", verifier, proofs[idx]) 341 | }(idx) 342 | } 343 | 344 | for _ = range oc[0].C { 345 | err := <-errChan 346 | if err != nil { 347 | return err 348 | } 349 | } 350 | 351 | return nil 352 | } 353 | -------------------------------------------------------------------------------- /crypto/nizk_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import "testing" 4 | 5 | func TestNIZKEncrypt(t *testing.T) { 6 | key := GenKey() 7 | x, X := key.Priv, key.Pub 8 | msg := GenRandMsg(size) 9 | 10 | c, proof := ProveEncrypt(X, msg) 11 | 12 | err := VerifyEncrypt(X, c, proof) 13 | if err != nil { 14 | t.Error("Prove fail:", err) 15 | } 16 | 17 | res := Decrypt(x, c) 18 | for m := range msg { 19 | if !msg[m].Equal(res[m]) { 20 | t.Error("Mismatched plaintext message.") 21 | } 22 | } 23 | } 24 | 25 | func TestNIZKRencrypt(t *testing.T) { 26 | key1 := GenKey() 27 | x1, X1 := key1.Priv, key1.Pub 28 | key2 := GenKey() 29 | x2, X2 := key2.Priv, key2.Pub 30 | 31 | msg := GenRandMsg(size) 32 | ciphertext := Encrypt(X1, msg) 33 | 34 | reenc, proofs := ProveReencrypt(x1, X2, ciphertext) 35 | 36 | err := VerifyReencrypt(X2, ciphertext, reenc, proofs) 37 | if err != nil { 38 | t.Error("Prove fail:", err) 39 | } 40 | 41 | res := Decrypt(x2, reenc) 42 | for m := range msg { 43 | if !msg[m].Equal(res[m]) { 44 | t.Error("Mismatched plaintext message.") 45 | } 46 | } 47 | } 48 | 49 | func TestNIZKShuffle(t *testing.T) { 50 | key := GenKey() 51 | x, X := key.Priv, key.Pub 52 | 53 | msgs := GenRandMsgs(3, size) 54 | ciphertexts := make([]Ciphertext, 3) 55 | for m := range msgs { 56 | ciphertexts[m] = Encrypt(X, msgs[m]) 57 | } 58 | 59 | shuffled, proofs := ProveShuffle(X, ciphertexts) 60 | 61 | results := make([]Message, 3) 62 | for r := range results { 63 | results[r] = Decrypt(x, shuffled[r]) 64 | } 65 | 66 | for r := range results { 67 | if !isMemberMessage(results[r], msgs) { 68 | t.Error("Missing message.") 69 | } 70 | } 71 | 72 | err := VerifyShuffle(X, ciphertexts, shuffled, proofs) 73 | if err != nil { 74 | t.Error("VerifyShuffle failed:", err) 75 | } 76 | } 77 | 78 | func BenchmarkEncryptProve(b *testing.B) { 79 | key := GenKey() 80 | _, X := key.Priv, key.Pub 81 | msg := GenRandMsg(1) 82 | 83 | b.ResetTimer() 84 | for i := 0; i < b.N; i++ { 85 | ProveEncrypt(X, msg) 86 | } 87 | } 88 | 89 | func BenchmarkEncryptVerify(b *testing.B) { 90 | key := GenKey() 91 | _, X := key.Priv, key.Pub 92 | msg := GenRandMsg(1) 93 | 94 | c, proof := ProveEncrypt(X, msg) 95 | 96 | b.ResetTimer() 97 | for i := 0; i < b.N; i++ { 98 | VerifyEncrypt(X, c, proof) 99 | } 100 | } 101 | 102 | func BenchmarkRencryptProve(b *testing.B) { 103 | key1 := GenKey() 104 | x1, X1 := key1.Priv, key1.Pub 105 | key2 := GenKey() 106 | _, X2 := key2.Priv, key2.Pub 107 | 108 | msg := GenRandMsg(1) 109 | ciphertext := Encrypt(X1, msg) 110 | 111 | b.ResetTimer() 112 | for i := 0; i < b.N; i++ { 113 | ProveReencrypt(x1, X2, ciphertext) 114 | } 115 | } 116 | 117 | func BenchmarkRencryptVerify(b *testing.B) { 118 | key1 := GenKey() 119 | x1, X1 := key1.Priv, key1.Pub 120 | key2 := GenKey() 121 | _, X2 := key2.Priv, key2.Pub 122 | 123 | msg := GenRandMsg(1) 124 | ciphertext := Encrypt(X1, msg) 125 | 126 | reenc, proofs := ProveReencrypt(x1, X2, ciphertext) 127 | 128 | b.ResetTimer() 129 | for i := 0; i < b.N; i++ { 130 | VerifyReencrypt(X2, ciphertext, reenc, proofs) 131 | } 132 | } 133 | 134 | func BenchmarkShuffleProve1024(b *testing.B) { 135 | numPts := 1024 136 | key := GenKey() 137 | _, X := key.Priv, key.Pub 138 | 139 | msgs := GenRandMsgs(numPts, 1) 140 | ciphertexts := make([]Ciphertext, numPts) 141 | for m := range msgs { 142 | ciphertexts[m] = Encrypt(X, msgs[m]) 143 | } 144 | 145 | b.ResetTimer() 146 | for i := 0; i < b.N; i++ { 147 | ProveShuffle(X, ciphertexts) 148 | } 149 | } 150 | 151 | func BenchmarkShuffleProve2048(b *testing.B) { 152 | numPts := 2048 153 | key := GenKey() 154 | _, X := key.Priv, key.Pub 155 | 156 | msgs := GenRandMsgs(numPts, 1) 157 | ciphertexts := make([]Ciphertext, numPts) 158 | for m := range msgs { 159 | ciphertexts[m] = Encrypt(X, msgs[m]) 160 | } 161 | 162 | b.ResetTimer() 163 | for i := 0; i < b.N; i++ { 164 | ProveShuffle(X, ciphertexts) 165 | } 166 | } 167 | 168 | func BenchmarkShuffleVerify1024(b *testing.B) { 169 | numPts := 1024 170 | key := GenKey() 171 | _, X := key.Priv, key.Pub 172 | 173 | msgs := GenRandMsgs(numPts, 1) 174 | ciphertexts := make([]Ciphertext, numPts) 175 | for m := range msgs { 176 | ciphertexts[m] = Encrypt(X, msgs[m]) 177 | } 178 | 179 | shuffled, proofs := ProveShuffle(X, ciphertexts) 180 | 181 | b.ResetTimer() 182 | for i := 0; i < b.N; i++ { 183 | VerifyShuffle(X, ciphertexts, shuffled, proofs) 184 | } 185 | } 186 | 187 | func BenchmarkShuffleVerify2048(b *testing.B) { 188 | numPts := 2048 189 | key := GenKey() 190 | _, X := key.Priv, key.Pub 191 | 192 | msgs := GenRandMsgs(numPts, 1) 193 | ciphertexts := make([]Ciphertext, numPts) 194 | for m := range msgs { 195 | ciphertexts[m] = Encrypt(X, msgs[m]) 196 | } 197 | 198 | shuffled, proofs := ProveShuffle(X, ciphertexts) 199 | 200 | b.ResetTimer() 201 | for i := 0; i < b.N; i++ { 202 | VerifyShuffle(X, ciphertexts, shuffled, proofs) 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /crypto/rand.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "log" 7 | 8 | "golang.org/x/crypto/sha3" 9 | ) 10 | 11 | type Reader struct { 12 | io.Reader 13 | } 14 | 15 | func NewRandReader(seed []byte) *Reader { 16 | h := sha3.NewShake128() 17 | h.Write(seed) 18 | return &Reader{h} 19 | } 20 | 21 | func (r *Reader) UInt() int { 22 | buf := make([]byte, 8) 23 | _, err := r.Read(buf) 24 | if err != nil { 25 | log.Fatal(err) 26 | } 27 | tmp, _ := binary.Uvarint(buf) 28 | return int(tmp) 29 | } 30 | -------------------------------------------------------------------------------- /crypto/shuffle.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | 7 | "github.com/dedis/kyber/util/random" 8 | ) 9 | 10 | // Reblind the messages and randomly permute them 11 | func Shuffle(X *PublicKey, cs []Ciphertext) []Ciphertext { 12 | rnd := random.New() 13 | k := len(cs) 14 | 15 | ciphertexts := make([]Ciphertext, k) 16 | tmp := make([]Ciphertext, k) 17 | 18 | chunks := runtime.NumCPU() 19 | div := k / chunks 20 | if k < chunks { 21 | div = 1 22 | chunks = k 23 | } else if k%chunks != 0 { 24 | div++ 25 | } 26 | 27 | wg := new(sync.WaitGroup) 28 | wg.Add(chunks) 29 | for d := 0; d < chunks; d++ { 30 | start := d * div 31 | end := (d + 1) * div 32 | if end > k { 33 | end = k 34 | } 35 | go func(start, end int) { 36 | defer wg.Done() 37 | for i := start; i < end; i++ { 38 | tmp[i] = Reblind(X, cs[i]) 39 | } 40 | }(start, end) 41 | } 42 | wg.Wait() 43 | 44 | // for i := range cs { 45 | // tmp[i] = Reblind(X, cs[i]) 46 | // } 47 | 48 | pi := make([]int, k) 49 | for i := 0; i < k; i++ { // Initialize a trivial permutation 50 | pi[i] = i 51 | } 52 | for i := k - 1; i > 0; i-- { // Shuffle by random swaps 53 | j := int(randUint64(rnd) % uint64(i+1)) 54 | if j != i { 55 | t := pi[j] 56 | pi[j] = pi[i] 57 | pi[i] = t 58 | } 59 | } 60 | 61 | for i := range cs { 62 | ciphertexts[i] = tmp[pi[i]] 63 | } 64 | return ciphertexts 65 | } 66 | -------------------------------------------------------------------------------- /crypto/threshold.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "sync" 7 | 8 | "github.com/dedis/kyber" 9 | dkg "github.com/dedis/kyber/share/dkg/pedersen" 10 | ) 11 | 12 | type Threshold struct { 13 | N int 14 | t int 15 | myIdx int 16 | keyPair *KeyPair 17 | groupKey PublicKey 18 | 19 | keyGen *dkg.DistKeyGenerator 20 | deals map[int]*ThresholdDeal 21 | secret *dkg.DistKeyShare 22 | 23 | dealCnt int // channel to count number of deals received 24 | respCnt chan bool // channel to count number of deals received 25 | dealCond *sync.Cond 26 | } 27 | 28 | func NewThreshold(myIdx, T int, key *KeyPair, longPubs []*PublicKey) *Threshold { 29 | N := len(longPubs) 30 | cpPub := CopyPubs(longPubs) 31 | myGroupKeys := make([]kyber.Point, N) 32 | for i := 0; i < N; i++ { 33 | myGroupKeys[i] = cpPub[i].p 34 | } 35 | 36 | keyGen, err := dkg.NewDistKeyGenerator(SUITE, key.Priv.s, 37 | myGroupKeys, T) 38 | if err != nil { 39 | log.Fatal("Could not create dist key gen", err) 40 | } 41 | deals, err := keyGen.Deals() 42 | if err != nil { 43 | log.Fatal("Could not create deals", err) 44 | } 45 | 46 | tdeals := make(map[int]*ThresholdDeal) 47 | for i := range deals { 48 | tdeals[i] = &ThresholdDeal{deals[i]} 49 | } 50 | 51 | t := &Threshold{ 52 | t: T, 53 | N: N, 54 | 55 | myIdx: myIdx, 56 | keyPair: key, 57 | 58 | keyGen: keyGen, 59 | deals: tdeals, 60 | 61 | dealCnt: 0, 62 | respCnt: make(chan bool, N*N), 63 | dealCond: sync.NewCond(new(sync.Mutex)), 64 | } 65 | 66 | return t 67 | } 68 | 69 | func (t *Threshold) AddDeal(deal *ThresholdDeal) (*ThresholdResponse, error) { 70 | t.dealCond.L.Lock() 71 | defer t.dealCond.L.Unlock() 72 | resp, err := t.keyGen.ProcessDeal(deal.D) 73 | if err != nil { 74 | return nil, err 75 | } 76 | t.dealCnt += 1 77 | if t.dealCnt == t.N-1 { 78 | t.dealCond.Broadcast() 79 | } 80 | return &ThresholdResponse{resp}, nil 81 | } 82 | 83 | func (t *Threshold) GetDeal(i int) *ThresholdDeal { 84 | return t.deals[i] 85 | } 86 | 87 | func (t *Threshold) AddResponse(resp *ThresholdResponse) error { 88 | t.dealCond.L.Lock() 89 | for t.dealCnt < t.N-1 { 90 | t.dealCond.Wait() 91 | } 92 | just, err := t.keyGen.ProcessResponse(resp.R) 93 | if just != nil { 94 | return errors.New("Justification not null") 95 | } else if err != nil { 96 | return err 97 | } 98 | t.dealCond.L.Unlock() 99 | t.respCnt <- true 100 | return nil 101 | } 102 | 103 | // Joint verifiable secret sharing setup 104 | func (t *Threshold) JVSS() error { 105 | // Wait until it receives enough resps 106 | for i := 0; i < (t.N-1)*(t.N-1); i++ { 107 | <-t.respCnt 108 | } 109 | 110 | var err error 111 | t.secret, err = t.keyGen.DistKeyShare() 112 | if err != nil { 113 | return err 114 | } 115 | t.groupKey = PublicKey{t.secret.Public()} 116 | return nil 117 | } 118 | 119 | func (t *Threshold) PublicKey() *PublicKey { 120 | return &t.groupKey 121 | } 122 | 123 | // Given threshold group in terms of the index within the group, 124 | // compute the lagrangian, and relevant point 125 | func (t *Threshold) Lagrange(group []int) *PrivateKey { 126 | numer := SUITE.Scalar().One() 127 | denom := SUITE.Scalar().One() 128 | xServer := SUITE.Scalar().SetInt64(1 + int64(t.myIdx)) 129 | for i := range group { 130 | if group[i] == t.myIdx { 131 | continue 132 | } 133 | numer = numer.Mul(numer, SUITE.Scalar().SetInt64(1+int64(group[i]))) 134 | xj := SUITE.Scalar().SetInt64(1 + int64(group[i])) 135 | xj = xj.Sub(xj, xServer) 136 | denom = denom.Mul(denom, xj) 137 | } 138 | numer = numer.Div(numer, denom) 139 | key := SUITE.Scalar().Mul(numer, t.secret.Share.V) 140 | return &PrivateKey{key} 141 | } 142 | -------------------------------------------------------------------------------- /crypto/threshold_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | ) 7 | 8 | var M int = 1 9 | var N = 5 10 | var T = 4 11 | 12 | func sendDeal(deal *ThresholdDeal, dst *Threshold, all []*Threshold) { 13 | resp, _ := dst.AddDeal(deal) 14 | for _, t := range all { 15 | if dst != t { 16 | err := t.AddResponse(resp) 17 | if err != nil { 18 | log.Fatal(err) 19 | } 20 | } 21 | } 22 | } 23 | 24 | func TestThresholdSharing(t *testing.T) { 25 | keys, pubs, _ := GenKeys(N) 26 | ts := make([]*Threshold, N) 27 | sendss := make([][]*ThresholdDeal, N) 28 | recvss := make([][]*ThresholdResponse, N) 29 | for i := range ts { 30 | ts[i] = NewThreshold(i, T, keys[i], pubs) 31 | sendss[i] = make([]*ThresholdDeal, N) 32 | recvss[i] = make([]*ThresholdResponse, N) 33 | } 34 | 35 | errs := make(chan error) 36 | 37 | for i := range ts { 38 | for j := range ts { 39 | if i == j { 40 | continue 41 | } 42 | deal := ts[i].GetDeal(j) 43 | go sendDeal(deal, ts[j], ts) 44 | } 45 | } 46 | 47 | for i := range ts { 48 | go func(i int) { 49 | errs <- ts[i].JVSS() 50 | }(i) 51 | } 52 | 53 | for i := range ts { 54 | err := <-errs 55 | if err != nil { 56 | t.Error("Person", i, err) 57 | } 58 | } 59 | 60 | groupKey := ts[0].PublicKey() 61 | 62 | msg := GenRandMsg(5) 63 | nullKey := &PublicKey{SUITE.Point().Null()} 64 | 65 | ciphertext := Encrypt(groupKey, msg) 66 | 67 | fullGroup := make([]int, T) 68 | for i := 0; i < T; i++ { 69 | fullGroup[i] = i 70 | } 71 | 72 | for i := 0; i < T; i++ { 73 | share := ts[i].Lagrange(fullGroup) 74 | ciphertext = Reencrypt(share, nullKey, ciphertext) 75 | } 76 | 77 | for i := range ciphertext.C { 78 | if !ciphertext.C[i].Equal(msg[i]) { 79 | t.Error("Data corrupted!") 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /crypto/types.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "github.com/dedis/kyber" 5 | dkg "github.com/dedis/kyber/share/dkg/pedersen" 6 | ) 7 | 8 | const ( 9 | NONCE_LEN = 16 10 | ) 11 | 12 | const ( 13 | MSG = 0 14 | TRAP = 1 15 | OTHER = 2 16 | ) 17 | 18 | type MsgType byte 19 | 20 | type Point struct { 21 | p kyber.Point 22 | } 23 | 24 | func (p1 *Point) Equal(p2 *Point) bool { 25 | return p1.p.Equal(p2.p) 26 | } 27 | 28 | func (p *Point) String() string { 29 | return p.p.String() 30 | } 31 | 32 | func (p1 *PublicKey) Equal(p2 *PublicKey) bool { 33 | return p1.p.Equal(p2.p) 34 | } 35 | 36 | func (p *PublicKey) String() string { 37 | return p.p.String() 38 | } 39 | 40 | type Scalar struct { 41 | s kyber.Scalar 42 | } 43 | 44 | func (s1 *Scalar) Equal(s2 *Scalar) bool { 45 | return s1.s.Equal(s2.s) 46 | } 47 | 48 | func (s *Scalar) String() string { 49 | return s.s.String() 50 | } 51 | 52 | func (s1 *PrivateKey) Equal(s2 *PrivateKey) bool { 53 | return s1.s.Equal(s2.s) 54 | } 55 | 56 | func (s *PrivateKey) String() string { 57 | return s.s.String() 58 | } 59 | 60 | type Message []*Point 61 | 62 | func (m1 Message) Equal(m2 Message) bool { 63 | ok := true 64 | for m := range m1 { 65 | ok = ok && m1[m].p.Equal(m2[m].p) 66 | } 67 | return ok 68 | } 69 | 70 | type PrivateKey Scalar 71 | type PublicKey Point 72 | 73 | type KeyPair struct { 74 | Priv *PrivateKey 75 | Pub *PublicKey 76 | } 77 | 78 | func (k *KeyPair) String() string { 79 | return "(" + k.Priv.String() + ", " + k.Pub.String() + ")" 80 | } 81 | 82 | type TrusteeKey struct { 83 | Round int // trustee keys are per round 84 | KeyPair 85 | } 86 | 87 | type HexKeyPair struct { 88 | Priv string 89 | Pub string 90 | } 91 | 92 | type InnerCiphertext struct { 93 | R *Point 94 | C []byte 95 | } 96 | 97 | // single user-submitted ciphertext 98 | type Ciphertext struct { 99 | R []*Point 100 | C []*Point 101 | Y []*Point 102 | } 103 | 104 | type Trap struct { 105 | Gid int 106 | Nonce []byte 107 | } 108 | 109 | type Commitment [32]byte 110 | 111 | func (c Commitment) String() string { 112 | buf := make([]byte, 32) 113 | copy(buf, c[:]) 114 | return string(buf) 115 | } 116 | 117 | type EncProof struct { 118 | S []*Point 119 | U []*Scalar 120 | } 121 | 122 | type ShufProof [][]byte 123 | 124 | type ReencProof [][]byte 125 | 126 | type ThresholdDeal struct { 127 | D *dkg.Deal 128 | } 129 | 130 | type ThresholdResponse struct { 131 | R *dkg.Response 132 | } 133 | -------------------------------------------------------------------------------- /db/db.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | // really simple place to gather up the data from all servers 4 | 5 | import ( 6 | "crypto/tls" 7 | "fmt" 8 | "net" 9 | "net/rpc" 10 | "sync" 11 | 12 | "github.com/kwonalbert/atom/atomrpc" 13 | ) 14 | 15 | type DB struct { 16 | entries map[int]*entry 17 | 18 | condLock *sync.Mutex 19 | conds map[int]*sync.Cond 20 | 21 | listener net.Listener 22 | tlsConfig *tls.Config 23 | } 24 | 25 | type entry struct { 26 | numGroups int // number of groups who sent you msg in a round 27 | msgs [][]byte 28 | } 29 | 30 | func NewDB(port int) (*DB, error) { 31 | _, tlsConfig := atomrpc.AtomTLSConfig() 32 | 33 | l, err := tls.Listen("tcp", fmt.Sprintf(":%d", port), tlsConfig) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | db := &DB{ 39 | entries: make(map[int]*entry), 40 | 41 | condLock: new(sync.Mutex), 42 | conds: make(map[int]*sync.Cond), 43 | 44 | listener: l, 45 | tlsConfig: tlsConfig, 46 | } 47 | 48 | rpcServer := rpc.NewServer() 49 | rpcServer.Register(db) 50 | go rpcServer.Accept(l) 51 | 52 | return db, nil 53 | } 54 | 55 | // create an entry if it hasn't been created before 56 | func (db *DB) createEntry(round int) { 57 | db.condLock.Lock() 58 | defer db.condLock.Unlock() 59 | if _, ok := db.conds[round]; !ok { 60 | db.conds[round] = sync.NewCond(new(sync.Mutex)) 61 | } 62 | if _, ok := db.entries[round]; !ok { 63 | db.entries[round] = &entry{ 64 | numGroups: 0, 65 | msgs: nil, 66 | } 67 | } 68 | } 69 | 70 | func (db *DB) Write(args *atomrpc.DBArgs, _ *int) error { 71 | db.createEntry(args.Round) 72 | db.conds[args.Round].L.Lock() 73 | entry := db.entries[args.Round] 74 | entry.msgs = append(entry.msgs, args.Msgs...) 75 | entry.numGroups++ 76 | if entry.numGroups == args.NumGroups { 77 | db.conds[args.Round].Broadcast() 78 | } 79 | db.conds[args.Round].L.Unlock() 80 | return nil 81 | } 82 | 83 | func (db *DB) Read(args *atomrpc.DBArgs, resp *[][]byte) error { 84 | db.createEntry(args.Round) 85 | db.conds[args.Round].L.Lock() 86 | for db.entries[args.Round].numGroups < args.NumGroups { 87 | db.conds[args.Round].Wait() 88 | } 89 | *resp = db.entries[args.Round].msgs 90 | db.conds[args.Round].L.Unlock() 91 | return nil 92 | } 93 | 94 | func (db *DB) Close() { 95 | if db.listener != nil { 96 | db.listener.Close() 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /db/db_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/kwonalbert/atom/atomrpc" 8 | ) 9 | 10 | func TestDB(t *testing.T) { 11 | db, err := NewDB(10001) 12 | if err != nil { 13 | t.Error(err) 14 | } 15 | 16 | res := make(chan [][]byte) 17 | go func() { // client 18 | args := atomrpc.DBArgs{ 19 | Round: 0, 20 | NumGroups: 1, 21 | } 22 | var resp [][]byte 23 | err := db.Read(&args, &resp) 24 | if err != nil { 25 | t.Error(err) 26 | } 27 | res <- resp 28 | }() 29 | 30 | msgs := make([][]byte, 64) 31 | for m := range msgs { 32 | msgs[m] = make([]byte, 160) 33 | rand.Read(msgs[m]) 34 | } 35 | 36 | go func(msgs [][]byte) { // server 37 | args := atomrpc.DBArgs{ 38 | Round: 0, 39 | NumGroups: 1, 40 | Msgs: msgs, 41 | } 42 | err := db.Write(&args, nil) 43 | if err != nil { 44 | t.Error(err) 45 | } 46 | }(msgs) 47 | 48 | result := <-res 49 | for r := range result { 50 | for i := range result[r] { 51 | if msgs[r][i] != result[r][i] { 52 | t.Error("Msg mismatch") 53 | } 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /directory/directory.go: -------------------------------------------------------------------------------- 1 | package directory 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/rpc" 9 | "sync" 10 | "time" 11 | 12 | . "github.com/kwonalbert/atom/atomrpc" 13 | . "github.com/kwonalbert/atom/common" 14 | ) 15 | 16 | type Directory struct { 17 | id int 18 | port int 19 | 20 | // used to wait for server+trustee reg 21 | wg *sync.WaitGroup 22 | done *sync.WaitGroup 23 | 24 | // used to wait for group reg 25 | gwg *sync.WaitGroup 26 | gdone *sync.WaitGroup 27 | 28 | listener net.Listener 29 | 30 | tlsCert *tls.Certificate 31 | tlsConfig *tls.Config 32 | 33 | // Exported fields; represents a logical directory 34 | SystemParameter 35 | Round int 36 | 37 | Servers []string 38 | Keys []string 39 | Certificates [][][]byte 40 | 41 | Trustees []string 42 | TrusteeKeys []string 43 | TrusteeCerts [][][]byte 44 | 45 | GroupKeys [][]string // uid to group key 46 | RoundKeys map[int]string // round to per round key 47 | } 48 | 49 | type DirectoryRPC struct { 50 | d *Directory 51 | } 52 | 53 | type Registration struct { 54 | Round int 55 | Addr string 56 | Level int // only relevant for group registration 57 | Id int 58 | Key string 59 | Certificate [][]byte 60 | } 61 | 62 | func (d *DirectoryRPC) Directory(_ *int, dir *Directory) error { 63 | d.d.wg.Wait() 64 | *dir = *(d.d) 65 | d.d.done.Done() 66 | return nil 67 | } 68 | 69 | func (d *DirectoryRPC) DirectoryWithGroupKeys(_ *int, dir *Directory) error { 70 | d.d.gwg.Wait() 71 | *dir = *(d.d) 72 | d.d.gdone.Done() 73 | return nil 74 | } 75 | 76 | func (d *DirectoryRPC) Register(reg *Registration, _ *int) error { 77 | d.d.Servers[reg.Id] = reg.Addr 78 | d.d.Keys[reg.Id] = reg.Key 79 | d.d.Certificates[reg.Id] = reg.Certificate 80 | d.d.wg.Done() 81 | return nil 82 | } 83 | 84 | func (d *DirectoryRPC) RegisterGroup(reg *Registration, _ *int) error { 85 | d.d.GroupKeys[reg.Level][reg.Id] = reg.Key 86 | d.d.gwg.Done() 87 | return nil 88 | } 89 | 90 | func (d *DirectoryRPC) RegisterRound(reg *Registration, _ *int) error { 91 | d.d.gwg.Done() 92 | if key, ok := d.d.RoundKeys[reg.Round]; !ok { 93 | d.d.RoundKeys[reg.Round] = reg.Key 94 | return nil 95 | } else { 96 | if key != reg.Key { 97 | return errors.New("Mismatching round key registration") 98 | } else { 99 | return nil 100 | } 101 | } 102 | } 103 | 104 | func (d *DirectoryRPC) RegisterTrustee(reg *Registration, _ *int) error { 105 | // TODO: Authenticate client somehow.. 106 | d.d.Trustees[reg.Id] = reg.Addr 107 | d.d.TrusteeKeys[reg.Id] = reg.Key 108 | d.d.TrusteeCerts[reg.Id] = reg.Certificate 109 | d.d.wg.Done() 110 | return nil 111 | } 112 | 113 | func (d *DirectoryRPC) Ping(_ *int, _ *int) error { 114 | return nil 115 | } 116 | 117 | func (d *DirectoryRPC) Randomness(_ *int, seed *[SEED_LEN]byte) error { 118 | *seed = SEED 119 | return nil 120 | } 121 | 122 | func (d *Directory) Close() { 123 | d.done.Wait() 124 | d.gdone.Wait() 125 | // Hopefully a second is enough to send back the last reply 126 | time.Sleep(1 * time.Second) 127 | 128 | if d.listener != nil { 129 | d.listener.Close() 130 | } 131 | } 132 | 133 | func NewDirectory(id, port, mode, netType, 134 | numServers, numGroups, perGroup, numTrustees, 135 | numMsgs, msgSize, threshold, 136 | numClients int) (*Directory, error) { 137 | 138 | tlsCert, tlsConfig := AtomTLSConfig() 139 | 140 | numLevels := 10 141 | if netType == BUTTERFLY { 142 | numLevels = Log2(numGroups) * Log2(numGroups) 143 | } 144 | 145 | p := SystemParameter{ 146 | Mode: mode, 147 | NetType: netType, 148 | 149 | NumServers: numServers, 150 | NumGroups: numGroups, 151 | PerGroup: perGroup, 152 | NumLevels: numLevels, 153 | NumTrustees: numTrustees, 154 | 155 | NumMsgs: numMsgs, 156 | MsgSize: msgSize, 157 | 158 | Threshold: threshold, 159 | } 160 | 161 | l, err := tls.Listen("tcp", fmt.Sprintf(":%d", port), tlsConfig) 162 | if err != nil { 163 | return nil, err 164 | } 165 | 166 | d := &Directory{ 167 | id: id, 168 | port: port, 169 | 170 | wg: new(sync.WaitGroup), 171 | done: new(sync.WaitGroup), 172 | 173 | gwg: new(sync.WaitGroup), 174 | gdone: new(sync.WaitGroup), 175 | 176 | listener: l, 177 | 178 | tlsCert: tlsCert, 179 | tlsConfig: tlsConfig, 180 | 181 | Round: 0, 182 | SystemParameter: p, 183 | 184 | Servers: make([]string, numServers), 185 | Keys: make([]string, numServers), 186 | Certificates: make([][][]byte, numServers), 187 | 188 | Trustees: make([]string, numTrustees), 189 | TrusteeKeys: make([]string, numTrustees), 190 | TrusteeCerts: make([][][]byte, numTrustees), 191 | 192 | GroupKeys: make([][]string, numLevels), 193 | RoundKeys: make(map[int]string), 194 | } 195 | 196 | for level := range d.GroupKeys { 197 | d.GroupKeys[level] = make([]string, numGroups) 198 | } 199 | 200 | d.wg.Add(numServers + numTrustees) 201 | d.gwg.Add(numGroups*numLevels + numTrustees) 202 | 203 | d.done.Add(numServers + numTrustees) 204 | d.gdone.Add(numClients + numServers) 205 | 206 | rpcServer := rpc.NewServer() 207 | rpcServer.Register(&DirectoryRPC{d}) 208 | go rpcServer.Accept(l) 209 | 210 | return d, nil 211 | } 212 | -------------------------------------------------------------------------------- /directory/helper.go: -------------------------------------------------------------------------------- 1 | package directory 2 | 3 | import ( 4 | "log" 5 | "net/rpc" 6 | 7 | . "github.com/kwonalbert/atom/common" 8 | . "github.com/kwonalbert/atom/crypto" 9 | ) 10 | 11 | func GetDirectory(dirServers []*rpc.Client) (*Directory, SystemParameter, []*PublicKey) { 12 | // TODO: actually check consensus 13 | var res *Directory 14 | var params SystemParameter 15 | for _, dirServer := range dirServers { 16 | var direc Directory 17 | err := dirServer.Call("DirectoryRPC.Directory", 0, &direc) 18 | if err != nil { 19 | log.Fatal("Directory err:", err) 20 | } 21 | res = &direc 22 | params = direc.SystemParameter 23 | } 24 | 25 | publicKeys := make([]*PublicKey, len(res.Keys)) 26 | for i, pub := range res.Keys { 27 | publicKeys[i] = LoadPubKey(pub) 28 | } 29 | return res, params, publicKeys 30 | } 31 | 32 | func GetGroupKeys(dirServers []*rpc.Client) (*Directory, SystemParameter, []*PublicKey, [][]*PublicKey) { 33 | var res *Directory 34 | var params SystemParameter 35 | 36 | // TODO: actually check consensus 37 | for _, dirServer := range dirServers { 38 | 39 | var direc Directory 40 | err := dirServer.Call("DirectoryRPC.DirectoryWithGroupKeys", 0, &direc) 41 | if err != nil { 42 | log.Fatal("Directory err:", err) 43 | } 44 | 45 | res = &direc 46 | params = direc.SystemParameter 47 | } 48 | 49 | publicKeys := make([]*PublicKey, len(res.Keys)) 50 | for i, pub := range res.Keys { 51 | publicKeys[i] = LoadPubKey(pub) 52 | } 53 | 54 | keys := make([][]*PublicKey, len(res.GroupKeys)) 55 | for level := range res.GroupKeys { 56 | keys[level] = make([]*PublicKey, len(res.GroupKeys[level])) 57 | for gid := range res.GroupKeys[level] { 58 | key := LoadPubKey(res.GroupKeys[level][gid]) 59 | keys[level][gid] = key 60 | } 61 | } 62 | return res, params, publicKeys, keys 63 | } 64 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import os 4 | import subprocess 5 | import threading 6 | import sys 7 | import time 8 | import argparse 9 | import json 10 | 11 | parser = argparse.ArgumentParser(description='Process some integers.') 12 | parser.add_argument('--inst', metavar='instances', type=str, default="", 13 | help='file containing json of all available EC2 instances') 14 | parser.add_argument('--port', metavar='port', type=int, default=8000, 15 | help='starting port number for directories and servers') 16 | parser.add_argument('--servers', metavar='#servers', type=int, 17 | help='number of physical servers') 18 | parser.add_argument('--gsize', metavar='size', type=int, 19 | help='size of each group') 20 | parser.add_argument('--groups', metavar='#groups', type=int, 21 | help='number of groups') 22 | parser.add_argument('--clients', metavar='#clients', type=int, 23 | help='number of clients') 24 | parser.add_argument('--trustees', metavar='#trustees', type=int, 25 | help='number of trustees') 26 | parser.add_argument('--msgs', metavar='#msgs', type=int, 27 | help='number of msgs per group') 28 | parser.add_argument('--msize', metavar='size', type=int, 29 | help='size of the message') 30 | parser.add_argument('--type', metavar='type', type=int, 31 | help='type of network') 32 | parser.add_argument('--mode', metavar='mode', type=int, 33 | help='mode of operation') 34 | 35 | flags = vars(parser.parse_args(sys.argv[1:])) 36 | if flags['branch'] == -1: 37 | flags['branch'] = flags['groups'] 38 | 39 | aws = not flags['inst'] == '' 40 | 41 | if aws: 42 | root = [] 43 | ips = [] 44 | with open(flags['inst']) as inst_file: 45 | insts = json.load(inst_file) 46 | for r in insts['Reservations']: 47 | for inst in r['Instances']: 48 | try: 49 | ip = inst['PrivateIpAddress'] 50 | try: 51 | if inst["Tags"][0]["Value"] == "Root": 52 | root.append(ip) 53 | except: 54 | ips.append(ip) 55 | except: 56 | pass 57 | print("number of servers in pool:", len(ips), len(root)) 58 | 59 | 60 | storage = '/tmp' 61 | gopath = os.getenv('GOPATH') 62 | src_dir = 'github.com/kwonalbert/atom' 63 | 64 | os.system('go install -tags experimental %s/cmd/db' % (src_dir)) 65 | os.system('go install -tags experimental %s/cmd/directory' % (src_dir)) 66 | os.system('go install -tags experimental %s/cmd/trustee' % (src_dir)) 67 | os.system('go install -tags experimental %s/cmd/server' % (src_dir)) 68 | os.system('go install -tags experimental %s/cmd/client' % (src_dir)) 69 | 70 | flag_dir_addr = "--dirAddr 127.0.0.1:%d" % flags['port'] 71 | if aws: 72 | flag_dir_addr = "--dirAddr %s:%d" % (root[0], flags['port']) 73 | flags['port'] += 1 74 | 75 | flag_db_addr = "--dbAddr 127.0.0.1:%d" % flags['port'] 76 | if aws: 77 | flag_db_addr = "--dbAddr %s:%d" % (root[0], flags['port']) 78 | flags['port'] += 1 79 | 80 | flag_num_clients = "--numClients %d" % flags['clients'] 81 | flag_per_group = "--perGroup %d" % flags['gsize'] 82 | flag_num_servers = "--numServers %d" % flags['servers'] 83 | flag_num_groups = "--numGroups %d" % flags['groups'] 84 | flag_num_trustees = "--numTrustees %d" % flags['trustees'] 85 | flag_num_msgs = "--numMsgs %d" % flags['msgs'] 86 | flag_msg_size = "--msgSize %d" % flags['msize'] 87 | flag_mode = "--mode %d" % flags['mode'] 88 | flag_net = "--net %d" % flags['type'] 89 | flag_server_keys = "--keyFile %s/src/%s/keys/server_keys.json" % (gopath, src_dir) 90 | flag_trustee_keys = "--keyFile %s/src/%s/keys/trustee_keys.json" % (gopath, src_dir) 91 | 92 | def localhost(c): 93 | os.system(c) 94 | 95 | def remotehost(dest, c): 96 | os.system("ssh -o StrictHostKeyChecking=no -i ~/.ssh/emerald.pem %s '%s'" % (dest, c)) 97 | 98 | dir_flags = " ".join([flag_dir_addr, 99 | flag_per_group, 100 | flag_num_servers, 101 | flag_num_clients, 102 | flag_num_groups, 103 | flag_num_trustees, 104 | flag_num_msgs, 105 | flag_msg_size, 106 | flag_mode, 107 | flag_net, 108 | flag_branch]) 109 | c = '%s/bin/directory %s' % (gopath, dir_flags) 110 | if aws: 111 | directory = threading.Thread(target=remotehost, args=(root[0], c,)) 112 | else : 113 | directory = threading.Thread(target=localhost, args=(c,)) 114 | directory.start() 115 | 116 | time.sleep(1) 117 | 118 | c = '%s/bin/db %s' % (gopath, flag_db_addr) 119 | if aws: 120 | db = threading.Thread(target=remotehost, args=(root[0], c,)) 121 | else : 122 | db = threading.Thread(target=localhost, args=(c,)) 123 | db.start() 124 | 125 | ts = [] 126 | for i in range(flags['trustees']): 127 | flag_id = "--id %d" % i 128 | flag_trustee_addr = "--addr 127.0.0.1:%d" % (flags['port']) 129 | if aws: 130 | flag_trustee_addr = "--addr %s:%d" % (root[0], flags['port']) 131 | flags['port'] += 1 132 | trustee_flags = " ".join([flag_trustee_keys, 133 | flag_dir_addr, 134 | flag_trustee_addr, 135 | flag_id]) 136 | c = '%s/bin/trustee %s' % (gopath, trustee_flags) 137 | if aws: 138 | t = threading.Thread(target=remotehost, args=(root[0], c,)) 139 | else: 140 | t = threading.Thread(target=localhost, args=(c,)) 141 | t.start() 142 | ts.append(t) 143 | 144 | time.sleep(1) 145 | 146 | print("Starting servers...") 147 | ss = [] 148 | for i in range(flags['servers']): 149 | flag_addr = "--addr 127.0.0.1:%d" % (flags['port']+i) 150 | if aws: 151 | flag_addr = "--addr %s:%d" % (ips[i%len(ips)], flags['port']+i) 152 | flag_id = "--id %d" % i 153 | serv_flags = " ".join([flag_server_keys, 154 | flag_dir_addr, 155 | flag_db_addr, 156 | flag_addr, 157 | flag_id]) 158 | c = '%s/bin/server %s' % (gopath, serv_flags) 159 | if aws: 160 | t = threading.Thread(target=remotehost, args=(ips[i%len(ips)], c,)) 161 | else: 162 | t = threading.Thread(target=localhost, args=(c,)) 163 | t.start() 164 | ss.append(t) 165 | 166 | time.sleep(0.5) 167 | 168 | print("Starting clients...") 169 | cs = [] 170 | for i in range(flags['clients']): 171 | flag_id = "--id %d" % i 172 | client_flags = " ".join([flag_dir_addr, 173 | flag_db_addr, 174 | flag_id]) 175 | 176 | c = '%s/bin/client %s' % (gopath, client_flags) 177 | if aws: 178 | t = threading.Thread(target=remotehost, args=(ips[i%len(ips)], c,)) 179 | else: 180 | t = threading.Thread(target=localhost, args=(c,)) 181 | t.start() 182 | cs.append(t) 183 | 184 | print("Waiting for completion...") 185 | for t in cs: 186 | t.join() 187 | 188 | if not aws: 189 | os.system('killall trustee') 190 | os.system('killall server') 191 | os.system('killall client') 192 | os.system('killall directory') 193 | os.system('killall db') 194 | -------------------------------------------------------------------------------- /server/helper.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | 7 | "golang.org/x/crypto/sha3" 8 | 9 | . "github.com/kwonalbert/atom/crypto" 10 | ) 11 | 12 | func selectGroup(inner InnerCiphertext, numGroups int) int { 13 | hash := sha3.Sum256(inner.C) 14 | var gid uint64 15 | binary.Read(bytes.NewBuffer(hash[:]), binary.LittleEndian, &gid) 16 | return int(gid % uint64(numGroups)) 17 | } 18 | 19 | func memberCommitment(comm Commitment, comms []Commitment) bool { 20 | for c := range comms { 21 | if comm == comms[c] { 22 | return true 23 | } 24 | } 25 | return false 26 | } 27 | -------------------------------------------------------------------------------- /server/member.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | 7 | . "github.com/kwonalbert/atom/common" 8 | atomcrypto "github.com/kwonalbert/atom/crypto" 9 | ) 10 | 11 | // Member of a group 12 | type Member struct { 13 | sid int // server id 14 | idx int // index in the group 15 | group *Group 16 | 17 | params SystemParameter 18 | 19 | share *atomcrypto.Threshold // threshold share 20 | 21 | roundLock *sync.Mutex 22 | roundInit map[int]bool 23 | 24 | finalizeLock *sync.Mutex 25 | 26 | collectBuf map[int][]atomcrypto.Ciphertext 27 | collectLock map[int]*sync.Cond 28 | 29 | commitBuf map[int][]atomcrypto.Commitment 30 | commitLock map[int]*sync.Cond 31 | 32 | resInnerBuf map[int]chan []atomcrypto.InnerCiphertext 33 | resTrapBuf map[int]chan []atomcrypto.Trap 34 | 35 | shufOK map[int]chan bool 36 | reencOK map[int]chan bool 37 | 38 | shufOld map[int][]atomcrypto.Ciphertext 39 | reencOld map[int][][]atomcrypto.Ciphertext 40 | } 41 | 42 | func NewMember(sid int, key *atomcrypto.KeyPair, params SystemParameter, group *Group) *Member { 43 | groupSize := len(group.Members) 44 | useThreshold := params.Threshold < groupSize 45 | 46 | idx := -1 47 | for i := range group.Members { 48 | if group.Members[i] == sid { 49 | idx = i 50 | break 51 | } 52 | } 53 | 54 | var share *atomcrypto.Threshold = nil 55 | if useThreshold { 56 | share = atomcrypto.NewThreshold(idx, params.Threshold, 57 | key, group.MemberKeys) 58 | } 59 | 60 | m := &Member{ 61 | sid: sid, 62 | idx: idx, 63 | group: group, 64 | 65 | params: params, 66 | 67 | share: share, 68 | 69 | roundLock: new(sync.Mutex), 70 | roundInit: make(map[int]bool), 71 | finalizeLock: new(sync.Mutex), 72 | 73 | collectBuf: make(map[int][]atomcrypto.Ciphertext), 74 | collectLock: make(map[int]*sync.Cond), 75 | 76 | commitBuf: make(map[int][]atomcrypto.Commitment), 77 | commitLock: make(map[int]*sync.Cond), 78 | 79 | resInnerBuf: make(map[int]chan []atomcrypto.InnerCiphertext), 80 | resTrapBuf: make(map[int]chan []atomcrypto.Trap), 81 | 82 | shufOK: make(map[int]chan bool), 83 | reencOK: make(map[int]chan bool), 84 | 85 | shufOld: make(map[int][]atomcrypto.Ciphertext), 86 | reencOld: make(map[int][][]atomcrypto.Ciphertext), 87 | } 88 | return m 89 | } 90 | 91 | func (m *Member) genMemberKey() { 92 | var groupKey *atomcrypto.PublicKey = nil 93 | if m.share != nil { 94 | m.share.JVSS() 95 | groupKey = m.share.PublicKey() 96 | } else { 97 | groupKey = atomcrypto.CombinePublicKeys(m.group.MemberKeys) 98 | } 99 | m.group.GroupKey = groupKey 100 | } 101 | 102 | func (m *Member) ciphertexts(round int) []atomcrypto.Ciphertext { 103 | if m.params.Mode == TRAP_MODE { 104 | m.collectLock[round].L.Lock() 105 | for len(m.collectBuf[round]) < 2*m.params.NumMsgs { 106 | m.collectLock[round].Wait() 107 | } 108 | m.collectLock[round].L.Unlock() 109 | } else { 110 | m.collectLock[round].L.Lock() 111 | for len(m.collectBuf[round]) < m.params.NumMsgs { 112 | m.collectLock[round].Wait() 113 | } 114 | m.collectLock[round].L.Unlock() 115 | } 116 | return m.collectBuf[round] 117 | } 118 | 119 | func (m *Member) commitWait(round int) { 120 | m.commitLock[round].L.Lock() 121 | for len(m.commitBuf[round]) < m.params.NumMsgs { 122 | m.commitLock[round].Wait() 123 | } 124 | m.commitLock[round].L.Unlock() 125 | } 126 | 127 | func (m *Member) startRound(round int) { 128 | if m.roundStarted(round) { 129 | return 130 | } 131 | m.roundLock.Lock() 132 | defer m.roundLock.Unlock() 133 | 134 | m.roundInit[round] = true 135 | m.collectBuf[round] = nil 136 | m.collectLock[round] = sync.NewCond(new(sync.Mutex)) 137 | m.commitLock[round] = sync.NewCond(new(sync.Mutex)) 138 | if m.params.Mode == TRAP_MODE { 139 | m.commitBuf[round] = nil 140 | } else { 141 | m.shufOK[round] = make(chan bool, m.params.Threshold) 142 | m.reencOK[round] = make(chan bool, m.params.Threshold) 143 | } 144 | } 145 | 146 | func (m *Member) roundStarted(round int) bool { 147 | m.roundLock.Lock() 148 | defer m.roundLock.Unlock() 149 | _, ok := m.roundInit[round] 150 | return ok 151 | } 152 | 153 | func (m *Member) collect(round int, id int, ciphertexts []atomcrypto.Ciphertext) { 154 | m.collectLock[round].L.Lock() 155 | m.collectBuf[round] = append(m.collectBuf[round], ciphertexts...) 156 | m.collectLock[round].Signal() 157 | m.collectLock[round].L.Unlock() 158 | } 159 | 160 | func (m *Member) collectCommitment(round int, id int, comms []atomcrypto.Commitment) { 161 | m.commitLock[round].L.Lock() 162 | m.commitBuf[round] = append(m.commitBuf[round], comms...) 163 | m.commitLock[round].Signal() 164 | m.commitLock[round].L.Unlock() 165 | } 166 | 167 | func (m *Member) verifyShuffle(old, new []atomcrypto.Ciphertext, proof atomcrypto.ShufProof) bool { 168 | err := atomcrypto.VerifyShuffle(m.group.GroupKey, old, new, proof) 169 | if err != nil { 170 | log.Println("Incorrect shuffle proof:", err) 171 | return false 172 | } 173 | return true 174 | } 175 | 176 | func (m *Member) queueShufOK(round int, ok bool) { 177 | m.shufOK[round] <- ok 178 | } 179 | 180 | func (m *Member) dequeShufOK(round int) bool { 181 | ok := <-m.shufOK[round] 182 | return ok 183 | } 184 | 185 | func (m *Member) shuffle(ciphertexts []atomcrypto.Ciphertext) []atomcrypto.Ciphertext { 186 | return atomcrypto.Shuffle(m.group.GroupKey, ciphertexts) 187 | } 188 | 189 | func (m *Member) proveShuffle(ciphertexts []atomcrypto.Ciphertext) ([]atomcrypto.Ciphertext, atomcrypto.ShufProof) { 190 | return atomcrypto.ProveShuffle(m.group.GroupKey, ciphertexts) 191 | } 192 | 193 | func (m *Member) divide(cs []atomcrypto.Ciphertext) [][]atomcrypto.Ciphertext { 194 | numNeighbors := len(m.group.AdjList) 195 | if numNeighbors == 0 { 196 | numNeighbors = 1 // last level, there are no neighbors 197 | } 198 | batches := make([][]atomcrypto.Ciphertext, numNeighbors) 199 | // TODO: assumes even sized batches; make it more general 200 | batchSize := len(cs) / numNeighbors 201 | for b := range batches { 202 | batches[b] = cs[b*batchSize : (b+1)*batchSize] 203 | } 204 | return batches 205 | } 206 | 207 | func (m *Member) neighborKeys(n int) []*atomcrypto.PublicKey { 208 | neighborKeys := make([]*atomcrypto.PublicKey, n) 209 | if len(m.group.AdjList) > 0 { // special case for last level 210 | for n, neighbor := range m.group.AdjList { 211 | neighborKeys[n] = neighbor.GroupKey 212 | } 213 | } else { 214 | neighborKeys[0] = atomcrypto.NullKey() 215 | } 216 | return neighborKeys 217 | } 218 | 219 | func (m *Member) verifyReencrypt(ob, nb [][]atomcrypto.Ciphertext, proofs [][]atomcrypto.ReencProof) bool { 220 | return atomcrypto.VerifyReencryptBatches(ob, nb, 221 | proofs, m.neighborKeys(len(nb))) 222 | } 223 | 224 | func (m *Member) queueReencOK(round int, ok bool) { 225 | m.reencOK[round] <- ok 226 | } 227 | 228 | func (m *Member) dequeReencOK(round int) bool { 229 | ok := <-m.reencOK[round] 230 | return ok 231 | } 232 | 233 | // decrypt using priv and reencrypt the message for neighbors 234 | func (m *Member) reencrypt(round int, priv *atomcrypto.PrivateKey, 235 | batches [][]atomcrypto.Ciphertext) [][]atomcrypto.Ciphertext { 236 | return atomcrypto.ReencryptBatches(priv, m.neighborKeys(len(batches)), batches) 237 | } 238 | 239 | // decrypt using priv and reencrypt the message for neighbors 240 | func (m *Member) proveReencrypt(round int, priv *atomcrypto.PrivateKey, 241 | batches [][]atomcrypto.Ciphertext) ([][]atomcrypto.Ciphertext, [][]atomcrypto.ReencProof) { 242 | return atomcrypto.ProveReencryptBatches(priv, m.neighborKeys(len(batches)), batches) 243 | } 244 | 245 | func (m *Member) startFinalize(round int) { 246 | m.finalizeLock.Lock() 247 | defer m.finalizeLock.Unlock() 248 | m.resInnerBuf[round] = make(chan []atomcrypto.InnerCiphertext, m.params.NumGroups) 249 | m.resTrapBuf[round] = make(chan []atomcrypto.Trap, m.params.NumGroups) 250 | } 251 | 252 | func (m *Member) finalizeStarted(round int) bool { 253 | m.finalizeLock.Lock() 254 | defer m.finalizeLock.Unlock() 255 | _, ok := m.resInnerBuf[round] 256 | return ok 257 | } 258 | 259 | func (m *Member) collectResult(round int, inners []atomcrypto.InnerCiphertext, traps []atomcrypto.Trap) { 260 | m.resInnerBuf[round] <- inners 261 | m.resTrapBuf[round] <- traps 262 | } 263 | 264 | func (m *Member) results(round int) ([]atomcrypto.InnerCiphertext, []atomcrypto.Trap) { 265 | var inners []atomcrypto.InnerCiphertext 266 | var traps []atomcrypto.Trap 267 | 268 | for i := 0; i < m.params.NumGroups; i++ { 269 | tmpi := <-m.resInnerBuf[round] 270 | inners = append(inners, tmpi...) 271 | 272 | tmpt := <-m.resTrapBuf[round] 273 | traps = append(traps, tmpt...) 274 | } 275 | 276 | return inners, traps 277 | } 278 | 279 | func (m *Member) commitments(round int) []atomcrypto.Commitment { 280 | return m.commitBuf[round] 281 | } 282 | 283 | func (m *Member) setShuffleOld(round int, old []atomcrypto.Ciphertext) { 284 | m.shufOld[round] = old 285 | } 286 | 287 | func (m *Member) shuffleOld(round int) []atomcrypto.Ciphertext { 288 | return m.shufOld[round] 289 | } 290 | 291 | func (m *Member) setReencryptOld(round int, old [][]atomcrypto.Ciphertext) { 292 | m.reencOld[round] = old 293 | } 294 | 295 | func (m *Member) reencryptOld(round int) [][]atomcrypto.Ciphertext { 296 | return m.reencOld[round] 297 | } 298 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "crypto/tls" 6 | "encoding/binary" 7 | "fmt" 8 | "log" 9 | "net" 10 | "net/rpc" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "github.com/kwonalbert/atom/directory" 17 | 18 | . "github.com/kwonalbert/atom/atomrpc" 19 | . "github.com/kwonalbert/atom/common" 20 | . "github.com/kwonalbert/atom/crypto" 21 | ) 22 | 23 | type ServerRPC struct { 24 | s *Server 25 | } 26 | 27 | type Server struct { 28 | id int 29 | addr string 30 | port int 31 | 32 | params SystemParameter 33 | 34 | dirAddrs []string 35 | dirServers []*rpc.Client 36 | dbServer *rpc.Client 37 | servers []*rpc.Client 38 | directory *directory.Directory 39 | publicKeys []*PublicKey 40 | 41 | trustees []*rpc.Client 42 | 43 | network [][]*Group 44 | partOf [][]*Group 45 | members map[int]*Member // maps a unique group id (not gid) to a member 46 | 47 | keyPair *KeyPair 48 | connected *sync.WaitGroup 49 | 50 | listener net.Listener 51 | 52 | tlsCert *tls.Certificate 53 | tlsConfig *tls.Config 54 | 55 | start time.Time 56 | slock *sync.Mutex 57 | } 58 | 59 | func NewServer(addr string, id int, keyFile string, 60 | dirAddrs []string, dbAddr string) (*Server, error) { 61 | port, err := strconv.Atoi(strings.Split(addr, ":")[1]) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | tlsCert, tlsConfig := AtomTLSConfig() 67 | 68 | if id == 0 { 69 | fmt.Println("Server started") 70 | } 71 | 72 | // read key pair from a file, or generate a new key pair 73 | var keyPair *KeyPair 74 | if keyFile == "" { 75 | keyPair = GenKey() 76 | } else { 77 | serverKeys, err := ReadKeys(keyFile) 78 | if err != nil { 79 | return nil, err 80 | } 81 | keyPair = LoadKey(serverKeys[id]) 82 | } 83 | 84 | dirServers := make([]*rpc.Client, len(dirAddrs)) 85 | for d, dirAddr := range dirAddrs { 86 | conn, err := tls.Dial("tcp", dirAddr, tlsConfig) 87 | if err != nil { 88 | return nil, err 89 | } 90 | dirServers[d] = rpc.NewClient(conn) 91 | var tmp int 92 | err = dirServers[d].Call("DirectoryRPC.Ping", &tmp, &tmp) 93 | if err != nil { 94 | return nil, err 95 | } 96 | } 97 | 98 | conn, err := tls.Dial("tcp", dbAddr, tlsConfig) 99 | if err != nil { 100 | return nil, err 101 | } 102 | dbServer := rpc.NewClient(conn) 103 | 104 | connected := new(sync.WaitGroup) 105 | connected.Add(1) 106 | 107 | s := &Server{ 108 | id: id, 109 | addr: addr, 110 | port: port, 111 | 112 | dirAddrs: dirAddrs, 113 | dbServer: dbServer, 114 | dirServers: dirServers, 115 | 116 | keyPair: keyPair, 117 | 118 | connected: connected, 119 | 120 | tlsCert: tlsCert, 121 | tlsConfig: tlsConfig, 122 | 123 | slock: new(sync.Mutex), 124 | } 125 | 126 | return s, nil 127 | } 128 | 129 | func (s *Server) Setup() { 130 | s.registerServer() 131 | if s.id == 0 { 132 | log.Println("Registered server") 133 | } 134 | 135 | s.getDirectory() 136 | if s.id == 0 { 137 | log.Println("Got directory") 138 | } 139 | 140 | s.accept() 141 | if s.id == 0 { 142 | log.Println("Server started") 143 | } 144 | 145 | s.connectServers() 146 | if s.id == 0 { 147 | log.Println("Connected servers") 148 | } 149 | 150 | s.genMemberKeys() 151 | if s.id == 0 { 152 | log.Println("Generated member key") 153 | } 154 | 155 | s.setupGroupKeys() 156 | if s.id == 0 { 157 | log.Println("Generated group key") 158 | } 159 | } 160 | 161 | func (s *Server) Close() { 162 | if s.listener != nil { 163 | s.listener.Close() 164 | } 165 | 166 | for _, dirServer := range s.dirServers { 167 | dirServer.Close() 168 | } 169 | 170 | for _, serv := range s.servers { 171 | if serv != nil { 172 | serv.Close() 173 | } 174 | } 175 | } 176 | 177 | func (s *Server) accept() { 178 | l, e := tls.Listen("tcp", fmt.Sprintf(":%d", s.port), s.tlsConfig) 179 | if e != nil { 180 | log.Fatal("listen error:", e) 181 | } 182 | s.listener = l 183 | 184 | rpcServer := rpc.NewServer() 185 | rpcServer.Register(&ServerRPC{s}) 186 | go rpcServer.Accept(l) 187 | } 188 | 189 | func (s *Server) registerServer() { 190 | pub := DumpPubKey(s.keyPair.Pub) 191 | for _, dirServer := range s.dirServers { 192 | reg := &directory.Registration{ 193 | Addr: s.addr, 194 | Id: s.id, 195 | Key: pub, 196 | Certificate: s.tlsCert.Certificate, 197 | } 198 | err := dirServer.Call("DirectoryRPC.Register", reg, nil) 199 | if err != nil { 200 | log.Fatal("Register err:", err) 201 | } 202 | } 203 | } 204 | 205 | func (s *Server) getDirectory() { 206 | s.directory, s.params, s.publicKeys = directory.GetDirectory(s.dirServers) 207 | if s.params.Mode == TRAP_MODE { 208 | s.trustees = make([]*rpc.Client, len(s.directory.Trustees)) 209 | for t, tAddr := range s.directory.Trustees { 210 | conn, err := tls.Dial("tcp", tAddr, s.tlsConfig) 211 | if err != nil { 212 | log.Fatal("Could not dial trustee") 213 | } 214 | s.trustees[t] = rpc.NewClient(conn) 215 | } 216 | } 217 | s.genGroups() 218 | } 219 | 220 | func (s *Server) genGroups() { 221 | var seed [SEED_LEN]byte 222 | for _, dirServer := range s.dirServers { 223 | var val [SEED_LEN]byte 224 | err := dirServer.Call("DirectoryRPC.Randomness", 0, &val) 225 | if err != nil { 226 | log.Fatal("Randomness err:", err) 227 | } 228 | Xor(val[:], seed[:]) 229 | } 230 | network := GenerateGroups(seed, s.params.NetType, s.params.NumServers, 231 | s.params.NumGroups, s.params.PerGroup, 232 | s.params.NumLevels, s.publicKeys) 233 | s.network = network 234 | 235 | s.partOf = make([][]*Group, len(network)) 236 | s.members = make(map[int]*Member) 237 | for level := range s.partOf { 238 | s.partOf[level] = make([]*Group, len(network[level])) 239 | for gid, node := range network[level] { 240 | if !IsMember(s.id, node.Members) { 241 | s.partOf[level][gid] = nil 242 | continue 243 | } 244 | group := network[level][gid] 245 | s.partOf[level][gid] = group 246 | s.members[group.Uid] = NewMember(s.id, s.keyPair, 247 | s.params, group) 248 | } 249 | } 250 | } 251 | 252 | func (s *Server) callGroup(servers []*rpc.Client, group *Group) []*rpc.Client { 253 | // connect to everyone in this group 254 | for _, member := range group.Members { 255 | if servers[member] != nil { 256 | continue 257 | } 258 | 259 | // all servers must be online during inital seup, so retry 260 | retry := 1 261 | for retry != 0 { 262 | conn, err := tls.Dial("tcp", s.directory.Servers[member], s.tlsConfig) 263 | if err == nil { 264 | retry = 0 265 | servers[member] = rpc.NewClient(conn) 266 | } 267 | } 268 | } 269 | return servers 270 | } 271 | 272 | func (s *Server) connectServers() { 273 | servers := make([]*rpc.Client, s.params.NumServers) 274 | for level := range s.partOf { 275 | for _, group := range s.partOf[level] { 276 | if group == nil { 277 | continue 278 | } 279 | s.callGroup(servers, group) 280 | 281 | // connect to all servers in neighboring group 282 | for _, neighbor := range group.AdjList { 283 | s.callGroup(servers, neighbor) 284 | } 285 | } 286 | } 287 | 288 | for _, group := range s.network[len(s.network)-1] { 289 | s.callGroup(servers, group) 290 | } 291 | 292 | s.servers = servers 293 | s.connected.Done() 294 | } 295 | 296 | func (s *Server) genMemberKeys() { 297 | if s.params.Threshold < s.params.PerGroup { 298 | for _, member := range s.members { 299 | for gidx, other := range member.group.Members { 300 | if member.sid == other { 301 | continue 302 | } 303 | 304 | //TODO: This used to be done in parallel, 305 | // but when moved to kyber, the encoding 306 | // seems to not work with RPC calls, so 307 | // now it's sequential.. 308 | args := DealArgs{ 309 | Uid: member.group.Uid, 310 | Idx: member.idx, 311 | Deal: member.share.GetDeal(gidx), 312 | } 313 | 314 | var reply DealReply 315 | err := s.servers[other].Call("ServerRPC.Deal", &args, &reply) 316 | if err != nil { 317 | log.Fatal("Deal fail:", err) 318 | } 319 | } 320 | } 321 | } 322 | 323 | for _, member := range s.members { 324 | member.genMemberKey() 325 | } 326 | } 327 | 328 | func (s *Server) addDealSendResponse(args *DealArgs) { 329 | s.connected.Wait() 330 | 331 | member := s.members[args.Uid] 332 | resp, err := member.share.AddDeal(args.Deal) 333 | if err != nil { 334 | log.Fatal("failed to add deal:") 335 | } 336 | 337 | for _, other := range member.group.Members { 338 | if member.sid == other { 339 | continue 340 | } 341 | args := ResponseArgs{ 342 | Uid: member.group.Uid, 343 | Resp: resp, 344 | } 345 | 346 | var reply ResponseReply 347 | err := s.servers[other].Call("ServerRPC.Response", &args, &reply) 348 | if err != nil { 349 | log.Fatal("Deal fail:", err) 350 | } 351 | } 352 | 353 | } 354 | 355 | func (s *Server) setupGroupKeys() { 356 | for level := range s.partOf { 357 | for gid := range s.partOf[level] { 358 | group := s.partOf[level][gid] 359 | if group == nil { 360 | continue 361 | } 362 | 363 | if group.Members[0] != s.id { 364 | // currently trusting the first server 365 | continue 366 | } 367 | 368 | pub := DumpPubKey(group.GroupKey) 369 | for _, dirServer := range s.dirServers { 370 | reg := &directory.Registration{ 371 | Level: group.Level, 372 | Id: group.Gid, 373 | Key: pub, 374 | } 375 | err := dirServer.Call("DirectoryRPC.RegisterGroup", reg, nil) 376 | if err != nil { 377 | log.Fatal("Register err:", err) 378 | } 379 | } 380 | } 381 | } 382 | 383 | var keys [][]*PublicKey 384 | s.directory, s.params, _, keys = directory.GetGroupKeys(s.dirServers) 385 | for level := range keys { 386 | for gid := range keys[level] { 387 | s.network[level][gid].GroupKey = keys[level][gid] 388 | } 389 | } 390 | } 391 | 392 | func (s *Server) collect(args *CollectArgs) { 393 | uid := s.partOf[args.Level][args.Gid].Uid 394 | member := s.members[uid] 395 | 396 | newArgs := &ShuffleArgs{ 397 | Ciphertexts: member.ciphertexts(args.Round), 398 | ArgInfo: args.ArgInfo, 399 | } 400 | 401 | // entry groups need to wait for commitments too 402 | if s.params.Mode == TRAP_MODE && args.Level == 0 { 403 | member.commitWait(args.Round) 404 | } 405 | 406 | if args.Cur == member.idx { 407 | go s.shuffle(newArgs) 408 | } 409 | } 410 | 411 | func (s *Server) shuffle(args *ShuffleArgs) { 412 | if args.ArgInfo.Gid == 0 { // just print the first level 413 | log.Println("shuffle:", args.ArgInfo) 414 | } 415 | 416 | uid := s.partOf[args.Level][args.Gid].Uid 417 | member := s.members[uid] 418 | 419 | // for NIZK mode, any server other than the first 420 | // should collect ok from other servers 421 | if s.params.Mode == VER_MODE && args.Cur != args.Group[0] { 422 | for i := 0; i < len(args.Group)-2; i++ { 423 | ok := member.dequeShufOK(args.Round) 424 | if !ok { 425 | log.Fatal("Bad shuffle proof given") 426 | } 427 | } 428 | } 429 | 430 | var res []Ciphertext 431 | var proof ShufProof 432 | if s.params.Mode == TRAP_MODE { 433 | res = member.shuffle(args.Ciphertexts) 434 | } else if s.params.Mode == VER_MODE { 435 | res, proof = member.proveShuffle(args.Ciphertexts) 436 | } 437 | 438 | last := args.Group[len(args.Group)-1] == member.idx 439 | 440 | info := ArgInfo{ 441 | Round: args.Round, 442 | Level: args.Level, 443 | Gid: args.Gid, 444 | Group: args.Group, 445 | } 446 | 447 | if s.params.Mode == VER_MODE { 448 | // ask all other servers to verify 449 | for _, idx := range args.Group { 450 | if idx == args.Cur { 451 | continue 452 | } 453 | info.Cur = args.ArgInfo.Cur 454 | newArgs := VerifyShuffleArgs{ 455 | Old: args.Ciphertexts, 456 | New: res, 457 | Proof: proof, 458 | ArgInfo: info, 459 | } 460 | 461 | next := member.group.Members[idx] 462 | var reply ShuffleReply 463 | err := AtomRPC(s.servers[next], "ServerRPC.VerifyShuffle", 464 | &newArgs, &reply, DEFAULT_TIMEOUT) 465 | if err != nil { 466 | log.Fatal("Verify shuffle request:", err) 467 | } 468 | } 469 | } 470 | 471 | if !last { // shuffle and send to next server 472 | idx := -1 473 | for i := range args.Group { 474 | if args.Group[i] == member.idx { 475 | idx = i 476 | break 477 | } 478 | } 479 | nextIdx := args.Group[(idx+1)%len(args.Group)] 480 | next := member.group.Members[nextIdx] 481 | info.Cur = nextIdx 482 | newArgs := ShuffleArgs{ 483 | Ciphertexts: res, 484 | ArgInfo: info, 485 | } 486 | 487 | var reply ShuffleReply 488 | err := AtomRPC(s.servers[next], "ServerRPC.Shuffle", 489 | &newArgs, &reply, DEFAULT_TIMEOUT) 490 | if err != nil { 491 | log.Fatal("Shuffle request:", err) 492 | } 493 | } else { // divide and send back to first server 494 | nextIdx := args.Group[0] 495 | next := member.group.Members[nextIdx] 496 | info.Cur = nextIdx 497 | batches := member.divide(res) 498 | newArgs := ReencryptArgs{ 499 | Batches: batches, 500 | ArgInfo: info, 501 | } 502 | 503 | var reply ReencryptReply 504 | err := AtomRPC(s.servers[next], "ServerRPC.Reencrypt", 505 | &newArgs, &reply, DEFAULT_TIMEOUT) 506 | if err != nil { 507 | log.Fatal("Reencrypt request:", err) 508 | } 509 | } 510 | } 511 | 512 | func (s *Server) verifyShuffle(args *VerifyShuffleArgs) { 513 | uid := s.partOf[args.Level][args.Gid].Uid 514 | member := s.members[uid] 515 | 516 | // TODO: also check args.Old == currently collected 517 | ok := member.verifyShuffle(args.Old, args.New, args.Proof) 518 | 519 | idx := -1 520 | for i := range args.Group { 521 | if args.Group[i] == args.Cur { 522 | idx = i 523 | break 524 | } 525 | } 526 | 527 | nextIdx := args.Group[(idx+1)%len(args.Group)] 528 | next := member.group.Members[nextIdx] 529 | 530 | if nextIdx == member.idx { 531 | if ok { 532 | return 533 | } else { 534 | log.Fatal("Bad shuffle proof") 535 | } 536 | } 537 | 538 | newArgs := ProofOKArgs{ 539 | OK: ok, 540 | ArgInfo: args.ArgInfo, 541 | } 542 | var reply ProofOKReply 543 | err := AtomRPC(s.servers[next], "ServerRPC.ShuffleOK", 544 | &newArgs, &reply, DEFAULT_TIMEOUT) 545 | if err != nil { 546 | log.Fatal("Shuffle ok request:", err) 547 | } 548 | } 549 | 550 | func (s *Server) reencrypt(args *ReencryptArgs) { 551 | if args.ArgInfo.Gid == 0 { // just print the first level 552 | log.Println("reencrypt:", args.ArgInfo) 553 | } 554 | 555 | uid := s.partOf[args.Level][args.Gid].Uid 556 | member := s.members[uid] 557 | 558 | priv := s.keyPair.Priv 559 | if member.share != nil { 560 | priv = member.share.Lagrange(args.Group) 561 | } 562 | 563 | if s.params.Mode == VER_MODE && args.Cur != args.Group[0] { 564 | for i := 0; i < len(args.Group)-2; i++ { 565 | ok := member.dequeReencOK(args.Round) 566 | if !ok { 567 | log.Fatal("Bad shuffle proof given") 568 | } 569 | } 570 | } 571 | 572 | var res [][]Ciphertext 573 | var proof [][]ReencProof 574 | if s.params.Mode == TRAP_MODE { 575 | res = member.reencrypt(args.Round, priv, args.Batches) 576 | } else if s.params.Mode == VER_MODE { 577 | res, proof = member.proveReencrypt(args.Round, priv, args.Batches) 578 | } 579 | 580 | last := args.Group[len(args.Group)-1] == member.idx 581 | 582 | idx := -1 583 | for i := range args.Group { 584 | if args.Group[i] == member.idx { 585 | idx = i 586 | break 587 | } 588 | } 589 | nextIdx := args.Group[(idx+1)%len(args.Group)] 590 | next := member.group.Members[nextIdx] 591 | 592 | info := ArgInfo{ 593 | Round: args.Round, 594 | Level: args.Level, 595 | Gid: args.Gid, 596 | Cur: nextIdx, 597 | Group: args.Group, 598 | } 599 | 600 | if s.params.Mode == VER_MODE { 601 | // ask all other servers to verify 602 | for _, idx := range args.Group { 603 | if idx == args.Cur { 604 | continue 605 | } 606 | info.Cur = args.ArgInfo.Cur 607 | newArgs := VerifyReencryptArgs{ 608 | Old: args.Batches, 609 | New: res, 610 | Proofs: proof, 611 | ArgInfo: info, 612 | } 613 | 614 | next := member.group.Members[idx] 615 | var reply ShuffleReply 616 | err := AtomRPC(s.servers[next], "ServerRPC.VerifyReencrypt", 617 | &newArgs, &reply, DEFAULT_TIMEOUT) 618 | if err != nil { 619 | log.Fatal("Verify reencrypt request:", err) 620 | } 621 | } 622 | } 623 | 624 | if !last { // reencrypt and send to next server 625 | newArgs := ReencryptArgs{ 626 | Batches: res, 627 | ArgInfo: info, 628 | } 629 | 630 | var reply ShuffleReply 631 | err := AtomRPC(s.servers[next], "ServerRPC.Reencrypt", 632 | &newArgs, &reply, DEFAULT_TIMEOUT) 633 | if err != nil { 634 | log.Fatal(err) 635 | } 636 | } else if args.Level == s.params.NumLevels-1 { // last level 637 | // FINISH PROTOCOL 638 | msgs := ExtractMessages(res[0]) 639 | if s.params.Mode == VER_MODE { 640 | plaintexts, _, err := ExtractPlaintexts(msgs) 641 | if err != nil { 642 | log.Fatal(err) 643 | } 644 | 645 | info := ArgInfo{ 646 | Round: args.Round, 647 | Level: args.Level, 648 | Gid: args.Gid, 649 | Group: args.Group, 650 | } 651 | newArgs := FinalizeArgs{ 652 | Plaintexts: plaintexts, 653 | ArgInfo: info, 654 | } 655 | for _, other := range member.group.Members { 656 | var reply FinalizeReply 657 | err := AtomRPC(s.servers[other], "ServerRPC.Finalize", 658 | &newArgs, &reply, DEFAULT_TIMEOUT) 659 | if err != nil { 660 | log.Fatal(err) 661 | } 662 | } 663 | } else { 664 | inners, traps, err := ExtractInnerAndTraps(msgs) 665 | if err != nil { 666 | log.Fatal(err) 667 | } 668 | 669 | innerDivs := make([][]InnerCiphertext, s.params.NumGroups) 670 | for i := range inners { 671 | gid := selectGroup(inners[i], s.params.NumGroups) 672 | innerDivs[gid] = append(innerDivs[gid], inners[i]) 673 | } 674 | 675 | trapDivs := make([][]Trap, s.params.NumGroups) 676 | for t := range traps { 677 | gid := traps[t].Gid 678 | trapDivs[gid] = append(trapDivs[gid], traps[t]) 679 | } 680 | 681 | // the first layer servers are responsible for verifying 682 | // since they know the commitments 683 | for _, group := range s.network[0] { 684 | info := ArgInfo{ 685 | Round: args.Round, 686 | Level: 0, 687 | Gid: group.Gid, 688 | Group: args.Group, 689 | } 690 | 691 | newArgs := FinalizeArgs{ 692 | Inners: innerDivs[group.Gid], 693 | Traps: trapDivs[group.Gid], 694 | ArgInfo: info, 695 | } 696 | for _, idx := range args.Group { 697 | other := group.Members[idx] 698 | var reply FinalizeReply 699 | err := AtomRPC(s.servers[other], "ServerRPC.Finalize", 700 | &newArgs, &reply, DEFAULT_TIMEOUT) 701 | if err != nil { 702 | log.Fatal(err) 703 | } 704 | } 705 | 706 | } 707 | } 708 | } else { // send to neighbors 709 | for n, neighbor := range member.group.AdjList { 710 | info := ArgInfo{ 711 | Round: args.Round, 712 | Level: args.Level + 1, 713 | Gid: neighbor.Gid, 714 | Cur: 0, 715 | Group: Xrange(s.params.Threshold), 716 | } 717 | 718 | for r := range res[n] { // no need for Y any more 719 | res[n][r].Y = nil 720 | } 721 | 722 | for _, idx := range info.Group { 723 | if s.params.Mode == TRAP_MODE && idx != info.Cur { 724 | continue 725 | } 726 | newArgs := CollectArgs{ 727 | Id: member.group.Gid, 728 | Ciphertexts: res[n], 729 | ArgInfo: info, 730 | } 731 | 732 | next := neighbor.Members[idx] 733 | 734 | var reply ReencryptReply 735 | err := AtomRPC(s.servers[next], "ServerRPC.Collect", 736 | &newArgs, &reply, DEFAULT_TIMEOUT) 737 | if err != nil { 738 | log.Fatal(err) 739 | } 740 | } 741 | } 742 | } 743 | } 744 | 745 | func (s *Server) verifyReencrypt(args *VerifyReencryptArgs) { 746 | uid := s.partOf[args.Level][args.Gid].Uid 747 | member := s.members[uid] 748 | 749 | // also check args.Old == currently collected 750 | ok := member.verifyReencrypt(args.Old, args.New, args.Proofs) 751 | 752 | idx := -1 753 | for i := range args.Group { 754 | if args.Group[i] == member.idx { 755 | idx = i 756 | break 757 | } 758 | } 759 | nextIdx := args.Group[(idx+1)%len(args.Group)] 760 | next := member.group.Members[nextIdx] 761 | 762 | newArgs := ProofOKArgs{ 763 | OK: ok, 764 | ArgInfo: args.ArgInfo, 765 | } 766 | var reply ProofOKReply 767 | err := AtomRPC(s.servers[next], "ServerRPC.ReencryptOK", 768 | &newArgs, &reply, DEFAULT_TIMEOUT) 769 | if err != nil { 770 | log.Fatal(err) 771 | } 772 | } 773 | 774 | func (s *Server) finalize(args *FinalizeArgs) { 775 | if args.ArgInfo.Gid == 0 { // just print the first level 776 | log.Println("finalize:", args.ArgInfo) 777 | } 778 | 779 | uid := s.partOf[args.Level][args.Gid].Uid 780 | member := s.members[uid] 781 | last := args.Group[len(args.Group)-1] == member.idx 782 | 783 | if s.params.Mode == VER_MODE { 784 | if last { 785 | args := DBArgs{ 786 | Round: args.Round, 787 | NumGroups: s.params.NumGroups, 788 | Msgs: args.Plaintexts, 789 | } 790 | err := s.dbServer.Call("DB.Write", args, nil) 791 | if err != nil { 792 | log.Fatal("DB Write error:", err) 793 | } 794 | } 795 | if member.idx == args.Group[0] { 796 | log.Println("Done with group", member.group.Gid, ":", time.Since(s.start), ". #msgs: ", s.params.NumMsgs) 797 | } 798 | return 799 | } 800 | 801 | // everything below is for trap mode only 802 | inners, traps := member.results(args.Round) 803 | 804 | // check that each inner msg is expected to be here 805 | // and there are no duplicates 806 | dups := make(map[string]bool) 807 | noDups := true 808 | correctHash := true 809 | for i := range inners { 810 | gid := selectGroup(inners[i], s.params.NumGroups) 811 | correctHash = correctHash && (gid == member.group.Gid) 812 | 813 | str := string(inners[i].C) 814 | if _, ok := dups[str]; !ok { 815 | dups[str] = true 816 | } else { 817 | noDups = false 818 | } 819 | } 820 | 821 | // check all traps are there 822 | comms := make([]Commitment, len(traps)) 823 | for t := range traps { 824 | comms[t] = Commit(traps[t]) 825 | } 826 | expComms := member.commitments(args.Round) 827 | correctTraps := len(expComms) == len(comms) 828 | for _, comm := range expComms { 829 | correctTraps = correctTraps && memberCommitment(comm, comms) 830 | } 831 | 832 | // report to all trustees 833 | newArgs := ReportArgs{ 834 | Round: args.Round, 835 | Sid: s.id, 836 | Uid: uid, 837 | CorrectHash: correctHash, 838 | CorrectTraps: correctTraps, 839 | NoDups: noDups, 840 | NumTraps: len(traps), 841 | NumMsgs: len(inners), 842 | } 843 | 844 | privs := make([]*PrivateKey, len(s.trustees)) 845 | for t, trustee := range s.trustees { 846 | var reply ReportReply 847 | err := trustee.Call("TrusteeRPC.Report", &newArgs, &reply) 848 | if err != nil { 849 | log.Fatal("Could not get keys from trustee:", err) 850 | } 851 | privs[t] = reply.Priv 852 | } 853 | priv := CombinePrivateKeys(privs) 854 | pub := LoadPubKey(s.directory.RoundKeys[args.Round]) 855 | 856 | buf := new(bytes.Buffer) 857 | err := binary.Write(buf, binary.LittleEndian, uint32(args.Round)) 858 | if err != nil { 859 | log.Fatal("Could not write round") 860 | } 861 | nonce := buf.Bytes() 862 | 863 | plaintexts := make([][]byte, len(inners)) 864 | for i := range inners { 865 | plaintexts[i], err = CCA2Decrypt(inners[i], nonce, priv, pub) 866 | if err != nil { 867 | log.Fatal("CCA2 Decrypt fail:", err) 868 | } 869 | } 870 | 871 | dbArgs := DBArgs{ 872 | Round: args.Round, 873 | NumGroups: s.params.NumGroups, 874 | Msgs: plaintexts, 875 | } 876 | err = s.dbServer.Call("DB.Write", dbArgs, nil) 877 | if err != nil { 878 | log.Fatal("DB Write error:", err) 879 | } 880 | 881 | if member.idx == args.Group[0] { 882 | log.Println("Done with group", member.group.Gid, ":", time.Since(s.start), ". #msgs: ", s.params.NumMsgs) 883 | } 884 | } 885 | 886 | func (s *ServerRPC) Deal(args *DealArgs, _ *DealReply) error { 887 | go s.s.addDealSendResponse(args) 888 | return nil 889 | } 890 | 891 | func (s *ServerRPC) Response(args *ResponseArgs, _ *ResponseReply) error { 892 | member := s.s.members[args.Uid] 893 | return member.share.AddResponse(args.Resp) 894 | } 895 | 896 | func (s *ServerRPC) Submit(args *SubmitArgs, _ *SubmitReply) error { 897 | if args.Level == 0 { 898 | s.s.slock.Lock() 899 | s.s.start = time.Now() 900 | s.s.slock.Unlock() 901 | } 902 | 903 | uid := s.s.partOf[args.Level][args.Gid].Uid 904 | member := s.s.members[uid] 905 | 906 | for c := range args.Ciphertexts { 907 | err := VerifyEncrypt(member.group.GroupKey, 908 | args.Ciphertexts[c], args.EncProofs[c]) 909 | if err != nil { 910 | return err 911 | } 912 | } 913 | 914 | if s.s.params.Mode == TRAP_MODE && member.idx != args.Cur { 915 | // TODO: send the verification result to all servers 916 | return nil 917 | } 918 | 919 | started := member.roundStarted(args.Round) 920 | if !started { 921 | member.startRound(args.Round) 922 | newArgs := &CollectArgs{ 923 | Id: args.Id, 924 | Ciphertexts: args.Ciphertexts, 925 | ArgInfo: args.ArgInfo, 926 | } 927 | go s.s.collect(newArgs) 928 | } 929 | 930 | member.collect(args.Round, args.Id, args.Ciphertexts) 931 | return nil 932 | } 933 | 934 | func (s *ServerRPC) Commit(args *CommitArgs, _ *CommitReply) error { 935 | uid := s.s.partOf[args.Level][args.Gid].Uid 936 | member := s.s.members[uid] 937 | 938 | started := member.roundStarted(args.Round) 939 | if !started { 940 | member.startRound(args.Round) 941 | newArgs := &CollectArgs{ 942 | Id: args.Id, 943 | ArgInfo: args.ArgInfo, 944 | } 945 | go s.s.collect(newArgs) 946 | } 947 | 948 | member.collectCommitment(args.Round, args.Id, args.Comms) 949 | return nil 950 | } 951 | 952 | func (s *ServerRPC) Collect(args *CollectArgs, _ *CollectReply) error { 953 | uid := s.s.partOf[args.Level][args.Gid].Uid 954 | member := s.s.members[uid] 955 | 956 | started := member.roundStarted(args.Round) 957 | if !started { 958 | member.startRound(args.Round) 959 | go s.s.collect(args) 960 | } 961 | member.collect(args.Round, args.Id, args.Ciphertexts) 962 | return nil 963 | } 964 | 965 | func (s *ServerRPC) Shuffle(args *ShuffleArgs, _ *ShuffleReply) error { 966 | go s.s.shuffle(args) 967 | return nil 968 | } 969 | 970 | func (s *ServerRPC) VerifyShuffle(args *VerifyShuffleArgs, _ *VerifyShuffleReply) error { 971 | go s.s.verifyShuffle(args) 972 | return nil 973 | } 974 | 975 | func (s *ServerRPC) ShuffleOK(args *ProofOKArgs, _ *ProofOKReply) error { 976 | uid := s.s.partOf[args.Level][args.Gid].Uid 977 | member := s.s.members[uid] 978 | // TODO: actually check if the person sending this is the right server 979 | member.queueShufOK(args.Round, args.OK) 980 | return nil 981 | } 982 | 983 | func (s *ServerRPC) Reencrypt(args *ReencryptArgs, _ *ReencryptReply) error { 984 | go s.s.reencrypt(args) 985 | return nil 986 | } 987 | 988 | func (s *ServerRPC) VerifyReencrypt(args *VerifyReencryptArgs, _ *VerifyReencryptReply) error { 989 | go s.s.verifyReencrypt(args) 990 | return nil 991 | } 992 | 993 | func (s *ServerRPC) ReencryptOK(args *ProofOKArgs, _ *ProofOKReply) error { 994 | uid := s.s.partOf[args.Level][args.Gid].Uid 995 | member := s.s.members[uid] 996 | // TODO: actually check if the person sending this is the right server 997 | member.queueReencOK(args.Round, args.OK) 998 | return nil 999 | } 1000 | 1001 | func (s *ServerRPC) Finalize(args *FinalizeArgs, _ *FinalizeReply) error { 1002 | uid := s.s.partOf[args.Level][args.Gid].Uid 1003 | member := s.s.members[uid] 1004 | 1005 | if s.s.params.Mode == TRAP_MODE { 1006 | started := member.finalizeStarted(args.Round) 1007 | if !started { 1008 | member.startFinalize(args.Round) 1009 | go s.s.finalize(args) 1010 | } 1011 | member.collectResult(args.Round, args.Inners, args.Traps) 1012 | } else { 1013 | go s.s.finalize(args) 1014 | } 1015 | return nil 1016 | } 1017 | 1018 | func (s *ServerRPC) Ping(_ *int, _ *int) error { 1019 | return nil 1020 | } 1021 | -------------------------------------------------------------------------------- /server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/kwonalbert/atom/common" 7 | "github.com/kwonalbert/atom/crypto" 8 | ) 9 | 10 | func BenchmarkMixing(b *testing.B) { 11 | keyPair := crypto.GenKey() 12 | group := &common.Group{ 13 | GroupKey: keyPair.Pub, 14 | } 15 | member := Member{ 16 | group: group, 17 | } 18 | numMsgs := 16384 19 | ciphertexts := make([]crypto.Ciphertext, numMsgs) 20 | msgs := crypto.GenRandMsgs(numMsgs, 1) 21 | for c := range ciphertexts { 22 | ciphertexts[c] = crypto.Encrypt(keyPair.Pub, msgs[c]) 23 | } 24 | b.ResetTimer() 25 | for i := 0; i < b.N; i++ { 26 | member.shuffle(ciphertexts) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /trustee/trustee.go: -------------------------------------------------------------------------------- 1 | package trustee 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "net" 9 | "net/rpc" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/kwonalbert/atom/directory" 14 | 15 | . "github.com/kwonalbert/atom/atomrpc" 16 | . "github.com/kwonalbert/atom/common" 17 | . "github.com/kwonalbert/atom/crypto" 18 | ) 19 | 20 | type TrusteeRPC struct { 21 | t *Trustee 22 | } 23 | 24 | type Trustee struct { 25 | addr string 26 | id int 27 | port int 28 | 29 | round int 30 | roundKeys map[int]*KeyPair 31 | roundGood map[int]chan bool 32 | 33 | params SystemParameter 34 | NumReports int // number of expected reports 35 | 36 | keyPair *KeyPair 37 | 38 | dirAddrs []string 39 | dirServers []*rpc.Client 40 | directory *directory.Directory 41 | publicKeys []*PublicKey 42 | 43 | reports map[int]chan *ReportArgs 44 | 45 | listener net.Listener 46 | 47 | tlsCert *tls.Certificate 48 | tlsConfig *tls.Config 49 | 50 | Priv *PrivateKey 51 | } 52 | 53 | func NewTrustee(addr string, id int, keyFile string, dirAddrs []string) (*Trustee, error) { 54 | tlsCert, tlsConfig := AtomTLSConfig() 55 | 56 | port, err := strconv.Atoi(strings.Split(addr, ":")[1]) 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | 61 | // read key pair from a file, or generate a new key pair 62 | var keyPair *KeyPair 63 | if keyFile == "" { 64 | keyPair = GenKey() 65 | } else { 66 | serverKeys, err := ReadKeys(keyFile) 67 | if err != nil { 68 | return nil, err 69 | } 70 | keyPair = LoadKey(serverKeys[id]) 71 | if err != nil { 72 | return nil, err 73 | } 74 | } 75 | 76 | l, err := tls.Listen("tcp", fmt.Sprintf(":%d", port), tlsConfig) 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | dirServers := make([]*rpc.Client, len(dirAddrs)) 82 | for d, dirAddr := range dirAddrs { 83 | conn, err := tls.Dial("tcp", dirAddr, tlsConfig) 84 | if err != nil { 85 | return nil, err 86 | } 87 | dirServers[d] = rpc.NewClient(conn) 88 | } 89 | 90 | t := &Trustee{ 91 | addr: addr, 92 | id: id, 93 | port: port, 94 | 95 | round: 0, 96 | roundKeys: make(map[int]*KeyPair), 97 | roundGood: make(map[int]chan bool), 98 | 99 | keyPair: keyPair, 100 | 101 | dirAddrs: dirAddrs, 102 | dirServers: dirServers, 103 | 104 | listener: l, 105 | 106 | reports: make(map[int]chan *ReportArgs), 107 | 108 | tlsCert: tlsCert, 109 | tlsConfig: tlsConfig, 110 | } 111 | rpcServer := rpc.NewServer() 112 | rpcServer.Register(&TrusteeRPC{t}) 113 | go rpcServer.Accept(l) 114 | return t, nil 115 | } 116 | 117 | func (t *Trustee) Setup() { 118 | if t.id == 0 { 119 | fmt.Println("Registering trustees") 120 | } 121 | t.registerTrustee() 122 | 123 | if t.id == 0 { 124 | fmt.Println("Getting directory") 125 | } 126 | t.getDirectory() 127 | } 128 | 129 | func (t *Trustee) returnResult(round int, res bool) { 130 | for i := 0; i < t.NumReports; i++ { 131 | t.roundGood[round] <- res 132 | } 133 | } 134 | 135 | func (t *Trustee) checkReports(round int) { 136 | totalTraps := make(map[int]int) 137 | totalMsgs := make(map[int]int) 138 | for i := 0; i < t.NumReports; i++ { 139 | report := <-t.reports[round] 140 | if !report.CorrectHash || !report.CorrectTraps || !report.NoDups { 141 | t.returnResult(round, false) 142 | } 143 | if _, ok := totalTraps[report.Uid]; !ok { 144 | totalTraps[report.Uid] = report.NumTraps 145 | totalMsgs[report.Uid] = report.NumMsgs 146 | } 147 | if totalTraps[report.Uid] != report.NumTraps || 148 | totalMsgs[report.Uid] != report.NumMsgs { 149 | t.returnResult(round, false) 150 | } 151 | } 152 | 153 | sumTraps := 0 154 | sumMsgs := 0 155 | for uid := range totalTraps { 156 | sumTraps += totalTraps[uid] 157 | sumMsgs += totalMsgs[uid] 158 | } 159 | if sumTraps != sumMsgs { 160 | t.returnResult(round, false) 161 | } 162 | t.returnResult(round, true) 163 | } 164 | 165 | func (t *Trustee) Close() { 166 | if t.listener != nil { 167 | t.listener.Close() 168 | } 169 | } 170 | 171 | func (t *Trustee) registerTrustee() { 172 | pub := DumpPubKey(t.keyPair.Pub) 173 | for _, dirServer := range t.dirServers { 174 | reg := &directory.Registration{ 175 | Addr: t.addr, 176 | Id: t.id, 177 | Key: pub, 178 | Certificate: t.tlsCert.Certificate, 179 | } 180 | err := dirServer.Call("DirectoryRPC.RegisterTrustee", reg, nil) 181 | if err != nil { 182 | log.Fatal("Register err:", err) 183 | } 184 | } 185 | } 186 | 187 | func (t *Trustee) getDirectory() { 188 | // TODO: actually check consensus 189 | for _, dirServer := range t.dirServers { 190 | 191 | var direc directory.Directory 192 | err := dirServer.Call("DirectoryRPC.Directory", 0, &direc) 193 | if err != nil { 194 | log.Fatal("Directory err:", err) 195 | } 196 | 197 | t.directory = &direc 198 | t.params = direc.SystemParameter 199 | t.NumReports = t.params.NumGroups * t.params.Threshold 200 | } 201 | 202 | publicKeys := make([]*PublicKey, len(t.directory.TrusteeKeys)) 203 | for i, pub := range t.directory.TrusteeKeys { 204 | publicKeys[i] = LoadPubKey(pub) 205 | } 206 | t.publicKeys = publicKeys 207 | } 208 | 209 | func (t *Trustee) RegisterRound() { 210 | // TODO: actually generate per round keys and share it, 211 | // instead of using long-term key every round 212 | // currently not using threshold for trustees 213 | round := t.round 214 | t.reports[round] = make(chan *ReportArgs, t.NumReports) 215 | t.roundKeys[round] = t.keyPair 216 | t.roundGood[round] = make(chan bool, t.NumReports) 217 | t.Priv = t.keyPair.Priv 218 | 219 | roundKey := CombinePublicKeys(t.publicKeys) 220 | roundPub := DumpPubKey(roundKey) 221 | 222 | for _, dirServer := range t.dirServers { 223 | reg := &directory.Registration{ 224 | Round: t.round, 225 | Id: t.id, 226 | Key: roundPub, 227 | } 228 | err := dirServer.Call("DirectoryRPC.RegisterRound", reg, nil) 229 | if err != nil { 230 | log.Fatal("Register err:", err) 231 | } 232 | } 233 | 234 | go t.checkReports(round) 235 | 236 | t.round += 1 237 | } 238 | 239 | func (t *TrusteeRPC) Report(report *ReportArgs, reply *ReportReply) error { 240 | t.t.reports[report.Round] <- report 241 | ok := <-t.t.roundGood[report.Round] 242 | if ok { 243 | *reply = ReportReply{ 244 | Priv: t.t.keyPair.Priv, 245 | } 246 | return nil 247 | } else { 248 | return errors.New("Report failed") 249 | } 250 | 251 | } 252 | -------------------------------------------------------------------------------- /trustee/trustee_test.go: -------------------------------------------------------------------------------- 1 | package trustee 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "log" 7 | "net/rpc" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/kwonalbert/atom/directory" 13 | 14 | . "github.com/kwonalbert/atom/atomrpc" 15 | . "github.com/kwonalbert/atom/common" 16 | . "github.com/kwonalbert/atom/crypto" 17 | ) 18 | 19 | var addr string = "127.0.0.1:%d" 20 | var dirPort = 8000 21 | var port = 8001 22 | 23 | var testNet = SQUARE 24 | var testMode = TRAP_MODE 25 | 26 | var numServers = 0 27 | var numGroups = 1 28 | var perGroup = 3 29 | var numTrustees = 2 30 | 31 | var numMsgs = 4 32 | var msgSize = 5 33 | var threshold = perGroup 34 | var numClients = 0 35 | 36 | func setup() (*directory.Directory, []*Trustee, error) { 37 | dir, err := directory.NewDirectory(0, dirPort, testMode, testNet, 38 | numServers, numGroups, perGroup, numTrustees, 39 | numMsgs, msgSize, threshold, 40 | numClients) 41 | if err != nil { 42 | return nil, nil, err 43 | } 44 | go dir.Close() 45 | time.Sleep(100 * time.Millisecond) 46 | 47 | trustees := make([]*Trustee, numTrustees) 48 | 49 | dirAddrs := []string{fmt.Sprintf(addr, dirPort)} 50 | 51 | wg := new(sync.WaitGroup) 52 | for i := range trustees { 53 | wg.Add(1) 54 | go func(i int) { 55 | defer wg.Done() 56 | var err error 57 | trustees[i], err = NewTrustee(fmt.Sprintf(addr, port+i), i, 58 | "", dirAddrs) 59 | if err != nil { 60 | log.Fatal("Trustee creation err:", err) 61 | } 62 | 63 | trustees[i].Setup() 64 | 65 | if i == 0 { 66 | fmt.Println("Generating per round keys") 67 | } 68 | trustees[i].RegisterRound() 69 | 70 | if i == 0 { 71 | fmt.Println("Finished trustee setup") 72 | } 73 | }(i) 74 | } 75 | wg.Wait() 76 | return dir, trustees, nil 77 | } 78 | 79 | func TestReport(t *testing.T) { 80 | _, trustees, err := setup() 81 | if err != nil { 82 | t.Error(err) 83 | } 84 | 85 | // correct report 86 | report := ReportArgs{ 87 | Round: 0, 88 | Sid: 0, 89 | Uid: 0, 90 | CorrectHash: true, 91 | CorrectTraps: true, 92 | NumTraps: 128, 93 | NumMsgs: 128, 94 | } 95 | 96 | _, tlsConfig := AtomTLSConfig() 97 | 98 | replies := make([]*PrivateKey, numGroups*perGroup) 99 | repliesLock := make([]*sync.Mutex, numGroups*perGroup) 100 | for r := range replies { 101 | repliesLock[r] = new(sync.Mutex) 102 | } 103 | 104 | wg := new(sync.WaitGroup) 105 | for u := range trustees { // report to each trustee 106 | for i := 0; i < numGroups*perGroup; i++ { // each member reports 107 | wg.Add(1) 108 | go func(i, u int) { 109 | defer wg.Done() 110 | conn, err := tls.Dial("tcp", fmt.Sprintf(addr, port+u), tlsConfig) 111 | if err != nil { 112 | t.Error(err) 113 | } 114 | trustee := rpc.NewClient(conn) 115 | var reply ReportReply 116 | err = trustee.Call("TrusteeRPC.Report", &report, &reply) 117 | if err != nil { 118 | t.Error(err) 119 | } 120 | repliesLock[i].Lock() 121 | if replies[i] == nil { 122 | replies[i] = reply.Priv 123 | } else { 124 | tmp := []*PrivateKey{replies[i], reply.Priv} 125 | replies[i] = CombinePrivateKeys(tmp) 126 | } 127 | repliesLock[i].Unlock() 128 | trustee.Close() 129 | }(i, u) 130 | } 131 | wg.Wait() 132 | } 133 | 134 | for r := range replies { 135 | for u := range trustees { 136 | exp := CombinePublicKeys(trustees[u].publicKeys) 137 | res := PubFromPriv(replies[r]) 138 | if !res.Equal(exp) { 139 | t.Error("Failed to recover trustee keys") 140 | } 141 | } 142 | } 143 | } 144 | --------------------------------------------------------------------------------