├── .github ├── CODEOWNERS └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── README.md ├── aw.go ├── aw_suite_test.go ├── aw_test.go ├── channel ├── channel.go ├── channel_suite_test.go ├── channel_test.go ├── client.go ├── client_test.go ├── filter.go ├── filter_test.go ├── opt.go └── opt_test.go ├── codec ├── codec.go ├── codec_suite_test.go ├── codec_test.go ├── gcm.go ├── gcm_test.go ├── length_prefix.go ├── length_prefix_test.go ├── plain.go └── plain_test.go ├── dht ├── dht_suite_test.go ├── dhtutil │ └── dht_util.go ├── resolver.go ├── resolver_test.go ├── table.go └── table_test.go ├── examples ├── chatroom │ └── chatroom.go ├── dat │ └── dat.go └── fuzz │ ├── fuzz │ └── fuzz.go ├── go.mod ├── go.sum ├── handshake ├── ecies.go ├── ecies_test.go ├── filter.go ├── filter_test.go ├── handshake.go ├── handshake_suite_test.go ├── handshake_test.go ├── once.go └── once_test.go ├── peer ├── gossip.go ├── gossip_test.go ├── opt.go ├── opt_test.go ├── peer.go ├── peer_suite_test.go ├── peer_test.go ├── peerdiscovery.go ├── peerdiscovery_test.go ├── sync.go └── sync_test.go ├── policy ├── allow.go ├── allow_test.go ├── policy.go ├── policy_suite_test.go ├── policy_test.go ├── timeout.go └── timeout_test.go ├── tcp ├── tcp.go ├── tcp_suite_test.go └── tcp_test.go ├── transport ├── transport.go ├── transport_suite_test.go └── transport_test.go └── wire ├── addr.go ├── addr_test.go ├── error.go ├── sigaddr.go ├── sigaddr_test.go ├── wire.go ├── wire_suite_test.go └── wire_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @tok-kkk @loongy 2 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: go 2 | on: [push] 3 | jobs: 4 | 5 | test: 6 | runs-on: ubuntu-latest 7 | if: "!contains(github.event.head_commit.message, '[skip ci]')" 8 | steps: 9 | 10 | - name: Set up Go 1.16 11 | uses: actions/setup-go@v1 12 | with: 13 | go-version: 1.16 14 | id: go 15 | 16 | - name: Check out code into the Go module directory 17 | uses: actions/checkout@v1 18 | 19 | - name: Caching modules 20 | uses: actions/cache@v1 21 | with: 22 | path: ~/go/pkg/mod 23 | key: ${{ runner.os }}-go-aw-${{ hashFiles('**/go.sum') }} 24 | 25 | - name: Get dependencies 26 | run: | 27 | export PATH=$PATH:$(go env GOPATH)/bin 28 | go get -u github.com/onsi/ginkgo/ginkgo 29 | go get -u github.com/onsi/gomega/... 30 | go get -u golang.org/x/lint/golint 31 | go get -u github.com/loongy/covermerge 32 | go get -u github.com/mattn/goveralls 33 | cd $GITHUB_WORKSPACE 34 | go vet ./... 35 | golint ./... 36 | - name: Run tests and report test coverage 37 | env: 38 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 39 | run: | 40 | export PATH=$PATH:$(go env GOPATH)/bin 41 | cd $GITHUB_WORKSPACE 42 | CI=true ginkgo --v --race --cover --coverprofile coverprofile.out ./... 43 | covermerge \ 44 | channel/coverprofile.out \ 45 | codec/coverprofile.out \ 46 | dht/coverprofile.out \ 47 | handshake/coverprofile.out \ 48 | peer/coverprofile.out \ 49 | policy/coverprofile.out \ 50 | tcp/coverprofile.out \ 51 | transport/coverprofile.out \ 52 | wire/coverprofile.out > coverprofile.out 53 | goveralls -coverprofile=coverprofile.out -service=github 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Go template 3 | # Binaries for programs and plugins 4 | *.exe 5 | *.exe~ 6 | *.dll 7 | *.so 8 | *.dylib 9 | 10 | # Test binary, built with `go test -c` 11 | *.test 12 | *.coverprofile 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # IDE folders 21 | /.idea 22 | 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Dappbase Pte. Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `🌪 airwave` 2 | 3 | [![GoDoc](https://godoc.org/github.com/renproject/aw?status.svg)](https://godoc.org/github.com/renproject/aw) 4 | ![](https://github.com/renproject/aw/workflows/go/badge.svg) 5 | ![Go Report](https://goreportcard.com/badge/github.com/renproject/aw) 6 | [![Coverage Status](https://coveralls.io/repos/github/renproject/aw/badge.svg?branch=master)](https://coveralls.io/github/renproject/aw?branch=master) 7 | [![License: MIT](https://img.shields.io/badge/license-MIT-green.svg)](https://opensource.org/licenses/MIT) 8 | 9 | A flexible P2P networking library for upgradable distributed systems. The core mission of `airwave` is to provide a simple P2P interface that can support a wide variety of different algorithms, with a focus on backwards compatible. The P2P interface supports: 10 | 11 | - Peer discovery 12 | - Handshake 13 | - Casting (send to one) 14 | - Multicasting (send to many) 15 | - Broadcasting (send to everyone) 16 | 17 | ### Handshake 18 | 19 | Airwave uses a 3 way sync handshake method to authorize peers in the network. The process is as follows: 20 | 21 | ![](docs/handshake.svg) 22 | 23 | The client sends a signed rsa public key on connect. The server validates the signature, generates a random challenge, and sends the signed random challenge encrypted with the client's public key; and the server's public key. The client validates the server's signature decrypts the challenge encrypts it with the server's publickey, signs it and sends it back. 24 | 25 | Built with ❤ by Ren. 26 | -------------------------------------------------------------------------------- /aw.go: -------------------------------------------------------------------------------- 1 | package aw 2 | -------------------------------------------------------------------------------- /aw_suite_test.go: -------------------------------------------------------------------------------- 1 | package aw_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestAirwave(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "Airwave Suite") 13 | } 14 | -------------------------------------------------------------------------------- /aw_test.go: -------------------------------------------------------------------------------- 1 | package aw_test 2 | -------------------------------------------------------------------------------- /channel/channel_suite_test.go: -------------------------------------------------------------------------------- 1 | package channel_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net" 8 | "testing" 9 | "time" 10 | 11 | "github.com/renproject/aw/channel" 12 | "github.com/renproject/aw/codec" 13 | "github.com/renproject/aw/handshake" 14 | "github.com/renproject/aw/policy" 15 | "github.com/renproject/aw/tcp" 16 | "github.com/renproject/id" 17 | 18 | . "github.com/onsi/ginkgo" 19 | . "github.com/onsi/gomega" 20 | ) 21 | 22 | func TestChannel(t *testing.T) { 23 | RegisterFailHandler(Fail) 24 | RunSpecs(t, "Channel suite") 25 | } 26 | 27 | func listen(ctx context.Context, attacher channel.Attacher, self, other id.Signatory) int { 28 | ip := "127.0.0.1" 29 | listener, port, err := tcp.ListenerWithAssignedPort(ctx, ip) 30 | Expect(err).ToNot(HaveOccurred()) 31 | go func() { 32 | defer GinkgoRecover() 33 | Expect(tcp.ListenWithListener( 34 | ctx, 35 | listener, 36 | func(conn net.Conn) { 37 | log.Printf("accepted: %v", conn.RemoteAddr()) 38 | enc, dec, remote, err := handshake.Insecure(self)( 39 | conn, 40 | codec.LengthPrefixEncoder(codec.PlainEncoder, codec.PlainEncoder), 41 | codec.LengthPrefixDecoder(codec.PlainDecoder, codec.PlainDecoder), 42 | ) 43 | if err != nil { 44 | log.Printf("handshake: %v", err) 45 | return 46 | } 47 | if !other.Equal(&remote) { 48 | log.Printf("handshake: expected %v, got %v", other, remote) 49 | return 50 | } 51 | if err := attacher.Attach(ctx, remote, conn, enc, dec); err != nil { 52 | log.Printf("attach listener: %v", err) 53 | return 54 | } 55 | }, 56 | func(err error) { 57 | log.Printf("listen: %v", err) 58 | }, 59 | nil, 60 | )).To(Equal(context.Canceled)) 61 | }() 62 | 63 | return port 64 | } 65 | 66 | func dial(ctx context.Context, attacher channel.Attacher, self, other id.Signatory, port int, retry time.Duration) { 67 | go func() { 68 | defer GinkgoRecover() 69 | for { 70 | select { 71 | case <-ctx.Done(): 72 | return 73 | default: 74 | } 75 | // Dial a connection in the background. We do it in the 76 | // background so that we can dial new connections later (to 77 | // replace this one, and verify that channels behave as 78 | // expected under these conditions). 79 | go func() { 80 | defer GinkgoRecover() 81 | Expect(tcp.Dial( 82 | ctx, 83 | fmt.Sprintf("127.0.0.1:%v", port), 84 | func(conn net.Conn) { 85 | log.Printf("dialed: %v", conn.RemoteAddr()) 86 | enc, dec, remote, err := handshake.Insecure(self)( 87 | conn, 88 | codec.LengthPrefixEncoder(codec.PlainEncoder, codec.PlainEncoder), 89 | codec.LengthPrefixDecoder(codec.PlainDecoder, codec.PlainDecoder), 90 | ) 91 | if err != nil { 92 | log.Printf("handshake: %v", err) 93 | return 94 | } 95 | if !other.Equal(&remote) { 96 | log.Printf("handshake: expected %v, got %v", other, remote) 97 | return 98 | } 99 | if err := attacher.Attach(ctx, remote, conn, enc, dec); err != nil { 100 | log.Printf("attach dialer: %v", err) 101 | return 102 | } 103 | }, 104 | func(err error) { 105 | log.Printf("dial: %v", err) 106 | }, 107 | policy.ConstantTimeout(100*time.Millisecond), 108 | )).To(Succeed()) 109 | }() 110 | // After some duration, dial again. This will create an 111 | // entirely new connection, and replace the previous 112 | // connection. 113 | <-time.After(retry) 114 | } 115 | }() 116 | } 117 | -------------------------------------------------------------------------------- /channel/channel_test.go: -------------------------------------------------------------------------------- 1 | package channel_test 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "log" 7 | "math/rand" 8 | "time" 9 | 10 | "github.com/renproject/aw/channel" 11 | "github.com/renproject/aw/wire" 12 | "github.com/renproject/id" 13 | 14 | . "github.com/onsi/ginkgo" 15 | . "github.com/onsi/gomega" 16 | ) 17 | 18 | var _ = Describe("Channels", func() { 19 | 20 | run := func(ctx context.Context, remote id.Signatory) (*channel.Channel, <-chan wire.Packet, chan<- wire.Msg) { 21 | inbound, outbound := make(chan wire.Packet), make(chan wire.Msg) 22 | ch := channel.New( 23 | channel.DefaultOptions().WithDrainTimeout(1500*time.Millisecond), 24 | remote, 25 | inbound, 26 | outbound) 27 | go func() { 28 | defer GinkgoRecover() 29 | if err := ch.Run(ctx); err != nil { 30 | log.Printf("run: %v", err) 31 | return 32 | } 33 | }() 34 | return ch, inbound, outbound 35 | } 36 | 37 | sink := func(outbound chan<- wire.Msg, n uint64) <-chan struct{} { 38 | quit := make(chan struct{}) 39 | go func() { 40 | defer GinkgoRecover() 41 | defer close(quit) 42 | timeout := time.After(30 * time.Second) 43 | for iter := uint64(0); iter < n; iter++ { 44 | time.Sleep(time.Millisecond) 45 | data := [8]byte{} 46 | binary.BigEndian.PutUint64(data[:], iter) 47 | select { 48 | case outbound <- wire.Msg{Data: data[:]}: 49 | case <-timeout: 50 | Expect(func() { panic("sink timeout") }).ToNot(Panic()) 51 | } 52 | } 53 | }() 54 | return quit 55 | } 56 | 57 | stream := func(inbound <-chan wire.Packet, n uint64, inOrder bool) <-chan struct{} { 58 | quit := make(chan struct{}) 59 | go func() { 60 | defer GinkgoRecover() 61 | defer close(quit) 62 | timeout := time.After(30 * time.Second) 63 | max := uint64(0) 64 | received := make(map[uint64]int, n) 65 | for iter := uint64(0); iter < n; iter++ { 66 | select { 67 | case msg := <-inbound: 68 | data := binary.BigEndian.Uint64(msg.Msg.Data) 69 | if data > max { 70 | max = data 71 | } 72 | received[data]++ 73 | if inOrder { 74 | Expect(data).To(Equal(iter)) 75 | } 76 | if rand.Int()%1000 == 0 { 77 | log.Printf("stream %v/%v", len(received), max+1) 78 | } 79 | case <-timeout: 80 | Expect(func() { panic("stream timeout") }).ToNot(Panic()) 81 | } 82 | } 83 | for msg, count := range received { 84 | if count > 1 { 85 | log.Printf("duplicate %v (%v)", msg, count) 86 | } 87 | Expect(count).To(BeNumerically(">=", 1)) 88 | } 89 | Expect(len(received)).To(Equal(int(n))) 90 | }() 91 | return quit 92 | } 93 | 94 | Context("when a connection is attached before sending messages", func() { 95 | It("should send and receive all message in order", func() { 96 | ctx, cancel := context.WithCancel(context.Background()) 97 | defer cancel() 98 | 99 | localPrivKey := id.NewPrivKey() 100 | remotePrivKey := id.NewPrivKey() 101 | localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory()) 102 | remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory()) 103 | 104 | // Remote channel will listen for incoming connections. 105 | port := listen(ctx, remoteCh, remotePrivKey.Signatory(), localPrivKey.Signatory()) 106 | // Local channel will dial the listener (and re-dial once per 107 | // minute; so it should not impact the test, which is expected 108 | // to complete in less than one minute). 109 | dial(ctx, localCh, localPrivKey.Signatory(), remotePrivKey.Signatory(), port, time.Minute) 110 | 111 | // Wait for the connections to be attached before beginning to 112 | // send/receive messages. 113 | time.Sleep(time.Second) 114 | 115 | // Number of messages that we will test. 116 | n := uint64(1000) 117 | // Send and receive messages in both direction; from local to 118 | // remote, and from remote to local. 119 | q1 := sink(localOutbound, n) 120 | q2 := stream(remoteInbound, n, true) 121 | q3 := sink(remoteOutbound, n) 122 | q4 := stream(localInbound, n, true) 123 | 124 | <-q1 125 | <-q2 126 | <-q3 127 | <-q4 128 | }) 129 | }) 130 | 131 | Context("when a connection is attached after sending messages", func() { 132 | It("should send and receive all message in order", func() { 133 | ctx, cancel := context.WithCancel(context.Background()) 134 | defer cancel() 135 | 136 | localPrivKey := id.NewPrivKey() 137 | remotePrivKey := id.NewPrivKey() 138 | localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory()) 139 | remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory()) 140 | 141 | // Number of messages that we will test. 142 | n := uint64(1000) 143 | // Send and receive messages in both direction; from local to 144 | // remote, and from remote to local. 145 | q1 := sink(localOutbound, n) 146 | q2 := stream(remoteInbound, n, true) 147 | q3 := sink(remoteOutbound, n) 148 | q4 := stream(localInbound, n, true) 149 | 150 | // Wait for some messages to begin being sent/received before 151 | // attaching network connections. 152 | time.Sleep(time.Second) 153 | 154 | // Remote channel will listen for incoming connections. 155 | port := listen(ctx, remoteCh, remotePrivKey.Signatory(), localPrivKey.Signatory()) 156 | // Local channel will dial the listener (and re-dial once per 157 | // minute; so it should not impact the test, which is expected 158 | // to complete in less than one minute). 159 | dial(ctx, localCh, localPrivKey.Signatory(), remotePrivKey.Signatory(), port, time.Minute) 160 | 161 | <-q1 162 | <-q2 163 | <-q3 164 | <-q4 165 | }) 166 | }) 167 | 168 | Context("when a connection is replaced while sending messages", func() { 169 | Context("when draining connections in the background", func() { 170 | It("should send and receive all messages out of order", func() { 171 | ctx, cancel := context.WithCancel(context.Background()) 172 | defer cancel() 173 | 174 | localPrivKey := id.NewPrivKey() 175 | remotePrivKey := id.NewPrivKey() 176 | localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory()) 177 | remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory()) 178 | 179 | // Number of messages that we will test. This number is higher than 180 | // in other tests, because we need sending/receiving to take long 181 | // enough that replacements will happen. 182 | n := uint64(10000) 183 | // Send and receive messages in both direction; from local to 184 | // remote, and from remote to local. 185 | q1 := sink(localOutbound, n) 186 | q2 := stream(remoteInbound, n, false) 187 | q3 := sink(remoteOutbound, n) 188 | q4 := stream(localInbound, n, false) 189 | 190 | // Remote channel will listen for incoming connections. 191 | port := listen(ctx, remoteCh, remotePrivKey.Signatory(), localPrivKey.Signatory()) 192 | // Local channel will dial the listener (and re-dial once per 193 | // second). 194 | dial(ctx, localCh, localPrivKey.Signatory(), remotePrivKey.Signatory(), port, time.Second) 195 | 196 | // Wait for sinking and streaming to finish. 197 | <-q1 198 | <-q2 199 | <-q3 200 | <-q4 201 | }) 202 | }) 203 | }) 204 | }) 205 | -------------------------------------------------------------------------------- /channel/client.go: -------------------------------------------------------------------------------- 1 | package channel 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "sync" 9 | 10 | "github.com/renproject/aw/codec" 11 | "github.com/renproject/aw/wire" 12 | "github.com/renproject/id" 13 | 14 | "go.uber.org/zap" 15 | ) 16 | 17 | type receiver struct { 18 | ctx context.Context 19 | f func(id.Signatory, wire.Packet) error 20 | } 21 | 22 | type sharedChannel struct { 23 | // ch defines a channel that is bound to a remote peer. 24 | ch *Channel 25 | // rc defines a reference-counter that tracks the number of references 26 | // currently bound. 27 | rc uint64 28 | // cancel the running channel. 29 | cancel context.CancelFunc 30 | // inbound channel receives messages from the remote peer to which the 31 | // channel is bound. 32 | inbound <-chan wire.Packet 33 | // outbound channel is sent messages that are destined for the remote peer 34 | // to which the channel is bound. 35 | outbound chan<- wire.Msg 36 | } 37 | 38 | type Msg struct { 39 | wire.Packet 40 | From id.Signatory 41 | } 42 | 43 | type Client struct { 44 | opts Options 45 | self id.Signatory 46 | 47 | sharedChannelsMu *sync.RWMutex 48 | sharedChannels map[id.Signatory]*sharedChannel 49 | 50 | inbound chan Msg 51 | receivers chan receiver 52 | receiversRunningMu *sync.Mutex 53 | receiversRunning bool 54 | } 55 | 56 | func NewClient(opts Options, self id.Signatory) *Client { 57 | return &Client{ 58 | opts: opts, 59 | self: self, 60 | 61 | sharedChannelsMu: new(sync.RWMutex), 62 | sharedChannels: map[id.Signatory]*sharedChannel{}, 63 | 64 | inbound: make(chan Msg), 65 | receivers: make(chan receiver), 66 | receiversRunningMu: new(sync.Mutex), 67 | receiversRunning: false, 68 | } 69 | } 70 | 71 | func (client *Client) Bind(remote id.Signatory) { 72 | client.sharedChannelsMu.Lock() 73 | defer client.sharedChannelsMu.Unlock() 74 | 75 | shared, ok := client.sharedChannels[remote] 76 | if ok { 77 | shared.rc++ 78 | return 79 | } 80 | 81 | inbound := make(chan wire.Packet, client.opts.InboundBufferSize) 82 | outbound := make(chan wire.Msg, client.opts.OutboundBufferSize) 83 | 84 | ctx, cancel := context.WithCancel(context.Background()) 85 | ch := New(client.opts, remote, inbound, outbound) 86 | go func() { 87 | if err := ch.Run(ctx); err != nil { 88 | if !errors.Is(err, context.Canceled) { 89 | client.opts.Logger.Error("run", zap.Error(err)) 90 | } 91 | } 92 | }() 93 | go func() { 94 | for { 95 | select { 96 | case <-ctx.Done(): 97 | return 98 | case packet := <-inbound: 99 | select { 100 | case <-ctx.Done(): 101 | return 102 | case client.inbound <- Msg{Packet: packet, From: remote}: 103 | } 104 | } 105 | } 106 | }() 107 | 108 | client.sharedChannels[remote] = &sharedChannel{ 109 | ch: ch, 110 | rc: 1, 111 | cancel: cancel, 112 | inbound: inbound, 113 | outbound: outbound, 114 | } 115 | } 116 | 117 | func (client *Client) Unbind(remote id.Signatory) { 118 | client.sharedChannelsMu.Lock() 119 | defer client.sharedChannelsMu.Unlock() 120 | 121 | shared, ok := client.sharedChannels[remote] 122 | if !ok { 123 | return 124 | } 125 | 126 | shared.rc-- 127 | if shared.rc == 0 { 128 | shared.cancel() 129 | delete(client.sharedChannels, remote) 130 | } 131 | } 132 | 133 | func (client *Client) IsBound(remote id.Signatory) bool { 134 | client.sharedChannelsMu.RLock() 135 | defer client.sharedChannelsMu.RUnlock() 136 | 137 | return client.sharedChannels[remote].rc > 0 138 | } 139 | 140 | // Attach a network connection, encoder, and decoder to the Channel associated 141 | // with a remote peer without incrementing the reference-counter of the Channel. 142 | // An error is returned if no Channel is associated with the remote peer. As 143 | // with the Attach method that is exposed directly by a Channel, this method is 144 | // blocking. 145 | func (client *Client) Attach(ctx context.Context, remote id.Signatory, conn net.Conn, enc codec.Encoder, dec codec.Decoder) error { 146 | client.sharedChannelsMu.RLock() 147 | shared, ok := client.sharedChannels[remote] 148 | if !ok { 149 | client.sharedChannelsMu.RUnlock() 150 | return fmt.Errorf("attach: no connection to %v", remote) 151 | } 152 | client.sharedChannelsMu.RUnlock() 153 | 154 | client.opts.Logger.Debug("attach", zap.String("self", client.self.String()), zap.String("remote", remote.String()), zap.String("addr", conn.RemoteAddr().String())) 155 | if err := shared.ch.Attach(ctx, remote, conn, enc, dec); err != nil { 156 | return fmt.Errorf("attach: %w", err) 157 | } 158 | return nil 159 | } 160 | 161 | func (client *Client) Send(ctx context.Context, remote id.Signatory, msg wire.Msg) error { 162 | client.sharedChannelsMu.RLock() 163 | shared, ok := client.sharedChannels[remote] 164 | if !ok { 165 | client.sharedChannelsMu.RUnlock() 166 | return fmt.Errorf("channel not found: %v", remote) 167 | } 168 | client.sharedChannelsMu.RUnlock() 169 | 170 | select { 171 | case <-ctx.Done(): 172 | return fmt.Errorf("sending message %w", ctx.Err()) 173 | case shared.outbound <- msg: 174 | return nil 175 | } 176 | } 177 | 178 | func (client *Client) Receive(ctx context.Context, f func(id.Signatory, wire.Packet) error) { 179 | client.receiversRunningMu.Lock() 180 | if client.receiversRunning { 181 | client.receiversRunningMu.Unlock() 182 | select { 183 | case <-ctx.Done(): 184 | case client.receivers <- receiver{ctx: ctx, f: f}: 185 | } 186 | return 187 | } 188 | client.receiversRunning = true 189 | client.receiversRunningMu.Unlock() 190 | 191 | go func() { 192 | receivers := []receiver{} 193 | 194 | for { 195 | select { 196 | case receiver := <-client.receivers: 197 | // A new receiver has been registered. 198 | receivers = append(receivers, receiver) 199 | case msg := <-client.inbound: 200 | marker := 0 201 | for _, receiver := range receivers { 202 | select { 203 | case <-receiver.ctx.Done(): 204 | // Do nothing. This will implicitly mark it for 205 | // deletion. 206 | default: 207 | if err := receiver.f(msg.From, msg.Packet); err != nil { 208 | // When a channel is killed, its context will be 209 | // cancelled, its underlying network connections 210 | // will be dropped, and sending will fail. A killed 211 | // channel can only be revived by completely 212 | // unbinding all references, and binding a new 213 | // reference. 214 | client.opts.Logger.Error("filter", zap.String("remote", msg.From.String()), zap.Error(err)) 215 | client.sharedChannelsMu.Lock() 216 | if shared, ok := client.sharedChannels[msg.From]; ok { 217 | shared.cancel() 218 | } 219 | client.sharedChannelsMu.Unlock() 220 | } 221 | receivers[marker] = receiver 222 | marker++ 223 | } 224 | } 225 | // Delete everything that was marked for deletion. 226 | for del := marker; del < len(receivers); del++ { 227 | receivers[del] = receiver{} 228 | } 229 | receivers = receivers[:marker] 230 | } 231 | if len(receivers) == 0 { 232 | break 233 | } 234 | } 235 | 236 | client.receiversRunningMu.Lock() 237 | client.receiversRunning = false 238 | client.receiversRunningMu.Unlock() 239 | }() 240 | 241 | select { 242 | case <-ctx.Done(): 243 | case client.receivers <- receiver{ctx: ctx, f: f}: 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /channel/client_test.go: -------------------------------------------------------------------------------- 1 | package channel_test 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "time" 7 | 8 | "github.com/renproject/aw/channel" 9 | "github.com/renproject/aw/wire" 10 | "github.com/renproject/id" 11 | 12 | . "github.com/onsi/ginkgo" 13 | . "github.com/onsi/gomega" 14 | ) 15 | 16 | var _ = Describe("Client", func() { 17 | 18 | sink := func(ctx context.Context, client *channel.Client, remote id.Signatory, n uint64) <-chan struct{} { 19 | quit := make(chan struct{}) 20 | go func() { 21 | defer GinkgoRecover() 22 | defer close(quit) 23 | ctx, cancel := context.WithTimeout(ctx, 30*time.Second) 24 | defer cancel() 25 | for iter := uint64(0); iter < n; iter++ { 26 | data := [8]byte{} 27 | binary.BigEndian.PutUint64(data[:], iter) 28 | err := client.Send(ctx, remote, wire.Msg{Data: data[:]}) 29 | Expect(err).ToNot(HaveOccurred()) 30 | } 31 | }() 32 | return quit 33 | } 34 | 35 | stream := func(ctx context.Context, client *channel.Client, n uint64) <-chan struct{} { 36 | quit := make(chan struct{}) 37 | go func() { 38 | defer GinkgoRecover() 39 | defer close(quit) 40 | defer time.Sleep(time.Millisecond) // Wait for the receiver to be shutdown. 41 | ctx, cancel := context.WithTimeout(ctx, 30*time.Second) 42 | defer cancel() 43 | receiver := make(chan wire.Msg) 44 | client.Receive(ctx, func(signatory id.Signatory, packet wire.Packet) error { 45 | receiver <- packet.Msg 46 | return nil 47 | }) 48 | for iter := uint64(0); iter < n; iter++ { 49 | time.Sleep(time.Millisecond) 50 | select { 51 | case <-ctx.Done(): 52 | Expect(ctx.Err()).ToNot(HaveOccurred()) 53 | case msg := <-receiver: 54 | data := binary.BigEndian.Uint64(msg.Data) 55 | Expect(data).To(Equal(iter)) 56 | } 57 | } 58 | }() 59 | return quit 60 | } 61 | 62 | Context("when binding and attaching", func() { 63 | It("should send and receive all messages in order", func() { 64 | ctx, cancel := context.WithCancel(context.Background()) 65 | defer cancel() 66 | 67 | localPrivKey := id.NewPrivKey() 68 | remotePrivKey := id.NewPrivKey() 69 | 70 | local := channel.NewClient( 71 | channel.DefaultOptions(), 72 | localPrivKey.Signatory()) 73 | local.Bind(remotePrivKey.Signatory()) 74 | defer local.Unbind(remotePrivKey.Signatory()) 75 | Expect(local.IsBound(remotePrivKey.Signatory())).To(BeTrue()) 76 | 77 | remote := channel.NewClient( 78 | channel.DefaultOptions(), 79 | remotePrivKey.Signatory()) 80 | remote.Bind(localPrivKey.Signatory()) 81 | defer remote.Unbind(localPrivKey.Signatory()) 82 | Expect(remote.IsBound(localPrivKey.Signatory())).To(BeTrue()) 83 | 84 | port := listen(ctx, remote, remotePrivKey.Signatory(), localPrivKey.Signatory()) 85 | dial(ctx, local, localPrivKey.Signatory(), remotePrivKey.Signatory(), port, time.Minute) 86 | 87 | n := uint64(5000) 88 | q1 := sink(ctx, local, remotePrivKey.Signatory(), n) 89 | q2 := stream(ctx, remote, n) 90 | q3 := sink(ctx, remote, localPrivKey.Signatory(), n) 91 | q4 := stream(ctx, local, n) 92 | 93 | <-q1 94 | <-q2 95 | <-q3 96 | <-q4 97 | }) 98 | }) 99 | 100 | Context("when binding and unbinding while attached", func() { 101 | It("should send and receive all messages in order", func() { 102 | ctx, cancel := context.WithCancel(context.Background()) 103 | defer cancel() 104 | 105 | localPrivKey := id.NewPrivKey() 106 | remotePrivKey := id.NewPrivKey() 107 | 108 | local := channel.NewClient( 109 | channel.DefaultOptions(), 110 | localPrivKey.Signatory()) 111 | local.Bind(remotePrivKey.Signatory()) 112 | defer local.Unbind(remotePrivKey.Signatory()) 113 | Expect(local.IsBound(remotePrivKey.Signatory())).To(BeTrue()) 114 | 115 | remote := channel.NewClient( 116 | channel.DefaultOptions(), 117 | remotePrivKey.Signatory()) 118 | remote.Bind(localPrivKey.Signatory()) 119 | defer remote.Unbind(localPrivKey.Signatory()) 120 | Expect(remote.IsBound(localPrivKey.Signatory())).To(BeTrue()) 121 | 122 | port := listen(ctx, remote, remotePrivKey.Signatory(), localPrivKey.Signatory()) 123 | dial(ctx, local, localPrivKey.Signatory(), remotePrivKey.Signatory(), port, time.Minute) 124 | 125 | go func() { 126 | remote := remotePrivKey.Signatory() 127 | for { 128 | select { 129 | case <-ctx.Done(): 130 | return 131 | default: 132 | local.Bind(remote) 133 | time.Sleep(time.Millisecond) 134 | local.Unbind(remote) 135 | time.Sleep(time.Millisecond) 136 | } 137 | } 138 | }() 139 | 140 | go func() { 141 | local := localPrivKey.Signatory() 142 | for { 143 | select { 144 | case <-ctx.Done(): 145 | return 146 | default: 147 | remote.Bind(local) 148 | time.Sleep(time.Millisecond) 149 | remote.Unbind(local) 150 | time.Sleep(time.Millisecond) 151 | } 152 | } 153 | }() 154 | 155 | n := uint64(5000) 156 | q1 := sink(ctx, local, remotePrivKey.Signatory(), n) 157 | q2 := stream(ctx, remote, n) 158 | q3 := sink(ctx, remote, localPrivKey.Signatory(), n) 159 | q4 := stream(ctx, local, n) 160 | 161 | <-q1 162 | <-q2 163 | <-q3 164 | <-q4 165 | }) 166 | }) 167 | 168 | Context("when sending before binding", func() { 169 | It("should return an error", func() { 170 | ctx, cancel := context.WithCancel(context.Background()) 171 | defer cancel() 172 | 173 | remotePrivKey := id.NewPrivKey() 174 | localPrivKey := id.NewPrivKey() 175 | local := channel.NewClient( 176 | channel.DefaultOptions(), 177 | localPrivKey.Signatory()) 178 | Expect(local.Send(ctx, remotePrivKey.Signatory(), wire.Msg{})).To(HaveOccurred()) 179 | }) 180 | }) 181 | 182 | Context("when attaching before binding", func() { 183 | It("should return an error", func() { 184 | ctx, cancel := context.WithCancel(context.Background()) 185 | defer cancel() 186 | 187 | remotePrivKey := id.NewPrivKey() 188 | localPrivKey := id.NewPrivKey() 189 | local := channel.NewClient( 190 | channel.DefaultOptions(), 191 | localPrivKey.Signatory()) 192 | Expect(local.Attach(ctx, remotePrivKey.Signatory(), nil, nil, nil)).To(HaveOccurred()) 193 | }) 194 | }) 195 | }) 196 | -------------------------------------------------------------------------------- /channel/filter.go: -------------------------------------------------------------------------------- 1 | package channel 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/renproject/aw/wire" 7 | "github.com/renproject/id" 8 | ) 9 | 10 | // A Filter is used to drop messages, and their respective channels, when the 11 | // messages are unexpected or malicious. 12 | type Filter interface { 13 | Filter(id.Signatory, wire.Msg) bool 14 | } 15 | 16 | // FilterFunc is a wrapper around a function that implements the Filter 17 | // interface. 18 | type FilterFunc func(id.Signatory, wire.Msg) bool 19 | 20 | func (f FilterFunc) Filter(from id.Signatory, msg wire.Msg) bool { 21 | return f(from, msg) 22 | } 23 | 24 | // A SyncFilter is used to filter synchronisation messages. If the local peer 25 | // is not expecting to receive a synchronisation message, then the message will 26 | // be filtered and the respective channel will be dropped. 27 | type SyncFilter struct { 28 | expectingMu *sync.RWMutex 29 | expecting map[string]int 30 | } 31 | 32 | // NewSyncFilter returns a new filter where all content IDs are denied by 33 | // default, and no allowances exist. 34 | func NewSyncFilter() *SyncFilter { 35 | return &SyncFilter{ 36 | expectingMu: new(sync.RWMutex), 37 | expecting: make(map[string]int, 1000), 38 | } 39 | } 40 | 41 | // Allow synchronisation messages for the given content ID. Every call to Allow 42 | // must be eventually followed by a call to Deny. By default, all content IDs 43 | // are denied until Allow is called. 44 | func (f *SyncFilter) Allow(contentID []byte) { 45 | f.expectingMu.Lock() 46 | defer f.expectingMu.Unlock() 47 | 48 | f.expecting[string(contentID)]++ 49 | } 50 | 51 | // Deny synchronisation messages for the given content ID. Denying a content ID 52 | // reverses one call to Allow. If there are no calls to Allow, this method does 53 | // nothing. 54 | func (f *SyncFilter) Deny(contentID []byte) { 55 | f.expectingMu.Lock() 56 | defer f.expectingMu.Unlock() 57 | 58 | contentIDAsString := string(contentID) 59 | if f.expecting[contentIDAsString] == 1 { 60 | delete(f.expecting, contentIDAsString) 61 | return 62 | } 63 | f.expecting[contentIDAsString]-- 64 | } 65 | 66 | // Filter returns true if the message is not a synchronisation message, or the 67 | // content ID is not expected. 68 | func (f *SyncFilter) Filter(from id.Signatory, msg wire.Msg) bool { 69 | if msg.Type != wire.MsgTypeSync { 70 | return true 71 | } 72 | 73 | f.expectingMu.RLock() 74 | defer f.expectingMu.RUnlock() 75 | 76 | _, ok := f.expecting[string(msg.Data)] 77 | if !ok { 78 | return true 79 | } 80 | return false 81 | } 82 | -------------------------------------------------------------------------------- /channel/filter_test.go: -------------------------------------------------------------------------------- 1 | package channel_test 2 | -------------------------------------------------------------------------------- /channel/opt.go: -------------------------------------------------------------------------------- 1 | package channel 2 | 3 | import ( 4 | "time" 5 | 6 | "go.uber.org/zap" 7 | "golang.org/x/time/rate" 8 | ) 9 | 10 | var ( 11 | DefaultDrainTimeout = 30 * time.Second 12 | DefaultMaxMessageSize = 4 * 1024 * 1024 // 4MB 13 | DefaultRateLimit = rate.Limit(1024 * 1024) // 1MB per second 14 | DefaultInboundBufferSize = 0 15 | DefaultOutboundBufferSize = 0 16 | ) 17 | 18 | // Options for parameterizing the behaviour of a Channel. 19 | type Options struct { 20 | Logger *zap.Logger 21 | DrainTimeout time.Duration 22 | MaxMessageSize int 23 | RateLimit rate.Limit 24 | InboundBufferSize int 25 | OutboundBufferSize int 26 | } 27 | 28 | // DefaultOptions returns Options with sane defaults. 29 | func DefaultOptions() Options { 30 | logger, err := zap.NewDevelopment() 31 | if err != nil { 32 | panic(err) 33 | } 34 | return Options{ 35 | Logger: logger, 36 | DrainTimeout: DefaultDrainTimeout, 37 | MaxMessageSize: DefaultMaxMessageSize, 38 | RateLimit: DefaultRateLimit, 39 | InboundBufferSize: DefaultInboundBufferSize, 40 | OutboundBufferSize: DefaultOutboundBufferSize, 41 | } 42 | } 43 | 44 | // WithLogger sets the Logger used for logging all errors, warnings, information, 45 | // debug traces, and so on. 46 | func (opts Options) WithLogger(logger *zap.Logger) Options { 47 | opts.Logger = logger 48 | return opts 49 | } 50 | 51 | // WithDrainTimeout sets the timeout used by the Channel when draining replaced 52 | // connections. If a Channel does not see a message on a draining connection 53 | // before the timeout, then the draining connection is dropped and closed, and 54 | // all future messages sent to the connection will be lost. 55 | func (opts Options) WithDrainTimeout(timeout time.Duration) Options { 56 | opts.DrainTimeout = timeout 57 | return opts 58 | } 59 | 60 | // WithMaxMessageSize sets the maximum number of bytes that a channel will read 61 | // at one time. This number restricts the maximum message size that remote peers 62 | // can send, defines the buffer size used for unmarshalling messages, and 63 | // defines the rate limit burst. 64 | func (opts Options) WithMaxMessageSize(maxMessageSize int) Options { 65 | opts.MaxMessageSize = maxMessageSize 66 | return opts 67 | } 68 | 69 | // WithRateLimit sets the bytes-per-second rate limit that will be enforced on 70 | // all network connections. If a network connection exceeds this limit, then the 71 | // connection will be closed, and a new one will need to be established. 72 | func (opts Options) WithRateLimit(rateLimit rate.Limit) Options { 73 | opts.RateLimit = rateLimit 74 | return opts 75 | } 76 | 77 | // WithInboundBufferSize defines the number of inbound messages that can be 78 | // buffered in memory before back-pressure will prevent the buffering of new 79 | // inbound messages. 80 | func (opts Options) WithInboundBufferSize(size int) Options { 81 | opts.InboundBufferSize = size 82 | return opts 83 | } 84 | 85 | // WithOutboundBufferSize defines the number of outbound messages that can be 86 | // buffered in memroy before back-pressure will prevent the buffering of new 87 | // outbound messages. 88 | func (opts Options) WithOutboundBufferSize(size int) Options { 89 | opts.OutboundBufferSize = size 90 | return opts 91 | } 92 | -------------------------------------------------------------------------------- /channel/opt_test.go: -------------------------------------------------------------------------------- 1 | package channel_test 2 | -------------------------------------------------------------------------------- /codec/codec.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // An Encoder is a function that encodes a byte slice into an I/O writer. It 8 | // returns the number of bytes written, and errors that happen. 9 | type Encoder func(w io.Writer, buf []byte) (int, error) 10 | 11 | // A Decoder is a function the decodes bytes from an I/O reader into a byte 12 | // slice. It returns the number of bytes read, and errors that happen. 13 | type Decoder func(r io.Reader, buf []byte) (int, error) 14 | -------------------------------------------------------------------------------- /codec/codec_suite_test.go: -------------------------------------------------------------------------------- 1 | package codec_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestCodec(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "Codec Suite") 13 | } 14 | -------------------------------------------------------------------------------- /codec/codec_test.go: -------------------------------------------------------------------------------- 1 | package codec_test 2 | -------------------------------------------------------------------------------- /codec/gcm.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "encoding/binary" 8 | "fmt" 9 | "io" 10 | "math" 11 | 12 | "github.com/renproject/id" 13 | ) 14 | 15 | type gcmNonce struct { 16 | // top and bottom together represent the top 32 bits and bottom 64 bits of a 96 bit unsigned integer 17 | top uint32 18 | bottom uint64 19 | countDown bool 20 | } 21 | 22 | func (nonce gcmNonce) next() { 23 | if nonce.countDown { 24 | nonce.pred() 25 | } else { 26 | nonce.succ() 27 | } 28 | 29 | } 30 | 31 | func (nonce gcmNonce) succ() { 32 | nonce.bottom++ 33 | // If bottom overflows, increment top by 1 34 | if nonce.bottom == 0 { 35 | nonce.top++ 36 | } 37 | } 38 | 39 | func (nonce gcmNonce) pred() { 40 | nonce.bottom-- 41 | // If bottom underflows, decrement top by 1 42 | if nonce.bottom == math.MaxUint64 { 43 | nonce.top-- 44 | } 45 | } 46 | 47 | // A GCMSession stores the state of a GCM authenticated/encrypted session. This 48 | // includes the read/write nonces, memory buffers, and the GCM cipher itself. 49 | type GCMSession struct { 50 | gcm cipher.AEAD 51 | readNonce gcmNonce 52 | writeNonce gcmNonce 53 | } 54 | 55 | // NewGCMSession accepts a symmetric secret key and returns a new GCMSession 56 | // that is configured using the symmetric secret key. 57 | func NewGCMSession(key [32]byte, self, remote id.Signatory) (*GCMSession, error) { 58 | block, err := aes.NewCipher(key[:]) 59 | if err != nil { 60 | return &GCMSession{}, fmt.Errorf("creating aes cipher: %v", err) 61 | } 62 | gcm, err := cipher.NewGCM(block) 63 | if err != nil { 64 | return &GCMSession{}, fmt.Errorf("creating gcm cipher: %v", err) 65 | } 66 | 67 | gcmSession := &GCMSession{ 68 | gcm: gcm, 69 | readNonce: gcmNonce{}, 70 | writeNonce: gcmNonce{}, 71 | } 72 | 73 | if bytes.Compare(self[:], remote[:]) < 0 { 74 | gcmSession.writeNonce.top = math.MaxUint32 75 | gcmSession.writeNonce.bottom = math.MaxUint64 76 | gcmSession.writeNonce.countDown = true 77 | } else { 78 | gcmSession.readNonce.top = math.MaxUint32 79 | gcmSession.readNonce.bottom = math.MaxUint64 80 | gcmSession.readNonce.countDown = true 81 | } 82 | return gcmSession, nil 83 | } 84 | 85 | // GCMEncoder accepts a GCMSession and an encoder that wraps data encryption 86 | func GCMEncoder(session *GCMSession, enc Encoder) Encoder { 87 | return func(w io.Writer, buf []byte) (int, error) { 88 | nonceBuf := [12]byte{} 89 | binary.BigEndian.PutUint32(nonceBuf[:4], session.writeNonce.top) 90 | binary.BigEndian.PutUint64(nonceBuf[4:], session.writeNonce.bottom) 91 | session.writeNonce.next() 92 | encoded := session.gcm.Seal(nil, nonceBuf[:], buf, nil) 93 | _, err := enc(w, encoded) 94 | if err != nil { 95 | return 0, fmt.Errorf("encoding sealed data: %v", err) 96 | } 97 | return len(buf), nil 98 | } 99 | } 100 | 101 | // GCMDEcoder accepts a GCMSession and a decoder that wraps data decryption 102 | func GCMDecoder(session *GCMSession, dec Decoder) Decoder { 103 | return func(r io.Reader, buf []byte) (int, error) { 104 | extendedSize := len(buf) + 16 105 | if cap(buf) < extendedSize { 106 | return 0, fmt.Errorf("decoding data: buffer too small, expected buffer capacity %v, got buffer capacity %v", extendedSize, cap(buf)) 107 | } 108 | buf = buf[:extendedSize] 109 | n, err := dec(r, buf) 110 | if err != nil { 111 | return n, fmt.Errorf("decoding data: %v", err) 112 | } 113 | nonceBuf := [12]byte{} 114 | binary.BigEndian.PutUint32(nonceBuf[:4], session.readNonce.top) 115 | binary.BigEndian.PutUint64(nonceBuf[4:], session.readNonce.bottom) 116 | session.readNonce.next() 117 | decrypted, err := session.gcm.Open(nil, nonceBuf[:], buf[:n], nil) 118 | 119 | if err != nil { 120 | return 0, fmt.Errorf("opening sealed data: %v", err) 121 | } 122 | copy(buf, decrypted) 123 | 124 | return len(decrypted), nil 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /codec/gcm_test.go: -------------------------------------------------------------------------------- 1 | package codec_test 2 | 3 | import ( 4 | "bytes" 5 | . "github.com/onsi/ginkgo" 6 | . "github.com/onsi/gomega" 7 | "github.com/renproject/aw/codec" 8 | "github.com/renproject/id" 9 | "math/rand" 10 | ) 11 | 12 | var _ = Describe("GCM Codec", func() { 13 | Context("when encoding and decoding a message using a GCM encoder and decoder", func() { 14 | It("should successfully transmit message in both directions", func() { 15 | var readerWriter1 bytes.Buffer 16 | var readerWriter2 bytes.Buffer 17 | data1 := "Hi there from 1!" 18 | data2 := "Hi there from 2!" 19 | var key [32]byte 20 | rand.Read(key[:]) 21 | privKey1 := id.NewPrivKey() 22 | privKey2 := id.NewPrivKey() 23 | gcmSession1, err := codec.NewGCMSession(key, id.NewSignatory(privKey1.PubKey()), id.NewSignatory(privKey2.PubKey())) 24 | Expect(err).To(BeNil()) 25 | gcmSession2, err := codec.NewGCMSession(key, id.NewSignatory(privKey2.PubKey()), id.NewSignatory(privKey1.PubKey())) 26 | Expect(err).To(BeNil()) 27 | 28 | enc1 := codec.LengthPrefixEncoder(codec.PlainEncoder, codec.GCMEncoder(gcmSession1, codec.PlainEncoder)) 29 | n1, err1 := enc1(&readerWriter1, []byte(data1)) 30 | enc2 := codec.LengthPrefixEncoder(codec.PlainEncoder, codec.GCMEncoder(gcmSession2, codec.PlainEncoder)) 31 | n2, err2 := enc2(&readerWriter2, []byte(data2)) 32 | Expect(err1).To(BeNil()) 33 | Expect(n1).To(Equal(16)) 34 | Expect(err2).To(BeNil()) 35 | Expect(n2).To(Equal(16)) 36 | 37 | var buf1 [4086]byte 38 | var buf2 [4086]byte 39 | dec1 := codec.LengthPrefixDecoder(codec.PlainDecoder, codec.GCMDecoder(gcmSession1, codec.PlainDecoder)) 40 | n1, err1 = dec1(&readerWriter2, buf1[:]) 41 | dec2 := codec.LengthPrefixDecoder(codec.PlainDecoder, codec.GCMDecoder(gcmSession2, codec.PlainDecoder)) 42 | n2, err2 = dec2(&readerWriter1, buf2[:]) 43 | Expect(err1).To(BeNil()) 44 | Expect(n1).To(Equal(16)) 45 | Expect(err2).To(BeNil()) 46 | Expect(n2).To(Equal(16)) 47 | 48 | Expect(string(buf1[:n1])).To(Equal("Hi there from 2!")) 49 | Expect(string(buf2[:n2])).To(Equal("Hi there from 1!")) 50 | 51 | }) 52 | }) 53 | }) 54 | -------------------------------------------------------------------------------- /codec/length_prefix.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | // LengthPrefixEncoder returns an Encoder that prefixes all data with a uint32 10 | // length. The returned Encoder wraps two other Encoders, one that is used to 11 | // encode the length prefix, and one that is used to encode the actual data. 12 | func LengthPrefixEncoder(prefixEnc Encoder, bodyEnc Encoder) Encoder { 13 | return func(w io.Writer, buf []byte) (int, error) { 14 | prefix := uint32(len(buf)) 15 | prefixBytes := [4]byte{} 16 | binary.BigEndian.PutUint32(prefixBytes[:], prefix) 17 | if _, err := prefixEnc(w, prefixBytes[:]); err != nil { 18 | return 0, fmt.Errorf("encoding data length: %w", err) 19 | } 20 | n, err := bodyEnc(w, buf) 21 | if err != nil { 22 | return n, fmt.Errorf("encoding data: %w", err) 23 | } 24 | return n, nil 25 | } 26 | } 27 | 28 | // LengthPrefixDecoder returns an Decoder that assumes all data is prefixed with 29 | // a uint32 length. The returned Decoder wraps two other Decoders, one that is 30 | // used to decode the length prefix, and one that is used to decode the actual 31 | // data. 32 | func LengthPrefixDecoder(prefixDec Decoder, bodyDec Decoder) Decoder { 33 | return func(r io.Reader, buf []byte) (int, error) { 34 | prefixBytes := [4]byte{} 35 | if _, err := prefixDec(r, prefixBytes[:]); err != nil { 36 | return 0, fmt.Errorf("decoding data length: %w", err) 37 | } 38 | prefix := binary.BigEndian.Uint32(prefixBytes[:]) 39 | if uint32(len(buf)) < prefix { 40 | return 0, fmt.Errorf("decoding data length: expected %v, got %v", len(buf), prefix) 41 | } 42 | n, err := bodyDec(r, buf[:prefix]) 43 | if err != nil { 44 | return n, fmt.Errorf("decoding data: %w", err) 45 | } 46 | return n, nil 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /codec/length_prefix_test.go: -------------------------------------------------------------------------------- 1 | package codec_test 2 | 3 | import ( 4 | "bytes" 5 | . "github.com/onsi/ginkgo" 6 | . "github.com/onsi/gomega" 7 | "github.com/renproject/aw/codec" 8 | ) 9 | 10 | var _ = Describe("Length Prefix Codec", func() { 11 | Context("when encoding and decoding a message using a length prefix encoder and decoder", func() { 12 | It("should successfully transmit message", func() { 13 | var readerWriter bytes.Buffer 14 | data := "Hi there!" 15 | 16 | enc := codec.LengthPrefixEncoder(codec.PlainEncoder, codec.PlainEncoder) 17 | n, err := enc(&readerWriter, []byte(data)) 18 | Expect(n).To(Equal(9)) 19 | Expect(err).To(BeNil()) 20 | 21 | var buf [4086]byte 22 | dec := codec.LengthPrefixDecoder(codec.PlainDecoder, codec.PlainDecoder) 23 | n, err = dec(&readerWriter, buf[:9]) 24 | Expect(n).To(Equal(9)) 25 | Expect(err).To(BeNil()) 26 | 27 | Expect(string(buf[:n])).To(Equal("Hi there!")) 28 | }) 29 | }) 30 | 31 | Context("when decoding a message using a length prefix decoder with a buffer larger than the message", func() { 32 | It("should not return an EOF error", func() { 33 | var readerWriter bytes.Buffer 34 | data := "Hi there!" 35 | 36 | enc := codec.LengthPrefixEncoder(codec.PlainEncoder, codec.PlainEncoder) 37 | n, err := enc(&readerWriter, []byte(data)) 38 | Expect(n).To(Equal(9)) 39 | Expect(err).To(BeNil()) 40 | 41 | var buf [4086]byte 42 | dec := codec.LengthPrefixDecoder(codec.PlainDecoder, codec.PlainDecoder) 43 | n, err = dec(&readerWriter, buf[:]) 44 | Expect(n).To(Equal(9)) 45 | Expect(err).To(BeNil()) 46 | 47 | Expect(string(buf[:n])).To(Equal("Hi there!")) 48 | }) 49 | }) 50 | }) 51 | -------------------------------------------------------------------------------- /codec/plain.go: -------------------------------------------------------------------------------- 1 | package codec 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // PlainEncoder writes data directly to the IO writer without modification. The 8 | // entire buffer is written. 9 | func PlainEncoder(w io.Writer, buf []byte) (int, error) { 10 | return w.Write(buf) 11 | } 12 | 13 | // PlainDecoder reads data directly from the IO reader without modification. The 14 | // entire buffer will be filled by reading data from the IO reader. This means 15 | // that the buffer must be of the right length with respect to the data that is 16 | // being read. 17 | func PlainDecoder(r io.Reader, buf []byte) (int, error) { 18 | n, err := io.ReadFull(r, buf) 19 | return n, err 20 | } 21 | -------------------------------------------------------------------------------- /codec/plain_test.go: -------------------------------------------------------------------------------- 1 | package codec_test 2 | 3 | import ( 4 | "bytes" 5 | "github.com/renproject/aw/codec" 6 | "io" 7 | 8 | . "github.com/onsi/ginkgo" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | var _ = Describe("Plain Codec", func() { 13 | Context("when encoding and decoding a message using a plain encoder and decoder", func() { 14 | It("should successfully transmit message", func() { 15 | var readerWriter bytes.Buffer 16 | data := "Hi there!" 17 | 18 | n, err := codec.PlainEncoder(&readerWriter, []byte(data)) 19 | Expect(n).To(Equal(9)) 20 | Expect(err).To(BeNil()) 21 | 22 | var buf [4086]byte 23 | n, err = codec.PlainDecoder(&readerWriter, buf[:9]) 24 | Expect(n).To(Equal(9)) 25 | Expect(err).To(BeNil()) 26 | 27 | Expect(string(buf[:n])).To(Equal("Hi there!")) 28 | }) 29 | }) 30 | 31 | Context("when decoding a message using a plain decoder with a buffer larger than the message", func() { 32 | It("should return an EOF error", func() { 33 | var readerWriter bytes.Buffer 34 | data := "Hi there!" 35 | 36 | n, err := codec.PlainEncoder(&readerWriter, []byte(data)) 37 | Expect(n).To(Equal(9)) 38 | Expect(err).To(BeNil()) 39 | 40 | var buf [4086]byte 41 | n, err = codec.PlainDecoder(&readerWriter, buf[:]) 42 | Expect(n).To(Equal(9)) 43 | Expect(err).To(Equal(io.ErrUnexpectedEOF)) 44 | 45 | Expect(string(buf[:n])).To(Equal("Hi there!")) 46 | }) 47 | }) 48 | }) 49 | -------------------------------------------------------------------------------- /dht/dht_suite_test.go: -------------------------------------------------------------------------------- 1 | package dht_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestDHT(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "DHT Suite") 13 | } 14 | -------------------------------------------------------------------------------- /dht/dhtutil/dht_util.go: -------------------------------------------------------------------------------- 1 | package dhtutil 2 | 3 | import ( 4 | "github.com/renproject/id" 5 | "math/rand" 6 | "reflect" 7 | "sort" 8 | "time" 9 | ) 10 | 11 | var defaultContent = [5]byte{10, 20, 30, 40, 50} 12 | 13 | func RandomContent() []byte { 14 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 15 | 16 | length := rand.Intn(20) + 10 17 | content := make([]byte, length) 18 | 19 | _, err := r.Read(content) 20 | if err != nil { 21 | return defaultContent[:] 22 | } 23 | return content 24 | } 25 | 26 | // IsSorted checks if the given list of addresses are sorted in order of their 27 | // XOR didstance from our own address. 28 | func IsSorted(identity id.Signatory, addrs []id.Signatory) bool { 29 | sortedAddrs := make([]id.Signatory, len(addrs)) 30 | copy(sortedAddrs, addrs) 31 | SortAddrs(identity, sortedAddrs) 32 | return reflect.DeepEqual(addrs, sortedAddrs) 33 | } 34 | 35 | // SortAddrs in order of their XOR distance from our own address. 36 | func SortAddrs(identity id.Signatory, addrs []id.Signatory) { 37 | sort.Slice(addrs, func(i, j int) bool { 38 | fstSignatory := addrs[i] 39 | sndSignatory := addrs[j] 40 | return isCloser(identity, fstSignatory, sndSignatory) 41 | }) 42 | } 43 | 44 | // SortSignatories in order of their XOR distance from our own address. 45 | func SortSignatories(identity id.Signatory, signatories []id.Signatory) { 46 | sort.Slice(signatories, func(i, j int) bool { 47 | return isCloser(identity, signatories[i], signatories[j]) 48 | }) 49 | } 50 | 51 | func isCloser(identity, fst, snd id.Signatory) bool { 52 | for b := 0; b < 32; b++ { 53 | d1 := identity[b] ^ fst[b] 54 | d2 := identity[b] ^ snd[b] 55 | if d1 < d2 { 56 | return true 57 | } 58 | if d2 < d1 { 59 | return false 60 | } 61 | } 62 | return false 63 | } 64 | -------------------------------------------------------------------------------- /dht/resolver.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // The ContentResolver interface is used to insert and query content. 8 | type ContentResolver interface { 9 | // Insert content with a specific content ID. Usually, the content ID will 10 | // stores information about the type and the hash of the content. 11 | InsertContent(contentID, content []byte) 12 | 13 | // QueryContent returns the content associated with a content ID. If there 14 | // is no associated content, it returns false. Otherwise, it returns true. 15 | QueryContent(contentID []byte) (content []byte, contentOk bool) 16 | } 17 | 18 | var ( 19 | // DefaultDoubleCacheContentResolverCapacity defines the default in-memory 20 | // cache capacity (in bytes) for the double-cache content resolver. 21 | DefaultDoubleCacheContentResolverCapacity = 16 * 1024 * 1024 // 16 MB 22 | ) 23 | 24 | // DoubleCacheContentResolverOptions for parameterising the behaviour of the 25 | // DoubleCacheContentResolver. 26 | type DoubleCacheContentResolverOptions struct { 27 | Capacity int 28 | } 29 | 30 | // DefaultDoubleCacheContentResolverOptions returns the default 31 | // DoubleCacheContentResolverOptions. 32 | func DefaultDoubleCacheContentResolverOptions() DoubleCacheContentResolverOptions { 33 | return DoubleCacheContentResolverOptions{ 34 | Capacity: DefaultDoubleCacheContentResolverCapacity, 35 | } 36 | } 37 | 38 | // WithCapacity sets the maximum in-memory cache capacity (in bytes). This 39 | // capacity accounts for the fact that the double-cache content resolver has two 40 | // in-memory buffers. For example, if the capacity is set to 2 MB, then the 41 | // double-cache content resolver is guaranteeed to consume, at most, 2 MB of 42 | // memory, but will only be able to cache 1 MB of data. 43 | func (opts DoubleCacheContentResolverOptions) WithCapacity(capacity int) DoubleCacheContentResolverOptions { 44 | opts.Capacity = capacity / 2 45 | return opts 46 | } 47 | 48 | // The DoubleCacheContentResolver uses the double-cache technique to implement a 49 | // fast in-memory cache. The cache can optionally wrap around another 50 | // content-resolver (which can be responsible for more persistent content 51 | // resolution). 52 | type DoubleCacheContentResolver struct { 53 | opts DoubleCacheContentResolverOptions 54 | next ContentResolver 55 | 56 | cacheMu *sync.Mutex 57 | cacheFrontSize int 58 | cacheFront map[string][]byte // Front is used to add new content until the max capacity is reached. 59 | cacheBack map[string][]byte // Back is used to read old content that has been rotated from the front. 60 | } 61 | 62 | // NewDoubleCacheContentResolver returns a new double-cache content resolver 63 | // that is wrapped around another content-resolver. 64 | func NewDoubleCacheContentResolver(opts DoubleCacheContentResolverOptions, next ContentResolver) *DoubleCacheContentResolver { 65 | return &DoubleCacheContentResolver{ 66 | opts: opts, 67 | next: next, 68 | 69 | cacheMu: new(sync.Mutex), 70 | cacheFrontSize: 0, 71 | cacheFront: make(map[string][]byte, 0), 72 | cacheBack: make(map[string][]byte, 0), 73 | } 74 | } 75 | 76 | // InsertContent into the double-cache content resolver. If the front cache is 77 | // full, it will be rotated to the back, the current back cache will be dropped, 78 | // and a new front cache will be created. This method will also insert the 79 | // content to the next content resovler (if one exists). 80 | func (r *DoubleCacheContentResolver) InsertContent(id, content []byte) { 81 | r.cacheMu.Lock() 82 | defer r.cacheMu.Unlock() 83 | 84 | // We cannot cache something that is greater than the maximum capacity. 85 | if len(content) > r.opts.Capacity { 86 | if r.next != nil { 87 | r.next.InsertContent(id, content) 88 | } 89 | return 90 | } 91 | 92 | // If the capacity has been exceeded, move the front cache to the back and 93 | // reset the front cache. 94 | if r.cacheFrontSize+len(content) > r.opts.Capacity { 95 | r.cacheBack = r.cacheFront 96 | r.cacheFrontSize = 0 97 | r.cacheFront = make(map[string][]byte, 0) 98 | } 99 | 100 | // Insert the content to the front cache and the next resolver (if it 101 | // exists). 102 | r.cacheFrontSize += len(content) 103 | r.cacheFront[string(id)] = content 104 | 105 | if r.next != nil { 106 | r.next.InsertContent(id, content) 107 | } 108 | } 109 | 110 | // QueryContent returns the content associated with the given content ID. If the 111 | // content is not found in the double-cache content resolver, the next content 112 | // resolver will be checked (if one exists). 113 | func (r *DoubleCacheContentResolver) QueryContent(id []byte) ([]byte, bool) { 114 | r.cacheMu.Lock() 115 | defer r.cacheMu.Unlock() 116 | 117 | // Check both caches for the content. 118 | if content, ok := r.cacheFront[string(id)]; ok { 119 | return content, ok 120 | } 121 | if content, ok := r.cacheBack[string(id)]; ok { 122 | return content, ok 123 | } 124 | 125 | // If the content has not been found, check the next resolver. 126 | if r.next != nil { 127 | return r.next.QueryContent(id) 128 | } 129 | return nil, false 130 | } 131 | 132 | // CallbackContentResolver implements the ContentResolve interface by delegating 133 | // all logic to callback functions. This is useful when defining an 134 | // implementation inline. 135 | type CallbackContentResolver struct { 136 | InsertContentCallback func([]byte, []byte) 137 | QueryContentCallback func([]byte) ([]byte, bool) 138 | } 139 | 140 | // InsertContent will delegate the implementation to the InsertContentCallback. 141 | // If the callback is nil, then this method will do nothing. 142 | func (r CallbackContentResolver) InsertContent(id, content []byte) { 143 | if r.InsertContentCallback != nil { 144 | r.InsertContentCallback(id, content) 145 | } 146 | } 147 | 148 | // QueryContent will delegate the implementation to the QueryContentCallback. If 149 | // the callback is nil, then this method will return false. 150 | func (r CallbackContentResolver) QueryContent(id []byte) ([]byte, bool) { 151 | if r.QueryContentCallback != nil { 152 | return r.QueryContentCallback(id) 153 | } 154 | return nil, false 155 | } 156 | -------------------------------------------------------------------------------- /dht/resolver_test.go: -------------------------------------------------------------------------------- 1 | package dht_test 2 | 3 | import ( 4 | "crypto/sha256" 5 | "testing/quick" 6 | "time" 7 | 8 | "github.com/renproject/aw/dht" 9 | "github.com/renproject/aw/dht/dhtutil" 10 | "github.com/renproject/id" 11 | 12 | . "github.com/onsi/ginkgo" 13 | . "github.com/onsi/gomega" 14 | ) 15 | 16 | var _ = Describe("Double-cache Content Resolver", func() { 17 | Context("when inserting content", func() { 18 | It("should be able to query it", func() { 19 | resolver := dht.NewDoubleCacheContentResolver( 20 | dht.DefaultDoubleCacheContentResolverOptions(), 21 | nil, 22 | ) 23 | 24 | f := func(contentType uint8, content []byte) bool { 25 | hash := id.Hash(sha256.Sum256(content)) 26 | resolver.InsertContent(hash[:], content) 27 | 28 | newContent, ok := resolver.QueryContent(hash[:]) 29 | Expect(ok).To(BeTrue()) 30 | Expect(newContent).To(Equal(content)) 31 | return true 32 | } 33 | Expect(quick.Check(f, nil)).To(Succeed()) 34 | }) 35 | 36 | It("should ignore content that is too big", func() { 37 | capacity := 19 38 | resolver := dht.NewDoubleCacheContentResolver( 39 | dht.DefaultDoubleCacheContentResolverOptions(). 40 | WithCapacity(capacity), 41 | nil, 42 | ) 43 | 44 | // Fill cache with data that is too big. 45 | content := [10]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09} 46 | hash := id.NewHash(content[:]) 47 | resolver.InsertContent(hash[:], content[:]) 48 | 49 | _, ok := resolver.QueryContent(hash[:]) 50 | Expect(ok).To(BeFalse()) 51 | }) 52 | 53 | It("should drop old values", func() { 54 | capacity := 20 55 | resolver := dht.NewDoubleCacheContentResolver( 56 | dht.DefaultDoubleCacheContentResolverOptions(). 57 | WithCapacity(capacity), 58 | nil, 59 | ) 60 | 61 | // Fill cache with data. 62 | content := [10]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09} 63 | hash := id.NewHash(content[:]) 64 | resolver.InsertContent(hash[:], content[:]) 65 | 66 | // Add more data. 67 | newContent := [10]byte{0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19} 68 | newHash := id.NewHash(newContent[:]) 69 | resolver.InsertContent(newHash[:], newContent[:]) 70 | 71 | // Both chunks of data should be present. 72 | _, ok := resolver.QueryContent(hash[:]) 73 | Expect(ok).To(BeTrue()) 74 | _, ok = resolver.QueryContent(newHash[:]) 75 | Expect(ok).To(BeTrue()) 76 | 77 | // Add event more data. 78 | newerContent := [10]byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29} 79 | newerHash := id.NewHash(newerContent[:]) 80 | resolver.InsertContent(newerHash[:], newerContent[:]) 81 | 82 | // Verify the two latest chunks exist, and that the rest has been 83 | // rotated out. 84 | _, ok = resolver.QueryContent(hash[:]) 85 | Expect(ok).To(BeFalse()) 86 | _, ok = resolver.QueryContent(newHash[:]) 87 | Expect(ok).To(BeTrue()) 88 | _, ok = resolver.QueryContent(newerHash[:]) 89 | Expect(ok).To(BeTrue()) 90 | }) 91 | }) 92 | 93 | Context("when querying content that does not exist", func() { 94 | It("should return false", func() { 95 | resolver := dht.NewDoubleCacheContentResolver( 96 | dht.DefaultDoubleCacheContentResolverOptions(), 97 | nil, 98 | ) 99 | 100 | f := func(contentType uint8, content []byte) bool { 101 | hash := id.Hash(sha256.Sum256(content)) 102 | newContent, ok := resolver.QueryContent(hash[:]) 103 | Expect(ok).To(BeFalse()) 104 | Expect(len(newContent)).To(Equal(0)) 105 | return true 106 | } 107 | Expect(quick.Check(f, nil)).To(Succeed()) 108 | }) 109 | }) 110 | 111 | Context("when using an inner resolver", func() { 112 | It("should forward calls to it", func() { 113 | insertCh := make(chan []byte) 114 | queryCh := make(chan []byte) 115 | 116 | resolver := dht.NewDoubleCacheContentResolver( 117 | dht.DefaultDoubleCacheContentResolverOptions(), 118 | dht.CallbackContentResolver{ 119 | InsertContentCallback: func(id []byte, data []byte) { 120 | insertCh <- id 121 | }, 122 | QueryContentCallback: func(id []byte) ([]byte, bool) { 123 | queryCh <- id 124 | return []byte{}, true 125 | }, 126 | }, 127 | ) 128 | 129 | // Insert and wait on the channel to make sure the inner 130 | // resolver received the message. 131 | hash := id.Hash(sha256.Sum256(dhtutil.RandomContent())) 132 | go resolver.InsertContent(hash[:], nil) 133 | newHash := <-insertCh 134 | Expect(newHash).To(Equal(hash[:])) 135 | 136 | // Get and wait on the channel to make sure the inner resolver 137 | // received the message. 138 | hash = sha256.Sum256(dhtutil.RandomContent()) 139 | go resolver.QueryContent(hash[:]) 140 | 141 | newHash = <-queryCh 142 | Expect(newHash).To(Equal(hash[:])) 143 | 144 | // Ensure the channels receive no additional messages. 145 | select { 146 | case <-insertCh: 147 | Fail("unexpected insert message") 148 | case <-queryCh: 149 | Fail("unexpected content message") 150 | case <-time.After(time.Second): 151 | } 152 | }) 153 | }) 154 | }) 155 | 156 | var _ = Describe("Callback Content Resolver", func() { 157 | Context("when callbacks are not defined", func() { 158 | It("should not panic", func() { 159 | hash := id.Hash{} 160 | Expect(func() { dht.CallbackContentResolver{}.InsertContent(hash[:], []byte{}) }).ToNot(Panic()) 161 | Expect(func() { dht.CallbackContentResolver{}.QueryContent(hash[:]) }).ToNot(Panic()) 162 | }) 163 | }) 164 | 165 | Context("when callbacks are defined", func() { 166 | It("should delegate to the callback", func() { 167 | hash := id.Hash{} 168 | cond1 := false 169 | cond2 := false 170 | 171 | resolver := dht.CallbackContentResolver{ 172 | InsertContentCallback: func([]byte, []byte) { 173 | cond1 = true 174 | }, 175 | QueryContentCallback: func([]byte) ([]byte, bool) { 176 | cond2 = true 177 | return nil, false 178 | }, 179 | } 180 | resolver.InsertContent(hash[:], []byte{}) 181 | resolver.QueryContent(hash[:]) 182 | 183 | Expect(cond1).To(BeTrue()) 184 | Expect(cond2).To(BeTrue()) 185 | }) 186 | }) 187 | }) 188 | -------------------------------------------------------------------------------- /dht/table.go: -------------------------------------------------------------------------------- 1 | package dht 2 | 3 | import ( 4 | "math/rand" 5 | "sort" 6 | "sync" 7 | "time" 8 | 9 | "github.com/renproject/aw/wire" 10 | "github.com/renproject/id" 11 | ) 12 | 13 | // Force InMemTable to implement the Table interface. 14 | var _ Table = &InMemTable{} 15 | 16 | type Expiry struct { 17 | minimumExpiryAge time.Duration 18 | timestamp time.Time 19 | } 20 | 21 | // A Table is responsible for keeping tack of peers, their network addresses, 22 | // and the subnet to which they belong. 23 | type Table interface { 24 | // Self returns the local peer. It does not return the network address of 25 | // the local peer, because it can change frequently, and is not guaranteed 26 | // to exist. 27 | Self() id.Signatory 28 | 29 | // AddPeer to the table with an associate network address. 30 | AddPeer(id.Signatory, wire.Address) 31 | // DeletePeer from the table. 32 | DeletePeer(id.Signatory) 33 | // PeerAddress returns the network address associated with the given peer. 34 | PeerAddress(id.Signatory) (wire.Address, bool) 35 | 36 | // Peers returns the n closest peers to the local peer, using XORing as the 37 | // measure of distance between two peers. 38 | Peers(int) []id.Signatory 39 | // RandomPeers returns n random peer IDs, using either partial permutation 40 | // or Floyd's sampling algorithm. 41 | RandomPeers(int) []id.Signatory 42 | // NumPeers returns the total number of peers with associated network 43 | // addresses in the table. 44 | NumPeers() int 45 | 46 | // HandleExpired returns whether a signatory has expired. It checks whether 47 | // an Expiry exists for the signatory, and if it does, has it expired? 48 | // If found expired, it deletes the peer from the table 49 | HandleExpired(id.Signatory) bool 50 | // AddExpiry to the table with given duration if no existing expiry is found 51 | AddExpiry(id.Signatory, time.Duration) 52 | // DeleteExpiry from the table 53 | DeleteExpiry(id.Signatory) 54 | 55 | // AddSubnet to the table. This returns a subnet hash that can be used to 56 | // read/delete the subnet. It is the merkle root hash of the peers in the 57 | // subnet. 58 | AddSubnet([]id.Signatory) id.Hash 59 | // DeleteSubnet from the table. If the subnet was in the table, then the 60 | // peers are returned. 61 | DeleteSubnet(id.Hash) 62 | // Subnet returns the peers from the table. 63 | Subnet(id.Hash) []id.Signatory 64 | } 65 | 66 | // InMemTable implements the Table using in-memory storage. 67 | type InMemTable struct { 68 | self id.Signatory 69 | 70 | sortedMu *sync.RWMutex 71 | sorted []id.Signatory 72 | 73 | addrsBySignatoryMu *sync.Mutex 74 | addrsBySignatory map[id.Signatory]wire.Address 75 | 76 | expiryBySignatoryMu *sync.Mutex 77 | expiryBySignatory map[id.Signatory]Expiry 78 | 79 | subnetsByHashMu *sync.Mutex 80 | subnetsByHash map[id.Hash][]id.Signatory 81 | 82 | randObj *rand.Rand 83 | } 84 | 85 | func NewInMemTable(self id.Signatory) *InMemTable { 86 | return &InMemTable{ 87 | self: self, 88 | 89 | sortedMu: new(sync.RWMutex), 90 | sorted: []id.Signatory{}, 91 | 92 | addrsBySignatoryMu: new(sync.Mutex), 93 | addrsBySignatory: map[id.Signatory]wire.Address{}, 94 | 95 | expiryBySignatoryMu: new(sync.Mutex), 96 | expiryBySignatory: map[id.Signatory]Expiry{}, 97 | 98 | subnetsByHashMu: new(sync.Mutex), 99 | subnetsByHash: map[id.Hash][]id.Signatory{}, 100 | 101 | randObj: rand.New(rand.NewSource(time.Now().UnixNano())), 102 | } 103 | } 104 | 105 | func (table *InMemTable) Self() id.Signatory { 106 | return table.self 107 | } 108 | 109 | func (table *InMemTable) AddPeer(peerID id.Signatory, peerAddr wire.Address) { 110 | table.sortedMu.Lock() 111 | table.addrsBySignatoryMu.Lock() 112 | 113 | defer table.sortedMu.Unlock() 114 | defer table.addrsBySignatoryMu.Unlock() 115 | 116 | if table.self.Equal(&peerID) { 117 | return 118 | } 119 | 120 | _, ok := table.addrsBySignatory[peerID] 121 | 122 | // Insert into the map to allow for address lookup using the signatory. 123 | table.addrsBySignatory[peerID] = peerAddr 124 | 125 | // Insert into the sorted signatories list based on its XOR distance from our 126 | // own address. 127 | if !ok { 128 | i := sort.Search(len(table.sorted), func(i int) bool { 129 | return table.isCloser(peerID, table.sorted[i]) 130 | }) 131 | table.sorted = append(table.sorted, id.Signatory{}) 132 | copy(table.sorted[i+1:], table.sorted[i:]) 133 | table.sorted[i] = peerID 134 | } 135 | } 136 | 137 | func (table *InMemTable) DeletePeer(peerID id.Signatory) { 138 | table.sortedMu.Lock() 139 | table.addrsBySignatoryMu.Lock() 140 | 141 | defer table.sortedMu.Unlock() 142 | defer table.addrsBySignatoryMu.Unlock() 143 | 144 | // Delete from the map. 145 | delete(table.addrsBySignatory, peerID) 146 | 147 | // Delete from the sorted list. 148 | numAddrs := len(table.sorted) 149 | i := sort.Search(numAddrs, func(i int) bool { 150 | return table.isCloser(peerID, table.sorted[i]) 151 | }) 152 | 153 | removeIndex := i - 1 154 | if removeIndex >= 0 { 155 | table.sorted = append(table.sorted[:removeIndex], table.sorted[removeIndex+1:]...) 156 | } 157 | } 158 | 159 | func (table *InMemTable) PeerAddress(peerID id.Signatory) (wire.Address, bool) { 160 | table.addrsBySignatoryMu.Lock() 161 | defer table.addrsBySignatoryMu.Unlock() 162 | 163 | addr, ok := table.addrsBySignatory[peerID] 164 | return addr, ok 165 | } 166 | 167 | // Peers returns the n closest peer IDs. 168 | func (table *InMemTable) Peers(n int) []id.Signatory { 169 | table.sortedMu.RLock() 170 | defer table.sortedMu.RUnlock() 171 | 172 | if n <= 0 { 173 | // For values of n that are less than, or equal to, zero, return an 174 | // empty list. We could panic instead, but this is a reasonable and 175 | // unsurprising alternative. 176 | return []id.Signatory{} 177 | } 178 | 179 | sigs := make([]id.Signatory, min(n, len(table.sorted))) 180 | copy(sigs, table.sorted) 181 | return sigs 182 | } 183 | 184 | // RandomPeers returns n random peer IDs 185 | func (table *InMemTable) RandomPeers(n int) []id.Signatory { 186 | table.sortedMu.RLock() 187 | defer table.sortedMu.RUnlock() 188 | m := len(table.sorted) 189 | 190 | if n <= 0 { 191 | // For values of n that are less than, or equal to, zero, return an 192 | // empty list. We could panic instead, but this is a reasonable and 193 | // unsurprising alternative. 194 | return []id.Signatory{} 195 | } 196 | if n >= m { 197 | sigs := make([]id.Signatory, m) 198 | copy(sigs, table.sorted) 199 | return sigs 200 | } 201 | 202 | // Use the first n elements of a permutation of the entire list of peer IDs 203 | // This is used only if the sorted array (array of length m) is sufficiently 204 | // small or the number of random elements to be selected (n) i sufficiently 205 | // large in comparison to m 206 | if m <= 10000 || n >= m/50.0 { 207 | shuffled := make([]id.Signatory, n) 208 | indexPerm := rand.Perm(m) 209 | for i := 0; i < n; i++ { 210 | shuffled[i] = table.sorted[indexPerm[i]] 211 | } 212 | return shuffled 213 | } 214 | 215 | // Otherwise, use Floyd's sampling algorithm to select n random elements 216 | set := make(map[int]struct{}, n) 217 | randomSelection := make([]id.Signatory, 0, n) 218 | for i := m - n; i < m; i++ { 219 | index := table.randObj.Intn(i) 220 | if _, ok := set[index]; !ok { 221 | set[index] = struct{}{} 222 | randomSelection = append(randomSelection, table.sorted[index]) 223 | continue 224 | } 225 | set[i] = struct{}{} 226 | randomSelection = append(randomSelection, table.sorted[i]) 227 | } 228 | return randomSelection 229 | } 230 | 231 | func (table *InMemTable) NumPeers() int { 232 | table.addrsBySignatoryMu.Lock() 233 | defer table.addrsBySignatoryMu.Unlock() 234 | 235 | return len(table.addrsBySignatory) 236 | } 237 | 238 | func (table *InMemTable) HandleExpired(peerID id.Signatory) bool { 239 | table.expiryBySignatoryMu.Lock() 240 | defer table.expiryBySignatoryMu.Unlock() 241 | expiry, ok := table.expiryBySignatory[peerID] 242 | if !ok { 243 | return false 244 | } 245 | expired := (time.Now().Sub(expiry.timestamp)) > expiry.minimumExpiryAge 246 | if expired { 247 | table.DeletePeer(peerID) 248 | delete(table.expiryBySignatory, peerID) 249 | } 250 | return expired 251 | } 252 | 253 | func (table *InMemTable) AddExpiry(peerID id.Signatory, duration time.Duration) { 254 | table.expiryBySignatoryMu.Lock() 255 | defer table.expiryBySignatoryMu.Unlock() 256 | _, ok := table.PeerAddress(peerID) 257 | if !ok { 258 | return 259 | } 260 | _, ok = table.expiryBySignatory[peerID] 261 | if ok { 262 | return 263 | } 264 | table.expiryBySignatory[peerID] = Expiry{ 265 | minimumExpiryAge: duration, 266 | timestamp: time.Now(), 267 | } 268 | } 269 | 270 | func (table *InMemTable) DeleteExpiry(peerID id.Signatory) { 271 | table.expiryBySignatoryMu.Lock() 272 | defer table.expiryBySignatoryMu.Unlock() 273 | delete(table.expiryBySignatory, peerID) 274 | } 275 | 276 | func (table *InMemTable) AddSubnet(signatories []id.Signatory) id.Hash { 277 | copied := make([]id.Signatory, len(signatories)) 278 | copy(copied, signatories) 279 | 280 | // Sort signatories in order of their XOR distance from the local peer. This 281 | // allows different peers to easily iterate over the same subnet in a 282 | // different order. 283 | sort.Slice(copied, func(i, j int) bool { 284 | return table.isCloser(copied[i], copied[j]) 285 | }) 286 | 287 | // It it important to note that we are using the unsorted slice for 288 | // computing the merkle root hash. This is done so that everyone can have 289 | // the same subnet ID for a given slice of signatories. 290 | hash := id.NewMerkleHashFromSignatories(signatories) 291 | 292 | table.subnetsByHashMu.Lock() 293 | defer table.subnetsByHashMu.Unlock() 294 | 295 | table.subnetsByHash[hash] = copied 296 | return hash 297 | } 298 | 299 | func (table *InMemTable) DeleteSubnet(hash id.Hash) { 300 | table.subnetsByHashMu.Lock() 301 | defer table.subnetsByHashMu.Unlock() 302 | 303 | delete(table.subnetsByHash, hash) 304 | } 305 | 306 | func (table *InMemTable) Subnet(hash id.Hash) []id.Signatory { 307 | table.subnetsByHashMu.Lock() 308 | defer table.subnetsByHashMu.Unlock() 309 | 310 | subnet, ok := table.subnetsByHash[hash] 311 | if !ok { 312 | return []id.Signatory{} 313 | } 314 | copied := make([]id.Signatory, len(subnet)) 315 | copy(copied, subnet) 316 | return copied 317 | } 318 | 319 | func (table *InMemTable) isCloser(fst, snd id.Signatory) bool { 320 | for b := 0; b < 32; b++ { 321 | d1 := table.self[b] ^ fst[b] 322 | d2 := table.self[b] ^ snd[b] 323 | if d1 < d2 { 324 | return true 325 | } 326 | if d2 < d1 { 327 | return false 328 | } 329 | } 330 | return false 331 | } 332 | 333 | func min(a, b int) int { 334 | if a < b { 335 | return a 336 | } 337 | return b 338 | } 339 | -------------------------------------------------------------------------------- /dht/table_test.go: -------------------------------------------------------------------------------- 1 | package dht_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | "strconv" 8 | "testing/quick" 9 | "time" 10 | 11 | "github.com/renproject/aw/dht" 12 | "github.com/renproject/aw/dht/dhtutil" 13 | "github.com/renproject/aw/wire" 14 | "github.com/renproject/id" 15 | 16 | . "github.com/onsi/ginkgo" 17 | . "github.com/onsi/gomega" 18 | ) 19 | 20 | var _ = Describe("DHT", func() { 21 | Describe("Addresses", func() { 22 | Context("when inserting an address", func() { 23 | It("should be able to query it", func() { 24 | table, _ := initDHT() 25 | 26 | f := func(seed int64) bool { 27 | privKey := id.NewPrivKey() 28 | sig := privKey.Signatory() 29 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 30 | 31 | table.AddPeer(sig, addr) 32 | 33 | signatory := id.NewSignatory((*id.PubKey)(&privKey.PublicKey)) 34 | newAddr, ok := table.PeerAddress(signatory) 35 | Expect(ok).To(BeTrue()) 36 | Expect(newAddr).To(Equal(addr)) 37 | return true 38 | } 39 | Expect(quick.Check(f, nil)).To(Succeed()) 40 | }) 41 | }) 42 | 43 | Context("when deleting an address", func() { 44 | It("should not be able to query it", func() { 45 | table, _ := initDHT() 46 | 47 | f := func(seed int64) bool { 48 | privKey := id.NewPrivKey() 49 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 50 | 51 | // Try to delete the address prior to inserting to make sure 52 | // it does not panic. 53 | signatory := id.NewSignatory((*id.PubKey)(&privKey.PublicKey)) 54 | table.DeletePeer(signatory) 55 | 56 | // Insert the address. 57 | table.AddPeer(signatory, addr) 58 | 59 | // Delete the address and make sure it no longer exists when 60 | // querying the DHT. 61 | table.DeletePeer(signatory) 62 | 63 | _, ok := table.PeerAddress(signatory) 64 | return !ok 65 | } 66 | Expect(quick.Check(f, nil)).To(Succeed()) 67 | }) 68 | }) 69 | 70 | Context("when querying addresses", func() { 71 | It("should return them in order of their XOR distance", func() { 72 | table, identity := initDHT() 73 | numAddrs := rand.Intn(990) + 10 // [10, 1000) 74 | 75 | // Insert `numAddrs` random addresses into the store. 76 | signatories := make([]id.Signatory, numAddrs) 77 | for i := 0; i < numAddrs; i++ { 78 | privKey := id.NewPrivKey() 79 | sig := privKey.Signatory() 80 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 81 | 82 | table.AddPeer(sig, addr) 83 | 84 | signatories = append(signatories, sig) 85 | } 86 | 87 | // Check addresses are returned in order of their XOR distance 88 | // from our own address. 89 | numQueriedAddrs := rand.Intn(numAddrs) 90 | queriedAddrs := table.Peers(numQueriedAddrs) 91 | Expect(len(queriedAddrs)).To(Equal(numQueriedAddrs)) 92 | Expect(dhtutil.IsSorted(identity, queriedAddrs)).To(BeTrue()) 93 | 94 | // Delete some addresses and make sure the list is still sorted. 95 | numDeletedAddrs := rand.Intn(numAddrs) 96 | for i := 0; i < numDeletedAddrs; i++ { 97 | signatory := signatories[i] 98 | table.DeletePeer(signatory) 99 | } 100 | 101 | queriedAddrs = table.Peers(numAddrs - numDeletedAddrs) 102 | Expect(len(queriedAddrs)).To(Equal(numAddrs - numDeletedAddrs)) 103 | Expect(dhtutil.IsSorted(identity, queriedAddrs)).To(BeTrue()) 104 | }) 105 | 106 | Context("if there are less than n addresses in the store", func() { 107 | It("should return all the addresses", func() { 108 | table, _ := initDHT() 109 | numAddrs := rand.Intn(100) 110 | 111 | // Insert `numAddrs` random addresses into the store. 112 | for i := 0; i < numAddrs; i++ { 113 | privKey := id.NewPrivKey() 114 | sig := privKey.Signatory() 115 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 116 | 117 | table.AddPeer(sig, addr) 118 | } 119 | 120 | addrs := table.Peers(100) 121 | Expect(len(addrs)).To(Equal(numAddrs)) 122 | }) 123 | }) 124 | 125 | Context("if there are no addresses in the store", func() { 126 | It("should return no addresses", func() { 127 | table, _ := initDHT() 128 | 129 | addrs := table.Peers(100) 130 | Expect(len(addrs)).To(Equal(0)) 131 | 132 | addrs = table.Peers(0) 133 | Expect(len(addrs)).To(Equal(0)) 134 | }) 135 | }) 136 | }) 137 | 138 | Context("when querying random peers", func() { 139 | It("should return the correct amount", func() { 140 | table, _ := initDHT() 141 | 142 | f := func(seed int64) bool { 143 | numAddrs := rand.Intn(11000) 144 | 145 | // Insert `numAddrs` random addresses into the store. 146 | for i := 0; i < numAddrs; i++ { 147 | privKey := id.NewPrivKey() 148 | sig := privKey.Signatory() 149 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 150 | 151 | table.AddPeer(sig, addr) 152 | } 153 | 154 | numRandomAddrs := rand.Intn(numAddrs) 155 | randomAddr := table.RandomPeers(numRandomAddrs) 156 | Expect(len(randomAddr)).To(Equal(numRandomAddrs)) 157 | return true 158 | } 159 | 160 | Expect(quick.Check(f, &quick.Config{MaxCount: 10})).To(Succeed()) 161 | }) 162 | 163 | Context("where the requested number is larger than the number of peers present in table", func() { 164 | It("should return the number of peers in table", func() { 165 | table, _ := initDHT() 166 | numAddrs := rand.Intn(100) 167 | 168 | // Insert `numAddrs` random addresses into the store. 169 | for i := 0; i < numAddrs; i++ { 170 | privKey := id.NewPrivKey() 171 | sig := privKey.Signatory() 172 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 173 | 174 | table.AddPeer(sig, addr) 175 | } 176 | 177 | randomAddr := table.RandomPeers(numAddrs + rand.Intn(100)) 178 | Expect(len(randomAddr)).To(Equal(numAddrs)) 179 | 180 | }) 181 | }) 182 | 183 | It("should return a unique subset each time", func() { 184 | table, _ := initDHT() 185 | numAddrs := rand.Intn(100) 186 | numRandAddrs := rand.Intn(numAddrs) 187 | 188 | // Insert `numAddrs` random addresses into the store. 189 | for i := 0; i < numAddrs; i++ { 190 | privKey := id.NewPrivKey() 191 | sig := privKey.Signatory() 192 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 193 | 194 | table.AddPeer(sig, addr) 195 | } 196 | 197 | lists := make([][]id.Signatory, 10) 198 | for i := range lists { 199 | lists[i] = table.RandomPeers(numRandAddrs) 200 | } 201 | 202 | for i := 0; i < 10; i++ { 203 | for j := i + 1; j < 10; j++ { 204 | Expect(lists[i]).To(Not(Equal(lists[j]))) 205 | } 206 | } 207 | }) 208 | 209 | It("should work while deleting peers from the table", func() { 210 | table, _ := initDHT() 211 | numAddrs := rand.Intn(100) 212 | numRandAddrs := rand.Intn(numAddrs) 213 | 214 | // Insert `numAddrs` random addresses into the store. 215 | deletedPeers := make([]id.Signatory, 0, 50) 216 | for i := 0; i < numAddrs; i++ { 217 | privKey := id.NewPrivKey() 218 | sig := privKey.Signatory() 219 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 220 | table.AddPeer(sig, addr) 221 | if i < numAddrs/2 { 222 | deletedPeers = append(deletedPeers, sig) 223 | } 224 | } 225 | 226 | done := make(chan struct{}, 1) 227 | go func() { 228 | defer close(done) 229 | 230 | for i := range deletedPeers { 231 | table.DeletePeer(deletedPeers[i]) 232 | } 233 | }() 234 | 235 | total := time.Duration(0) 236 | for i := 0; i < 50; i++ { 237 | start := time.Now() 238 | table.RandomPeers(numRandAddrs) 239 | duration := time.Now().Sub(start) 240 | total += duration 241 | } 242 | log.Printf("RandomPeers takes %v on average", total/50) 243 | <-done 244 | }) 245 | }) 246 | 247 | Context("when querying the number of addresses", func() { 248 | It("should return the correct amount", func() { 249 | table, _ := initDHT() 250 | numAddrs := rand.Intn(100) 251 | 252 | // Insert `numAddrs` random addresses into the store. 253 | for i := 0; i < numAddrs; i++ { 254 | privKey := id.NewPrivKey() 255 | sig := privKey.Signatory() 256 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) 257 | 258 | table.AddPeer(sig, addr) 259 | } 260 | 261 | n := table.NumPeers() 262 | Expect(n).To(Equal(numAddrs)) 263 | }) 264 | }) 265 | 266 | Context("when re-inserting an address", func() { 267 | It("the sorted list of signatories should remain unchanged", func() { 268 | 269 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 270 | f := func(seed int64) bool { 271 | table, _ := initDHT() 272 | numPeers := r.Intn(100) + 1 273 | for i := 0; i < numPeers; i++ { 274 | privKey := id.NewPrivKey() 275 | sig := privKey.Signatory() 276 | ipAddr := fmt.Sprintf("%d.%d.%d.%d:%d", 277 | r.Intn(256), r.Intn(256), r.Intn(256), r.Intn(256), r.Intn(65536)) 278 | addr := wire.NewUnsignedAddress(wire.TCP, ipAddr, uint64(time.Now().UnixNano())) 279 | table.AddPeer(sig, addr) 280 | } 281 | 282 | peers := table.Peers(numPeers + 10) 283 | Expect(len(peers)).To(Equal(numPeers)) 284 | randomSig := peers[r.Intn(numPeers)] 285 | newIPAddr := fmt.Sprintf("%d.%d.%d.%d:%d", 286 | r.Intn(256), r.Intn(256), r.Intn(256), r.Intn(256), r.Intn(65536)) 287 | newAddr := wire.NewUnsignedAddress(wire.TCP, newIPAddr, uint64(time.Now().UnixNano())) 288 | table.AddPeer(randomSig, newAddr) 289 | 290 | newPeers := table.Peers(numPeers + 1) 291 | Expect(len(newPeers)).To(Equal(numPeers)) 292 | Expect(peers).To(Equal(newPeers)) 293 | return true 294 | } 295 | Expect(quick.Check(f, nil)).To(Succeed()) 296 | }) 297 | }) 298 | 299 | Measure("Adding 10000 addresses to distributed hash table", func(b Benchmarker) { 300 | table, _ := initDHT() 301 | signatories := make([]id.Signatory, 0) 302 | for i := 0; i < 10000; i++ { 303 | privKey := id.NewPrivKey() 304 | sig := privKey.Signatory() 305 | signatories = append(signatories, sig) 306 | } 307 | runtime := b.Time("runtime", func() { 308 | for i := 0; i < len(signatories); i++ { 309 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:"+strconv.Itoa(i), uint64(time.Now().UnixNano())) 310 | table.AddPeer(signatories[i], addr) 311 | } 312 | }) 313 | Ω(runtime.Seconds()) 314 | }, 10) 315 | 316 | Measure("Removing 10000 addresses from distributed hash table", func(b Benchmarker) { 317 | table, _ := initDHT() 318 | signatories := make([]id.Signatory, 0) 319 | for i := 0; i < 10000; i++ { 320 | privKey := id.NewPrivKey() 321 | sig := privKey.Signatory() 322 | addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:"+strconv.Itoa(i), uint64(time.Now().UnixNano())) 323 | table.AddPeer(sig, addr) 324 | signatories = append(signatories, sig) 325 | } 326 | runtime := b.Time("runtime", func() { 327 | for i := 0; i < len(signatories); i++ { 328 | table.DeletePeer(signatories[i]) 329 | } 330 | }) 331 | Ω(runtime.Seconds()) 332 | }, 10) 333 | }) 334 | 335 | Describe("Subnets", func() { 336 | Context("when adding a subnet", func() { 337 | It("should be able to query it", func() { 338 | table, identity := initDHT() 339 | 340 | // Generate a random number of signatories. 341 | numSignatories := rand.Intn(100) 342 | signatories := make([]id.Signatory, numSignatories) 343 | for i := 0; i < numSignatories; i++ { 344 | privKey := id.NewPrivKey() 345 | signatories[i] = id.NewSignatory((*id.PubKey)(&privKey.PublicKey)) 346 | } 347 | 348 | hash := table.AddSubnet(signatories) 349 | newSignatories := table.Subnet(hash) 350 | 351 | // Sort the original slice by XOR distance from our address and 352 | // verify it is equal to the result. 353 | dhtutil.SortSignatories(identity, signatories) 354 | Expect(newSignatories).To(Equal(signatories)) 355 | }) 356 | }) 357 | 358 | Context("when deleting a subnet", func() { 359 | It("should not be able to query it", func() { 360 | table, _ := initDHT() 361 | 362 | // Generate a random number of signatories. 363 | numSignatories := rand.Intn(100) 364 | signatories := make([]id.Signatory, numSignatories) 365 | for i := 0; i < numSignatories; i++ { 366 | privKey := id.NewPrivKey() 367 | signatories[i] = id.NewSignatory((*id.PubKey)(&privKey.PublicKey)) 368 | } 369 | 370 | hash := table.AddSubnet(signatories) 371 | table.DeleteSubnet(hash) 372 | 373 | newSignatories := table.Subnet(hash) 374 | Expect(len(newSignatories)).To(Equal(0)) 375 | }) 376 | }) 377 | 378 | Context("when querying a subnet that does not exist", func() { 379 | It("should return an empty list", func() { 380 | table, _ := initDHT() 381 | 382 | data := make([]byte, 32) 383 | _, err := rand.Read(data[:]) 384 | Expect(err).ToNot(HaveOccurred()) 385 | 386 | hash := id.NewHash(data) 387 | signatories := table.Subnet(hash) 388 | Expect(len(signatories)).To(Equal(0)) 389 | }) 390 | }) 391 | }) 392 | }) 393 | 394 | func initDHT() (dht.Table, id.Signatory) { 395 | privKey := id.NewPrivKey() 396 | identity := id.NewSignatory((*id.PubKey)(&privKey.PublicKey)) 397 | return dht.NewInMemTable(identity), identity 398 | } 399 | -------------------------------------------------------------------------------- /examples/chatroom/chatroom.go: -------------------------------------------------------------------------------- 1 | package chatroom 2 | 3 | func main() { 4 | } 5 | -------------------------------------------------------------------------------- /examples/dat/dat.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | } 5 | -------------------------------------------------------------------------------- /examples/fuzz/fuzz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/renproject/aw/3828e9da5963d4b9fea66c450a88ae1bfe9631a2/examples/fuzz/fuzz -------------------------------------------------------------------------------- /examples/fuzz/fuzz.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "math/rand" 8 | "time" 9 | 10 | "github.com/renproject/aw/dht" 11 | 12 | "github.com/renproject/aw/channel" 13 | "github.com/renproject/aw/handshake" 14 | "github.com/renproject/aw/peer" 15 | "github.com/renproject/aw/transport" 16 | "github.com/renproject/aw/wire" 17 | "github.com/renproject/id" 18 | "go.uber.org/zap" 19 | ) 20 | 21 | func main() { 22 | loggerConfig := zap.NewProductionConfig() 23 | loggerConfig.Level.SetLevel(zap.PanicLevel) 24 | logger, err := loggerConfig.Build() 25 | if err != nil { 26 | panic(err) 27 | } 28 | 29 | // Number of peers. 30 | n := 200 31 | 32 | // Init options for all peers. 33 | opts := make([]peer.Options, n) 34 | for i := range opts { 35 | i := i 36 | opts[i] = peer.DefaultOptions().WithLogger(logger) 37 | } 38 | 39 | // Init and run peers. 40 | peers := make([]*peer.Peer, n) 41 | tables := make([]dht.Table, n) 42 | clients := make([]*channel.Client, n) 43 | transports := make([]*transport.Transport, n) 44 | for i := range peers { 45 | self := opts[i].PrivKey.Signatory() 46 | h := handshake.Filter(func(id.Signatory) error { return nil }, handshake.ECIES(opts[i].PrivKey)) 47 | contentResolver := dht.NewDoubleCacheContentResolver(dht.DefaultDoubleCacheContentResolverOptions(), nil) 48 | clients[i] = channel.NewClient( 49 | channel.DefaultOptions(). 50 | WithLogger(logger), 51 | self) 52 | tables[i] = dht.NewInMemTable(self) 53 | transports[i] = transport.New( 54 | transport.DefaultOptions(). 55 | WithLogger(logger). 56 | WithClientTimeout(5*time.Second). 57 | WithOncePoolOptions(handshake.DefaultOncePoolOptions().WithMinimumExpiryAge(10*time.Second)). 58 | WithPort(uint16(3333+i)), 59 | self, 60 | clients[i], 61 | h, 62 | tables[i]) 63 | peers[i] = peer.New( 64 | opts[i], 65 | transports[i]) 66 | peers[i].Receive(context.Background(), func(from id.Signatory, packet wire.Packet) error { 67 | fmt.Printf("%4v: received \"%v\" from %4v\n", opts[i].PrivKey.Signatory(), string(packet.Msg.Data), from) 68 | return nil 69 | }) 70 | peers[i].Resolve(context.Background(), contentResolver) 71 | go func(i int) { 72 | for { 73 | // Randomly crash peers. 74 | func() { 75 | r := rand.New(rand.NewSource(time.Now().UnixNano() + int64(i))) 76 | d := time.Minute * time.Duration(1000+r.Int()%9000) 77 | ctx, cancel := context.WithTimeout(context.Background(), d) 78 | defer cancel() 79 | peers[i].Run(ctx) 80 | }() 81 | } 82 | }(i) 83 | } 84 | 85 | for { 86 | time.Sleep(time.Millisecond * time.Duration(rand.Int()%1000)) 87 | for i := range peers { 88 | j := (i + 1) % len(peers) 89 | fmt.Printf("peer[%v] sending to peer[%v]\n", i, j) 90 | tables[i].AddPeer(peers[j].ID(), wire.NewUnsignedAddress(wire.TCP, fmt.Sprintf("localhost:%v", 3333+int64(j)), uint64(time.Now().UnixNano()))) 91 | func() { 92 | ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) 93 | defer cancel() 94 | if err := peers[i].Send(ctx, peers[j].ID(), wire.Msg{Data: []byte(fmt.Sprintf("hello from %v!", i))}); err != nil { 95 | log.Printf("send: %v", err) 96 | } 97 | }() 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/renproject/aw 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/ethereum/go-ethereum v1.9.10 7 | github.com/onsi/ginkgo v1.16.4 8 | github.com/onsi/gomega v1.10.1 9 | github.com/renproject/id v0.4.2 10 | github.com/renproject/surge v1.2.5 11 | go.uber.org/multierr v1.6.0 // indirect 12 | go.uber.org/zap v1.16.0 13 | golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 14 | ) 15 | -------------------------------------------------------------------------------- /handshake/ecies.go: -------------------------------------------------------------------------------- 1 | package handshake 2 | 3 | import ( 4 | "bytes" 5 | "crypto/ecdsa" 6 | "crypto/rand" 7 | "fmt" 8 | "io" 9 | "math/big" 10 | "net" 11 | 12 | "github.com/ethereum/go-ethereum/crypto" 13 | "github.com/ethereum/go-ethereum/crypto/ecies" 14 | "github.com/renproject/aw/codec" 15 | "github.com/renproject/id" 16 | ) 17 | 18 | const sizeOfSecretKey = 32 19 | const sizeOfEncryptedSecretKey = 145 // 113-byte encryption header + 32-byte secret key 20 | 21 | func ECIES(privKey *id.PrivKey) Handshake { 22 | return func(conn net.Conn, enc codec.Encoder, dec codec.Decoder) (codec.Encoder, codec.Decoder, id.Signatory, error) { 23 | // Channel for passing errors from the writing goroutine to the reading 24 | // goroutine (which has the ability to return the error). 25 | errCh := make(chan error, 1) 26 | 27 | // Channel for passing the remote pubkey to the writing goroutine (which 28 | // operates in parallel to the reading goroutine). 29 | remotePubKeyCh := make(chan id.PubKey, 1) 30 | defer close(remotePubKeyCh) 31 | 32 | // Channel for passing the session key to the writing goroutine (which 33 | // operates in parallel to the reading goroutine). 34 | remoteSecretKeyCh := make(chan []byte, 1) 35 | defer close(remoteSecretKeyCh) 36 | 37 | // A pointer to the pubKey contained in the privKey struct 38 | localPubKey := privKey.PubKey() 39 | 40 | // Generate a local secret key. We do it here, because it is needed by 41 | // the writing and reading goroutine. 42 | localSecretKey := [sizeOfSecretKey]byte{} 43 | if _, err := rand.Read(localSecretKey[:]); err != nil { 44 | return nil, nil, id.Signatory{}, fmt.Errorf("generate local secret key: %v", err) 45 | } 46 | 47 | // Begin background goroutine for writing information to the network 48 | // connection. 49 | go func() { 50 | defer close(errCh) 51 | 52 | // Write local pubkey so that the remote peer knows how to encrypt 53 | // its secret key and send it back to the local peer. 54 | xBuf := paddedTo32(localPubKey.X) 55 | yBuf := paddedTo32(localPubKey.Y) 56 | if _, err := conn.Write(xBuf[:]); err != nil { 57 | errCh <- fmt.Errorf("write local pubkey x: %v", err) 58 | return 59 | } 60 | if _, err := conn.Write(yBuf[:]); err != nil { 61 | errCh <- fmt.Errorf("write local pubkey y: %v", err) 62 | return 63 | } 64 | 65 | // Encrypt the local secret key using the remote pubkey and write 66 | // it to the remote peer. 67 | remotePubKey, ok := <-remotePubKeyCh 68 | if !ok { 69 | return 70 | } 71 | importedRemotePubKey := ecies.ImportECDSAPublic((*ecdsa.PublicKey)(&remotePubKey)) 72 | encryptedLocalSecretKey, err := ecies.Encrypt(rand.Reader, importedRemotePubKey, localSecretKey[:], nil, nil) 73 | if err != nil { 74 | errCh <- fmt.Errorf("encrypt local secret key: %v", err) 75 | return 76 | } 77 | if _, err := conn.Write(encryptedLocalSecretKey); err != nil { 78 | errCh <- fmt.Errorf("write local secret key: %v", err) 79 | return 80 | } 81 | 82 | // Encrypted the remote secret key using the remote pubkey and write 83 | // it to the remote peer. This allows the remote peer to verify that 84 | // the local peer does have access to the previous asserted local 85 | // pubkey. 86 | remoteSecretKey, ok := <-remoteSecretKeyCh 87 | if !ok { 88 | return 89 | } 90 | encryptedRemoteSecretKey, err := ecies.Encrypt(rand.Reader, importedRemotePubKey, remoteSecretKey, nil, nil) 91 | if err != nil { 92 | errCh <- fmt.Errorf("encrypt remote secret key: %v", err) 93 | return 94 | } 95 | if _, err := conn.Write(encryptedRemoteSecretKey); err != nil { 96 | errCh <- fmt.Errorf("write remote secret key: %v", err) 97 | return 98 | } 99 | }() 100 | 101 | // Read the remote pubkey. 102 | remotePubKeyBuf := [64]byte{} 103 | if _, err := io.ReadFull(conn, remotePubKeyBuf[:]); err != nil { 104 | return nil, nil, id.Signatory{}, fmt.Errorf("read remote pubkey: %v", err) 105 | } 106 | remotePubKey := id.PubKey{ 107 | Curve: crypto.S256(), 108 | X: new(big.Int).SetBytes(remotePubKeyBuf[:32]), 109 | Y: new(big.Int).SetBytes(remotePubKeyBuf[32:]), 110 | } 111 | remotePubKeyCh <- remotePubKey 112 | 113 | // Read the encrypted remote secret key, and then decrypt it. 114 | encryptedRemoteSecretKey := [sizeOfEncryptedSecretKey]byte{} 115 | if _, err := io.ReadFull(conn, encryptedRemoteSecretKey[:]); err != nil { 116 | return nil, nil, id.Signatory{}, fmt.Errorf("read remote secret key: %v", err) 117 | } 118 | remoteSecretKey, err := ecies.ImportECDSA((*ecdsa.PrivateKey)(privKey)).Decrypt(encryptedRemoteSecretKey[:], nil, nil) 119 | if err != nil { 120 | return nil, nil, id.Signatory{}, fmt.Errorf("decrypt remote secret key: %v", err) 121 | } 122 | remoteSecretKeyCh <- remoteSecretKey 123 | 124 | // Read the encrypted local secret back from the remote peer. This 125 | // proves to the local peer that the remote peer has access to its 126 | // previously asserted pubkey. 127 | encryptedLocalSecretKeyCheck := [sizeOfEncryptedSecretKey]byte{} 128 | if _, err := io.ReadFull(conn, encryptedLocalSecretKeyCheck[:]); err != nil { 129 | return nil, nil, id.Signatory{}, fmt.Errorf("read local secret key: %v", err) 130 | } 131 | localSecretKeyCheck, err := ecies.ImportECDSA((*ecdsa.PrivateKey)(privKey)).Decrypt(encryptedLocalSecretKeyCheck[:], nil, nil) 132 | if err != nil { 133 | return nil, nil, id.Signatory{}, fmt.Errorf("decrypt local secret key: %v", err) 134 | } 135 | if !bytes.Equal(localSecretKey[:], localSecretKeyCheck[:]) { 136 | return nil, nil, id.Signatory{}, fmt.Errorf("check local secret key") 137 | } 138 | 139 | // Check whether or not that an error happened in the writing goroutine 140 | // (and wait for the writing goroutine to end). 141 | err, ok := <-errCh 142 | if ok { 143 | return nil, nil, id.Signatory{}, err 144 | } 145 | 146 | // Build the session key, and use this to build GCM encoders/decoders. 147 | sessionKey := [sizeOfSecretKey]byte{} 148 | for i := 0; i < sizeOfSecretKey; i++ { 149 | sessionKey[i] = localSecretKey[i] ^ remoteSecretKey[i] 150 | } 151 | 152 | self := id.NewSignatory(localPubKey) 153 | remote := id.NewSignatory(&remotePubKey) 154 | gcmSession, err := codec.NewGCMSession(sessionKey, self, remote) 155 | if err != nil { 156 | return nil, nil, id.Signatory{}, fmt.Errorf("establish gcm session: %v", err) 157 | } 158 | return codec.GCMEncoder(gcmSession, enc), codec.GCMDecoder(gcmSession, dec), remote, nil 159 | } 160 | } 161 | 162 | // paddedTo32 encodes a big integer as a big-endian into a 32-byte array. It 163 | // will panic if the big integer is more than 32 bytes. 164 | // Modified from: 165 | // https://github.com/ethereum/go-ethereum/blob/master/common/math/big.go 166 | // 17381ecc6695ea9c2d8e5ee0aee5cf70d59a301a 167 | func paddedTo32(bigint *big.Int) [32]byte { 168 | if bigint.BitLen()/8 > 32 { 169 | panic(fmt.Sprintf("too big: expected n<32, got n=%v", bigint.BitLen()/8)) 170 | } 171 | ret := [32]byte{} 172 | readBits(bigint, ret[:]) 173 | return ret 174 | } 175 | 176 | // readBits encodes the absolute value of bigint as big-endian bytes. Callers 177 | // must ensure that buf has enough space. If buf is too short the result will be 178 | // incomplete. 179 | // Modified from: 180 | // https://github.com/ethereum/go-ethereum/blob/master/common/math/big.go 181 | // 17381ecc6695ea9c2d8e5ee0aee5cf70d59a301a 182 | func readBits(bigint *big.Int, buf []byte) { 183 | i := len(buf) 184 | for _, d := range bigint.Bits() { 185 | for j := 0; j < wordBytes && i > 0; j++ { 186 | i-- 187 | buf[i] = byte(d) 188 | d >>= 8 189 | } 190 | } 191 | } 192 | 193 | const ( 194 | // wordBits is the number of bits in a big word. 195 | wordBits = 32 << (uint64(^big.Word(0)) >> 63) 196 | // wordBytes is the number of bytes in a big word. 197 | wordBytes = wordBits / 8 198 | ) 199 | -------------------------------------------------------------------------------- /handshake/ecies_test.go: -------------------------------------------------------------------------------- 1 | package handshake_test 2 | -------------------------------------------------------------------------------- /handshake/filter.go: -------------------------------------------------------------------------------- 1 | package handshake 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/renproject/aw/codec" 8 | "github.com/renproject/id" 9 | ) 10 | 11 | // Filter accepts a filtering function and a Handshake function, and returns 12 | // wrapping Handshake function that runs the wrapped Handshake before applying 13 | // the filtering function to the remote peer ID. If the wrapped Handshake 14 | // returns an error, the filtering function will be skipped, and the error will 15 | // be returned. Otherwise, the filtering function will be called and its error 16 | // returned. 17 | func Filter(f func(id.Signatory) error, h Handshake) Handshake { 18 | return func(conn net.Conn, enc codec.Encoder, dec codec.Decoder) (codec.Encoder, codec.Decoder, id.Signatory, error) { 19 | enc, dec, remote, err := h(conn, enc, dec) 20 | if err != nil { 21 | return enc, dec, remote, err 22 | } 23 | if err := f(remote); err != nil { 24 | return enc, dec, remote, fmt.Errorf("filter %v: %v", remote, err) 25 | } 26 | return enc, dec, remote, nil 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /handshake/filter_test.go: -------------------------------------------------------------------------------- 1 | package handshake_test 2 | -------------------------------------------------------------------------------- /handshake/handshake.go: -------------------------------------------------------------------------------- 1 | package handshake 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/renproject/aw/codec" 8 | "github.com/renproject/id" 9 | ) 10 | 11 | const keySize = 32 12 | 13 | const encryptionHeaderSize = 113 14 | 15 | // encryptedKeySize specifiec the size in bytes of a single key encrypted using ECIES 16 | const encryptedKeySize = encryptionHeaderSize + keySize 17 | 18 | // Handshake functions accept a connection, an encoder, and decoder. The encoder 19 | // and decoder are used to establish an authenticated and encrypted connection. 20 | // A new encoder and decoder are returned, which wrap the accpted encoder and 21 | // decoder with any additional functionality required to perform authentication 22 | // and encryption. The identity of the remote peer is also returned. 23 | type Handshake func(net.Conn, codec.Encoder, codec.Decoder) (codec.Encoder, codec.Decoder, id.Signatory, error) 24 | 25 | // Insecure returns a Handshake that does no authentication or encryption. 26 | // During the handshake, the local peer writes its own identity to the 27 | // connection, and then reads the identity of the remote peer. No verification 28 | // of identities is done. Insecure should only be used in private networks. 29 | func Insecure(self id.Signatory) Handshake { 30 | return func(conn net.Conn, enc codec.Encoder, dec codec.Decoder) (codec.Encoder, codec.Decoder, id.Signatory, error) { 31 | if _, err := enc(conn, self[:]); err != nil { 32 | return nil, nil, id.Signatory{}, fmt.Errorf("encoding local id: %v", err) 33 | } 34 | remote := id.Signatory{} 35 | if _, err := dec(conn, remote[:]); err != nil { 36 | return nil, nil, id.Signatory{}, fmt.Errorf("decoding remote id: %v", err) 37 | } 38 | return enc, dec, remote, nil 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /handshake/handshake_suite_test.go: -------------------------------------------------------------------------------- 1 | package handshake_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestHandshake(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "Handshake suite") 13 | } 14 | -------------------------------------------------------------------------------- /handshake/handshake_test.go: -------------------------------------------------------------------------------- 1 | package handshake_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | "github.com/renproject/aw/codec" 11 | "github.com/renproject/aw/handshake" 12 | "github.com/renproject/aw/policy" 13 | "github.com/renproject/aw/tcp" 14 | "github.com/renproject/id" 15 | 16 | . "github.com/onsi/ginkgo" 17 | . "github.com/onsi/gomega" 18 | ) 19 | 20 | var _ = Describe("Handshake", func() { 21 | Context("connecting a client to a server", func() { 22 | When("the server is online", func() { 23 | It("should connect successfully", func() { 24 | ctx, cancel := context.WithCancel(context.Background()) 25 | defer cancel() 26 | var dialRetry, dialSuccess chan bool = nil, make(chan bool) 27 | portCh := listenOnAssignedPort(ctx) 28 | dial(ctx, <-portCh, dialRetry, dialSuccess) 29 | Expect(<-dialSuccess).Should(BeTrue()) 30 | }) 31 | }) 32 | 33 | When("the server is offline", func() { 34 | It("should retry", func() { 35 | ctx, cancel := context.WithCancel(context.Background()) 36 | defer cancel() 37 | var dialRetry, dialSuccess chan bool = make(chan bool), make(chan bool) 38 | dial(ctx, 3333, dialRetry, dialSuccess) 39 | Expect(<-dialRetry).Should(BeTrue()) 40 | }) 41 | 42 | It("if the server comes online should eventually connect successfully", func() { 43 | ctx, cancel := context.WithCancel(context.Background()) 44 | defer cancel() 45 | var dialRetry, dialSuccess chan bool = make(chan bool), make(chan bool) 46 | port := 3334 47 | dial(ctx, port, dialRetry, dialSuccess) 48 | time.Sleep(500 * time.Millisecond) 49 | listen(ctx, port) 50 | Expect(<-dialRetry).Should(BeTrue()) 51 | Expect(<-dialSuccess).Should(BeTrue()) 52 | }) 53 | }) 54 | }) 55 | }) 56 | 57 | func listen(ctx context.Context, port int) { 58 | go func() { 59 | privKey := id.NewPrivKey() 60 | h := handshake.ECIES(privKey) 61 | 62 | tcp.Listen(ctx, 63 | fmt.Sprintf("127.0.0.1:%v", port), 64 | func(conn net.Conn) { 65 | h(conn, 66 | codec.PlainEncoder, 67 | codec.PlainDecoder, 68 | ) 69 | }, 70 | nil, 71 | nil, 72 | ) 73 | }() 74 | } 75 | 76 | func listenOnAssignedPort(ctx context.Context) <-chan int { 77 | portCh := make(chan int, 1) 78 | go func() { 79 | privKey := id.NewPrivKey() 80 | h := handshake.ECIES(privKey) 81 | 82 | ip := "127.0.0.1" 83 | listener, port, err := tcp.ListenerWithAssignedPort(ctx, ip) 84 | Expect(err).ToNot(HaveOccurred()) 85 | portCh <- port 86 | 87 | tcp.ListenWithListener(ctx, 88 | listener, 89 | func(conn net.Conn) { 90 | h(conn, 91 | codec.PlainEncoder, 92 | codec.PlainDecoder, 93 | ) 94 | }, 95 | nil, 96 | nil, 97 | ) 98 | }() 99 | 100 | return portCh 101 | } 102 | 103 | func dial(ctx context.Context, port int, dialRetry, dialSuccess chan bool) { 104 | go func() { 105 | retrySignalOnce := sync.Once{} 106 | privKey := id.NewPrivKey() 107 | h := handshake.ECIES(privKey) 108 | 109 | tcp.Dial(ctx, 110 | fmt.Sprintf("127.0.0.1:%v", port), 111 | func(conn net.Conn) { 112 | _, _, _, err := h(conn, 113 | codec.PlainEncoder, 114 | codec.PlainDecoder) 115 | if err == nil { 116 | dialSuccess <- true 117 | } 118 | }, 119 | func(error) { 120 | retrySignalOnce.Do(func() { 121 | dialRetry <- true 122 | }) 123 | }, 124 | policy.ConstantTimeout(50*time.Millisecond), 125 | ) 126 | }() 127 | } 128 | -------------------------------------------------------------------------------- /handshake/once.go: -------------------------------------------------------------------------------- 1 | package handshake 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | "github.com/renproject/aw/codec" 11 | "github.com/renproject/aw/wire" 12 | "github.com/renproject/id" 13 | ) 14 | 15 | var DefaultMinimumExpiryAge = time.Minute 16 | 17 | type OncePoolOptions struct { 18 | MinimumExpiryAge time.Duration 19 | } 20 | 21 | func DefaultOncePoolOptions() OncePoolOptions { 22 | return OncePoolOptions{ 23 | MinimumExpiryAge: DefaultMinimumExpiryAge, 24 | } 25 | } 26 | 27 | func (opts OncePoolOptions) WithMinimumExpiryAge(minExpiryAge time.Duration) OncePoolOptions { 28 | opts.MinimumExpiryAge = minExpiryAge 29 | return opts 30 | } 31 | 32 | type onceConn struct { 33 | timestamp time.Time 34 | conn net.Conn 35 | } 36 | 37 | type OncePool struct { 38 | opts OncePoolOptions 39 | 40 | connsMu *sync.Mutex 41 | conns map[id.Signatory]onceConn 42 | } 43 | 44 | func NewOncePool(opts OncePoolOptions) OncePool { 45 | return OncePool{ 46 | opts: opts, 47 | 48 | connsMu: new(sync.Mutex), 49 | conns: map[id.Signatory]onceConn{}, 50 | } 51 | } 52 | 53 | func Once(self id.Signatory, pool *OncePool, h Handshake) Handshake { 54 | return func(conn net.Conn, enc codec.Encoder, dec codec.Decoder) (codec.Encoder, codec.Decoder, id.Signatory, error) { 55 | enc, dec, remote, err := h(conn, enc, dec) 56 | if err != nil { 57 | return enc, dec, remote, fmt.Errorf("handshake error = %v", err) 58 | } 59 | 60 | cmp := bytes.Compare(self[:], remote[:]) 61 | if cmp == 0 { 62 | return enc, dec, remote, nil 63 | } 64 | if cmp < 0 { 65 | keepAlive := [128]byte{} 66 | if _, err := dec(conn, keepAlive[:1]); err != nil { 67 | return enc, dec, remote, fmt.Errorf("decoding keep-alive message: %v", err) 68 | } 69 | if keepAlive[0] == 0x00 { 70 | return nil, nil, remote, wire.NewNegligibleError(fmt.Errorf("kill connection from %v", remote)) 71 | } 72 | 73 | pool.connsMu.Lock() 74 | defer pool.connsMu.Unlock() 75 | 76 | if existingConn, ok := pool.conns[remote]; ok { 77 | // Ignore the error, because we no longer need this connection. 78 | _ = existingConn.conn.Close() 79 | } 80 | pool.conns[remote] = onceConn{timestamp: time.Now(), conn: conn} 81 | return enc, dec, remote, nil 82 | } 83 | 84 | // Lock, perform non-blocking operations, and then unblock. This allows 85 | // us to avoid doing blocking operations (such as encoding/decoding 86 | // keep-alive messages) while holding the mutex lock. 87 | pool.connsMu.Lock() 88 | existingConn, existingConnIsOk := pool.conns[remote] 89 | existingConnNeedsReplacement := !existingConnIsOk || time.Now().Sub(existingConn.timestamp) > pool.opts.MinimumExpiryAge 90 | if existingConnNeedsReplacement { 91 | pool.conns[remote] = onceConn{timestamp: time.Now(), conn: conn} 92 | } 93 | pool.connsMu.Unlock() 94 | 95 | if !existingConnNeedsReplacement { 96 | _, err := enc(conn, msgKeepAliveFalse) 97 | // Ignore the error, because we no longer need this connection. 98 | _ = conn.Close() 99 | if err != nil { 100 | return enc, dec, remote, fmt.Errorf("encoding keep-alive message 0x00 to %v: %v", remote, err) 101 | } 102 | return enc, dec, remote, wire.NewNegligibleError(fmt.Errorf("kill connection to %v", remote)) 103 | } 104 | 105 | if existingConnIsOk { 106 | // Ignore the error, because we no longer need this connection. 107 | _ = existingConn.conn.Close() 108 | } 109 | if _, err := enc(conn, msgKeepAliveTrue); err != nil { 110 | // An error occurred while writing the "keep alive" message to 111 | // the remote peer. This results in an inconsistent state, so we 112 | // should recover the state by deleting the recently inserted 113 | // channel. This will cause the remote peer to eventually error 114 | // on their connection, and will cause our local peer to 115 | // eventually re-attempt channel creation. 116 | pool.connsMu.Lock() 117 | // Ignore the error, because we no longer need this connection. 118 | _ = pool.conns[remote].conn.Close() 119 | delete(pool.conns, remote) 120 | pool.connsMu.Unlock() 121 | return enc, dec, remote, fmt.Errorf("encoding keep-alive message 0x01 to %v: %v", remote, err) 122 | } 123 | 124 | return enc, dec, remote, nil 125 | } 126 | } 127 | 128 | var ( 129 | msgKeepAliveFalse = []byte{0x00} 130 | msgKeepAliveTrue = []byte{0x01} 131 | ) 132 | -------------------------------------------------------------------------------- /handshake/once_test.go: -------------------------------------------------------------------------------- 1 | package handshake_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "sync/atomic" 8 | "time" 9 | 10 | "github.com/renproject/aw/codec" 11 | "github.com/renproject/aw/handshake" 12 | "github.com/renproject/aw/policy" 13 | "github.com/renproject/aw/tcp" 14 | "github.com/renproject/id" 15 | 16 | . "github.com/onsi/ginkgo" 17 | . "github.com/onsi/gomega" 18 | ) 19 | 20 | var _ = Describe("Handshake", func() { 21 | Describe("Once", func() { 22 | Context("when a pair of nodes are trying to establish connections both ways", func() { 23 | It("should only maintain one connection", func() { 24 | pool1 := handshake.NewOncePool(handshake.DefaultOncePoolOptions()) 25 | pool2 := handshake.NewOncePool(handshake.DefaultOncePoolOptions()) 26 | 27 | privKey1 := id.NewPrivKey() 28 | privKey2 := id.NewPrivKey() 29 | ctx, cancel := context.WithCancel(context.Background()) 30 | defer cancel() 31 | 32 | handshakeDone1 := make(chan struct{}, 1) 33 | handshakeDone2 := make(chan struct{}, 1) 34 | serverHandshakeDone := make(chan struct{}, 2) 35 | 36 | ip := "127.0.0.1" 37 | portCh1 := make(chan int, 1) 38 | portCh2 := make(chan int, 1) 39 | 40 | var connectionKillCount int64 = 0 41 | go func() { 42 | listener, port, err := tcp.ListenerWithAssignedPort(ctx, ip) 43 | Expect(err).ToNot(HaveOccurred()) 44 | portCh1 <- port 45 | tcp.ListenWithListener(ctx, 46 | listener, 47 | func(conn net.Conn) { 48 | h := handshake.ECIES(privKey1) 49 | h = handshake.Once(privKey1.Signatory(), &pool1, h) 50 | _, _, _, err := h(conn, codec.PlainEncoder, codec.PlainDecoder) 51 | if err != nil { 52 | fmt.Printf("%v - server side \n", err) 53 | atomic.AddInt64(&connectionKillCount, 1) 54 | } 55 | serverHandshakeDone <- struct{}{} 56 | }, 57 | nil, 58 | policy.Max(2), 59 | ) 60 | }() 61 | 62 | go func() { 63 | listener, port, err := tcp.ListenerWithAssignedPort(ctx, ip) 64 | Expect(err).ToNot(HaveOccurred()) 65 | portCh2 <- port 66 | tcp.ListenWithListener(ctx, 67 | listener, 68 | func(conn net.Conn) { 69 | h := handshake.ECIES(privKey2) 70 | h = handshake.Once(privKey2.Signatory(), &pool2, h) 71 | _, _, _, err := h(conn, codec.PlainEncoder, codec.PlainDecoder) 72 | if err != nil { 73 | fmt.Printf("%v - server side \n", err) 74 | atomic.AddInt64(&connectionKillCount, 1) 75 | } 76 | serverHandshakeDone <- struct{}{} 77 | }, 78 | nil, 79 | policy.Max(2), 80 | ) 81 | }() 82 | <-time.After(100 * time.Millisecond) 83 | 84 | go func() { 85 | tcp.Dial(ctx, 86 | fmt.Sprintf("localhost:%v", <-portCh2), 87 | func(conn net.Conn) { 88 | h := handshake.ECIES(privKey1) 89 | h = handshake.Once(privKey1.Signatory(), &pool1, h) 90 | _, _, _, err := h(conn, codec.PlainEncoder, codec.PlainDecoder) 91 | if err != nil { 92 | fmt.Printf("%v - client side 1\n", err) 93 | atomic.AddInt64(&connectionKillCount, 1) 94 | } 95 | handshakeDone1 <- struct{}{} 96 | }, 97 | nil, 98 | policy.ConstantTimeout(time.Second*2), 99 | ) 100 | }() 101 | 102 | tcp.Dial(ctx, 103 | fmt.Sprintf("localhost:%v", <-portCh1), 104 | func(conn net.Conn) { 105 | h := handshake.ECIES(privKey2) 106 | h = handshake.Once(privKey2.Signatory(), &pool2, h) 107 | _, _, _, err := h(conn, codec.PlainEncoder, codec.PlainDecoder) 108 | if err != nil { 109 | fmt.Printf("%v - client side 2\n", err) 110 | atomic.AddInt64(&connectionKillCount, 1) 111 | } 112 | handshakeDone2 <- struct{}{} 113 | }, 114 | nil, 115 | policy.ConstantTimeout(time.Second*2), 116 | ) 117 | 118 | <-handshakeDone1 119 | <-handshakeDone2 120 | <-serverHandshakeDone 121 | <-serverHandshakeDone 122 | 123 | Expect(atomic.LoadInt64(&connectionKillCount)).To(Equal(int64(2))) 124 | 125 | }) 126 | }) 127 | 128 | Context("when a client tries to establish two connections to a server", func() { 129 | It("should only maintain one connection", func() { 130 | pool1 := handshake.NewOncePool(handshake.DefaultOncePoolOptions()) 131 | pool2 := handshake.NewOncePool(handshake.DefaultOncePoolOptions()) 132 | 133 | privKey1 := id.NewPrivKey() 134 | privKey2 := id.NewPrivKey() 135 | ctx, cancel := context.WithCancel(context.Background()) 136 | defer cancel() 137 | 138 | handshakeDone1 := make(chan struct{}, 1) 139 | handshakeDone2 := make(chan struct{}, 1) 140 | serverHandshakeDone := make(chan struct{}, 2) 141 | 142 | ip := "127.0.0.1" 143 | portCh := make(chan int, 1) 144 | 145 | var connectionKillCount int64 = 0 146 | go func() { 147 | listener, port, err := tcp.ListenerWithAssignedPort(ctx, ip) 148 | Expect(err).ToNot(HaveOccurred()) 149 | portCh <- port 150 | tcp.ListenWithListener(ctx, 151 | listener, 152 | func(conn net.Conn) { 153 | h := handshake.ECIES(privKey1) 154 | h = handshake.Once(privKey1.Signatory(), &pool1, h) 155 | _, _, _, err := h(conn, codec.PlainEncoder, codec.PlainDecoder) 156 | if err != nil { 157 | fmt.Printf("%v - server side \n", err) 158 | atomic.AddInt64(&connectionKillCount, 1) 159 | } 160 | serverHandshakeDone <- struct{}{} 161 | }, 162 | nil, 163 | policy.Max(2), 164 | ) 165 | }() 166 | port := <-portCh 167 | <-time.After(100 * time.Millisecond) 168 | 169 | go func() { 170 | tcp.Dial(ctx, 171 | fmt.Sprintf("localhost:%v", port), 172 | func(conn net.Conn) { 173 | h := handshake.ECIES(privKey2) 174 | h = handshake.Once(privKey2.Signatory(), &pool2, h) 175 | _, _, _, err := h(conn, codec.PlainEncoder, codec.PlainDecoder) 176 | if err != nil { 177 | fmt.Printf("%v - client side \n", err) 178 | atomic.AddInt64(&connectionKillCount, 1) 179 | } 180 | handshakeDone1 <- struct{}{} 181 | }, 182 | nil, 183 | policy.ConstantTimeout(time.Second*2), 184 | ) 185 | }() 186 | 187 | tcp.Dial(ctx, 188 | fmt.Sprintf("localhost:%v", port), 189 | func(conn net.Conn) { 190 | h := handshake.ECIES(privKey2) 191 | h = handshake.Once(privKey2.Signatory(), &pool2, h) 192 | _, _, _, err := h(conn, codec.PlainEncoder, codec.PlainDecoder) 193 | if err != nil { 194 | fmt.Printf("%v - client side \n", err) 195 | atomic.AddInt64(&connectionKillCount, 1) 196 | } 197 | handshakeDone2 <- struct{}{} 198 | }, 199 | nil, 200 | policy.ConstantTimeout(time.Second*2), 201 | ) 202 | 203 | <-handshakeDone1 204 | <-handshakeDone2 205 | <-serverHandshakeDone 206 | <-serverHandshakeDone 207 | 208 | Expect(atomic.LoadInt64(&connectionKillCount)).To(Equal(int64(2))) 209 | }) 210 | }) 211 | }) 212 | }) 213 | -------------------------------------------------------------------------------- /peer/gossip.go: -------------------------------------------------------------------------------- 1 | package peer 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "sync" 7 | 8 | "github.com/renproject/aw/channel" 9 | "github.com/renproject/aw/dht" 10 | "github.com/renproject/aw/transport" 11 | "github.com/renproject/aw/wire" 12 | "github.com/renproject/id" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | type Gossiper struct { 17 | opts GossiperOptions 18 | 19 | filter *channel.SyncFilter 20 | transport *transport.Transport 21 | 22 | subnetsMu *sync.Mutex 23 | subnets map[string]id.Hash 24 | 25 | resolverMu *sync.RWMutex 26 | resolver dht.ContentResolver 27 | } 28 | 29 | func NewGossiper(opts GossiperOptions, filter *channel.SyncFilter, transport *transport.Transport) *Gossiper { 30 | return &Gossiper{ 31 | opts: opts, 32 | 33 | filter: filter, 34 | transport: transport, 35 | 36 | subnetsMu: new(sync.Mutex), 37 | subnets: make(map[string]id.Hash, 1024), 38 | 39 | resolverMu: new(sync.RWMutex), 40 | resolver: nil, 41 | } 42 | } 43 | 44 | func (g *Gossiper) Resolve(resolver dht.ContentResolver) { 45 | g.resolverMu.Lock() 46 | defer g.resolverMu.Unlock() 47 | 48 | g.resolver = resolver 49 | } 50 | 51 | func (g *Gossiper) Gossip(ctx context.Context, contentID []byte, subnet *id.Hash) { 52 | if subnet == nil { 53 | subnet = &DefaultSubnet 54 | } 55 | 56 | recipients := []id.Signatory{} 57 | if subnet.Equal(&DefaultSubnet) { 58 | recipients = g.transport.Table().Peers(g.opts.Alpha) 59 | } else { 60 | if recipients = g.transport.Table().Subnet(*subnet); len(recipients) > g.opts.Alpha { 61 | recipients = recipients[:g.opts.Alpha] 62 | } 63 | } 64 | 65 | msg := wire.Msg{Version: wire.MsgVersion1, To: *subnet, Type: wire.MsgTypePush, Data: contentID} 66 | wg := new(sync.WaitGroup) 67 | for i := range recipients { 68 | recipient := recipients[i] 69 | wg.Add(1) 70 | go func() { 71 | defer wg.Done() 72 | 73 | innerContext, cancel := context.WithTimeout(ctx, g.opts.Timeout) 74 | defer cancel() 75 | 76 | // Ignore the error, cause random recipient could be offline. 77 | _ = g.transport.Send(innerContext, recipient, msg) 78 | }() 79 | } 80 | wg.Wait() 81 | } 82 | 83 | func (g *Gossiper) DidReceiveMessage(from id.Signatory, msg wire.Msg) error { 84 | switch msg.Type { 85 | case wire.MsgTypePush: 86 | g.didReceivePush(from, msg) 87 | case wire.MsgTypePull: 88 | g.didReceivePull(from, msg) 89 | case wire.MsgTypeSync: 90 | // TODO: Fix Channel to gracefully handle the error returned if a message is filtered 91 | if g.filter.Filter(from, msg) { 92 | return nil 93 | } 94 | g.didReceiveSync(from, msg) 95 | } 96 | return nil 97 | } 98 | 99 | func (g *Gossiper) didReceivePush(from id.Signatory, msg wire.Msg) { 100 | if len(msg.Data) == 0 { 101 | return 102 | } 103 | 104 | // Check whether the content is already known. This can cause performance 105 | // bottle-necks if the content resolver is slow. 106 | g.resolverMu.RLock() 107 | if g.resolver == nil { 108 | g.resolverMu.RUnlock() 109 | return 110 | } 111 | if _, ok := g.resolver.QueryContent(msg.Data); ok { 112 | g.resolverMu.RUnlock() 113 | return 114 | } 115 | g.resolverMu.RUnlock() 116 | 117 | ctx, cancel := context.WithTimeout(context.Background(), g.opts.Timeout) 118 | 119 | // Later, we will probably receive a synchronisation message for the content 120 | // associated with this push. We store the subnet now, so that we know how 121 | // to propagate the content later. 122 | g.subnetsMu.Lock() 123 | g.subnets[string(msg.Data)] = msg.To 124 | g.subnetsMu.Unlock() 125 | 126 | // We are expecting a synchronisation message, because we are about to send 127 | // out a pull message. So, we need to allow the content in the filter. 128 | g.filter.Allow(msg.Data) 129 | 130 | // Cleanup after the synchronisation timeout has passed. This prevents 131 | // memory leaking in the filter and in the subnets map. It means that until 132 | // the timeout passes, we will be accepting synchronisation messages for 133 | // this content ID. 134 | go func() { 135 | <-ctx.Done() 136 | cancel() 137 | 138 | g.subnetsMu.Lock() 139 | delete(g.subnets, string(msg.Data)) 140 | g.subnetsMu.Unlock() 141 | 142 | g.filter.Deny(msg.Data) 143 | }() 144 | 145 | if err := g.transport.Send(ctx, from, wire.Msg{ 146 | Version: wire.MsgVersion1, 147 | Type: wire.MsgTypePull, 148 | To: id.Hash(from), 149 | Data: msg.Data, 150 | }); err != nil { 151 | g.opts.Logger.Error("pull", zap.String("peer", from.String()), zap.String("id", base64.RawURLEncoding.EncodeToString(msg.Data)), zap.Error(err)) 152 | return 153 | } 154 | } 155 | 156 | func (g *Gossiper) didReceivePull(from id.Signatory, msg wire.Msg) { 157 | if len(msg.Data) == 0 { 158 | return 159 | } 160 | 161 | var content []byte 162 | var contentOk bool 163 | func() { 164 | g.resolverMu.RLock() 165 | defer g.resolverMu.RUnlock() 166 | 167 | if g.resolver == nil { 168 | return 169 | } 170 | content, contentOk = g.resolver.QueryContent(msg.Data) 171 | }() 172 | if !contentOk { 173 | g.opts.Logger.Debug("content not found", zap.String("peer", from.String()), zap.String("id", base64.RawURLEncoding.EncodeToString(msg.Data))) 174 | return 175 | } 176 | 177 | ctx, cancel := context.WithTimeout(context.Background(), g.opts.Timeout) 178 | defer cancel() 179 | 180 | if err := g.transport.Send(ctx, from, wire.Msg{ 181 | Version: wire.MsgVersion1, 182 | To: id.Hash(from), 183 | Type: wire.MsgTypeSync, 184 | Data: msg.Data, 185 | SyncData: content, 186 | }); err != nil { 187 | g.opts.Logger.Error("sync", zap.String("peer", from.String()), zap.String("id", base64.RawURLEncoding.EncodeToString(msg.Data)), zap.Error(err)) 188 | } 189 | return 190 | } 191 | 192 | func (g *Gossiper) didReceiveSync(from id.Signatory, msg wire.Msg) { 193 | g.resolverMu.RLock() 194 | if g.resolver == nil { 195 | g.resolverMu.RUnlock() 196 | return 197 | } 198 | 199 | _, alreadySeenContent := g.resolver.QueryContent(msg.Data) 200 | if alreadySeenContent { 201 | g.resolverMu.RUnlock() 202 | return 203 | } 204 | if len(msg.Data) == 0 || len(msg.SyncData) == 0 { 205 | g.resolverMu.RUnlock() 206 | return 207 | } 208 | 209 | // We are relying on the correctness of the channel filtering to ensure that 210 | // no synchronisation messages reach the gossiper unless the gossiper (or 211 | // the synchroniser) have allowed them. 212 | g.resolver.InsertContent(msg.Data, msg.SyncData) 213 | g.resolverMu.RUnlock() 214 | 215 | g.subnetsMu.Lock() 216 | subnet, ok := g.subnets[string(msg.Data)] 217 | g.subnetsMu.Unlock() 218 | 219 | if !ok { 220 | // The gossip has taken too long, and the subnet was removed from the 221 | // map to preserve memory. Gossiping cannot continue. 222 | return 223 | } 224 | 225 | ctx, cancel := context.WithTimeout(context.Background(), g.opts.Timeout) 226 | defer cancel() 227 | 228 | g.Gossip(ctx, msg.Data, &subnet) 229 | } 230 | -------------------------------------------------------------------------------- /peer/gossip_test.go: -------------------------------------------------------------------------------- 1 | package peer_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | "strings" 9 | "time" 10 | 11 | "github.com/renproject/aw/peer" 12 | "github.com/renproject/aw/wire" 13 | "github.com/renproject/id" 14 | "go.uber.org/zap" 15 | "go.uber.org/zap/zapcore" 16 | 17 | . "github.com/onsi/ginkgo" 18 | . "github.com/onsi/gomega" 19 | ) 20 | 21 | var _ = Describe("Gossip", func() { 22 | Context("When a node is gossipping with peers", func() { 23 | It("should sync content correctly", func() { 24 | 25 | // Number of peers 26 | n := 4 27 | opts, peers, tables, contentResolvers, _, _ := setup(n) 28 | 29 | for i := range peers { 30 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 31 | defer cancel() 32 | go peers[i].Run(ctx) 33 | tables[i].AddPeer(opts[(i+1)%n].PrivKey.Signatory(), 34 | wire.NewUnsignedAddress(wire.TCP, 35 | fmt.Sprintf("%v:%v", "localhost", uint16(3333+i+1)), uint64(time.Now().UnixNano()))) 36 | tables[(i+1)%n].AddPeer(opts[i].PrivKey.Signatory(), 37 | wire.NewUnsignedAddress(wire.TCP, 38 | fmt.Sprintf("%v:%v", "localhost", uint16(3333+i)), uint64(time.Now().UnixNano()))) 39 | } 40 | for i := range peers { 41 | msgHello := fmt.Sprintf("Hi from %v", peers[i].ID().String()) 42 | contentID := id.NewHash([]byte(msgHello)) 43 | contentResolvers[i].InsertContent(contentID[:], []byte(msgHello)) 44 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 45 | defer cancel() 46 | peers[i].Gossip(ctx, contentID[:], &peer.DefaultSubnet) 47 | } 48 | <-time.After(5 * time.Second) 49 | for i := range peers { 50 | for j := range peers { 51 | msgHello := fmt.Sprintf("Hi from %v", peers[j].ID().String()) 52 | contentID := id.NewHash([]byte(msgHello)) 53 | _, ok := contentResolvers[i].QueryContent(contentID[:]) 54 | Expect(ok).To(BeTrue()) 55 | } 56 | } 57 | }) 58 | 59 | It("should not send to itself", func() { 60 | // Custom logger that writes the error logs to a file 61 | highPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 62 | return lvl >= zapcore.ErrorLevel 63 | }) 64 | lowPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { 65 | return lvl < zapcore.ErrorLevel 66 | }) 67 | 68 | consoleDebugging := zapcore.Lock(os.Stdout) 69 | consoleErrors, err := os.Create("err") 70 | if err != nil { 71 | panic(err) 72 | } 73 | defer os.Remove("err") 74 | consoleEncoder := zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()) 75 | 76 | core := zapcore.NewTee( 77 | zapcore.NewCore(consoleEncoder, consoleErrors, highPriority), 78 | zapcore.NewCore(consoleEncoder, consoleDebugging, lowPriority), 79 | ) 80 | logger := zap.New(core) 81 | opts, peers, tables, contentResolvers, _, _ := setupWithLogger(1, logger) 82 | 83 | // Add the peer's own address to the peer table. If this regression 84 | // has been fixed, then the call to AddPeer should actually return 85 | // early and not add an entry to the table. 86 | tables[0].AddPeer(opts[0].PrivKey.Signatory(), 87 | wire.NewUnsignedAddress(wire.TCP, 88 | fmt.Sprintf("%v:%v", "localhost", uint16(3333)), uint64(time.Now().UnixNano()))) 89 | 90 | for i := range peers { 91 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 92 | defer cancel() 93 | go peers[i].Run(ctx) 94 | } 95 | 96 | for i := range peers { 97 | msgHello := fmt.Sprintf("Hi from %v", peers[i].ID().String()) 98 | contentID := id.NewHash([]byte(msgHello)) 99 | contentResolvers[i].InsertContent(contentID[:], []byte(msgHello)) 100 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 101 | defer cancel() 102 | peers[i].Gossip(ctx, contentID[:], &peer.DefaultSubnet) 103 | } 104 | 105 | // FIXME(ross): If we don't wait long enough, the peer will not 106 | // have sent a message to itself yet. The time this takes should be 107 | // very small; the fact that this can still have not occurred after 108 | // a whole second indicates that something is wrong. 109 | <-time.After(5 * time.Second) 110 | 111 | buf := make([]byte, 1024) 112 | n, err := consoleErrors.ReadAt(buf, 0) 113 | if err != nil && err != io.EOF { 114 | panic(err) 115 | } 116 | fmt.Printf("%s\n", string(buf[:n])) 117 | 118 | // If a peer sends a message to themself, there will be a decoding 119 | // error. 120 | Expect(strings.Contains(string(buf[:n]), "message authentication failed")).To(BeFalse()) 121 | }) 122 | }) 123 | }) 124 | -------------------------------------------------------------------------------- /peer/opt.go: -------------------------------------------------------------------------------- 1 | package peer 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/renproject/id" 7 | "go.uber.org/zap" 8 | ) 9 | 10 | type SyncerOptions struct { 11 | Logger *zap.Logger 12 | Alpha int 13 | WiggleTimeout time.Duration 14 | } 15 | 16 | func DefaultSyncerOptions() SyncerOptions { 17 | logger, err := zap.NewDevelopment() 18 | if err != nil { 19 | panic(err) 20 | } 21 | return SyncerOptions{ 22 | Logger: logger, 23 | Alpha: DefaultAlpha, 24 | WiggleTimeout: DefaultTimeout, 25 | } 26 | } 27 | 28 | func (opts SyncerOptions) WithLogger(logger *zap.Logger) SyncerOptions { 29 | opts.Logger = logger 30 | return opts 31 | } 32 | 33 | func (opts SyncerOptions) WithAlpha(alpha int) SyncerOptions { 34 | opts.Alpha = alpha 35 | return opts 36 | } 37 | 38 | func (opts SyncerOptions) WithWiggleTimeout(timeout time.Duration) SyncerOptions { 39 | opts.WiggleTimeout = timeout 40 | return opts 41 | } 42 | 43 | type GossiperOptions struct { 44 | Logger *zap.Logger 45 | Alpha int 46 | Timeout time.Duration 47 | } 48 | 49 | func DefaultGossiperOptions() GossiperOptions { 50 | logger, err := zap.NewDevelopment() 51 | if err != nil { 52 | panic(err) 53 | } 54 | return GossiperOptions{ 55 | Logger: logger, 56 | Alpha: DefaultAlpha, 57 | Timeout: DefaultTimeout, 58 | } 59 | } 60 | 61 | func (opts GossiperOptions) WithLogger(logger *zap.Logger) GossiperOptions { 62 | opts.Logger = logger 63 | return opts 64 | } 65 | 66 | func (opts GossiperOptions) WithAlpha(alpha int) GossiperOptions { 67 | opts.Alpha = alpha 68 | return opts 69 | } 70 | 71 | func (opts GossiperOptions) WithTimeout(timeout time.Duration) GossiperOptions { 72 | opts.Timeout = timeout 73 | return opts 74 | } 75 | 76 | type DiscoveryOptions struct { 77 | Logger *zap.Logger 78 | Alpha int 79 | MaxExpectedPeers int 80 | PingTimePeriod time.Duration 81 | } 82 | 83 | func DefaultDiscoveryOptions() DiscoveryOptions { 84 | logger, err := zap.NewDevelopment() 85 | if err != nil { 86 | panic(err) 87 | } 88 | return DiscoveryOptions{ 89 | Logger: logger, 90 | Alpha: DefaultAlpha, 91 | MaxExpectedPeers: DefaultAlpha, 92 | PingTimePeriod: DefaultTimeout, 93 | } 94 | } 95 | 96 | func (opts DiscoveryOptions) WithLogger(logger *zap.Logger) DiscoveryOptions { 97 | opts.Logger = logger 98 | return opts 99 | } 100 | 101 | func (opts DiscoveryOptions) WithAlpha(alpha int) DiscoveryOptions { 102 | opts.Alpha = alpha 103 | return opts 104 | } 105 | 106 | func (opts DiscoveryOptions) WithMaxExpectedPeers(max int) DiscoveryOptions { 107 | opts.MaxExpectedPeers = max 108 | return opts 109 | } 110 | 111 | func (opts DiscoveryOptions) WithPingTimePeriod(period time.Duration) DiscoveryOptions { 112 | opts.PingTimePeriod = period 113 | return opts 114 | } 115 | 116 | type Options struct { 117 | SyncerOptions 118 | GossiperOptions 119 | DiscoveryOptions 120 | 121 | Logger *zap.Logger 122 | PrivKey *id.PrivKey 123 | } 124 | 125 | func DefaultOptions() Options { 126 | logger, err := zap.NewDevelopment() 127 | if err != nil { 128 | panic(err) 129 | } 130 | privKey := id.NewPrivKey() 131 | return Options{ 132 | SyncerOptions: DefaultSyncerOptions(), 133 | GossiperOptions: DefaultGossiperOptions(), 134 | DiscoveryOptions: DefaultDiscoveryOptions(), 135 | 136 | Logger: logger, 137 | PrivKey: privKey, 138 | } 139 | } 140 | 141 | func (opts Options) WithSyncerOptions(syncerOptions SyncerOptions) Options { 142 | opts.SyncerOptions = syncerOptions 143 | return opts 144 | } 145 | 146 | func (opts Options) WithGossiperOptions(gossiperOptions GossiperOptions) Options { 147 | opts.GossiperOptions = gossiperOptions 148 | return opts 149 | } 150 | 151 | func (opts Options) WithDiscoveryOptions(discoveryOptions DiscoveryOptions) Options { 152 | opts.DiscoveryOptions = discoveryOptions 153 | return opts 154 | } 155 | 156 | func (opts Options) WithLogger(logger *zap.Logger) Options { 157 | opts.Logger = logger 158 | return opts 159 | } 160 | 161 | func (opts Options) WithPrivKey(privKey *id.PrivKey) Options { 162 | opts.PrivKey = privKey 163 | return opts 164 | } 165 | -------------------------------------------------------------------------------- /peer/opt_test.go: -------------------------------------------------------------------------------- 1 | package peer_test 2 | -------------------------------------------------------------------------------- /peer/peer.go: -------------------------------------------------------------------------------- 1 | package peer 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/renproject/aw/channel" 10 | "github.com/renproject/aw/dht" 11 | "github.com/renproject/aw/transport" 12 | "github.com/renproject/aw/wire" 13 | "github.com/renproject/id" 14 | ) 15 | 16 | var ( 17 | DefaultSubnet = id.Hash{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} 18 | DefaultAlpha = 5 19 | DefaultTimeout = time.Second 20 | DefaultGossipTimeout = 3 * time.Second 21 | ) 22 | 23 | var ( 24 | ErrPeerNotFound = errors.New("peer not found") 25 | ) 26 | 27 | type Peer struct { 28 | opts Options 29 | transport *transport.Transport 30 | syncer *Syncer 31 | gossiper *Gossiper 32 | discoveryClient *DiscoveryClient 33 | } 34 | 35 | func New(opts Options, transport *transport.Transport) *Peer { 36 | filter := channel.NewSyncFilter() 37 | return &Peer{ 38 | opts: opts, 39 | transport: transport, 40 | syncer: NewSyncer(opts.SyncerOptions, filter, transport), 41 | gossiper: NewGossiper(opts.GossiperOptions, filter, transport), 42 | discoveryClient: NewDiscoveryClient(opts.DiscoveryOptions, transport), 43 | } 44 | } 45 | 46 | func (p *Peer) ID() id.Signatory { 47 | return p.opts.PrivKey.Signatory() 48 | } 49 | 50 | func (p *Peer) Syncer() *Syncer { 51 | return p.syncer 52 | } 53 | 54 | func (p *Peer) Gossiper() *Gossiper { 55 | return p.gossiper 56 | } 57 | 58 | func (p *Peer) Transport() *transport.Transport { 59 | return p.transport 60 | } 61 | 62 | func (p *Peer) Link(remote id.Signatory) { 63 | p.transport.Link(remote) 64 | } 65 | 66 | func (p *Peer) Unlink(remote id.Signatory) { 67 | p.transport.Unlink(remote) 68 | } 69 | 70 | func (p *Peer) Ping(ctx context.Context) error { 71 | return fmt.Errorf("unimplemented") 72 | } 73 | 74 | func (p *Peer) Send(ctx context.Context, to id.Signatory, msg wire.Msg) error { 75 | return p.transport.Send(ctx, to, msg) 76 | } 77 | 78 | func (p *Peer) Sync(ctx context.Context, contentID []byte, hint *id.Signatory) ([]byte, error) { 79 | return p.syncer.Sync(ctx, contentID, hint) 80 | } 81 | 82 | func (p *Peer) Gossip(ctx context.Context, contentID []byte, subnet *id.Hash) { 83 | p.gossiper.Gossip(ctx, contentID, subnet) 84 | } 85 | 86 | func (p *Peer) DiscoverPeers(ctx context.Context) { 87 | p.discoveryClient.DiscoverPeers(ctx) 88 | } 89 | 90 | func (p *Peer) Run(ctx context.Context) { 91 | p.transport.Receive(ctx, func(from id.Signatory, packet wire.Packet) error { 92 | // TODO(ross): Think about merging the syncer and the gossiper. 93 | if err := p.syncer.DidReceiveMessage(from, packet.Msg); err != nil { 94 | return err 95 | } 96 | if err := p.gossiper.DidReceiveMessage(from, packet.Msg); err != nil { 97 | return err 98 | } 99 | if err := p.discoveryClient.DidReceiveMessage(from, packet.IPAddr, packet.Msg); err != nil { 100 | return err 101 | } 102 | return nil 103 | }) 104 | p.transport.Run(ctx) 105 | } 106 | 107 | func (p *Peer) Receive(ctx context.Context, f func(id.Signatory, wire.Packet) error) { 108 | p.transport.Receive(ctx, f) 109 | } 110 | 111 | func (p *Peer) Resolve(ctx context.Context, contentResolver dht.ContentResolver) { 112 | p.gossiper.Resolve(contentResolver) 113 | } 114 | -------------------------------------------------------------------------------- /peer/peer_suite_test.go: -------------------------------------------------------------------------------- 1 | package peer 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestPeer(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "Peer Test Suite") 13 | } 14 | -------------------------------------------------------------------------------- /peer/peer_test.go: -------------------------------------------------------------------------------- 1 | package peer_test 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/renproject/aw/channel" 8 | "github.com/renproject/aw/dht" 9 | "github.com/renproject/aw/handshake" 10 | "github.com/renproject/aw/peer" 11 | "github.com/renproject/aw/transport" 12 | "github.com/renproject/id" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | func setup(numPeers int) ([]peer.Options, []*peer.Peer, []dht.Table, []dht.ContentResolver, []*channel.Client, []*transport.Transport) { 17 | loggerConfig := zap.NewProductionConfig() 18 | loggerConfig.Level.SetLevel(zap.ErrorLevel) 19 | logger, err := loggerConfig.Build() 20 | if err != nil { 21 | panic(err) 22 | } 23 | 24 | return setupWithLogger(numPeers, logger) 25 | } 26 | 27 | func setupWithLogger(numPeers int, logger *zap.Logger) ([]peer.Options, []*peer.Peer, []dht.Table, []dht.ContentResolver, []*channel.Client, []*transport.Transport) { 28 | // Init options for all peers. 29 | opts := make([]peer.Options, numPeers) 30 | for i := range opts { 31 | i := i 32 | opts[i] = peer.DefaultOptions().WithLogger(logger) 33 | } 34 | 35 | peers := make([]*peer.Peer, numPeers) 36 | tables := make([]dht.Table, numPeers) 37 | contentResolvers := make([]dht.ContentResolver, numPeers) 38 | clients := make([]*channel.Client, numPeers) 39 | transports := make([]*transport.Transport, numPeers) 40 | for i := range peers { 41 | self := opts[i].PrivKey.Signatory() 42 | h := handshake.Filter(func(id.Signatory) error { return nil }, handshake.ECIES(opts[i].PrivKey)) 43 | clients[i] = channel.NewClient( 44 | channel.DefaultOptions(). 45 | WithLogger(logger), 46 | self) 47 | tables[i] = dht.NewInMemTable(self) 48 | contentResolvers[i] = dht.NewDoubleCacheContentResolver(dht.DefaultDoubleCacheContentResolverOptions(), nil) 49 | transports[i] = transport.New( 50 | transport.DefaultOptions(). 51 | WithLogger(logger). 52 | WithClientTimeout(5*time.Second). 53 | WithOncePoolOptions(handshake.DefaultOncePoolOptions().WithMinimumExpiryAge(10*time.Second)). 54 | WithPort(uint16(3333+i)), 55 | self, 56 | clients[i], 57 | h, 58 | tables[i]) 59 | peers[i] = peer.New( 60 | opts[i], 61 | transports[i]) 62 | peers[i].Resolve(context.Background(), contentResolvers[i]) 63 | } 64 | return opts, peers, tables, contentResolvers, clients, transports 65 | } 66 | -------------------------------------------------------------------------------- /peer/peerdiscovery.go: -------------------------------------------------------------------------------- 1 | package peer 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "fmt" 7 | "net" 8 | "time" 9 | 10 | "github.com/renproject/aw/transport" 11 | "github.com/renproject/aw/wire" 12 | "github.com/renproject/id" 13 | "github.com/renproject/surge" 14 | 15 | "go.uber.org/zap" 16 | ) 17 | 18 | type DiscoveryClient struct { 19 | opts DiscoveryOptions 20 | 21 | transport *transport.Transport 22 | } 23 | 24 | func NewDiscoveryClient(opts DiscoveryOptions, transport *transport.Transport) *DiscoveryClient { 25 | return &DiscoveryClient{ 26 | opts: opts, 27 | transport: transport, 28 | } 29 | } 30 | 31 | func (dc *DiscoveryClient) DiscoverPeers(ctx context.Context) { 32 | var pingData [2]byte 33 | binary.LittleEndian.PutUint16(pingData[:], dc.transport.Port()) 34 | 35 | msg := wire.Msg{ 36 | Version: wire.MsgVersion1, 37 | Type: wire.MsgTypePing, 38 | Data: pingData[:], 39 | } 40 | 41 | ticker := time.NewTicker(dc.opts.PingTimePeriod) 42 | defer ticker.Stop() 43 | 44 | alpha := dc.opts.Alpha 45 | sendDuration := dc.opts.PingTimePeriod / time.Duration(alpha) 46 | Outer: 47 | for { 48 | for _, sig := range dc.transport.Table().Peers(alpha) { 49 | err := func() error { 50 | innerCtx, innerCancel := context.WithTimeout(ctx, sendDuration) 51 | defer innerCancel() 52 | msg.To = id.Hash(sig) 53 | return dc.transport.Send(innerCtx, sig, msg) 54 | }() 55 | if err != nil { 56 | dc.opts.Logger.Debug("pinging", zap.Error(err)) 57 | if err == context.Canceled || err == context.DeadlineExceeded { 58 | break 59 | } 60 | } 61 | select { 62 | case <-ticker.C: 63 | continue Outer 64 | default: 65 | } 66 | } 67 | select { 68 | case <-ctx.Done(): 69 | return 70 | case <-ticker.C: 71 | } 72 | } 73 | } 74 | 75 | func (dc *DiscoveryClient) DidReceiveMessage(from id.Signatory, ipAddr net.Addr, msg wire.Msg) error { 76 | switch msg.Type { 77 | case wire.MsgTypePing: 78 | if err := dc.didReceivePing(from, ipAddr, msg); err != nil { 79 | return err 80 | } 81 | case wire.MsgTypePingAck: 82 | if err := dc.didReceivePingAck(from, msg); err != nil { 83 | return err 84 | } 85 | } 86 | return nil 87 | } 88 | 89 | func (dc *DiscoveryClient) didReceivePing(from id.Signatory, ipAddr net.Addr, msg wire.Msg) error { 90 | ctx, cancel := context.WithCancel(context.Background()) 91 | defer cancel() 92 | 93 | if dataLen := len(msg.Data); dataLen != 2 { 94 | return fmt.Errorf("malformed port received in ping message. expected: 2 bytes, received: %v bytes", dataLen) 95 | } 96 | port := binary.LittleEndian.Uint16(msg.Data) 97 | 98 | dc.transport.Table().AddPeer( 99 | from, 100 | wire.NewUnsignedAddress(wire.TCP, fmt.Sprintf("%v:%v", ipAddr.(*net.TCPAddr).IP.String(), port), uint64(time.Now().UnixNano())), 101 | ) 102 | 103 | peers := dc.transport.Table().Peers(dc.opts.MaxExpectedPeers) 104 | addrAndSig := make([]wire.SignatoryAndAddress, 0, len(peers)) 105 | for _, sig := range peers { 106 | addr, addrOk := dc.transport.Table().PeerAddress(sig) 107 | if !addrOk { 108 | dc.opts.Logger.DPanic("acking ping", zap.String("peer", "does not exist in table")) 109 | continue 110 | } 111 | sigAndAddr := wire.SignatoryAndAddress{Signatory: sig, Address: addr} 112 | addrAndSig = append(addrAndSig, sigAndAddr) 113 | } 114 | 115 | addrAndSigBytes, err := surge.ToBinary(addrAndSig) 116 | if err != nil { 117 | return fmt.Errorf("bad ping ack: %v", err) 118 | } 119 | response := wire.Msg{ 120 | Version: wire.MsgVersion1, 121 | Type: wire.MsgTypePingAck, 122 | To: id.Hash(from), 123 | Data: addrAndSigBytes, 124 | } 125 | if err := dc.transport.Send(ctx, from, response); err != nil { 126 | dc.opts.Logger.Debug("acking ping", zap.Error(err)) 127 | } 128 | return nil 129 | } 130 | 131 | func (dc *DiscoveryClient) didReceivePingAck(from id.Signatory, msg wire.Msg) error { 132 | slice := []wire.SignatoryAndAddress{} 133 | err := surge.FromBinary(&slice, msg.Data) 134 | if err != nil { 135 | return fmt.Errorf("bad ping ack: %v", err) 136 | } 137 | 138 | for _, x := range slice { 139 | dc.transport.Table().AddPeer(x.Signatory, x.Address) 140 | } 141 | return nil 142 | } 143 | -------------------------------------------------------------------------------- /peer/peerdiscovery_test.go: -------------------------------------------------------------------------------- 1 | package peer_test 2 | 3 | import ( 4 | "context" 5 | "encoding/binary" 6 | "fmt" 7 | "go.uber.org/zap" 8 | "time" 9 | 10 | "github.com/renproject/aw/dht" 11 | "github.com/renproject/aw/peer" 12 | "github.com/renproject/aw/transport" 13 | "github.com/renproject/aw/wire" 14 | "github.com/renproject/id" 15 | 16 | . "github.com/onsi/ginkgo" 17 | . "github.com/onsi/gomega" 18 | ) 19 | 20 | func testPeerDiscovery(n int, peers []*peer.Peer, tables []dht.Table, transports []*transport.Transport) context.CancelFunc { 21 | time.Sleep(time.Second) 22 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 23 | 24 | for i := range peers { 25 | go peers[i].DiscoverPeers(ctx) 26 | } 27 | <-ctx.Done() 28 | 29 | for i := range peers { 30 | Expect(tables[i].NumPeers()).To(Equal(n - 1)) 31 | for j := range peers { 32 | if i != j { 33 | self := transports[j].Self() 34 | addr, ok := tables[i].PeerAddress(transports[j].Self()) 35 | if !ok { 36 | fmt.Printf("Sig not found: %v\n", self) 37 | for _, k := range tables[i].Peers(10) { 38 | sig := id.Signatory{} 39 | copy(sig[:], k[:]) 40 | x, _ := tables[i].PeerAddress(sig) 41 | fmt.Printf("Sig in table: %v, Addr: %v\n", sig, x) 42 | } 43 | } 44 | Expect(ok).To(BeTrue()) 45 | Expect(addr.Value).To(Or( 46 | Equal(fmt.Sprintf("127.0.0.1:%v", uint16(3333+j))), 47 | Equal(fmt.Sprintf("localhost:%v", uint16(3333+j))), 48 | Equal(fmt.Sprintf(":%v", uint16(3333+j))))) 49 | } 50 | } 51 | } 52 | 53 | return cancel 54 | } 55 | 56 | func createRingTopology(n int, opts []peer.Options, peers []*peer.Peer, tables []dht.Table, transports []*transport.Transport) context.CancelFunc { 57 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 58 | for i := range peers { 59 | go peers[i].Run(ctx) 60 | tables[i].AddPeer(opts[(i+1)%n].PrivKey.Signatory(), 61 | wire.NewUnsignedAddress(wire.TCP, 62 | fmt.Sprintf("%v:%v", "localhost", uint16(3333+((i+1)%n))), uint64(time.Now().UnixNano()))) 63 | } 64 | return cancel 65 | } 66 | 67 | func createLineTopology(n int, opts []peer.Options, peers []*peer.Peer, tables []dht.Table, transports []*transport.Transport) context.CancelFunc { 68 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 69 | for i := range peers { 70 | go peers[i].Run(ctx) 71 | if i < n-1 { 72 | tables[i].AddPeer(opts[i+1].PrivKey.Signatory(), 73 | wire.NewUnsignedAddress(wire.TCP, 74 | fmt.Sprintf("%v:%v", "localhost", uint16(3333+i+1)), uint64(time.Now().UnixNano()))) 75 | 76 | } 77 | } 78 | return cancel 79 | } 80 | 81 | func createStarTopology(n int, opts []peer.Options, peers []*peer.Peer, tables []dht.Table, transports []*transport.Transport) context.CancelFunc { 82 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 83 | for i := range peers { 84 | go peers[i].Run(ctx) 85 | if i != 0 { 86 | tables[i].AddPeer(opts[0].PrivKey.Signatory(), 87 | wire.NewUnsignedAddress(wire.TCP, 88 | fmt.Sprintf("%v:%v", "localhost", uint16(3333)), uint64(time.Now().UnixNano()))) 89 | 90 | } 91 | } 92 | return cancel 93 | } 94 | 95 | var _ = Describe("Peer Discovery", func() { 96 | Context("when trying to discover other peers using the peer discovery client in a ring topology", func() { 97 | It("should successfully find all peers", func() { 98 | n := 5 99 | opts, peers, tables, _, _, transports := setup(n) 100 | 101 | cancelPeerContext := createRingTopology(n, opts, peers, tables, transports) 102 | defer cancelPeerContext() 103 | 104 | cancelPeerDiscoveryContext := testPeerDiscovery(n, peers, tables, transports) 105 | defer cancelPeerDiscoveryContext() 106 | }) 107 | }) 108 | 109 | Context("when trying to discover other peers using the peer discovery client in a line topology", func() { 110 | It("should successfully find all peers", func() { 111 | n := 5 112 | opts, peers, tables, _, _, transports := setup(n) 113 | 114 | cancelPeerContext := createLineTopology(n, opts, peers, tables, transports) 115 | defer cancelPeerContext() 116 | 117 | cancelPeerDiscoveryContext := testPeerDiscovery(n, peers, tables, transports) 118 | defer cancelPeerDiscoveryContext() 119 | }) 120 | }) 121 | 122 | Context("when trying to discover other peers using the peer discovery client in a star topology", func() { 123 | It("should successfully find all peers", func() { 124 | n := 5 125 | opts, peers, tables, _, _, transports := setup(n) 126 | 127 | cancelPeerContext := createStarTopology(n, opts, peers, tables, transports) 128 | defer cancelPeerContext() 129 | 130 | cancelPeerDiscoveryContext := testPeerDiscovery(n, peers, tables, transports) 131 | defer cancelPeerDiscoveryContext() 132 | }) 133 | }) 134 | 135 | Context("when sending malformed pings to peer", func() { 136 | It("peer should not panic", func() { 137 | 138 | n := 2 139 | opts, peers, tables, _, _, transports := setup(n) 140 | 141 | cancelPeerContext := createRingTopology(n, opts, peers, tables, transports) 142 | defer cancelPeerContext() 143 | 144 | ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) 145 | defer cancel() 146 | func(ctx context.Context) { 147 | var pingData [4]byte 148 | binary.LittleEndian.PutUint32(pingData[:], uint32(transports[0].Port())) 149 | 150 | msg := wire.Msg{ 151 | Version: wire.MsgVersion1, 152 | Type: wire.MsgTypePing, 153 | } 154 | 155 | count := 0 156 | ticker := time.NewTicker(time.Second) 157 | defer ticker.Stop() 158 | 159 | sendDuration := time.Second 160 | Outer: 161 | for { 162 | if count%2 == 1 { 163 | msg.Data = pingData[:] 164 | } else { 165 | msg.Data = nil 166 | } 167 | for _, sig := range transports[0].Table().Peers(2) { 168 | err := func() error { 169 | innerCtx, innerCancel := context.WithTimeout(ctx, sendDuration) 170 | defer innerCancel() 171 | msg.To = id.Hash(sig) 172 | return transports[0].Send(innerCtx, sig, msg) 173 | }() 174 | if err != nil { 175 | opts[0].Logger.Debug("pinging", zap.Error(err)) 176 | if err == context.Canceled || err == context.DeadlineExceeded { 177 | break 178 | } 179 | } 180 | select { 181 | case <-ticker.C: 182 | continue Outer 183 | default: 184 | } 185 | } 186 | select { 187 | case <-ctx.Done(): 188 | return 189 | case <-ticker.C: 190 | count++ 191 | } 192 | } 193 | }(ctx) 194 | }) 195 | }) 196 | }) 197 | -------------------------------------------------------------------------------- /peer/sync.go: -------------------------------------------------------------------------------- 1 | package peer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | "github.com/renproject/aw/channel" 10 | "github.com/renproject/aw/transport" 11 | "github.com/renproject/aw/wire" 12 | "github.com/renproject/id" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | type pendingContent struct { 17 | // content is nil while synchronisation is happening. After synchronisation 18 | // has completed, content will be set. 19 | content []byte 20 | 21 | // cond is used to wait and notify goroutines about the completion of 22 | // synchronisation. 23 | cond *sync.Cond 24 | } 25 | 26 | // wait for the content to be synchronised. Returns a channel that the caller 27 | // can block on while waiting for the content. This is better than blocking 28 | // interally, because it can be composed with contexts/timeouts. 29 | func (pending *pendingContent) wait() <-chan []byte { 30 | w := make(chan []byte, 1) 31 | go func() { 32 | pending.cond.L.Lock() 33 | for pending.content == nil { 34 | pending.cond.Wait() 35 | } 36 | content := make([]byte, len(pending.content), len(pending.content)) 37 | copy(content, pending.content) 38 | pending.cond.L.Unlock() 39 | w <- content 40 | }() 41 | return w 42 | } 43 | 44 | // signal that the content is synchronised. All goroutines waiting on the 45 | // content will be awaken and will create their own copies of the content. 46 | func (pending *pendingContent) signal(content []byte) { 47 | pending.cond.L.Lock() 48 | pending.content = content 49 | pending.cond.L.Unlock() 50 | pending.cond.Broadcast() 51 | } 52 | 53 | type Syncer struct { 54 | opts SyncerOptions 55 | filter *channel.SyncFilter 56 | transport *transport.Transport 57 | 58 | pendingMu *sync.Mutex 59 | pending map[string]*pendingContent 60 | } 61 | 62 | func NewSyncer(opts SyncerOptions, filter *channel.SyncFilter, transport *transport.Transport) *Syncer { 63 | return &Syncer{ 64 | opts: opts, 65 | filter: filter, 66 | transport: transport, 67 | 68 | pendingMu: new(sync.Mutex), 69 | pending: make(map[string]*pendingContent, 1024), 70 | } 71 | } 72 | 73 | func (syncer *Syncer) Sync(ctx context.Context, contentID []byte, hint *id.Signatory) ([]byte, error) { 74 | syncer.pendingMu.Lock() 75 | pending, ok := syncer.pending[string(contentID)] 76 | if !ok { 77 | pending = &pendingContent{ 78 | content: nil, 79 | cond: sync.NewCond(new(sync.Mutex)), 80 | } 81 | syncer.pending[string(contentID)] = pending 82 | } 83 | syncer.pendingMu.Unlock() 84 | 85 | // Allow synchronisation messages for the content ID. This is required in 86 | // order for channel to not filter inbound content (of unknown size). At the 87 | // end of the method, we Deny the content ID again, un-doing the Allow and 88 | // blocking content again. 89 | syncer.filter.Allow(contentID) 90 | defer func() { 91 | go func() { 92 | time.Sleep(syncer.opts.WiggleTimeout) 93 | syncer.filter.Deny(contentID) 94 | }() 95 | }() 96 | 97 | // Ensure that pending content is removed. 98 | defer func() { 99 | syncer.pendingMu.Lock() 100 | delete(syncer.pending, string(contentID)) 101 | syncer.pendingMu.Unlock() 102 | }() 103 | 104 | if ok { 105 | select { 106 | case <-ctx.Done(): 107 | return nil, ctx.Err() 108 | case content := <-pending.wait(): 109 | return content, nil 110 | } 111 | } 112 | 113 | // Get addresses close to our address. We will iterate over these addresses 114 | // in order and attempt to synchronise content by sending them pull 115 | // messages. 116 | peers := syncer.transport.Table().RandomPeers(syncer.opts.Alpha) 117 | if hint != nil { 118 | peers = append([]id.Signatory{*hint}, peers...) 119 | } 120 | 121 | for i := range peers { 122 | select { 123 | case <-ctx.Done(): 124 | return nil, ctx.Err() 125 | default: 126 | } 127 | p := peers[i] 128 | go func() { 129 | err := syncer.transport.Send(ctx, p, wire.Msg{ 130 | Version: wire.MsgVersion1, 131 | Type: wire.MsgTypePull, 132 | Data: contentID, 133 | }) 134 | if err != nil { 135 | syncer.opts.Logger.Debug("sync", zap.String("peer", p.String()), zap.Error(fmt.Errorf("pulling: %v", err))) 136 | } 137 | }() 138 | } 139 | select { 140 | case <-ctx.Done(): 141 | return nil, ctx.Err() 142 | case content := <-pending.wait(): 143 | return content, nil 144 | } 145 | } 146 | 147 | func (syncer *Syncer) DidReceiveMessage(from id.Signatory, msg wire.Msg) error { 148 | if msg.Type == wire.MsgTypeSync { 149 | // TODO: Fix Channel to not drop connection on first filtered message, 150 | // since it could be a valid message that is simply late (comes after the grace period) 151 | if syncer.filter.Filter(from, msg) { 152 | return nil 153 | } 154 | syncer.pendingMu.Lock() 155 | pending, ok := syncer.pending[string(msg.Data)] 156 | if ok && msg.SyncData != nil { 157 | pending.signal(msg.SyncData) 158 | } 159 | syncer.pendingMu.Unlock() 160 | } 161 | return nil 162 | } 163 | -------------------------------------------------------------------------------- /peer/sync_test.go: -------------------------------------------------------------------------------- 1 | package peer_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/renproject/aw/peer" 9 | "github.com/renproject/aw/wire" 10 | "github.com/renproject/id" 11 | 12 | . "github.com/onsi/ginkgo" 13 | . "github.com/onsi/gomega" 14 | ) 15 | 16 | var _ = Describe("Sync", func() { 17 | Context("when trying to sync valid content id on demand with nil hint", func() { 18 | It("should successfully receive corresponding message", func() { 19 | 20 | n := 2 21 | opts, peers, tables, contentResolvers, _, transports := setup(n) 22 | 23 | for i := range opts { 24 | opts[i].SyncerOptions = opts[i].SyncerOptions.WithWiggleTimeout(2 * time.Second) 25 | peers[i] = peer.New( 26 | opts[i], 27 | transports[i]) 28 | peers[i].Resolve(context.Background(), contentResolvers[i]) 29 | } 30 | 31 | tables[0].AddPeer(peers[1].ID(), wire.NewUnsignedAddress(wire.TCP, 32 | fmt.Sprintf("%v:%v", "localhost", uint16(3333+1)), uint64(time.Now().UnixNano()))) 33 | tables[1].AddPeer(peers[0].ID(), wire.NewUnsignedAddress(wire.TCP, 34 | fmt.Sprintf("%v:%v", "localhost", uint16(3333)), uint64(time.Now().UnixNano()))) 35 | 36 | for i := range peers { 37 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 38 | defer cancel() 39 | go peers[i].Run(ctx) 40 | } 41 | 42 | helloMsg := "Hi from peer 0" 43 | contentID := id.NewHash([]byte(helloMsg)) 44 | contentResolvers[0].InsertContent(contentID[:], []byte(helloMsg)) 45 | 46 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 47 | defer cancel() 48 | 49 | msg, err := peers[1].Sync(ctx, contentID[:], nil) 50 | for err != nil { 51 | select { 52 | case <-ctx.Done(): 53 | panic("Timeout expired before content was synced") 54 | default: 55 | } 56 | msg, err = peers[1].Sync(ctx, contentID[:], nil) 57 | } 58 | 59 | Ω(err).To(BeNil()) 60 | Ω(msg).To(Equal([]byte(helloMsg))) 61 | }) 62 | }) 63 | 64 | Context("when getting a successful sync response on sending multiple parallel sync requests", func() { 65 | It("should not drop connections for additional sync responses", func() { 66 | 67 | n := 5 68 | opts, peers, tables, contentResolvers, _, transports := setup(n) 69 | 70 | for i := range opts { 71 | opts[i].SyncerOptions = opts[i].SyncerOptions.WithWiggleTimeout(2 * time.Second) 72 | peers[i] = peer.New( 73 | opts[i], 74 | transports[i]) 75 | peers[i].Resolve(context.Background(), contentResolvers[i]) 76 | } 77 | 78 | for i := range peers { 79 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 80 | defer cancel() 81 | go peers[i].Run(ctx) 82 | 83 | for j := range peers { 84 | if i != j { 85 | tables[i].AddPeer(opts[j].PrivKey.Signatory(), 86 | wire.NewUnsignedAddress(wire.TCP, 87 | fmt.Sprintf("%v:%v", "localhost", uint16(3333+j)), uint64(time.Now().UnixNano()))) 88 | } 89 | helloMsg := fmt.Sprintf("Hello from peer %d", j) 90 | contentID := id.NewHash([]byte(helloMsg)) 91 | contentResolvers[i].InsertContent(contentID[:], []byte(helloMsg)) 92 | } 93 | } 94 | 95 | for i := range peers { 96 | for j := range peers { 97 | helloMsg := fmt.Sprintf("Hello from peer %d", j) 98 | contentID := id.NewHash([]byte(helloMsg)) 99 | 100 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 101 | defer cancel() 102 | msg, err := peers[i].Sync(ctx, contentID[:], nil) 103 | 104 | for { 105 | if err == nil { 106 | break 107 | } 108 | select { 109 | case <-ctx.Done(): 110 | break 111 | default: 112 | msg, err = peers[i].Sync(ctx, contentID[:], nil) 113 | } 114 | } 115 | 116 | Ω(msg).To(Equal([]byte(helloMsg))) 117 | } 118 | } 119 | }) 120 | }) 121 | 122 | Context("if a sync request fails", func() { 123 | It("the corresponding pending content condition variable should be deleted", func() { 124 | n := 2 125 | opts, peers, tables, contentResolvers, _, transports := setup(n) 126 | 127 | tables[0].AddPeer(opts[1].PrivKey.Signatory(), 128 | wire.NewUnsignedAddress(wire.TCP, 129 | fmt.Sprintf("%v:%v", "localhost", uint16(3333+1)), uint64(time.Now().UnixNano()))) 130 | tables[1].AddPeer(opts[0].PrivKey.Signatory(), 131 | wire.NewUnsignedAddress(wire.TCP, 132 | fmt.Sprintf("%v:%v", "localhost", uint16(3333)), uint64(time.Now().UnixNano()))) 133 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 134 | defer cancel() 135 | go peers[0].Run(ctx) 136 | go func(ctx context.Context) { 137 | once := false 138 | transports[1].Receive(ctx, func(from id.Signatory, packet wire.Packet) error { 139 | if !once { 140 | once = true 141 | return nil 142 | } 143 | 144 | if err := peers[1].Syncer().DidReceiveMessage(from, packet.Msg); err != nil { 145 | return err 146 | } 147 | if err := peers[1].Gossiper().DidReceiveMessage(from, packet.Msg); err != nil { 148 | return err 149 | } 150 | return nil 151 | }) 152 | transports[1].Run(ctx) 153 | }(ctx) 154 | 155 | helloMsg := "Hello World!" 156 | contentID := id.NewHash([]byte(helloMsg)) 157 | contentResolvers[0].InsertContent(contentID[:], []byte(helloMsg)) 158 | 159 | func() { 160 | syncCtx, syncCancel := context.WithTimeout(context.Background(), 2*time.Second) 161 | defer syncCancel() 162 | msg, err := peers[1].Sync(syncCtx, contentID[:], nil) 163 | Expect(msg).To(BeNil()) 164 | Expect(err).To(Not(BeNil())) 165 | }() 166 | 167 | func() { 168 | syncCtx, syncCancel := context.WithTimeout(context.Background(), 2*time.Second) 169 | defer syncCancel() 170 | msg, err := peers[1].Sync(syncCtx, contentID[:], nil) 171 | for err != nil { 172 | select { 173 | case <-syncCtx.Done(): 174 | panic("Timeout expired before content was synced") 175 | default: 176 | } 177 | msg, err = peers[1].Sync(ctx, contentID[:], nil) 178 | } 179 | Ω(err).To(BeNil()) 180 | Ω(msg).To(Equal([]byte(helloMsg))) 181 | }() 182 | }) 183 | }) 184 | }) 185 | -------------------------------------------------------------------------------- /policy/allow.go: -------------------------------------------------------------------------------- 1 | package policy 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "strings" 8 | "sync" 9 | 10 | "golang.org/x/time/rate" 11 | ) 12 | 13 | // ErrRateLimited is returned when a connection is dropped because it has 14 | // exceeded its rate limit for connection attempts. 15 | var ErrRateLimited = errors.New("rate limited") 16 | 17 | // ErrMaxConnectionsExceeded is returned when a connection is dropped 18 | // because the maximum number of inbound/outbound connections has been 19 | // reached. 20 | var ErrMaxConnectionsExceeded = errors.New("max connections exceeded") 21 | 22 | // Allow is a function that filters connections. If an error is returned, the 23 | // connection is filtered and closed. Otherwise, it is maintained. A clean-up 24 | // function is also returned. This function is called after the connection is 25 | // closed, regardless of whether the closure was caused by filtering or normal 26 | // control-flow. 27 | type Allow func(net.Conn) (error, Cleanup) 28 | 29 | // Cleanup resource allocation, or reverse per-connection state mutations, done 30 | // by an Allow function. 31 | type Cleanup func() 32 | 33 | // All returns an Allow function that only passes a connection if all Allow 34 | // functions in a set pass for that connection. Execution is lazy; when one of 35 | // the Allow functions returns an error, no more Allow functions will be called. 36 | func All(fs ...Allow) Allow { 37 | return func(conn net.Conn) (error, Cleanup) { 38 | cleanup := func() {} 39 | for _, f := range fs { 40 | err, cleanupF := f(conn) 41 | if cleanupF != nil { 42 | cleanupCopy := cleanup 43 | cleanup = func() { 44 | cleanupF() 45 | cleanupCopy() 46 | } 47 | } 48 | if err != nil { 49 | return err, cleanup 50 | } 51 | } 52 | return nil, cleanup 53 | } 54 | } 55 | 56 | // Any returns an Allow function that passes a connection if any Allow functions 57 | // in a set pass for that connection. Execution is not lazy; even when one of 58 | // the Allow functions returns a non-nil error, all other Allow functions will 59 | // be called. 60 | func Any(fs ...Allow) Allow { 61 | return func(conn net.Conn) (error, Cleanup) { 62 | cleanup := func() {} 63 | any := false 64 | errs := make([]string, 0, len(fs)) 65 | for _, f := range fs { 66 | err, cleanupF := f(conn) 67 | if cleanupF != nil { 68 | cleanupCopy := cleanup 69 | cleanup = func() { 70 | cleanupF() 71 | cleanupCopy() 72 | } 73 | } 74 | if err == nil { 75 | any = true 76 | continue 77 | } 78 | errs = append(errs, err.Error()) 79 | } 80 | if any { 81 | return nil, cleanup 82 | } 83 | return fmt.Errorf("%v", strings.Join(errs, ", ")), cleanup 84 | } 85 | } 86 | 87 | // RateLimit returns an Allow function that rejects an IP-address if it attempts 88 | // too many connections too quickly. 89 | func RateLimit(r rate.Limit, b, cap int) Allow { 90 | cap /= 2 91 | front := make(map[string]*rate.Limiter, cap) 92 | back := make(map[string]*rate.Limiter, cap) 93 | 94 | return func(conn net.Conn) (error, Cleanup) { 95 | remoteAddr := "" 96 | if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { 97 | remoteAddr = tcpAddr.IP.String() 98 | } else { 99 | remoteAddr = conn.RemoteAddr().String() 100 | } 101 | 102 | allow := func(limiter *rate.Limiter) (error, func()) { 103 | if limiter.Allow() { 104 | return nil, nil 105 | } 106 | return ErrRateLimited, nil 107 | } 108 | 109 | limiter := front[remoteAddr] 110 | if limiter != nil { 111 | return allow(limiter) 112 | } 113 | 114 | limiter = back[remoteAddr] 115 | if limiter != nil { 116 | return allow(limiter) 117 | } 118 | 119 | if len(front) == cap { 120 | back = front 121 | front = make(map[string]*rate.Limiter, cap) 122 | } 123 | 124 | limiter = rate.NewLimiter(r, b) 125 | front[remoteAddr] = limiter 126 | return allow(limiter) 127 | } 128 | } 129 | 130 | // Max returns an Allow function that rejects connections once a maximum number 131 | // of connections have already been accepted and are being kept-alive. Once an 132 | // accepted connection is closed, it opens up room for another connection to be 133 | // accepted. 134 | func Max(maxConns int) Allow { 135 | connsMu := new(sync.RWMutex) 136 | conns := 0 137 | 138 | return func(conn net.Conn) (error, Cleanup) { 139 | if maxConns < 0 { 140 | return nil, nil 141 | } 142 | 143 | // Pre-emptive read-lock check. 144 | connsMu.RLock() 145 | allow := conns < maxConns 146 | connsMu.RUnlock() 147 | if !allow { 148 | return ErrMaxConnectionsExceeded, nil 149 | } 150 | 151 | // Concurrent-safe write-lock check. 152 | connsMu.Lock() 153 | if conns < maxConns { 154 | conns++ 155 | } else { 156 | allow = false 157 | } 158 | connsMu.Unlock() 159 | if !allow { 160 | return ErrMaxConnectionsExceeded, nil 161 | } 162 | 163 | return nil, func() { 164 | connsMu.Lock() 165 | conns-- 166 | connsMu.Unlock() 167 | } 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /policy/allow_test.go: -------------------------------------------------------------------------------- 1 | package policy_test 2 | -------------------------------------------------------------------------------- /policy/policy.go: -------------------------------------------------------------------------------- 1 | // Package policy defines functions that control which connections are allowed, 2 | // and which ones are denied. For server-side imposed policies, Allow functions 3 | // are used to filter connections. For client-side imposed policies, Timeout 4 | // functions are used to filter connections. 5 | // 6 | // When listening for incoming connections, it is impossible to control the 7 | // conditions under which the remote client attempts to dial a new connection. 8 | // As such, more powerful policy functions are required. The opposite is true 9 | // for the dialer; when dialing a new connection to a remote peer, it can be 10 | // assumed that the application-level logic will avoid dialing under conditions 11 | // that the dialer does not find desirable. Here, the Timeout functions are the 12 | // client-side analogy to the Allow functions. Although the Timeout functions 13 | // are much less powerful, they are sufficient for dialers. 14 | // 15 | // Policy functions (Allow functions and Timeout functions) are built in using a 16 | // functional style, and are designed to be composed together. In this way, 17 | // policy functions are able to define only a small and simple (and easily 18 | // testable) amounts of filtering logic, but still be composed together to 19 | // create more complex filtering logic. 20 | // 21 | // // Create a policy that only allows 100 concurrent connections at any one 22 | // // point. 23 | // maxConns := policy.Max(100) 24 | // // Create a policy that only allows 1 connection attempt per second per IP 25 | // // address. 26 | // rateLimit := policy.RateLimit(1.0, 1, 65535) 27 | // // Compose these policies together to require that all of them pass. 28 | // all := policy.All(maxConns, rateLimit) 29 | // // Or, compose these policies together to require that any of them pass. 30 | // any := policy.Any(maxConns, rateLimit) 31 | // 32 | // Timeout functions are similarly composable. 33 | // 34 | // // Create a policy to Timeout after 1 second. 35 | // one := policy.ConstantTimeout(time.Second) 36 | // // Create a policy to scale this constant timeout by 60% with every attempt. 37 | // backoff := policy.LinearBackoff(1.6, one) 38 | // // Create a policy to clamp the Timeout to an upper bound of one minute, no 39 | // // matter how many attempts there have been. 40 | // max := policy.MaxTimeout(time.Minute, backoff) 41 | // 42 | // The policy functions available by default (of course, the programmer is free 43 | // to implement their own policy functions) are quite simple. But, by providing 44 | // "higher order policies" such as All, Any, or LinearBackoff, policies can be 45 | // composed in more interesting ways. In practice, most of the policies that 46 | // should be implemented at the raw connection level can be created by composing 47 | // the available policies. 48 | package policy 49 | -------------------------------------------------------------------------------- /policy/policy_suite_test.go: -------------------------------------------------------------------------------- 1 | package policy_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestPolicy(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "Policy suite") 13 | } 14 | -------------------------------------------------------------------------------- /policy/policy_test.go: -------------------------------------------------------------------------------- 1 | package policy_test 2 | -------------------------------------------------------------------------------- /policy/timeout.go: -------------------------------------------------------------------------------- 1 | package policy 2 | 3 | import ( 4 | "math" 5 | "time" 6 | ) 7 | 8 | // Timeout functions accept an attempt (from 0 to the maximum integer) and 9 | // return an expected duration for which the attempt should run. 10 | type Timeout func(int) time.Duration 11 | 12 | // ConstantTimeout returns a Timeout function that always returns a constant 13 | // duration. 14 | func ConstantTimeout(duration time.Duration) Timeout { 15 | return func(int) time.Duration { 16 | return duration 17 | } 18 | } 19 | 20 | // MaxTimeout returns a Timeout function that restricts another Timeout function 21 | // to return a maximum duration. 22 | func MaxTimeout(duration time.Duration, timeout Timeout) Timeout { 23 | return func(attempt int) time.Duration { 24 | customTimeout := timeout(attempt) 25 | if customTimeout > duration { 26 | return duration 27 | } 28 | return customTimeout 29 | } 30 | } 31 | 32 | // LinearBackoff returns a Timeout function that scales the duration returned by 33 | // another Timeout function linearly with respect to the attempt. 34 | func LinearBackoff(rate float64, timeout Timeout) Timeout { 35 | return func(attempt int) time.Duration { 36 | return time.Duration(rate*float64(attempt)) * timeout(attempt) 37 | } 38 | } 39 | 40 | // ExponentialBackoff returns a Timeout function that scales the duration 41 | // returned by another Timeout function exponentially with respect to the 42 | // attempt. 43 | func ExponentialBackoff(rate float64, timeout Timeout) Timeout { 44 | return func(attempt int) time.Duration { 45 | return time.Duration(math.Pow(rate, float64(attempt))) * timeout(attempt) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /policy/timeout_test.go: -------------------------------------------------------------------------------- 1 | package policy_test 2 | -------------------------------------------------------------------------------- /tcp/tcp.go: -------------------------------------------------------------------------------- 1 | package tcp 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "time" 8 | 9 | "github.com/renproject/aw/policy" 10 | ) 11 | 12 | // Listen for connections from remote peers until the context is done. The 13 | // allow function will be used to control the acceptance/rejection of connection 14 | // attempts, and can be used to implement maximum connection limits, per-IP 15 | // rate-limiting, and so on. This function spawns all accepted connections into 16 | // their own background goroutines that run the handle function, and then 17 | // clean-up the connection. This function blocks until the context is done. 18 | func Listen(ctx context.Context, address string, handle func(net.Conn), handleErr func(error), allow policy.Allow) error { 19 | // Create a TCP listener from given address and return an error if unable to do so 20 | listener, err := new(net.ListenConfig).Listen(ctx, "tcp", address) 21 | if err != nil { 22 | return err 23 | } 24 | 25 | // The 'ctx' we passed to Listen() will not unblock `Listener.Accept()` if 26 | // context exceeding the deadline. We need to manually close the listener 27 | // to stop `Listener.Accept()` from blocking. 28 | // See https://github.com/golang/go/issues/28120 29 | go func() { 30 | <- ctx.Done() 31 | listener.Close() 32 | }() 33 | return ListenWithListener(ctx, listener, handle, handleErr, allow) 34 | } 35 | 36 | // ListenWithListener is the same as Listen but instead of specifying an 37 | // address, it accepts an already constructed listener. 38 | // 39 | // NOTE: The listener passed to this function will be closed when the given 40 | // context finishes. 41 | func ListenWithListener(ctx context.Context, listener net.Listener, handle func(net.Conn), handleErr func(error), allow policy.Allow) error { 42 | if handle == nil { 43 | return fmt.Errorf("nil handle function") 44 | } 45 | 46 | if handleErr == nil { 47 | handleErr = func(err error) {} 48 | } 49 | 50 | defer listener.Close() 51 | 52 | for { 53 | select { 54 | case <-ctx.Done(): 55 | return ctx.Err() 56 | default: 57 | } 58 | 59 | conn, err := listener.Accept() 60 | if err != nil { 61 | handleErr(fmt.Errorf("accept connection: %w", err)) 62 | continue 63 | } 64 | 65 | if allow == nil { 66 | go func() { 67 | defer conn.Close() 68 | 69 | handle(conn) 70 | }() 71 | continue 72 | } 73 | 74 | if err, cleanup := allow(conn); err == nil { 75 | go func() { 76 | defer conn.Close() 77 | 78 | defer func() { 79 | if cleanup != nil { 80 | cleanup() 81 | } 82 | }() 83 | handle(conn) 84 | }() 85 | continue 86 | } 87 | conn.Close() 88 | } 89 | } 90 | 91 | // ListenerWithAssignedPort creates a new listener on a random port assigned by 92 | // the OS. On success, both the listener and port are returned. 93 | func ListenerWithAssignedPort(ctx context.Context, ip string) (net.Listener, int, error) { 94 | listener, err := new(net.ListenConfig).Listen(ctx, "tcp", fmt.Sprintf("%v:%v", ip, 0)) 95 | if err != nil { 96 | return nil, 0, err 97 | } 98 | port := listener.Addr().(*net.TCPAddr).Port 99 | return listener, port, nil 100 | } 101 | 102 | // Dial a remote peer until a connection is successfully established, or until 103 | // the context is done. Multiple dial attempts can be made, and the timeout 104 | // function is used to define an upper bound on dial attempts. This function 105 | // blocks until the connection is handled (and the handle function returns). 106 | // This function will clean-up the connection. 107 | func Dial(ctx context.Context, address string, handle func(net.Conn), handleErr func(error), timeout func(int) time.Duration) error { 108 | dialer := new(net.Dialer) 109 | 110 | if handle == nil { 111 | return fmt.Errorf("nil handle function") 112 | } 113 | 114 | if handleErr == nil { 115 | handleErr = func(error) {} 116 | } 117 | 118 | if timeout == nil { 119 | timeout = func(int) time.Duration { return time.Second } 120 | } 121 | 122 | for attempt := 1; ; attempt++ { 123 | select { 124 | case <-ctx.Done(): 125 | return fmt.Errorf("dialing %w", ctx.Err()) 126 | default: 127 | } 128 | 129 | dialCtx, dialCancel := context.WithTimeout(ctx, timeout(attempt)) 130 | conn, err := dialer.DialContext(dialCtx, "tcp", address) 131 | if err != nil { 132 | handleErr(err) 133 | <-dialCtx.Done() 134 | dialCancel() 135 | continue 136 | } 137 | dialCancel() 138 | 139 | return func() (err error) { 140 | defer func() { 141 | err = conn.Close() 142 | }() 143 | 144 | handle(conn) 145 | return 146 | }() 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /tcp/tcp_suite_test.go: -------------------------------------------------------------------------------- 1 | package tcp_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestTCP(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "TCP Suite") 13 | } 14 | -------------------------------------------------------------------------------- /tcp/tcp_test.go: -------------------------------------------------------------------------------- 1 | package tcp_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "net" 9 | "time" 10 | 11 | "github.com/renproject/aw/policy" 12 | "github.com/renproject/aw/tcp" 13 | 14 | . "github.com/onsi/ginkgo" 15 | . "github.com/onsi/gomega" 16 | ) 17 | 18 | var _ = Describe("TCP", func() { 19 | 20 | // run test scenarios with some configurations that enable enough variation 21 | // to test everything we are interested in testing. 22 | run := func(ctx context.Context, dialDelay, listenerDelay time.Duration, enableListener, rejectInboundConns bool) { 23 | // Channel for communicating to the main goroutine which messages have 24 | // been received by the listener and the dialer. 25 | messageReceived := make(chan [100]byte, 2) 26 | message := [100]byte{} 27 | 28 | // Generate a random message. 29 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 30 | r.Read(message[:]) 31 | 32 | port := 12345 33 | 34 | if enableListener { 35 | allow := policy.Max(0) 36 | if !rejectInboundConns { 37 | // We want to allow the dialing attempt. To catch different 38 | // cases, sometimes we use an explicit policy, and other times 39 | // we use a nil policy. 40 | if r.Int()%2 == 0 { 41 | allow = policy.Max(1) 42 | } else { 43 | allow = nil 44 | } 45 | } 46 | var listener net.Listener 47 | var err error 48 | ip := "127.0.0.1" 49 | listener, port, err = tcp.ListenerWithAssignedPort(ctx, ip) 50 | Expect(err).ToNot(HaveOccurred()) 51 | go func() { 52 | defer GinkgoRecover() 53 | 54 | time.Sleep(listenerDelay) 55 | 56 | err := tcp.ListenWithListener( 57 | ctx, 58 | listener, 59 | func(conn net.Conn) { 60 | defer GinkgoRecover() 61 | 62 | // Assume that the dialer will send a message first, and 63 | // try to receive the message. 64 | received := [100]byte{} 65 | n, err := io.ReadFull(conn, received[:]) 66 | 67 | // Check that the expected message was received, 68 | Expect(n).To(Equal(100)) 69 | Expect(err).ToNot(HaveOccurred()) 70 | Expect(received[:]).To(Equal(message[:])) 71 | messageReceived <- received 72 | 73 | // and then send it back. 74 | n, err = conn.Write(message[:]) 75 | Expect(n).To(Equal(100)) 76 | Expect(err).ToNot(HaveOccurred()) 77 | }, 78 | nil, 79 | allow, 80 | ) 81 | Expect(err).To(Equal(context.Canceled)) 82 | }() 83 | } 84 | 85 | // Delaying the dialer is useful to allow the listener to boot. This 86 | // delay might need to be longer than expected, depending on the CI 87 | // environment. 88 | time.Sleep(dialDelay) 89 | 90 | err := tcp.Dial( 91 | ctx, 92 | fmt.Sprintf("127.0.0.1:%v", port), 93 | func(conn net.Conn) { 94 | defer GinkgoRecover() 95 | 96 | // Send a message and, if the listener is enabled, then expect 97 | // it to succeed. Otherwise, expect it to fail. 98 | n, err := conn.Write(message[:]) 99 | Expect(n).To(Equal(100)) 100 | Expect(err).ToNot(HaveOccurred()) 101 | 102 | // Same thing with receiving messages. 103 | received := [100]byte{} 104 | n, err = io.ReadFull(conn, received[:]) 105 | if !rejectInboundConns { 106 | Expect(n).To(Equal(100)) 107 | Expect(err).ToNot(HaveOccurred()) 108 | Expect(received[:]).To(Equal(message[:])) 109 | messageReceived <- received 110 | } else { 111 | Expect(err).To(HaveOccurred()) 112 | } 113 | }, 114 | nil, 115 | policy.ConstantTimeout(time.Second)) 116 | 117 | // If the listener is enabled, and there is no policy for rejecting 118 | // inbound connection attempts, then we expect messages to have been 119 | // sent and received. 120 | if enableListener && !rejectInboundConns { 121 | Expect(err).ToNot(HaveOccurred()) 122 | Expect(<-messageReceived).To(Equal(message)) 123 | Expect(<-messageReceived).To(Equal(message)) 124 | } else if rejectInboundConns { 125 | Expect(err).ToNot(HaveOccurred()) 126 | } else { 127 | Expect(err).To(HaveOccurred()) 128 | } 129 | } 130 | 131 | Context("when dialing a listener that is accepting inbound connections", func() { 132 | It("should send and receive messages", func() { 133 | for iter := 0; iter < 3; iter++ { 134 | func() { 135 | // Allow some time for the previous iteration to shutdown. 136 | defer time.Sleep(10 * time.Millisecond) 137 | 138 | ctx, cancel := context.WithCancel(context.Background()) 139 | defer cancel() 140 | run( 141 | ctx, 142 | time.Millisecond, // Delay before dialing 143 | time.Duration(0), // Delay before listening 144 | true, // Enable listening 145 | false, // No listening policy 146 | ) 147 | }() 148 | } 149 | }) 150 | }) 151 | 152 | Context("when dialing a listener that is not accepting inbound connections", func() { 153 | It("should return an error", func() { 154 | for iter := 0; iter < 3; iter++ { 155 | func() { 156 | // Allow some time for the previous iteration to shutdown. 157 | defer time.Sleep(10 * time.Millisecond) 158 | 159 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 160 | defer cancel() 161 | run( 162 | ctx, 163 | time.Millisecond, // Delay before dialing 164 | time.Duration(0), // Delay before listening 165 | false, // Disable listening 166 | false, // No listening policy 167 | ) 168 | }() 169 | } 170 | }) 171 | 172 | Context("when the listener begins accepting inbound connections", func() { 173 | It("should begin sending and receiving messages", func() { 174 | for iter := 0; iter < 3; iter++ { 175 | func() { 176 | // Allow some time for the previous iteration to shutdown. 177 | defer time.Sleep(10 * time.Millisecond) 178 | 179 | ctx, cancel := context.WithCancel(context.Background()) 180 | defer cancel() 181 | run( 182 | ctx, 183 | time.Millisecond, // Delay before dialing 184 | time.Second, // Delay before listening 185 | true, // Enable listening 186 | false, // No listening policy 187 | ) 188 | }() 189 | } 190 | }) 191 | }) 192 | }) 193 | 194 | Context("when dialing a listener that is rejecting inbound connections", func() { 195 | It("should return an error", func() { 196 | for iter := 0; iter < 3; iter++ { 197 | func() { 198 | // Allow some time for the previous iteration to shutdown. 199 | defer time.Sleep(10 * time.Millisecond) 200 | 201 | ctx, cancel := context.WithCancel(context.Background()) 202 | defer cancel() 203 | run( 204 | ctx, 205 | time.Millisecond, // Delay before dialing 206 | time.Second, // Delay before listening 207 | true, // Enable listening 208 | true, // No listening policy 209 | ) 210 | }() 211 | } 212 | }) 213 | }) 214 | }) 215 | -------------------------------------------------------------------------------- /transport/transport.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "sync" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/renproject/aw/dht" 14 | 15 | "github.com/renproject/aw/channel" 16 | "github.com/renproject/aw/codec" 17 | "github.com/renproject/aw/handshake" 18 | "github.com/renproject/aw/policy" 19 | "github.com/renproject/aw/tcp" 20 | "github.com/renproject/aw/wire" 21 | "github.com/renproject/id" 22 | 23 | "go.uber.org/zap" 24 | ) 25 | 26 | // Default options. 27 | var ( 28 | DefaultHost = "localhost" 29 | DefaultPort = uint16(3333) 30 | DefaultEncoder = codec.PlainEncoder 31 | DefaultDecoder = codec.PlainDecoder 32 | DefaultDialTimeout = policy.ConstantTimeout(time.Second) 33 | DefaultClientTimeout = 10 * time.Second 34 | DefaultServerTimeout = 10 * time.Second 35 | DefaultExpiryTimeout = time.Minute 36 | ) 37 | 38 | // Options used to parameterise the behaviour of a Transport. 39 | type Options struct { 40 | Logger *zap.Logger 41 | Host string 42 | Port uint16 43 | Encoder codec.Encoder 44 | Decoder codec.Decoder 45 | DialTimeout policy.Timeout 46 | ClientTimeout time.Duration 47 | ServerTimeout time.Duration 48 | OncePoolOptions handshake.OncePoolOptions 49 | ExpiryDuration time.Duration 50 | } 51 | 52 | // DefaultOptions returns Options with sensible defaults. 53 | func DefaultOptions() Options { 54 | logger, err := zap.NewDevelopment() 55 | if err != nil { 56 | panic(err) 57 | } 58 | return Options{ 59 | Logger: logger, 60 | Host: DefaultHost, 61 | Port: DefaultPort, 62 | Encoder: DefaultEncoder, 63 | Decoder: DefaultDecoder, 64 | DialTimeout: DefaultDialTimeout, 65 | ClientTimeout: DefaultClientTimeout, 66 | ServerTimeout: DefaultServerTimeout, 67 | OncePoolOptions: handshake.DefaultOncePoolOptions(), 68 | ExpiryDuration: DefaultExpiryTimeout, 69 | } 70 | } 71 | 72 | func (opts Options) WithLogger(logger *zap.Logger) Options { 73 | opts.Logger = logger 74 | return opts 75 | } 76 | 77 | func (opts Options) WithHost(host string) Options { 78 | opts.Host = host 79 | return opts 80 | } 81 | 82 | func (opts Options) WithPort(port uint16) Options { 83 | opts.Port = port 84 | return opts 85 | } 86 | 87 | func (opts Options) WithClientTimeout(timeout time.Duration) Options { 88 | opts.ClientTimeout = timeout 89 | return opts 90 | } 91 | 92 | func (opts Options) WithServerTimeout(timeout time.Duration) Options { 93 | opts.ServerTimeout = timeout 94 | return opts 95 | } 96 | 97 | func (opts Options) WithOncePoolOptions(oncePoolOpts handshake.OncePoolOptions) Options { 98 | opts.OncePoolOptions = oncePoolOpts 99 | return opts 100 | } 101 | 102 | func (opts Options) WithExpiry(minimumDuration time.Duration) Options { 103 | opts.ExpiryDuration = minimumDuration 104 | return opts 105 | } 106 | 107 | type Transport struct { 108 | opts Options 109 | 110 | self id.Signatory 111 | client *channel.Client 112 | once handshake.Handshake 113 | 114 | linksMu *sync.RWMutex 115 | links map[id.Signatory]bool 116 | 117 | connsMu *sync.RWMutex 118 | conns map[id.Signatory]int64 119 | 120 | table dht.Table 121 | } 122 | 123 | func New(opts Options, self id.Signatory, client *channel.Client, h handshake.Handshake, table dht.Table) *Transport { 124 | oncePool := handshake.NewOncePool(opts.OncePoolOptions) 125 | return &Transport{ 126 | opts: opts, 127 | 128 | self: self, 129 | client: client, 130 | once: handshake.Once(self, &oncePool, h), 131 | 132 | linksMu: new(sync.RWMutex), 133 | links: map[id.Signatory]bool{}, 134 | 135 | connsMu: new(sync.RWMutex), 136 | conns: map[id.Signatory]int64{}, 137 | 138 | table: table, 139 | } 140 | } 141 | 142 | func (t *Transport) Table() dht.Table { 143 | return t.table 144 | } 145 | 146 | func (t *Transport) Self() id.Signatory { 147 | return t.self 148 | } 149 | 150 | func (t *Transport) Host() string { 151 | return t.opts.Host 152 | } 153 | 154 | func (t *Transport) Port() uint16 { 155 | return t.opts.Port 156 | } 157 | 158 | func (t *Transport) Send(ctx context.Context, remote id.Signatory, msg wire.Msg) error { 159 | remoteAddr, ok := t.table.PeerAddress(remote) 160 | if !ok { 161 | return fmt.Errorf("peer not found: %v", remote) 162 | } 163 | 164 | if t.IsConnected(remote) { 165 | t.opts.Logger.Debug("send", zap.Bool("connected", true), zap.String("remote", remote.String()), zap.String("addr", remoteAddr.String())) 166 | return t.client.Send(ctx, remote, msg) 167 | } 168 | 169 | if t.IsLinked(remote) { 170 | t.opts.Logger.Debug("send", zap.Bool("linked", true), zap.String("remote", remote.String()), zap.String("addr", remoteAddr.String())) 171 | go t.dial(ctx, remote, remoteAddr) 172 | return t.client.Send(ctx, remote, msg) 173 | } 174 | 175 | t.opts.Logger.Debug("send", zap.Bool("linked", false), zap.Bool("connected", false), zap.String("remote", remote.String()), zap.String("addr", remoteAddr.String())) 176 | t.client.Bind(remote) 177 | go func() { 178 | defer t.client.Unbind(remote) 179 | t.dial(ctx, remote, remoteAddr) 180 | }() 181 | return t.client.Send(ctx, remote, msg) 182 | } 183 | 184 | func (t *Transport) Receive(ctx context.Context, receiver func(id.Signatory, wire.Packet) error) { 185 | t.client.Receive(ctx, receiver) 186 | } 187 | 188 | func (t *Transport) Link(remote id.Signatory) { 189 | t.linksMu.Lock() 190 | defer t.linksMu.Unlock() 191 | 192 | if t.links[remote] { 193 | return 194 | } 195 | t.client.Bind(remote) 196 | t.links[remote] = true 197 | } 198 | 199 | func (t *Transport) Unlink(remote id.Signatory) { 200 | t.linksMu.Lock() 201 | defer t.linksMu.Unlock() 202 | 203 | if t.links[remote] { 204 | t.client.Unbind(remote) 205 | delete(t.links, remote) 206 | } 207 | } 208 | 209 | func (t *Transport) IsLinked(remote id.Signatory) bool { 210 | t.linksMu.Lock() 211 | defer t.linksMu.Unlock() 212 | 213 | return t.links[remote] 214 | } 215 | 216 | func (t *Transport) IsConnected(remote id.Signatory) bool { 217 | t.connsMu.RLock() 218 | defer t.connsMu.RUnlock() 219 | 220 | return t.conns[remote] > 0 221 | } 222 | 223 | func (t *Transport) Run(ctx context.Context) { 224 | for { 225 | select { 226 | case <-ctx.Done(): 227 | return 228 | default: 229 | t.run(ctx) 230 | } 231 | } 232 | } 233 | 234 | func (t *Transport) run(ctx context.Context) { 235 | defer func() { 236 | if r := recover(); r != nil { 237 | t.opts.Logger.DPanic("recover", zap.Error(fmt.Errorf("%v", r))) 238 | } 239 | }() 240 | 241 | // Listen for incoming connection attempts. 242 | t.opts.Logger.Info("listening", zap.String("host", t.opts.Host), zap.Uint16("port", t.opts.Port)) 243 | err := tcp.Listen( 244 | ctx, 245 | fmt.Sprintf("%v:%v", t.opts.Host, t.opts.Port), 246 | func(conn net.Conn) { 247 | addr := conn.RemoteAddr().String() 248 | enc, dec, remote, err := t.once(conn, t.opts.Encoder, t.opts.Decoder) 249 | if err != nil { 250 | var e wire.NegligibleError 251 | if !errors.As(err, &e) { 252 | t.opts.Logger.Error("handshake", zap.String("addr", addr), zap.Error(err)) 253 | } 254 | return 255 | } 256 | 257 | enc = codec.LengthPrefixEncoder(codec.PlainEncoder, enc) 258 | dec = codec.LengthPrefixDecoder(codec.PlainDecoder, dec) 259 | 260 | // If the Transport is linked to the remote peer, then the 261 | // network connection should be kept alive until the remote peer 262 | // is unlinked (or the network connection faults). 263 | if t.IsLinked(remote) { 264 | t.opts.Logger.Debug("accepted", zap.Bool("linked", true), zap.String("remote", remote.String()), zap.String("addr", addr)) 265 | defer t.opts.Logger.Debug("accepted: drop", zap.Bool("linked", true), zap.String("remote", remote.String()), zap.String("addr", addr)) 266 | 267 | // Attaching a connection will block until the Channel is 268 | // unbound (which happens when the Transport is unlinked), the 269 | // connection is replaced, or the connection faults. 270 | t.connect(remote) 271 | defer t.disconnect(remote) 272 | if err := t.client.Attach(ctx, remote, conn, enc, dec); err != nil { 273 | // If ctx is canceled, this usually means the entire transport has been shutdown 274 | // and we can safely ignore all errors with client.Attach. 275 | if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { 276 | t.opts.Logger.Error("incoming attachment", zap.String("remote", remote.String()), zap.String("addr", addr), zap.Error(err)) 277 | } 278 | } 279 | return 280 | } 281 | 282 | // Otherwise, this connection should be short-lived. A Channel still 283 | // needs to be created (because one probably does not exist), but a 284 | // bounded time should be used. 285 | ctx, cancel := context.WithTimeout(ctx, t.opts.ServerTimeout) 286 | defer cancel() 287 | 288 | t.opts.Logger.Debug("accepted", zap.Bool("linked", false), zap.Duration("timeout", t.opts.ServerTimeout), zap.String("remote", remote.String()), zap.String("addr", addr)) 289 | defer t.opts.Logger.Debug("accepted: drop", zap.Bool("linked", false), zap.Duration("timeout", t.opts.ServerTimeout), zap.String("remote", remote.String()), zap.String("addr", addr)) 290 | 291 | t.client.Bind(remote) 292 | defer t.client.Unbind(remote) 293 | 294 | t.connect(remote) 295 | defer t.disconnect(remote) 296 | if err := t.client.Attach(ctx, remote, conn, enc, dec); err != nil { 297 | if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { 298 | t.opts.Logger.Error("incoming attachment", zap.String("remote", remote.String()), zap.String("addr", addr), zap.Error(err)) 299 | } 300 | } 301 | }, 302 | func(err error) { 303 | if !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) && !errors.Is(err, syscall.ECONNRESET) { 304 | t.opts.Logger.Error("listen", zap.Error(err)) 305 | } 306 | }, 307 | nil) 308 | if err != nil { 309 | if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { 310 | t.opts.Logger.Error("listen", zap.Error(err)) 311 | } 312 | } 313 | } 314 | 315 | func (t *Transport) dial(retryCtx context.Context, remote id.Signatory, remoteAddr wire.Address) { 316 | // It is tempting to skip dialing if there is already a connection. However, 317 | // it is desirable to be able to re-dial in the case that the network 318 | // address has changed. As such, we do not do any skip checks, and assume 319 | // that dial is only called when the caller is absolutely sure that a dial 320 | // should happen. 321 | 322 | if remoteAddr.Protocol != wire.TCP { 323 | t.opts.Logger.Debug("skipping non-tcp address", zap.String("addr", remoteAddr.String())) 324 | return 325 | } 326 | 327 | exit := make(chan struct{}) 328 | for { 329 | dialCtx, cancel := context.WithTimeout(context.Background(), t.opts.ClientTimeout) 330 | 331 | t.opts.Logger.Debug("dialing", zap.String("remote", remote.String()), zap.String("addr", remoteAddr.String())) 332 | 333 | err := tcp.Dial( 334 | dialCtx, 335 | remoteAddr.Value, 336 | func(conn net.Conn) { 337 | addr := conn.RemoteAddr().String() 338 | enc, dec, r, err := t.once(conn, t.opts.Encoder, t.opts.Decoder) 339 | if err != nil { 340 | var e wire.NegligibleError 341 | if !errors.As(err, &e) { 342 | t.opts.Logger.Error("handshake", zap.String("remote", remote.String()), zap.String("addr", addr), zap.Error(err)) 343 | } 344 | return 345 | } 346 | if !r.Equal(&remote) { 347 | t.opts.Logger.Error("handshake", zap.String("expected", remote.String()), zap.String("got", r.String()), zap.Error(fmt.Errorf("bad remote"))) 348 | return 349 | } 350 | 351 | enc = codec.LengthPrefixEncoder(codec.PlainEncoder, enc) 352 | dec = codec.LengthPrefixDecoder(codec.PlainDecoder, dec) 353 | 354 | t.connect(remote) 355 | defer t.disconnect(remote) 356 | 357 | if t.IsLinked(remote) { 358 | t.opts.Logger.Debug("dialed", zap.Bool("linked", true), zap.String("remote", remote.String()), zap.String("addr", addr)) 359 | defer t.opts.Logger.Debug("dialed: drop", zap.Bool("linked", true), zap.String("remote", remote.String()), zap.String("addr", addr)) 360 | 361 | // If the Transport is linked to the remote peer, then the 362 | // network connection should be kept alive until the remote peer 363 | // is unlinked (or the network connection faults). To do this, 364 | // we override the context and re-use it. Otherwise, the 365 | // previously defined context will be used, which will 366 | // eventually timeout. 367 | dialCtx = context.Background() 368 | } else { 369 | t.opts.Logger.Debug("dialed", zap.Bool("linked", false), zap.Duration("timeout", t.opts.ClientTimeout), zap.String("remote", remote.String()), zap.String("addr", addr)) 370 | defer t.opts.Logger.Debug("dialed: drop", zap.Bool("linked", false), zap.Duration("timeout", t.opts.ClientTimeout), zap.String("remote", remote.String()), zap.String("addr", addr)) 371 | } 372 | 373 | if err := t.client.Attach(dialCtx, remote, conn, enc, dec); err != nil { 374 | // Context deadline exceeds means we decide to drop the 375 | // connection and the error could be ignored. 376 | if !errors.Is(err, context.DeadlineExceeded) { 377 | t.opts.Logger.Error("outgoing", zap.String("remote", remote.String()), zap.String("addr", addr), zap.Error(err)) 378 | } 379 | } 380 | }, 381 | func(err error) { 382 | t.opts.Logger.Debug("dial", zap.String("remote", remote.String()), zap.String("addr", remoteAddr.String()), zap.Error(err)) 383 | t.table.AddExpiry(remote, t.opts.ExpiryDuration) 384 | if t.table.HandleExpired(remote) { 385 | close(exit) 386 | cancel() 387 | } 388 | }, 389 | t.opts.DialTimeout) 390 | if err != nil { 391 | t.opts.Logger.Debug("dial", zap.String("remote", remote.String()), zap.String("addr", remoteAddr.String()), zap.Error(err)) 392 | select { 393 | case <-exit: 394 | break 395 | case <-retryCtx.Done(): 396 | case <-dialCtx.Done(): 397 | if !t.IsConnected(remote) { 398 | // Cancel current dial context if restarting loop 399 | cancel() 400 | continue 401 | } 402 | } 403 | } else { 404 | t.table.DeleteExpiry(remote) 405 | } 406 | 407 | // Cancel last dial context before exiting 408 | cancel() 409 | return 410 | } 411 | } 412 | 413 | func (t *Transport) connect(remote id.Signatory) { 414 | t.connsMu.Lock() 415 | defer t.connsMu.Unlock() 416 | 417 | t.conns[remote]++ 418 | } 419 | 420 | func (t *Transport) disconnect(remote id.Signatory) { 421 | t.connsMu.Lock() 422 | defer t.connsMu.Unlock() 423 | 424 | if t.conns[remote] > 0 { 425 | if t.conns[remote]--; t.conns[remote] == 0 { 426 | delete(t.conns, remote) 427 | } 428 | } 429 | } 430 | -------------------------------------------------------------------------------- /transport/transport_suite_test.go: -------------------------------------------------------------------------------- 1 | package transport_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestTransport(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "Transport suite") 13 | } 14 | -------------------------------------------------------------------------------- /transport/transport_test.go: -------------------------------------------------------------------------------- 1 | package transport_test 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/renproject/aw/channel" 8 | "github.com/renproject/aw/dht" 9 | "github.com/renproject/aw/handshake" 10 | "github.com/renproject/aw/transport" 11 | "github.com/renproject/aw/wire" 12 | "github.com/renproject/id" 13 | 14 | . "github.com/onsi/ginkgo" 15 | . "github.com/onsi/gomega" 16 | ) 17 | 18 | var _ = Describe("Transport", func() { 19 | Describe("Dial", func() { 20 | Context("when failing to connect to peer", func() { 21 | It("should create an expiry and delete peer after expiration", func() { 22 | privKey := id.NewPrivKey() 23 | self := privKey.Signatory() 24 | h := handshake.Filter(func(id.Signatory) error { return nil }, handshake.ECIES(privKey)) 25 | client := channel.NewClient( 26 | channel.DefaultOptions(), 27 | self) 28 | table := dht.NewInMemTable(self) 29 | transport := transport.New( 30 | transport.DefaultOptions(). 31 | WithClientTimeout(10*time.Second). 32 | WithOncePoolOptions(handshake.DefaultOncePoolOptions().WithMinimumExpiryAge(10*time.Second)). 33 | WithExpiry(5*time.Second). 34 | WithPort(uint16(3333)), 35 | self, 36 | client, 37 | h, 38 | table, 39 | ) 40 | 41 | privKey2 := id.NewPrivKey() 42 | table.AddPeer(privKey2.Signatory(), 43 | wire.NewUnsignedAddress(wire.TCP, ":3334", uint64(time.Now().UnixNano()))) 44 | _, ok := table.PeerAddress(privKey2.Signatory()) 45 | Expect(ok).To(BeTrue()) 46 | 47 | ctx, cancel := context.WithCancel(context.Background()) 48 | go func() { 49 | transport.Send(ctx, privKey2.Signatory(), wire.Msg{}) 50 | }() 51 | 52 | time.Sleep(7 * time.Second) 53 | cancel() 54 | 55 | _, ok = table.PeerAddress(privKey2.Signatory()) 56 | Expect(ok).To(BeFalse()) 57 | }) 58 | }) 59 | }) 60 | }) 61 | -------------------------------------------------------------------------------- /wire/addr.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "encoding/json" 8 | "fmt" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/ethereum/go-ethereum/crypto" 13 | "github.com/renproject/id" 14 | "github.com/renproject/surge" 15 | ) 16 | 17 | // Protocol defines the network protocol used by an address for 18 | // sending/receiving data over-the-wire. 19 | type Protocol uint8 20 | 21 | func (p Protocol) String() string { 22 | switch p { 23 | case TCP: 24 | return "tcp" 25 | case UDP: 26 | return "udp" 27 | case WebSocket: 28 | return "ws" 29 | default: 30 | return "unknown" 31 | } 32 | } 33 | 34 | func (p Protocol) MarshalJSON() ([]byte, error) { 35 | return json.Marshal(uint8(p)) 36 | } 37 | 38 | func (p *Protocol) UnmarshalJSON(data []byte) error { 39 | return json.Unmarshal(data, (*uint8)(p)) 40 | } 41 | 42 | // Protocol values for the different network address protocols that are 43 | // supported. 44 | const ( 45 | UndefinedProtocol = Protocol(0) 46 | TCP = Protocol(1) 47 | UDP = Protocol(2) 48 | WebSocket = Protocol(3) 49 | ) 50 | 51 | // NewAddressHash returns the Hash of an Address for signing by the peer. An 52 | // error is returned when the arguments too large and cannot be marshaled into 53 | // bytes without exceeding memory allocation restrictions. 54 | func NewAddressHash(protocol Protocol, value string, nonce uint64) (id.Hash, error) { 55 | buf := make([]byte, surge.SizeHintU8+surge.SizeHintString(value)+surge.SizeHintU64) 56 | return NewAddressHashWithBuffer(protocol, value, nonce, buf) 57 | } 58 | 59 | // NewAddressHashWithBuffer writes the Hash of an Address into a bytes buffer 60 | // for signing by the peer. An error is returned when the arguments are too 61 | // large and cannot be marshaled into bytes without exceeding memory allocation 62 | // restrictions. This function is useful when doing a lot of hashing, because it 63 | // allows for buffer re-use. 64 | func NewAddressHashWithBuffer(protocol Protocol, value string, nonce uint64, data []byte) (id.Hash, error) { 65 | var err error 66 | buf := data 67 | rem := surge.MaxBytes 68 | if buf, rem, err = surge.MarshalU8(uint8(protocol), buf, rem); err != nil { 69 | return id.Hash{}, err 70 | } 71 | if buf, rem, err = surge.MarshalString(value, buf, rem); err != nil { 72 | return id.Hash{}, err 73 | } 74 | if buf, rem, err = surge.MarshalU64(nonce, buf, rem); err != nil { 75 | return id.Hash{}, err 76 | } 77 | return sha256.Sum256(data[:len(data)-len(buf)]), nil 78 | } 79 | 80 | // An Address is a verifiable and expirable network address associated with a 81 | // specific peer. The peer can be verified by checking the Signatory of the peer 82 | // against the Signature in the Address. The Address can be expired by issuing a 83 | // new Address for the same peer, using a later nonce. By convention, nonces are 84 | // interpreted as seconds since UNIX epoch. 85 | type Address struct { 86 | Protocol Protocol `json:"protocol"` 87 | Value string `json:"value"` 88 | Nonce uint64 `json:"nonce"` 89 | Signature id.Signature `json:"signature"` 90 | } 91 | 92 | // NewUnsignedAddress returns an Address that has an empty signature. The Sign 93 | // method should be called before the returned Address is used. 94 | func NewUnsignedAddress(protocol Protocol, value string, nonce uint64) Address { 95 | return Address{ 96 | Protocol: protocol, 97 | Value: value, 98 | Nonce: nonce, 99 | } 100 | } 101 | 102 | // SizeHint returns the number of bytes needed to represent this Address in 103 | // binary. 104 | func (addr Address) SizeHint() int { 105 | return surge.SizeHintU8 + 106 | surge.SizeHintString(addr.Value) + 107 | surge.SizeHintU64 + 108 | addr.Signature.SizeHint() 109 | } 110 | 111 | // Marshal this Address into binary. 112 | func (addr Address) Marshal(buf []byte, rem int) ([]byte, int, error) { 113 | var err error 114 | if buf, rem, err = surge.MarshalU8(uint8(addr.Protocol), buf, rem); err != nil { 115 | return buf, rem, err 116 | } 117 | if buf, rem, err = surge.MarshalString(addr.Value, buf, rem); err != nil { 118 | return buf, rem, err 119 | } 120 | if buf, rem, err = surge.MarshalU64(addr.Nonce, buf, rem); err != nil { 121 | return buf, rem, err 122 | } 123 | return addr.Signature.Marshal(buf, rem) 124 | } 125 | 126 | // Unmarshal from binary into this Address. 127 | func (addr *Address) Unmarshal(buf []byte, rem int) ([]byte, int, error) { 128 | var err error 129 | buf, rem, err = surge.UnmarshalU8((*uint8)(&addr.Protocol), buf, rem) 130 | if err != nil { 131 | return buf, rem, err 132 | } 133 | buf, rem, err = surge.UnmarshalString(&addr.Value, buf, rem) 134 | if err != nil { 135 | return buf, rem, err 136 | } 137 | buf, rem, err = surge.UnmarshalU64(&addr.Nonce, buf, rem) 138 | if err != nil { 139 | return buf, rem, err 140 | } 141 | return addr.Signature.Unmarshal(buf, rem) 142 | } 143 | 144 | // Sign this Address and set its Signature. 145 | func (addr *Address) Sign(privKey *id.PrivKey) error { 146 | buf := make([]byte, surge.SizeHintU8+surge.SizeHintString(addr.Value)+surge.SizeHintU64) 147 | return addr.SignWithBuffer(privKey, buf) 148 | } 149 | 150 | // SignWithBuffer will Sign the Address and set its Signature. It uses a Buffer 151 | // for all marshaling to allow for buffer re-use. 152 | func (addr *Address) SignWithBuffer(privKey *id.PrivKey, buf []byte) error { 153 | hash, err := NewAddressHashWithBuffer(addr.Protocol, addr.Value, addr.Nonce, buf) 154 | if err != nil { 155 | return fmt.Errorf("hashing address: %v", err) 156 | } 157 | signature, err := crypto.Sign(hash[:], (*ecdsa.PrivateKey)(privKey)) 158 | if err != nil { 159 | return fmt.Errorf("signing address hash: %v", err) 160 | } 161 | if n := copy(addr.Signature[:], signature); n != len(addr.Signature) { 162 | return fmt.Errorf("copying signature: expected n=%v, got n=%v", len(addr.Signature), n) 163 | } 164 | return nil 165 | } 166 | 167 | // Verify that the Address was signed by a specific Signatory. 168 | func (addr *Address) Verify(signatory id.Signatory) error { 169 | buf := make([]byte, surge.SizeHintU8+surge.SizeHintString(addr.Value)+surge.SizeHintU64) 170 | return addr.VerifyWithBuffer(signatory, buf) 171 | } 172 | 173 | // VerifyWithBuffer will verify that the Address was signed by a specific 174 | // Signatory. It uses a Buffer for all marshaling to allow for buffer re-use. 175 | func (addr *Address) VerifyWithBuffer(signatory id.Signatory, buf []byte) error { 176 | hash, err := NewAddressHashWithBuffer(addr.Protocol, addr.Value, addr.Nonce, buf) 177 | if err != nil { 178 | return fmt.Errorf("hashing address: %v", err) 179 | } 180 | verifiedPubKey, err := crypto.SigToPub(hash[:], addr.Signature[:]) 181 | if err != nil { 182 | return fmt.Errorf("identifying address signature: %v", err) 183 | } 184 | verifiedSignatory := id.NewSignatory((*id.PubKey)(verifiedPubKey)) 185 | if !signatory.Equal(&verifiedSignatory) { 186 | return fmt.Errorf("verifying address signatory: expected %v, got %v", signatory, verifiedSignatory) 187 | } 188 | return nil 189 | } 190 | 191 | // Signatory returns the Signatory from the Address, based on the Signature. If 192 | // the Address is unsigned, then the empty Signatory is returned. 193 | func (addr *Address) Signatory() (id.Signatory, error) { 194 | buf := make([]byte, surge.SizeHintU8+surge.SizeHintString(addr.Value)+surge.SizeHintU64) 195 | return addr.SignatoryWithBuffer(buf) 196 | } 197 | 198 | // SignatoryWithBuffer returns the Signatory from the Address, based on the 199 | // Signature. If the Address is unsigned, then the empty Signatory is returned. 200 | // It uses a Buffer for all marshaling to allow for buffer re-use. 201 | func (addr *Address) SignatoryWithBuffer(buf []byte) (id.Signatory, error) { 202 | // Check whether or not the Address is unsigned. 203 | if addr.Signature.Equal(&id.Signature{}) { 204 | return id.Signatory{}, nil 205 | } 206 | 207 | // If the Address is signed, extract the Signatory and return it. 208 | hash, err := NewAddressHashWithBuffer(addr.Protocol, addr.Value, addr.Nonce, buf) 209 | if err != nil { 210 | return id.Signatory{}, fmt.Errorf("hashing address: %v", err) 211 | } 212 | pubKey, err := crypto.SigToPub(hash[:], addr.Signature[:]) 213 | if err != nil { 214 | return id.Signatory{}, fmt.Errorf("identifying address signature: %v", err) 215 | } 216 | return id.NewSignatory((*id.PubKey)(pubKey)), nil 217 | } 218 | 219 | // String returns a human-readable representation of the Address. The string 220 | // representation is safe for use in URLs and filenames. 221 | func (addr Address) String() string { 222 | return fmt.Sprintf("/%v/%v/%v/%v", addr.Protocol, addr.Value, addr.Nonce, addr.Signature) 223 | } 224 | 225 | // Equal compares two Addressees. Returns true if they are the same, otherwise 226 | // returns false. 227 | func (addr *Address) Equal(other *Address) bool { 228 | return addr.Protocol == other.Protocol && 229 | addr.Value == other.Value && 230 | addr.Nonce == other.Nonce && 231 | addr.Signature.Equal(&other.Signature) 232 | } 233 | 234 | // DecodeString into a wire-compatible Address. 235 | func DecodeString(addr string) (Address, error) { 236 | // Remove any leading slashes. 237 | if strings.HasPrefix(addr, "/") { 238 | addr = addr[1:] 239 | } 240 | 241 | addrParts := strings.Split(addr, "/") 242 | if len(addrParts) != 4 { 243 | return Address{}, fmt.Errorf("invalid format %v", addr) 244 | } 245 | var protocol Protocol 246 | switch addrParts[0] { 247 | case "tcp": 248 | protocol = TCP 249 | case "udp": 250 | protocol = UDP 251 | case "ws": 252 | protocol = WebSocket 253 | default: 254 | return Address{}, fmt.Errorf("invalid protocol %v", addrParts[0]) 255 | } 256 | value := addrParts[1] 257 | nonce, err := strconv.ParseUint(addrParts[2], 10, 64) 258 | if err != nil { 259 | return Address{}, err 260 | } 261 | var sig id.Signature 262 | sigBytes, err := base64.RawURLEncoding.DecodeString(addrParts[3]) 263 | if err != nil { 264 | return Address{}, err 265 | } 266 | if len(sigBytes) != 65 { 267 | return Address{}, fmt.Errorf("invalid signature %v", addrParts[3]) 268 | } 269 | copy(sig[:], sigBytes) 270 | return Address{ 271 | Protocol: protocol, 272 | Value: value, 273 | Nonce: nonce, 274 | Signature: sig, 275 | }, nil 276 | } 277 | -------------------------------------------------------------------------------- /wire/addr_test.go: -------------------------------------------------------------------------------- 1 | package wire_test 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "github.com/renproject/aw/wire" 7 | 8 | . "github.com/onsi/ginkgo" 9 | . "github.com/onsi/gomega" 10 | ) 11 | 12 | var _ = Describe("Address", func() { 13 | Context("when generating an address hash", func() { 14 | It("should generate a different hash for a different address", func() { 15 | r := rand.New(rand.NewSource(GinkgoRandomSeed())) 16 | addr := wire.NewUnsignedAddress(wire.TCP, "", 0) 17 | 18 | // Generate a hash for the address. 19 | h1, err := wire.NewAddressHash(addr.Protocol, addr.Value, addr.Nonce) 20 | Expect(err).ToNot(HaveOccurred()) 21 | 22 | // Generate a new address hash with a different protocol. 23 | h2, err := wire.NewAddressHash(wire.Protocol(r.Uint32()), addr.Value, addr.Nonce) 24 | Expect(err).ToNot(HaveOccurred()) 25 | 26 | // Generate a new address hash with a different value. 27 | b := make([]byte, 100) 28 | _, err = rand.Read(b) 29 | Expect(err).ToNot(HaveOccurred()) 30 | h3, err := wire.NewAddressHash(addr.Protocol, string(b), addr.Nonce) 31 | Expect(err).ToNot(HaveOccurred()) 32 | 33 | // Generate a new address hash with a different nonce. 34 | h4, err := wire.NewAddressHash(addr.Protocol, addr.Value, r.Uint64()) 35 | Expect(err).ToNot(HaveOccurred()) 36 | 37 | // Ensure all address hashes are different. 38 | Expect(h1).ToNot(Equal(h2)) 39 | Expect(h1).ToNot(Equal(h3)) 40 | Expect(h1).ToNot(Equal(h4)) 41 | Expect(h2).ToNot(Equal(h3)) 42 | Expect(h2).ToNot(Equal(h4)) 43 | Expect(h3).ToNot(Equal(h4)) 44 | }) 45 | }) 46 | }) 47 | -------------------------------------------------------------------------------- /wire/error.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | // NegligibleError are errors can be ignored by the logger. 4 | type NegligibleError struct { 5 | Err error 6 | } 7 | 8 | func (ne NegligibleError) Error() string { 9 | return ne.Err.Error() 10 | } 11 | 12 | func NewNegligibleError(err error) error { 13 | return NegligibleError{Err: err} 14 | } 15 | -------------------------------------------------------------------------------- /wire/sigaddr.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/renproject/id" 7 | ) 8 | 9 | type SignatoryAndAddress struct { 10 | Signatory id.Signatory 11 | Address Address 12 | } 13 | 14 | func (sigAndAddr SignatoryAndAddress) SizeHint() int { 15 | return sigAndAddr.Signatory.SizeHint() + sigAndAddr.Address.SizeHint() 16 | } 17 | 18 | func (sigAndAddr SignatoryAndAddress) Marshal(buf []byte, rem int) ([]byte, int, error) { 19 | buf, rem, err := sigAndAddr.Signatory.Marshal(buf, rem) 20 | if err != nil { 21 | return buf, rem, fmt.Errorf("marshal signatory: %v", err) 22 | } 23 | buf, rem, err = sigAndAddr.Address.Marshal(buf, rem) 24 | if err != nil { 25 | return buf, rem, fmt.Errorf("marshal address: %v", err) 26 | } 27 | return buf, rem, err 28 | } 29 | 30 | func (sigAndAddr *SignatoryAndAddress) Unmarshal(buf []byte, rem int) ([]byte, int, error) { 31 | buf, rem, err := (&sigAndAddr.Signatory).Unmarshal(buf, rem) 32 | if err != nil { 33 | return buf, rem, fmt.Errorf("unmarshal signatory: %v", err) 34 | } 35 | buf, rem, err = sigAndAddr.Address.Unmarshal(buf, rem) 36 | if err != nil { 37 | return buf, rem, fmt.Errorf("unmarshal address: %v", err) 38 | } 39 | return buf, rem, err 40 | } 41 | -------------------------------------------------------------------------------- /wire/sigaddr_test.go: -------------------------------------------------------------------------------- 1 | package wire_test 2 | -------------------------------------------------------------------------------- /wire/wire.go: -------------------------------------------------------------------------------- 1 | package wire 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/renproject/id" 8 | 9 | "github.com/renproject/surge" 10 | ) 11 | 12 | // Enumerate all valid MsgVersion values. 13 | const ( 14 | MsgVersion1 = uint16(1) 15 | ) 16 | 17 | // Enumerate all valid MsgType values. 18 | const ( 19 | MsgTypePush = uint16(1) 20 | MsgTypePull = uint16(2) 21 | MsgTypeSync = uint16(3) 22 | MsgTypeSend = uint16(4) 23 | MsgTypePing = uint16(5) 24 | MsgTypePingAck = uint16(6) 25 | ) 26 | 27 | // Msg defines the low-level message structure that is sent on-the-wire between 28 | // peers. 29 | type Msg struct { 30 | Version uint16 `json:"version"` 31 | Type uint16 `json:"type"` 32 | To id.Hash `json:"to"` 33 | Data []byte `json:"data"` 34 | SyncData []byte `json:"syncData"` 35 | } 36 | 37 | // Packet defines a struct that captures the incoming message and the corresponding IP address 38 | type Packet struct { 39 | Msg Msg 40 | IPAddr net.Addr 41 | } 42 | 43 | // SizeHint returns the number of bytes required to represent a Msg in binary. 44 | func (msg Msg) SizeHint() int { 45 | return surge.SizeHintU16 + 46 | surge.SizeHintU16 + 47 | id.SizeHintHash + 48 | surge.SizeHintBytes(msg.Data) 49 | } 50 | 51 | // Marshal a Msg to binary. 52 | func (msg Msg) Marshal(buf []byte, rem int) ([]byte, int, error) { 53 | buf, rem, err := surge.MarshalU16(msg.Version, buf, rem) 54 | if err != nil { 55 | return buf, rem, fmt.Errorf("marshal version: %v", err) 56 | } 57 | buf, rem, err = surge.MarshalU16(msg.Type, buf, rem) 58 | if err != nil { 59 | return buf, rem, fmt.Errorf("marshal type: %v", err) 60 | } 61 | buf, rem, err = surge.Marshal(msg.To, buf, rem) 62 | if err != nil { 63 | return buf, rem, fmt.Errorf("marshal to: %v", err) 64 | } 65 | buf, rem, err = surge.MarshalBytes(msg.Data, buf, rem) 66 | if err != nil { 67 | return buf, rem, fmt.Errorf("marshal data: %v", err) 68 | } 69 | return buf, rem, err 70 | } 71 | 72 | // Unmarshal a Msg from binary. 73 | func (msg *Msg) Unmarshal(buf []byte, rem int) ([]byte, int, error) { 74 | buf, rem, err := surge.UnmarshalU16(&msg.Version, buf, rem) 75 | if err != nil { 76 | return buf, rem, fmt.Errorf("unmarshal version: %v", err) 77 | } 78 | buf, rem, err = surge.UnmarshalU16(&msg.Type, buf, rem) 79 | if err != nil { 80 | return buf, rem, fmt.Errorf("unmarshal type: %v", err) 81 | } 82 | buf, rem, err = surge.Unmarshal(&msg.To, buf, rem) 83 | if err != nil { 84 | return buf, rem, fmt.Errorf("unmarshal to: %v", err) 85 | } 86 | buf, rem, err = surge.Unmarshal(&msg.Data, buf, rem) 87 | if err != nil { 88 | return buf, rem, fmt.Errorf("unmarshal data: %v", err) 89 | } 90 | return buf, rem, err 91 | } 92 | -------------------------------------------------------------------------------- /wire/wire_suite_test.go: -------------------------------------------------------------------------------- 1 | package wire_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/onsi/ginkgo" 7 | . "github.com/onsi/gomega" 8 | ) 9 | 10 | func TestWire(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "Wire suite") 13 | } 14 | -------------------------------------------------------------------------------- /wire/wire_test.go: -------------------------------------------------------------------------------- 1 | package wire_test 2 | --------------------------------------------------------------------------------