├── .travis.yml ├── .gitignore ├── README.md ├── go.mod ├── receivequeue_test.go ├── varint.go ├── listener.go ├── multipath.go ├── dialer.go ├── receivequeue.go ├── go.sum ├── conn.go ├── subflow.go ├── LICENSE └── multipath_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.14.3 5 | - tip 6 | 7 | script: 8 | - go test -v -race 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Multipath aggregates ordered and reliable connections over multiple paths together for throughput and resilience. It relies on existing dialers and listeners to create `net.Conn`s and wrap them as subflows on which it basically does two things: 2 | 3 | 1. On the sender side, transmits data over the subflow with the lowest roundtrip time, and if it takes long to get an acknowledgement, retransmits data over other subflows one by one. 4 | 1. On the receiver side, reorders the data received from all subflows and delivers ordered byte stream (`net.Conn`) to the upper layer. 5 | 6 | See docs in [multipath.go](multipath.go) for details. 7 | 8 | This code is used in https://github.com/benjojo/bondcat, and hat's off to [@benjojo](https://github.com/benjojo) for implementing [many fixes](https://github.com/benjojo/bondcat/tree/main/multipath). 9 | 10 | At a high level, this concept is built on a similar notion as [MultiPath TCP](https://www.multipath-tcp.org/), but our "multipath" works at a higher level in the stack where it implements different protocols that each have their own subflow but are all running on their own TCP or reliable UDP transports underneath. 11 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/getlantern/multipath 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/dustin/go-humanize v1.0.0 7 | github.com/getlantern/ema v0.0.0-20190620044903-5943d28f40e4 8 | github.com/getlantern/golog v0.0.0-20211223150227-d4d95a44d873 9 | github.com/google/uuid v1.1.2 10 | github.com/libp2p/go-buffer-pool v0.0.2 11 | github.com/stretchr/testify v1.8.0 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect 17 | github.com/getlantern/errors v1.0.1 // indirect 18 | github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect 19 | github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect 20 | github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect 21 | github.com/go-stack/stack v1.8.0 // indirect 22 | github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect 23 | github.com/pmezard/go-difflib v1.0.0 // indirect 24 | go.uber.org/atomic v1.7.0 // indirect 25 | go.uber.org/multierr v1.6.0 // indirect 26 | go.uber.org/zap v1.19.1 // indirect 27 | gopkg.in/yaml.v3 v3.0.1 // indirect 28 | ) 29 | -------------------------------------------------------------------------------- /receivequeue_test.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRead(t *testing.T) { 11 | q := newReceiveQueue(2) 12 | fn := uint64(minFrameNumber - 1) 13 | addFrame := func(s string) { 14 | fn++ 15 | q.add(&rxFrame{fn: fn, bytes: []byte(s)}, nil) 16 | } 17 | shouldRead := func(s string) { 18 | b := make([]byte, 3) 19 | n, err := q.read(b) 20 | assert.NoError(t, err) 21 | assert.Equal(t, s, string(b[:n])) 22 | } 23 | 24 | addFrame("abcd") 25 | shouldRead("abc") 26 | addFrame("abcd") 27 | shouldRead("dab") 28 | shouldRead("cd") 29 | addFrame("abcd") 30 | // adding the same frame number again should have no effect 31 | q.add(&rxFrame{fn: fn, bytes: []byte("1234")}, nil) 32 | shouldRead("abc") 33 | shouldRead("d") 34 | 35 | shouldWaitBeforeRead := func(d time.Duration, s string) { 36 | start := time.Now() 37 | b := make([]byte, 3) 38 | n, err := q.read(b) 39 | assert.NoError(t, err) 40 | assert.Equal(t, s, string(b[:n])) 41 | assert.InDelta(t, time.Since(start), d, float64(50*time.Millisecond)) 42 | } 43 | delay := 100 * time.Millisecond 44 | time.AfterFunc(delay, func() { 45 | addFrame("abcd") 46 | }) 47 | shouldWaitBeforeRead(delay, "abc") 48 | time.AfterFunc(delay, func() { 49 | addFrame("abc") 50 | }) 51 | shouldWaitBeforeRead(0, "d") 52 | shouldWaitBeforeRead(delay, "abc") 53 | 54 | // frames can be added out of order 55 | q.add(&rxFrame{fn: fn + 2, bytes: []byte("1234")}, nil) 56 | time.AfterFunc(delay, func() { 57 | addFrame("abcd") 58 | }) 59 | shouldWaitBeforeRead(delay, "abc") 60 | shouldWaitBeforeRead(0, "d12") 61 | } 62 | 63 | func TestReadRXQEarlyClose(t *testing.T) { 64 | q := newReceiveQueue(10) 65 | fn := uint64(minFrameNumber - 1) 66 | addFrame := func(s string) { 67 | fn++ 68 | q.add(&rxFrame{fn: fn, bytes: []byte(s)}, nil) 69 | } 70 | shouldRead := func(s string) { 71 | b := make([]byte, 5) 72 | n, err := q.read(b) 73 | assert.NoError(t, err) 74 | assert.Equal(t, s, string(b[:n])) 75 | } 76 | 77 | addFrame("Hello") 78 | shouldRead("Hello") 79 | addFrame("World") 80 | addFrame("Burld") 81 | q.close() 82 | time.Sleep(time.Millisecond * 101) 83 | shouldRead("World") 84 | shouldRead("Burld") 85 | b := make([]byte, 10) 86 | _, err := q.read(b) 87 | if err != ErrClosed { 88 | t.FailNow() 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /varint.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | // Adapted from github.com/lucas-clemente/quic-go/internal/utils/varint.go 10 | 11 | // taken from the QUIC draft 12 | const ( 13 | maxVarInt1 = 63 14 | maxVarInt2 = 16383 15 | maxVarInt4 = 1073741823 16 | maxVarInt8 = 4611686018427387903 17 | ) 18 | 19 | // ReadVarInt reads a number in the QUIC varint format 20 | func ReadVarInt(b io.ByteReader) (uint64, error) { 21 | firstByte, err := b.ReadByte() 22 | if err != nil { 23 | return 0, err 24 | } 25 | // the first two bits of the first byte encode the length 26 | len := 1 << ((firstByte & 0xc0) >> 6) 27 | b1 := firstByte & (0xff - 0xc0) 28 | if len == 1 { 29 | return uint64(b1), nil 30 | } 31 | b2, err := b.ReadByte() 32 | if err != nil { 33 | return 0, err 34 | } 35 | if len == 2 { 36 | return uint64(b2) + uint64(b1)<<8, nil 37 | } 38 | b3, err := b.ReadByte() 39 | if err != nil { 40 | return 0, err 41 | } 42 | b4, err := b.ReadByte() 43 | if err != nil { 44 | return 0, err 45 | } 46 | if len == 4 { 47 | return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil 48 | } 49 | b5, err := b.ReadByte() 50 | if err != nil { 51 | return 0, err 52 | } 53 | b6, err := b.ReadByte() 54 | if err != nil { 55 | return 0, err 56 | } 57 | b7, err := b.ReadByte() 58 | if err != nil { 59 | return 0, err 60 | } 61 | b8, err := b.ReadByte() 62 | if err != nil { 63 | return 0, err 64 | } 65 | return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil 66 | } 67 | 68 | // WriteVarInt writes a number in the QUIC varint format 69 | func WriteVarInt(b *bytes.Buffer, i uint64) { 70 | if i <= maxVarInt1 { 71 | b.WriteByte(uint8(i)) 72 | } else if i <= maxVarInt2 { 73 | b.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) 74 | } else if i <= maxVarInt4 { 75 | b.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) 76 | } else if i <= maxVarInt8 { 77 | b.Write([]byte{ 78 | uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), 79 | uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), 80 | }) 81 | } else { 82 | panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) 83 | } 84 | } 85 | 86 | // VarIntLen determines the number of bytes that will be needed to write a number 87 | func VarIntLen(i uint64) int { 88 | if i <= maxVarInt1 { 89 | return 1 90 | } 91 | if i <= maxVarInt2 { 92 | return 2 93 | } 94 | if i <= maxVarInt4 { 95 | return 4 96 | } 97 | if i <= maxVarInt8 { 98 | return 8 99 | } 100 | // Don't use a fmt.Sprintf here to format the error message. 101 | // The function would then exceed the inlining budget. 102 | panic(struct { 103 | message string 104 | num uint64 105 | }{"value doesn't fit into 62 bits: ", i}) 106 | } 107 | 108 | type byteReader struct { 109 | io.Reader 110 | b [1]byte 111 | } 112 | 113 | func (r byteReader) ReadByte() (byte, error) { 114 | _, err := r.Reader.Read(r.b[:]) 115 | if err != nil { 116 | return 0, err 117 | } 118 | return r.b[0], nil 119 | } 120 | -------------------------------------------------------------------------------- /listener.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | ) 12 | 13 | type mpListener struct { 14 | listeners []net.Listener 15 | listenerStats []StatsTracker 16 | mpConns map[connectionID]*mpConn 17 | muMPConns sync.Mutex 18 | chNextAccepted chan net.Conn 19 | startOnce sync.Once 20 | chClose chan struct{} 21 | closeOnce sync.Once 22 | } 23 | 24 | func NewListener(listeners []net.Listener, stats []StatsTracker) net.Listener { 25 | if len(listeners) != len(stats) { 26 | panic("the number of stats trackers should match listeners") 27 | } 28 | mpl := &mpListener{ 29 | listeners: listeners, 30 | listenerStats: stats, 31 | mpConns: make(map[connectionID]*mpConn), 32 | chNextAccepted: make(chan net.Conn), 33 | chClose: make(chan struct{}), 34 | } 35 | return mpl 36 | } 37 | 38 | func (mpl *mpListener) Accept() (net.Conn, error) { 39 | mpl.startOnce.Do(mpl.start) 40 | select { 41 | case <-mpl.chClose: 42 | return nil, ErrClosed 43 | case conn := <-mpl.chNextAccepted: 44 | return conn, nil 45 | } 46 | } 47 | 48 | func (mpl *mpListener) Close() error { 49 | mpl.closeOnce.Do(func() { close(mpl.chClose) }) 50 | return nil 51 | } 52 | 53 | // Addr satisfies the net.Listener interface. It returns a fake addr. 54 | func (mpl *mpListener) Addr() net.Addr { 55 | return fakeAddr{} 56 | } 57 | 58 | func (mpl *mpListener) start() { 59 | for i, l := range mpl.listeners { 60 | go func(l net.Listener, st StatsTracker) { 61 | for { 62 | if err := mpl.acceptFrom(l, st); err != nil { 63 | select { 64 | case <-mpl.chClose: 65 | return 66 | default: 67 | log.Debugf("failed to accept on %s: %v", l.Addr(), err) 68 | } 69 | } 70 | } 71 | }(l, mpl.listenerStats[i]) 72 | } 73 | } 74 | 75 | func (mpl *mpListener) acceptFrom(l net.Listener, st StatsTracker) error { 76 | conn, err := l.Accept() 77 | if err != nil { 78 | return err 79 | } 80 | var leadBytes [leadBytesLength]byte 81 | _, err = io.ReadFull(conn, leadBytes[:]) 82 | if err != nil { 83 | return err 84 | } 85 | if uint8(leadBytes[0]) != 0 { 86 | return ErrUnexpectedVersion 87 | } 88 | var cid connectionID 89 | copy(cid[:], leadBytes[1:]) 90 | newConn := false 91 | if cid == zeroCID { 92 | newConn = true 93 | cid = connectionID(uuid.New()) 94 | copy(leadBytes[1:], cid[:]) 95 | log.Tracef("New connection from %v, assigned CID %x", conn.RemoteAddr(), cid) 96 | } else { 97 | log.Tracef("New subflow of CID %x from %v", cid, conn.RemoteAddr()) 98 | } 99 | probeStart := time.Now() 100 | // echo lead bytes back to the client 101 | if _, err := conn.Write(leadBytes[:]); err != nil { 102 | return err 103 | } 104 | mpl.muMPConns.Lock() 105 | bc, exists := mpl.mpConns[cid] 106 | if !exists { 107 | if newConn { 108 | bc = newMPConn(cid, conn.RemoteAddr()) 109 | mpl.mpConns[cid] = bc 110 | } else { 111 | mpl.muMPConns.Unlock() 112 | return fmt.Errorf("unexpected subflow of CID %v from %v", cid, conn.RemoteAddr()) 113 | } 114 | } 115 | mpl.muMPConns.Unlock() 116 | bc.add(fmt.Sprintf("%x(%s)", cid, conn.LocalAddr().String()), conn, false, probeStart, st) 117 | if newConn { 118 | mpl.chNextAccepted <- bc 119 | } 120 | return nil 121 | } 122 | 123 | func (mpl *mpListener) remove(cid connectionID) { 124 | mpl.muMPConns.Lock() 125 | delete(mpl.mpConns, cid) 126 | mpl.muMPConns.Unlock() 127 | } 128 | -------------------------------------------------------------------------------- /multipath.go: -------------------------------------------------------------------------------- 1 | // Package multipath provides a simple way to aggregate multiple network paths 2 | // between a pair of hosts to form a single connection from the upper layer 3 | // perspective, for throughput and resilience. 4 | // 5 | // The term connection, path and subflow used here are the same as mentioned in 6 | // MP-TCP https://www.rfc-editor.org/rfc/rfc8684.html#name-terminology 7 | // 8 | // Each subflow is a bidirectional byte stream each side in the following form 9 | // until being disrupted or the connection ends. When establishing the very 10 | // first subflow, the client sends an all-zero connnection ID (CID) and the 11 | // server sends the assigned CID back. Subsequent subflows use the same CID. 12 | // 13 | // ---------------------------------------------------- 14 | // | version(1) | cid(16) | frames (...) | 15 | // ---------------------------------------------------- 16 | // 17 | // There are two types of frames. Data frame carries application data while ack 18 | // frame carries acknowledgement to the frame just received. When one data 19 | // frame is not acked in time, it is sent over another subflow, until all 20 | // available subflows have been tried. Payload size and frame number uses 21 | // variable-length integer encoding as described here: 22 | // https://tools.ietf.org/html/draft-ietf-quic-transport-29#section-16 23 | // 24 | // -------------------------------------------------------- 25 | // | payload size(1-8) | frame number (1-8) | payload | 26 | // -------------------------------------------------------- 27 | // 28 | // --------------------------------------- 29 | // | 00000000 | ack frame number (1-8) | 30 | // --------------------------------------- 31 | // 32 | // Ack frames with frame number < 10 are reserved for control. For now only 0 33 | // and 1 are used, for ping and pong frame respectively. They are for updating 34 | // RTT on inactive subflows and detecting recovered subflows. 35 | // 36 | // Ping frame: 37 | // ------------------------- 38 | // | 00000000 | 00000000 | 39 | // ------------------------- 40 | // 41 | // Pong frame: 42 | // ------------------------- 43 | // | 00000000 | 00000001 | 44 | // ------------------------- 45 | // 46 | package multipath 47 | 48 | import ( 49 | "bytes" 50 | "errors" 51 | "sync" 52 | "sync/atomic" 53 | "time" 54 | 55 | "github.com/getlantern/golog" 56 | "github.com/google/uuid" 57 | pool "github.com/libp2p/go-buffer-pool" 58 | ) 59 | 60 | const ( 61 | minFrameNumber uint64 = 10 62 | frameTypePing uint64 = 0 63 | frameTypePong uint64 = 1 64 | 65 | maxFrameSizeToCalculateRTT uint64 = 1500 66 | leadBytesLength = 1 + 16 // 1 byte version + 16 bytes CID 67 | // Assuming an average 1KB frame size, it would be able to buffer 4MB of 68 | // data without back pressure before the upper layer reads them. 69 | recieveQueueLength = 4096 70 | maxVarIntLength = 8 71 | probeInterval = time.Minute 72 | longRTT = time.Minute 73 | rttAlpha = 0.5 // this causes EMA to reflect changes more rapidly 74 | ) 75 | 76 | var ( 77 | ErrUnexpectedVersion = errors.New("unexpected version") 78 | ErrUnexpectedCID = errors.New("unexpected connnection ID") 79 | ErrClosed = errors.New("closed connection") 80 | ErrFailOnAllDialers = errors.New("fail on all dialers") 81 | log = golog.LoggerFor("multipath") 82 | zeroCID connectionID 83 | ) 84 | 85 | type connectionID uuid.UUID 86 | 87 | type rxFrame struct { 88 | fn uint64 89 | bytes []byte 90 | } 91 | 92 | type transmissionDatapoint struct { 93 | sf *subflow 94 | txTime time.Time 95 | } 96 | 97 | type sendFrame struct { 98 | fn uint64 99 | sz uint64 100 | buf []byte 101 | released *int32 // 1 == true; 0 == false. Use pointer so copied object still references the same address, as buf does 102 | retransmissions int 103 | sentVia []transmissionDatapoint // Contains the subflows it's already been written to, and when 104 | beingRetransmitted uint64 105 | changeLock sync.Mutex 106 | } 107 | 108 | func composeFrame(fn uint64, b []byte) *sendFrame { 109 | sz := len(b) 110 | buf := pool.Get(maxVarIntLength + maxVarIntLength + sz) 111 | wb := bytes.NewBuffer(buf[:0]) 112 | WriteVarInt(wb, uint64(sz)) 113 | WriteVarInt(wb, fn) 114 | if sz > 0 { 115 | wb.Write(b) 116 | } 117 | var released int32 118 | return &sendFrame{fn: fn, sz: uint64(sz), buf: wb.Bytes(), released: &released} 119 | } 120 | 121 | func (f *sendFrame) isDataFrame() bool { 122 | return f.sz > 0 123 | } 124 | 125 | func (f *sendFrame) release() { 126 | if atomic.CompareAndSwapInt32(f.released, 0, 1) { 127 | pool.Put(f.buf) 128 | } 129 | } 130 | 131 | // StatsTracker allows getting a sense of how the paths perform. Its methods 132 | // are called when each subflow sends or receives a frame. 133 | type StatsTracker interface { 134 | OnRecv(uint64) 135 | OnSent(uint64) 136 | OnRetransmit(uint64) 137 | UpdateRTT(time.Duration) 138 | } 139 | 140 | type NullTracker struct{} 141 | 142 | func (st NullTracker) OnRecv(uint64) {} 143 | func (st NullTracker) OnSent(uint64) {} 144 | func (st NullTracker) OnRetransmit(uint64) {} 145 | func (st NullTracker) UpdateRTT(time.Duration) {} 146 | -------------------------------------------------------------------------------- /dialer.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "net" 9 | "sort" 10 | "sync/atomic" 11 | "time" 12 | 13 | "github.com/dustin/go-humanize" 14 | "github.com/getlantern/ema" 15 | ) 16 | 17 | // Dialer is the interface each subflow dialer needs to satisify. It is also 18 | // the type of the multipath dialer. 19 | type Dialer interface { 20 | DialContext(ctx context.Context) (net.Conn, error) 21 | Label() string 22 | } 23 | 24 | // Stats is also provided by the multipath dialer so the caller can get and 25 | // print the status of each path. 26 | type Stats interface { 27 | FormatStats() (stats []string) 28 | } 29 | 30 | type subflowDialer struct { 31 | Dialer 32 | label string 33 | successes uint64 34 | consecSuccesses uint64 35 | failures uint64 36 | framesSent uint64 37 | framesRetransmit uint64 38 | framesRecv uint64 39 | bytesSent uint64 40 | bytesRetransmit uint64 41 | bytesRecv uint64 42 | emaRTT *ema.EMA 43 | } 44 | 45 | func (sfd *subflowDialer) DialContext(ctx context.Context) (net.Conn, error) { 46 | conn, err := sfd.Dialer.DialContext(ctx) 47 | if err == nil { 48 | atomic.AddUint64(&sfd.successes, 1) 49 | atomic.AddUint64(&sfd.consecSuccesses, 1) 50 | } else { 51 | // reset RTT to deprioritize this dialer 52 | sfd.emaRTT.SetDuration(longRTT) 53 | atomic.AddUint64(&sfd.failures, 1) 54 | atomic.StoreUint64(&sfd.consecSuccesses, 0) 55 | } 56 | return conn, err 57 | } 58 | 59 | func (sfd *subflowDialer) OnRecv(n uint64) { 60 | atomic.AddUint64(&sfd.framesRecv, 1) 61 | atomic.AddUint64(&sfd.bytesRecv, n) 62 | } 63 | func (sfd *subflowDialer) OnSent(n uint64) { 64 | atomic.AddUint64(&sfd.framesSent, 1) 65 | atomic.AddUint64(&sfd.bytesSent, n) 66 | } 67 | func (sfd *subflowDialer) OnRetransmit(n uint64) { 68 | atomic.AddUint64(&sfd.framesRetransmit, 1) 69 | atomic.AddUint64(&sfd.bytesRetransmit, n) 70 | } 71 | func (sfd *subflowDialer) UpdateRTT(rtt time.Duration) { 72 | sfd.emaRTT.UpdateDuration(rtt) 73 | } 74 | 75 | type mpDialer struct { 76 | dest string 77 | dialers []*subflowDialer 78 | } 79 | 80 | func NewDialer(dest string, dialers []Dialer) Dialer { 81 | var subflowDialers []*subflowDialer 82 | for _, d := range dialers { 83 | subflowDialers = append(subflowDialers, &subflowDialer{Dialer: d, label: d.Label(), emaRTT: ema.NewDuration(longRTT, rttAlpha)}) 84 | } 85 | d := &mpDialer{dest, subflowDialers} 86 | return d 87 | } 88 | 89 | // DialContext dials the addr using all dialers and returns a connection 90 | // contains subflows from whatever dialers available. 91 | func (mpd *mpDialer) DialContext(ctx context.Context) (net.Conn, error) { 92 | var bc *mpConn 93 | dialOne := func(d *subflowDialer, cid connectionID) (connectionID, bool) { 94 | conn, err := d.DialContext(ctx) 95 | if err != nil { 96 | log.Errorf("failed to dial %s: %v", d.Label(), err) 97 | return zeroCID, false 98 | } 99 | probeStart := time.Now() 100 | newCID, err := mpd.handshake(conn, cid) 101 | if err != nil { 102 | log.Errorf("failed to handshake %s, continuing: %v", d.Label(), err) 103 | conn.Close() 104 | return zeroCID, false 105 | } 106 | if cid == zeroCID { 107 | bc = newMPConn(newCID, conn.RemoteAddr()) 108 | go func() { 109 | for { 110 | time.Sleep(time.Second) 111 | select { 112 | case <-ctx.Done(): 113 | return 114 | default: 115 | bc.pendingAckMu.RLock() 116 | oldest := time.Duration(0) 117 | oldestFN := uint64(0) 118 | for fn, frame := range bc.pendingAckMap { 119 | if time.Since(frame.sentAt) > oldest { 120 | oldest = time.Since(frame.sentAt) 121 | oldestFN = fn 122 | } 123 | } 124 | bc.pendingAckMu.RUnlock() 125 | if oldest > time.Second { 126 | log.Debugf("Frame %d has not been acked for %v\n", oldestFN, oldest) 127 | } 128 | } 129 | } 130 | }() 131 | } 132 | bc.add(fmt.Sprintf("%x(%s)", newCID, d.label), conn, true, probeStart, d) 133 | return newCID, true 134 | } 135 | dialers := mpd.sorted() 136 | for i, d := range dialers { 137 | // dial the first connection with zero connection ID 138 | cid, ok := dialOne(d, zeroCID) 139 | if !ok { 140 | continue 141 | } 142 | if i < len(dialers)-1 { 143 | // dial the rest in parallel with server assigned connection ID 144 | for _, d := range dialers[i+1:] { 145 | go dialOne(d, cid) 146 | } 147 | } 148 | return bc, nil 149 | } 150 | return nil, ErrFailOnAllDialers 151 | } 152 | 153 | // handshake exchanges version and cid with the peer and returns the connnection ID 154 | // both end agrees if no error happens. 155 | func (mpd *mpDialer) handshake(conn net.Conn, cid connectionID) (connectionID, error) { 156 | var leadBytes [leadBytesLength]byte 157 | // the first byte, version, is implicitly set to 0 158 | copy(leadBytes[1:], cid[:]) 159 | _, err := conn.Write(leadBytes[:]) 160 | if err != nil { 161 | return zeroCID, err 162 | } 163 | _, err = io.ReadFull(conn, leadBytes[:]) 164 | if err != nil { 165 | return zeroCID, err 166 | } 167 | if uint8(leadBytes[0]) != 0 { 168 | return zeroCID, ErrUnexpectedVersion 169 | } 170 | var newCID connectionID 171 | copy(newCID[:], leadBytes[1:]) 172 | if cid != zeroCID && cid != newCID { 173 | return zeroCID, ErrUnexpectedCID 174 | } 175 | return newCID, nil 176 | } 177 | 178 | func (mpd *mpDialer) Label() string { 179 | return fmt.Sprintf("multipath dialer to %s with %d paths", mpd.dest, len(mpd.dialers)) 180 | } 181 | 182 | func (mpd *mpDialer) sorted() []*subflowDialer { 183 | dialersCopy := make([]*subflowDialer, len(mpd.dialers)) 184 | copy(dialersCopy, mpd.dialers) 185 | sort.Slice(dialersCopy, func(i, j int) bool { 186 | it := dialersCopy[i].emaRTT.GetDuration() 187 | jt := dialersCopy[j].emaRTT.GetDuration() 188 | // both have unknown RTT or fail to dial, give each a chance 189 | if it == jt { 190 | return rand.Intn(2) > 0 191 | } 192 | return it < jt 193 | }) 194 | return dialersCopy 195 | } 196 | 197 | func (mpd *mpDialer) FormatStats() (stats []string) { 198 | for _, d := range mpd.sorted() { 199 | stats = append(stats, fmt.Sprintf("%s S: %4d(%3d) F: %4d RTT: %6.0fms SENT: %7d/%7s RECV: %7d/%7s RT: %7d/%7s", 200 | d.label, 201 | atomic.LoadUint64(&d.successes), 202 | atomic.LoadUint64(&d.consecSuccesses), 203 | atomic.LoadUint64(&d.failures), 204 | d.emaRTT.GetDuration().Seconds()*1000, 205 | atomic.LoadUint64(&d.framesSent), humanize.Bytes(atomic.LoadUint64(&d.bytesSent)), 206 | atomic.LoadUint64(&d.framesRecv), humanize.Bytes(atomic.LoadUint64(&d.bytesRecv)), 207 | atomic.LoadUint64(&d.framesRetransmit), humanize.Bytes(atomic.LoadUint64(&d.bytesRetransmit)))) 208 | } 209 | return 210 | } 211 | -------------------------------------------------------------------------------- /receivequeue.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | 9 | pool "github.com/libp2p/go-buffer-pool" 10 | ) 11 | 12 | // receiveQueue keeps received frames for the upper layer to read. It is 13 | // maintained as a ring buffer with fixed size. It takes advantage of the fact 14 | // that the frame number is sequential, so when a new frame arrives, it is 15 | // placed at buf[frameNumber % size]. 16 | type receiveQueue struct { 17 | readFrameTip uint64 18 | buf []rxFrame 19 | size uint64 20 | // rp stands for read pointer, point to the index of the frame containing 21 | // data yet to be read. 22 | rp uint64 23 | availableFrameChannel chan bool 24 | readNotifyChannel chan bool 25 | readDeadline time.Time 26 | deadlineLock sync.Mutex 27 | closing uint32 // 1 == true, 0 == false -- This is used to "drain" the Queue 28 | fullyClosed uint32 // 1 == true, 0 == false 29 | readLock *sync.Mutex 30 | } 31 | 32 | func newReceiveQueue(size int) *receiveQueue { 33 | rq := &receiveQueue{ 34 | buf: make([]rxFrame, size), 35 | size: uint64(size), 36 | rp: minFrameNumber % uint64(size), // frame number starts with minFrameNumber, so should the read pointer 37 | availableFrameChannel: make(chan bool, 1), 38 | readNotifyChannel: make(chan bool), 39 | readLock: &sync.Mutex{}, 40 | } 41 | return rq 42 | } 43 | 44 | func (rq *receiveQueue) add(f *rxFrame, sf *subflow) { 45 | select { 46 | case rq.availableFrameChannel <- true: 47 | default: 48 | } 49 | // Another thing to protect against, is that we might be 50 | // locally blocked on a full receiveQueue. If that is the 51 | // case then we don't want to return instantly from this 52 | // function since that will just case retransmits to fire 53 | // over and over again, causing mass bandwidth loss. 54 | // Instead let's quickly check if we have all of the data we need 55 | // to read, and if we do, hang until we don't have that problem anymore 56 | if rq.isFull() { 57 | for { 58 | var abort bool 59 | select { 60 | case <-rq.readNotifyChannel: 61 | if !rq.isFull() { 62 | abort = true 63 | break 64 | } 65 | } 66 | if abort { 67 | break 68 | } 69 | } 70 | } 71 | select { 72 | case rq.availableFrameChannel <- true: 73 | default: 74 | } 75 | 76 | readFrameTip := atomic.LoadUint64(&rq.readFrameTip) 77 | 78 | if readFrameTip != 0 { 79 | if readFrameTip > f.fn || readFrameTip == f.fn { 80 | sf.ack(f.fn) 81 | return 82 | } 83 | } 84 | 85 | if f.fn > readFrameTip+rq.size && readFrameTip != 0 { 86 | log.Debugf("Near corruption incident?? %v vs the max peek of %v (frametip %d)", f.fn, readFrameTip+rq.size-1, readFrameTip) 87 | return // Nope! this will corrupt the buffer 88 | } 89 | 90 | if rq.tryAdd(f) { 91 | sf.ack(f.fn) 92 | return 93 | } 94 | 95 | // Protect against the socket being closed 96 | if atomic.LoadUint32(&rq.fullyClosed) == 1 { 97 | pool.Put(f.bytes) 98 | return 99 | } 100 | 101 | } 102 | 103 | func (rq *receiveQueue) isFull() bool { 104 | printFull := false 105 | for i := uint64(0); i < rq.size; i++ { 106 | expectedFrameNumber := atomic.LoadUint64(&rq.readFrameTip) + i 107 | idx := expectedFrameNumber % rq.size 108 | 109 | rq.readLock.Lock() 110 | if rq.buf[idx].fn != expectedFrameNumber { 111 | if printFull { 112 | log.Tracef("receiveQueue is %d%% full! (%d/%d)", int((float32(i) / float32(rq.size) * 100)), i, rq.size) 113 | } 114 | rq.readLock.Unlock() 115 | return false 116 | } 117 | 118 | if rq.buf[idx].bytes == nil { 119 | rq.readLock.Unlock() 120 | return false 121 | } 122 | rq.readLock.Unlock() 123 | 124 | if i == rq.size/2 { 125 | printFull = true 126 | } 127 | } 128 | 129 | return true 130 | } 131 | 132 | func (rq *receiveQueue) tryAdd(f *rxFrame) bool { 133 | rq.readLock.Lock() 134 | idx := f.fn % rq.size 135 | if rq.buf[idx].bytes == nil { 136 | // empty slot 137 | rq.buf[idx] = *f 138 | if idx == rq.rp { 139 | select { 140 | case rq.availableFrameChannel <- true: 141 | default: 142 | } 143 | } 144 | rq.readLock.Unlock() 145 | return true 146 | } else if rq.buf[idx].fn == f.fn { 147 | rq.readLock.Unlock() 148 | // retransmission, ignore 149 | log.Tracef("Got a retransmit. for %d", f.fn) 150 | pool.Put(f.bytes) 151 | return true 152 | } 153 | rq.readLock.Unlock() 154 | 155 | if idx != 0 { 156 | log.Tracef("Not what I was looking for, I'm looking for frame %v", rq.buf[idx-1].fn+1) 157 | } 158 | return false 159 | } 160 | 161 | func (rq *receiveQueue) read(b []byte) (int, error) { 162 | for { 163 | rq.readLock.Lock() 164 | if rq.buf[rq.rp].bytes != nil { 165 | rq.readLock.Unlock() 166 | break 167 | } 168 | rq.readLock.Unlock() 169 | 170 | if atomic.LoadUint32(&rq.fullyClosed) == 1 { 171 | return 0, ErrClosed 172 | } 173 | if atomic.LoadUint32(&rq.closing) == 1 { 174 | // if we are closing, then we should check if there is anything left to send 175 | // before sending ErrClosed back upstream, otherwise we may close "early" with 176 | // some data still inside of us! 177 | break 178 | } 179 | 180 | if rq.dlExceeded() { 181 | return 0, context.DeadlineExceeded 182 | } 183 | 184 | select { 185 | case rq.readNotifyChannel <- true: 186 | default: 187 | } 188 | <-rq.availableFrameChannel 189 | } 190 | 191 | rq.readLock.Lock() 192 | defer rq.readLock.Unlock() 193 | 194 | totalN := 0 195 | cur := rq.buf[rq.rp].bytes 196 | for cur != nil && totalN < len(b) { 197 | oldFrameTip := atomic.LoadUint64(&rq.readFrameTip) 198 | if (rq.buf[rq.rp].fn != oldFrameTip+1) && (rq.buf[rq.rp].fn != oldFrameTip) && oldFrameTip != 0 { 199 | log.Errorf("receiveQueue buffer corruption detected [%v vs %v] (The crash happened at idx = %d)", rq.buf[rq.rp].fn, oldFrameTip+1, rq.rp) 200 | log.Tracef("All Buffers: ") 201 | for idx, v := range rq.buf { 202 | log.Tracef("\t[%d]fn %d, [%d]byte\n", idx, v.fn, len(v.bytes)) 203 | } 204 | rq.close() 205 | return 0, ErrClosed 206 | } 207 | n := copy(b[totalN:], cur) 208 | if n == len(cur) { 209 | log.Tracef("Finished with read frame %d\n", rq.buf[rq.rp].fn) 210 | atomic.StoreUint64(&rq.readFrameTip, rq.buf[rq.rp].fn) 211 | pool.Put(cur) 212 | rq.buf[rq.rp].bytes = nil 213 | rq.rp = (rq.rp + 1) % rq.size 214 | } else { 215 | // The frames in the ring buffer are never overridden, so we can 216 | // safely update the bytes to reflect the next read position. 217 | rq.buf[rq.rp].bytes = cur[n:] 218 | log.Tracef("Partial read frame %d\n", rq.buf[rq.rp].fn) 219 | } 220 | totalN += n 221 | cur = rq.buf[rq.rp].bytes 222 | } 223 | 224 | select { 225 | case rq.readNotifyChannel <- true: 226 | default: 227 | } 228 | 229 | if totalN == 0 && atomic.LoadUint32(&rq.closing) == 1 { 230 | // close fully 231 | atomic.StoreUint32(&rq.fullyClosed, 1) 232 | return 0, ErrClosed 233 | } 234 | 235 | return totalN, nil 236 | } 237 | 238 | func (rq *receiveQueue) setReadDeadline(dl time.Time) { 239 | rq.deadlineLock.Lock() 240 | rq.readDeadline = dl 241 | rq.deadlineLock.Unlock() 242 | if !dl.IsZero() { 243 | ttl := dl.Sub(time.Now()) 244 | if ttl <= 0 { 245 | for { 246 | abort := false 247 | select { 248 | case rq.availableFrameChannel <- true: 249 | default: 250 | abort = true 251 | } 252 | if abort { 253 | break 254 | } 255 | } 256 | } else { 257 | time.AfterFunc(ttl, func() { 258 | rq.availableFrameChannel <- true 259 | }) 260 | } 261 | } 262 | } 263 | 264 | func (rq *receiveQueue) dlExceeded() bool { 265 | return !rq.readDeadline.IsZero() && !rq.readDeadline.After(time.Now()) 266 | } 267 | 268 | func (rq *receiveQueue) close() { 269 | atomic.StoreUint32(&rq.closing, 1) 270 | abort := false 271 | 272 | for { 273 | select { 274 | case rq.availableFrameChannel <- true: 275 | default: 276 | abort = true 277 | } 278 | if abort { 279 | break 280 | } 281 | } 282 | } 283 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= 2 | github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= 7 | github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= 8 | github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 h1:NRUJuo3v3WGC/g5YiyF790gut6oQr5f3FBI88Wv0dx4= 9 | github.com/getlantern/context v0.0.0-20190109183933-c447772a6520/go.mod h1:L+mq6/vvYHKjCX2oez0CgEAJmbq1fbb/oNJIWQkBybY= 10 | github.com/getlantern/ema v0.0.0-20190620044903-5943d28f40e4 h1:PMK8QQn9GLTQXdHnqoNhyToOa8snagaZVt9Xb36NEUc= 11 | github.com/getlantern/ema v0.0.0-20190620044903-5943d28f40e4/go.mod h1:tzRwT19aDrWSr6yRDs8iOvaXXCau96EgWsgGT9wIpoQ= 12 | github.com/getlantern/errors v1.0.1 h1:XukU2whlh7OdpxnkXhNH9VTLVz0EVPGKDV5K0oWhvzw= 13 | github.com/getlantern/errors v1.0.1/go.mod h1:l+xpFBrCtDLpK9qNjxs+cHU6+BAdlBaxHqikB6Lku3A= 14 | github.com/getlantern/golog v0.0.0-20211223150227-d4d95a44d873 h1:nnod94N4hMKb7pyJmnXDk+HR23o1S2CbZ4oMKzHbp9A= 15 | github.com/getlantern/golog v0.0.0-20211223150227-d4d95a44d873/go.mod h1:+ZU1h+iOVqWReBpky6d5Y2WL0sF2Llxu+QcxJFs2+OU= 16 | github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 h1:micT5vkcr9tOVk1FiH8SWKID8ultN44Z+yzd2y/Vyb0= 17 | github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7/go.mod h1:dD3CgOrwlzca8ed61CsZouQS5h5jIzkK9ZWrTcf0s+o= 18 | github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 h1:XYzSdCbkzOC0FDNrgJqGRo8PCMFOBFL9py72DRs7bmc= 19 | github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA= 20 | github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA= 21 | github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA= 22 | github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= 23 | github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= 24 | github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= 25 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 26 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 27 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 28 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 29 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 30 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 31 | github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs= 32 | github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM= 33 | github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c h1:rp5dCmg/yLR3mgFuSOe4oEnDDmGLROTvMragMUXpTQw= 34 | github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c/go.mod h1:X07ZCGwUbLaax7L0S3Tw4hpejzu63ZrrQiUe6W0hcy0= 35 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 36 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 37 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 38 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 39 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 40 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 41 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 42 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 43 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 44 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= 45 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 46 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 47 | go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= 48 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 49 | go.uber.org/goleak v1.1.11-0.20210813005559-691160354723 h1:sHOAIxRGBp443oHZIPB+HsUGaksVCXVQENPxwTfQdH4= 50 | go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= 51 | go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= 52 | go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 53 | go.uber.org/zap v1.19.1 h1:ue41HOKd1vGURxrmeKIgELGb3jPW9DMUDGtsinblHwI= 54 | go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= 55 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 56 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 57 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 58 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 59 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 60 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 61 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 62 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 63 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 64 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 65 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 66 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 67 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 68 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 69 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 70 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 71 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 72 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 73 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 74 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 75 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 76 | golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 77 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 78 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 79 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 80 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 81 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 82 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 83 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= 84 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 85 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 86 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 87 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 88 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 89 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "net" 5 | "sort" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | ) 10 | 11 | type mpConn struct { 12 | cid connectionID 13 | remoteAddr net.Addr 14 | lastFN uint64 15 | subflows []*subflow 16 | muSubflows sync.RWMutex 17 | recvQueue *receiveQueue 18 | closed uint32 // 1 == true, 0 == false 19 | writerMaybeReady chan bool 20 | tryRetransmit chan bool 21 | 22 | pendingAckMap map[uint64]*pendingAck 23 | pendingAckMu *sync.RWMutex 24 | } 25 | 26 | func newMPConn(cid connectionID, remoteAddr net.Addr) *mpConn { 27 | mpc := &mpConn{ 28 | cid: cid, 29 | remoteAddr: remoteAddr, 30 | lastFN: minFrameNumber - 1, 31 | recvQueue: newReceiveQueue(recieveQueueLength), 32 | writerMaybeReady: make(chan bool, 1), 33 | tryRetransmit: make(chan bool, 1), 34 | pendingAckMap: make(map[uint64]*pendingAck), 35 | pendingAckMu: &sync.RWMutex{}, 36 | } 37 | go mpc.retransmitLoop() 38 | return mpc 39 | } 40 | 41 | func (bc *mpConn) Read(b []byte) (n int, err error) { 42 | return bc.recvQueue.read(b) 43 | } 44 | 45 | func (bc *mpConn) Write(b []byte) (n int, err error) { 46 | frame := composeFrame(atomic.AddUint64(&bc.lastFN, 1), b) 47 | 48 | for { 49 | bc.pendingAckMu.RLock() 50 | inflight := len(bc.pendingAckMap) 51 | bc.pendingAckMu.RUnlock() 52 | if inflight > 500 { 53 | time.Sleep(time.Millisecond * 100) 54 | log.Tracef("too many inflights") 55 | continue 56 | } 57 | 58 | for _, sf := range bc.sortedSubflows() { 59 | 60 | if atomic.LoadUint64(&sf.actuallyBusyOnWrite) == 1 { 61 | // Avoid a possibly blocked writer for a retransmit 62 | continue 63 | } 64 | 65 | select { 66 | case sf.sendQueue <- frame: 67 | return len(b), nil 68 | default: 69 | } 70 | } 71 | if len(bc.sortedSubflows()) == 0 { 72 | return 0, ErrClosed 73 | } 74 | 75 | <-bc.writerMaybeReady 76 | } 77 | } 78 | 79 | func (bc *mpConn) Close() error { 80 | bc.close() 81 | for _, sf := range bc.sortedSubflows() { 82 | sf.close() 83 | } 84 | return nil 85 | } 86 | 87 | func (bc *mpConn) close() { 88 | atomic.StoreUint32(&bc.closed, 1) 89 | bc.recvQueue.close() 90 | } 91 | 92 | type fakeAddr struct{} 93 | 94 | func (fakeAddr) Network() string { return "multipath" } 95 | func (fakeAddr) String() string { return "multipath" } 96 | 97 | func (bc *mpConn) LocalAddr() net.Addr { 98 | return fakeAddr{} 99 | } 100 | 101 | func (bc *mpConn) RemoteAddr() net.Addr { 102 | return bc.remoteAddr 103 | } 104 | 105 | func (bc *mpConn) SetDeadline(t time.Time) error { 106 | bc.SetReadDeadline(t) 107 | return bc.SetWriteDeadline(t) 108 | } 109 | 110 | func (bc *mpConn) SetReadDeadline(t time.Time) error { 111 | bc.recvQueue.setReadDeadline(t) 112 | return nil 113 | } 114 | 115 | func (bc *mpConn) SetWriteDeadline(t time.Time) error { 116 | bc.muSubflows.RLock() 117 | defer bc.muSubflows.RUnlock() 118 | for _, sf := range bc.subflows { 119 | if err := sf.conn.SetWriteDeadline(t); err != nil { 120 | return err 121 | } 122 | } 123 | return nil 124 | } 125 | 126 | func (bc *mpConn) retransmit(frame *sendFrame) { 127 | frame.changeLock.Lock() 128 | defer frame.changeLock.Unlock() 129 | 130 | if atomic.LoadUint64(&frame.beingRetransmitted) == 1 { 131 | return 132 | } 133 | atomic.StoreUint64(&frame.beingRetransmitted, 1) 134 | defer func() { 135 | atomic.StoreUint64(&frame.beingRetransmitted, 0) 136 | }() 137 | 138 | subflows := bc.sortedSubflows() 139 | 140 | alreadyTransmittedOnAllSubflows := false 141 | for { 142 | abort := false 143 | if bc.closed == 1 { 144 | return 145 | } 146 | 147 | var selectedSubflow *subflow 148 | 149 | abort, alreadyTransmittedOnAllSubflows, selectedSubflow = selectSubflowForRetransmit(subflows, frame, false) 150 | if selectedSubflow == nil { 151 | abort, alreadyTransmittedOnAllSubflows, selectedSubflow = selectSubflowForRetransmit(subflows, frame, true) 152 | if selectedSubflow == nil { 153 | abort = true 154 | alreadyTransmittedOnAllSubflows = true 155 | break 156 | } 157 | } 158 | 159 | select { 160 | case <-selectedSubflow.chClose: 161 | continue 162 | case selectedSubflow.sendQueue <- frame: 163 | frame.retransmissions++ 164 | log.Debugf("retransmitted frame %d via %s", frame.fn, selectedSubflow.to) 165 | if frame.sentVia == nil { 166 | frame.sentVia = make([]transmissionDatapoint, 0) 167 | } 168 | frame.sentVia = append(frame.sentVia, transmissionDatapoint{selectedSubflow, time.Now()}) 169 | return 170 | default: 171 | } 172 | 173 | if abort { 174 | break 175 | } 176 | <-bc.tryRetransmit 177 | } 178 | 179 | if !alreadyTransmittedOnAllSubflows { 180 | log.Debugf("frame %d is being retransmitted on all subflows of %x", frame.fn, bc.cid) 181 | } 182 | 183 | return 184 | } 185 | 186 | func selectSubflowForRetransmit(subflows []*subflow, frame *sendFrame, timeFallback bool) (bool, bool, *subflow) { 187 | var selectedSubflow *subflow 188 | for _, sf := range subflows { 189 | if atomic.LoadUint64(&sf.actuallyBusyOnWrite) == 1 { 190 | // Avoid a possibly blocked writer for a retransmit 191 | // let's avoid re-sending a frame down the same socket twice. 192 | // Since at best, it just double sends a frame into the send buffer 193 | // and at worst it blocks other frames from entering a send buffer. 194 | continue 195 | } 196 | // Have we used this subflow before for this frame? 197 | usedBefore := false 198 | var avoidTime time.Time 199 | for _, avoidSF := range frame.sentVia { 200 | if sf == avoidSF.sf { 201 | usedBefore = true 202 | avoidTime = avoidSF.txTime 203 | } 204 | } 205 | 206 | // It may be acceptable to use a subflow that has been used before 207 | // if we are in timeFallback mode 208 | if usedBefore { 209 | if timeFallback { 210 | if time.Since(avoidTime) > time.Second { 211 | usedBefore = false 212 | } 213 | } else { 214 | continue 215 | } 216 | } 217 | 218 | if !usedBefore { 219 | // frame.sentVia 220 | return false, false, sf 221 | } 222 | } 223 | return true, true, selectedSubflow 224 | } 225 | 226 | func (bc *mpConn) sortedSubflows() []*subflow { 227 | bc.muSubflows.RLock() 228 | subflows := make([]*subflow, len(bc.subflows)) 229 | copy(subflows, bc.subflows) 230 | bc.muSubflows.RUnlock() 231 | sort.Slice(subflows, func(i, j int) bool { 232 | return subflows[i].getRTT() < subflows[j].getRTT() 233 | }) 234 | return subflows 235 | } 236 | 237 | func (bc *mpConn) add(to string, c net.Conn, clientSide bool, probeStart time.Time, tracker StatsTracker) { 238 | bc.muSubflows.Lock() 239 | defer bc.muSubflows.Unlock() 240 | bc.subflows = append(bc.subflows, startSubflow(to, c, bc, clientSide, probeStart, tracker)) 241 | } 242 | 243 | func (bc *mpConn) remove(theSubflow *subflow) { 244 | bc.muSubflows.Lock() 245 | var remains []*subflow 246 | for _, sf := range bc.subflows { 247 | if sf != theSubflow { 248 | remains = append(remains, sf) 249 | } 250 | } 251 | bc.subflows = remains 252 | left := len(remains) 253 | bc.muSubflows.Unlock() 254 | if left == 0 { 255 | bc.close() 256 | } 257 | } 258 | 259 | func (bc *mpConn) retransmitLoop() { 260 | evalTick := time.NewTicker(time.Millisecond * 100) 261 | for { 262 | select { 263 | case <-evalTick.C: 264 | } 265 | if atomic.LoadUint32(&bc.closed) == 1 { 266 | return 267 | } 268 | 269 | bc.pendingAckMu.RLock() 270 | RetransmitFrames := make([]pendingAck, 0) 271 | for fn, frame := range bc.pendingAckMap { 272 | if time.Since(frame.sentAt) > frame.outboundSf.retransTimer() { 273 | if bc.pendingAckMap[fn] != nil { 274 | RetransmitFrames = append(RetransmitFrames, *frame) 275 | } 276 | } 277 | } 278 | bc.pendingAckMu.RUnlock() 279 | 280 | sort.Slice(RetransmitFrames, func(i, j int) bool { 281 | return RetransmitFrames[i].fn < RetransmitFrames[j].fn 282 | }) 283 | 284 | for _, frame := range RetransmitFrames { 285 | sendframe := frame.framePtr 286 | sendframe.changeLock.Lock() 287 | if bc.isPendingAck(frame.fn) { 288 | // No ack means the subflow fails or has a longer RTT 289 | // log.Errorf("Retransmitting! %#v", frame.fn) 290 | if sendframe.beingRetransmitted == 0 { 291 | go bc.retransmit(sendframe) 292 | } 293 | sendframe.changeLock.Unlock() 294 | } else { 295 | // It is ok to release buffer here as the frame will never 296 | // be retransmitted again. 297 | sendframe.release() 298 | sendframe.changeLock.Unlock() 299 | bc.pendingAckMu.Lock() 300 | delete(bc.pendingAckMap, frame.fn) 301 | bc.pendingAckMu.Unlock() 302 | } 303 | } 304 | 305 | } 306 | } 307 | 308 | func (bc *mpConn) isPendingAck(fn uint64) bool { 309 | if fn > minFrameNumber { 310 | bc.pendingAckMu.RLock() 311 | defer bc.pendingAckMu.RUnlock() 312 | return bc.pendingAckMap[fn] != nil 313 | } 314 | return false 315 | 316 | } 317 | -------------------------------------------------------------------------------- /subflow.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "math/rand" 7 | "net" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/getlantern/ema" 13 | pool "github.com/libp2p/go-buffer-pool" 14 | ) 15 | 16 | type pendingAck struct { 17 | fn uint64 18 | sz uint64 19 | sentAt time.Time 20 | outboundSf *subflow 21 | framePtr *sendFrame 22 | } 23 | 24 | type subflow struct { 25 | to string 26 | conn net.Conn 27 | mpc *mpConn 28 | 29 | chClose chan struct{} 30 | closeOnce sync.Once 31 | sendQueue chan *sendFrame 32 | pendingPing *pendingAck // Only for pings 33 | muPendingPing sync.RWMutex 34 | emaRTT *ema.EMA 35 | tracker StatsTracker 36 | actuallyBusyOnWrite uint64 37 | finishedClosing chan bool 38 | } 39 | 40 | func startSubflow(to string, c net.Conn, mpc *mpConn, clientSide bool, probeStart time.Time, tracker StatsTracker) *subflow { 41 | sf := &subflow{ 42 | to: to, 43 | conn: c, 44 | mpc: mpc, 45 | chClose: make(chan struct{}), 46 | sendQueue: make(chan *sendFrame, 1), 47 | finishedClosing: make(chan bool, 1), 48 | // pendingPing is used for storing the subflow's ping data. Handy since pings are subflow dependent 49 | pendingPing: nil, 50 | emaRTT: ema.NewDuration(longRTT, rttAlpha), 51 | tracker: tracker, 52 | } 53 | go sf.sendLoop() 54 | if clientSide { 55 | initialRTT := time.Since(probeStart) 56 | tracker.UpdateRTT(initialRTT) 57 | sf.emaRTT.SetDuration(initialRTT) 58 | // pong immediately so the server can calculate the RTT between when it 59 | // sends the leading bytes and receives the pong frame. 60 | sf.ack(frameTypePong) 61 | } else { 62 | // server side subflow expects a pong frame to calculate RTT. 63 | sf.muPendingPing.Lock() 64 | sf.pendingPing = &pendingAck{frameTypePong, 0, probeStart, sf, nil} 65 | sf.muPendingPing.Unlock() 66 | } 67 | go func() { 68 | if err := sf.readLoop(); err != nil && err != io.EOF { 69 | log.Debugf("read loop to %s ended: %v", sf.to, err) 70 | } 71 | }() 72 | return sf 73 | } 74 | 75 | func (sf *subflow) readLoop() (err error) { 76 | ch := make(chan *rxFrame) 77 | r := byteReader{Reader: sf.conn} 78 | go sf.readLoopFrames(ch, r) 79 | 80 | probeTimer := time.NewTimer(randomize(probeInterval)) 81 | go sf.probe() // Force a ping out right away, to calibrate our own timings 82 | 83 | for { 84 | select { 85 | case frame := <-ch: // Fed by readLoopFrames 86 | if frame == nil { 87 | return 88 | } 89 | sf.mpc.recvQueue.add(frame, sf) 90 | if !probeTimer.Stop() { 91 | <-probeTimer.C 92 | } 93 | probeTimer.Reset(randomize(probeInterval)) 94 | case <-probeTimer.C: 95 | go sf.probe() 96 | probeTimer.Reset(randomize(probeInterval)) 97 | } 98 | } 99 | } 100 | 101 | func (sf *subflow) readLoopFrames(ch chan *rxFrame, r byteReader) bool { 102 | defer close(ch) 103 | var err error 104 | for { 105 | // The is the core "reactor" where frames are read. The frame format 106 | // can be found in the top of multipath.go 107 | var sz, fn uint64 108 | sz, err = ReadVarInt(r) 109 | if err != nil { 110 | sf.close() 111 | return true 112 | } 113 | fn, err = ReadVarInt(r) 114 | if err != nil { 115 | sf.close() 116 | return true 117 | } 118 | if sz == 0 { 119 | sf.gotACK(fn) 120 | continue 121 | } 122 | log.Tracef("got frame %d from %s with %d bytes", fn, sf.to, sz) 123 | if sz > 1<<20 { 124 | // This almost always happens due to frame corruption. 125 | log.Errorf("Frame of size %v from %s is impossible", sz, sf.to) 126 | sf.close() 127 | return true 128 | } 129 | buf := pool.Get(int(sz)) 130 | _, err = io.ReadFull(r, buf) 131 | if err != nil { 132 | pool.Put(buf) 133 | sf.close() 134 | return true 135 | } 136 | 137 | if fn > (atomic.LoadUint64(&sf.mpc.recvQueue.readFrameTip) + sf.mpc.recvQueue.size) { 138 | // This frame dropped is too far in the future to apply 139 | continue 140 | } 141 | 142 | ch <- &rxFrame{fn: fn, bytes: buf} 143 | sf.tracker.OnRecv(sz) 144 | select { 145 | case <-sf.chClose: 146 | return true 147 | default: 148 | } 149 | } 150 | } 151 | 152 | func (sf *subflow) sendLoop() { 153 | closing := false 154 | closeCountdown := time.NewTimer(time.Millisecond * 33) 155 | closeCountdown.Stop() 156 | defer func() { 157 | sf.finishedClosing <- true 158 | }() 159 | 160 | go func() { 161 | <-sf.chClose 162 | closeCountdown.Reset(time.Millisecond * 33) 163 | closing = true 164 | }() 165 | 166 | for { 167 | select { 168 | case <-closeCountdown.C: 169 | sf.conn.Close() 170 | return 171 | case frame := <-sf.sendQueue: 172 | if closing { 173 | closeCountdown.Reset(time.Millisecond * 33) 174 | } 175 | if closing { 176 | closing = true 177 | } 178 | 179 | frame.changeLock.Lock() 180 | if frame.retransmissions != 0 { 181 | log.Tracef("Retransmit on %d, for the %dth time", frame.fn, frame.retransmissions) 182 | } 183 | if *frame.released == 1 { 184 | log.Errorf("Tried to send a frame that has already been released! Frame Number: %v", frame.fn) 185 | 186 | select { 187 | case sf.mpc.writerMaybeReady <- true: 188 | default: 189 | } 190 | 191 | frame.changeLock.Unlock() 192 | continue 193 | } 194 | if frame.retransmissions == 0 { 195 | if frame.sentVia == nil { 196 | frame.sentVia = make([]transmissionDatapoint, 0) 197 | } 198 | frame.sentVia = append(frame.sentVia, transmissionDatapoint{sf, time.Now()}) 199 | } 200 | 201 | sf.addPendingAck(frame) 202 | frame.changeLock.Unlock() 203 | 204 | atomic.StoreUint64(&sf.actuallyBusyOnWrite, 1) 205 | n, err := sf.conn.Write(frame.buf) 206 | atomic.StoreUint64(&sf.actuallyBusyOnWrite, 0) 207 | var abort bool 208 | for { 209 | // wake all writers up, since they might have something to send now that we likely 210 | // have free capacity. 211 | select { 212 | case sf.mpc.writerMaybeReady <- true: 213 | default: 214 | abort = true 215 | } 216 | if abort { 217 | break 218 | } 219 | } 220 | 221 | // only wake up one re-transmitter, to better control the possible hored of them 222 | select { 223 | case sf.mpc.tryRetransmit <- true: 224 | default: 225 | } 226 | 227 | if err != nil { 228 | log.Debugf("failed to write frame %d to %s: %v", frame.fn, sf.to, err) 229 | 230 | if frame.isDataFrame() { 231 | go sf.mpc.retransmit(frame) 232 | } 233 | 234 | if n != 0 && len(frame.buf) != n { 235 | log.Tracef("We may have corrupted the output %#v vs %#v", n, len(frame.buf)) 236 | // In this case, we will not try and write the remaining, and instead we will assume 237 | // that writing to the socket again will only make this worse, so aborting the subflow 238 | sf.close() 239 | return 240 | } 241 | 242 | sf.close() 243 | return 244 | } 245 | 246 | if n != len(frame.buf) { 247 | panic(fmt.Sprintf("expect to write %d bytes on %s, written %d", len(frame.buf), sf.to, n)) 248 | } 249 | if !frame.isDataFrame() { 250 | frame.release() 251 | continue 252 | } 253 | log.Tracef("done writing frame %d with %d bytes via %s", frame.fn, frame.sz, sf.to) 254 | frame.changeLock.Lock() 255 | if frame.retransmissions == 0 { 256 | sf.tracker.OnSent(frame.sz) 257 | } else { 258 | sf.tracker.OnRetransmit(frame.sz) 259 | } 260 | frame.changeLock.Unlock() 261 | } 262 | } 263 | } 264 | 265 | func (sf *subflow) ack(fn uint64) { 266 | if sf == nil { 267 | // This should only ever happen in testing. 268 | log.Debugf("Nil subflow requested to do an ack! (should only happen on tests)") 269 | return 270 | } 271 | 272 | select { 273 | case <-sf.chClose: 274 | case sf.sendQueue <- composeFrame(fn, nil): 275 | } 276 | } 277 | 278 | func (sf *subflow) gotACK(fn uint64) { 279 | log.Tracef("got ack for frame %d from %s", fn, sf.to) 280 | if fn == frameTypePing { 281 | log.Tracef("pong to %s", sf.to) 282 | sf.ack(frameTypePong) 283 | return 284 | } 285 | 286 | sf.mpc.pendingAckMu.RLock() 287 | pending := sf.mpc.pendingAckMap[fn] 288 | if sf.mpc.pendingAckMap[fn] != nil { 289 | sf.mpc.pendingAckMu.RUnlock() 290 | sf.mpc.pendingAckMu.Lock() 291 | delete(sf.mpc.pendingAckMap, fn) 292 | sf.mpc.pendingAckMu.Unlock() 293 | } else { 294 | sf.mpc.pendingAckMu.RUnlock() 295 | return 296 | } 297 | 298 | if time.Since(pending.sentAt) < time.Second { 299 | pending.outboundSf.updateRTT(time.Since(pending.sentAt)) 300 | } else { 301 | pending.outboundSf.updateRTT(time.Second) 302 | } 303 | } 304 | 305 | func (sf *subflow) updateRTT(rtt time.Duration) { 306 | sf.tracker.UpdateRTT(rtt) 307 | sf.emaRTT.UpdateDuration(rtt) 308 | } 309 | 310 | func (sf *subflow) getRTT() time.Duration { 311 | recorded := sf.emaRTT.GetDuration() 312 | // RTT is updated only when ack is received or retransmission timer raises, 313 | // which can be stale when the subflow starts hanging. If that happens, the 314 | // time since the earliest yet-to-be-acknowledged frame being sent is more 315 | // up-to-date. 316 | var realtime time.Duration 317 | sf.muPendingPing.RLock() 318 | if sf.pendingPing != nil { 319 | realtime = time.Since(sf.pendingPing.sentAt) 320 | } else { 321 | sf.muPendingPing.RUnlock() 322 | return recorded 323 | } 324 | sf.muPendingPing.RUnlock() 325 | if realtime > recorded { 326 | return realtime 327 | } else { 328 | return recorded 329 | } 330 | } 331 | 332 | func (sf *subflow) addPendingAck(frame *sendFrame) { 333 | switch frame.fn { 334 | case frameTypePing: 335 | // we expect pong for ping 336 | sf.muPendingPing.Lock() 337 | sf.pendingPing = &pendingAck{frameTypePong, 0, time.Now(), sf, nil} 338 | sf.muPendingPing.Unlock() 339 | case frameTypePong: 340 | // expect no response for pong 341 | default: 342 | if frame.isDataFrame() { 343 | sf.mpc.pendingAckMu.Lock() 344 | sf.mpc.pendingAckMap[frame.fn] = &pendingAck{frame.fn, frame.sz, time.Now(), sf, frame} 345 | sf.mpc.pendingAckMu.Unlock() 346 | } 347 | } 348 | } 349 | 350 | func (sf *subflow) isPendingAck(fn uint64) bool { 351 | if fn > minFrameNumber { 352 | sf.mpc.pendingAckMu.RLock() 353 | defer sf.mpc.pendingAckMu.RUnlock() 354 | return sf.mpc.pendingAckMap[fn] != nil 355 | } 356 | return false 357 | } 358 | 359 | func (sf *subflow) probe() { 360 | log.Tracef("ping %s", sf.to) 361 | sf.ack(frameTypePing) 362 | } 363 | 364 | func (sf *subflow) retransTimer() time.Duration { 365 | d := sf.emaRTT.GetDuration() * 2 366 | if d > 512*time.Millisecond { 367 | d = 512 * time.Millisecond 368 | } 369 | if d < 1*time.Millisecond { 370 | d = time.Millisecond 371 | } 372 | return d 373 | } 374 | 375 | func (sf *subflow) close() { 376 | sf.closeOnce.Do(func() { 377 | log.Tracef("closing subflow to %s", sf.to) 378 | sf.mpc.remove(sf) 379 | close(sf.chClose) 380 | drainTime := time.Now() 381 | maxDrainTime := time.NewTimer(time.Second) 382 | select { 383 | case <-maxDrainTime.C: 384 | case <-sf.finishedClosing: 385 | } 386 | maxDrainTime.Stop() 387 | log.Debugf("Took %v to close subflow", time.Since(drainTime)) 388 | }) 389 | } 390 | 391 | func randomize(d time.Duration) time.Duration { 392 | return d/2 + time.Duration(rand.Int63n(int64(d))) 393 | } 394 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /multipath_test.go: -------------------------------------------------------------------------------- 1 | package multipath 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math/rand" 8 | "net" 9 | "net/http" 10 | _ "net/http/pprof" 11 | "strconv" 12 | "sync" 13 | "sync/atomic" 14 | "testing" 15 | "time" 16 | 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | var frozenListeners [100]int 21 | var frozenDailers [100]int 22 | var frozenTrackingLock sync.Mutex 23 | 24 | func TestE2E(t *testing.T) { 25 | // Failing on timeout 26 | go func() { 27 | http.ListenAndServe("localhost:6060", nil) 28 | }() 29 | listeners := []net.Listener{} 30 | trackers := []StatsTracker{} 31 | dialers := []Dialer{} 32 | for i := 0; i < 3; i++ { 33 | l, err := net.Listen("tcp", ":") 34 | if !assert.NoError(t, err) { 35 | continue 36 | } 37 | defer l.Close() 38 | listeners = append(listeners, newTestListener(l, i)) 39 | trackers = append(trackers, NullTracker{}) 40 | // simulate one or more dialers to each listener 41 | for j := 0; j <= 1; j++ { 42 | dialers = append(dialers, newTestDialer(l.Addr().String(), len(dialers))) 43 | } 44 | } 45 | log.Debugf("Testing with %d listeners and %d dialers", len(listeners), len(dialers)) 46 | bl := NewListener(listeners, trackers) 47 | defer bl.Close() 48 | bd := NewDialer("endpoint", dialers) 49 | 50 | go func() { 51 | lastDebug := "" 52 | for { 53 | frozenTrackingLock.Lock() 54 | newDebug := "Dailers: \n" 55 | for k, v := range dialers { 56 | newDebug += fmt.Sprintf("\t(%d) - %v\n", frozenDailers[k], v.(*testDialer).name) 57 | } 58 | newDebug += "Listeners: \n" 59 | for k, v := range listeners { 60 | newDebug += fmt.Sprintf("\t(%d) - %v\n", frozenListeners[k], v.(*testListener).l.Addr()) 61 | } 62 | if newDebug != lastDebug { 63 | log.Debug(newDebug) 64 | lastDebug = newDebug 65 | } 66 | frozenTrackingLock.Unlock() 67 | time.Sleep(time.Millisecond * 33) 68 | } 69 | }() 70 | 71 | go func() { 72 | for { 73 | conn, err := bl.Accept() 74 | select { 75 | case <-bl.(*mpListener).chClose: 76 | return 77 | default: 78 | } 79 | assert.NoError(t, err) 80 | go func() { 81 | defer conn.Close() 82 | b := make([]byte, 10240) 83 | for { 84 | n, err := conn.Read(b) 85 | if err != nil { 86 | return 87 | } 88 | log.Debugf("server read %d bytes", n) 89 | n2, err := conn.Write(b[:n]) 90 | if err != nil { 91 | return 92 | } 93 | log.Debugf("server wrote back %d bytes", n2) 94 | } 95 | }() 96 | } 97 | }() 98 | conn, err := bd.DialContext(context.Background()) 99 | if !assert.NoError(t, err) { 100 | return 101 | } 102 | defer conn.Close() 103 | b := make([]byte, 4) 104 | roundtrip := func() { 105 | for i := 0; i < 5; i++ { 106 | copy(b, []byte(strconv.Itoa(i))) 107 | n, err := conn.Write(b) 108 | assert.NoError(t, err) 109 | assert.Equal(t, len(b), n) 110 | log.Debugf("client written '%s'", b) 111 | _, err = io.ReadFull(conn, b) 112 | assert.NoError(t, err) 113 | log.Debugf("client read '%s'", b) 114 | } 115 | } 116 | roundtrip() 117 | 118 | for i := 0; i < len(listeners)-1; i++ { 119 | log.Debugf("========listener[%d] is hanging", i) 120 | frozenTrackingLock.Lock() 121 | frozenListeners[i] = 1 122 | frozenTrackingLock.Unlock() 123 | listeners[i].(*testListener).setDelay(time.Hour) 124 | roundtrip() 125 | } 126 | for i := 0; i < len(dialers)-1; i++ { 127 | log.Debugf("========%s is hanging", dialers[i].Label()) 128 | frozenTrackingLock.Lock() 129 | frozenDailers[i] = 1 130 | frozenTrackingLock.Unlock() 131 | dialers[i].(*testDialer).setDelay(time.Hour) 132 | roundtrip() 133 | } 134 | log.Debugf("========reenabled listener #0 and %s", dialers[0].Label()) 135 | listeners[0].(*testListener).setDelay(0) 136 | dialers[0].(*testDialer).setDelay(0) 137 | frozenTrackingLock.Lock() 138 | frozenListeners[0] = 0 139 | frozenDailers[0] = 0 140 | frozenTrackingLock.Unlock() 141 | 142 | log.Debug("========the last listener is hanging") 143 | listeners[len(listeners)-1].(*testListener).setDelay(time.Hour) 144 | frozenTrackingLock.Lock() 145 | frozenListeners[len(listeners)-1] = 1 146 | frozenTrackingLock.Unlock() 147 | 148 | roundtrip() 149 | log.Debugf("========%s is hanging", dialers[len(dialers)-1].Label()) 150 | dialers[len(dialers)-1].(*testDialer).setDelay(time.Hour) 151 | frozenTrackingLock.Lock() 152 | frozenDailers[len(dialers)-1] = 1 153 | frozenTrackingLock.Unlock() 154 | roundtrip() 155 | 156 | log.Debugf("========Now test writing and reading back tons of data") 157 | b2 := make([]byte, 81920) 158 | b3 := make([]byte, 81920) 159 | rand.Read(b2) 160 | for i := 0; i < 10; i++ { 161 | n, err := conn.Write(b2[:rand.Intn(len(b2))]) 162 | assert.NoError(t, err) 163 | log.Debugf("client wrote %d bytes", n) 164 | _, err = io.ReadFull(conn, b3[:n]) 165 | assert.NoError(t, err) 166 | assert.EqualValues(t, b2[:n], b3[:n]) 167 | } 168 | 169 | // wake up all sleeping goroutines to clean up resources 170 | for i := 0; i < len(listeners); i++ { 171 | listeners[i].(*testListener).setDelay(0) 172 | } 173 | for i := 0; i < len(dialers); i++ { 174 | dialers[i].(*testDialer).setDelay(0) 175 | } 176 | } 177 | 178 | func TestE2EEarlyClose(t *testing.T) { 179 | // Reusing the testE2E infra as much as possible 180 | // 181 | // this test is here to ensure that mpConns transfer the full 182 | // set of data transmitted when the connection is closed, to avoid truncation. 183 | listeners := []net.Listener{} 184 | trackers := []StatsTracker{} 185 | dialers := []Dialer{} 186 | for i := 0; i < 3; i++ { 187 | l, err := net.Listen("tcp", ":") 188 | if !assert.NoError(t, err) { 189 | continue 190 | } 191 | defer l.Close() 192 | listeners = append(listeners, newTestListener(l, i)) 193 | trackers = append(trackers, NullTracker{}) 194 | // simulate one or more dialers to each listener 195 | for j := 0; j <= 1; j++ { 196 | dialers = append(dialers, newTestDialer(l.Addr().String(), len(dialers))) 197 | } 198 | } 199 | log.Debugf("Testing with %d listeners and %d dialers", len(listeners), len(dialers)) 200 | bl := NewListener(listeners, trackers) 201 | defer bl.Close() 202 | bd := NewDialer("endpoint", dialers) 203 | 204 | go func() { 205 | lastDebug := "" 206 | for { 207 | frozenTrackingLock.Lock() 208 | newDebug := "Dailers: \n" 209 | for k, v := range dialers { 210 | newDebug += fmt.Sprintf("\t(%d) - %v\n", frozenDailers[k], v.(*testDialer).name) 211 | } 212 | newDebug += "Listeners: \n" 213 | for k, v := range listeners { 214 | newDebug += fmt.Sprintf("\t(%d) - %v\n", frozenListeners[k], v.(*testListener).l.Addr()) 215 | } 216 | if newDebug != lastDebug { 217 | log.Debug(newDebug) 218 | lastDebug = newDebug 219 | } 220 | frozenTrackingLock.Unlock() 221 | time.Sleep(time.Millisecond * 33) 222 | } 223 | }() 224 | 225 | go func() { 226 | for { 227 | conn, err := bl.Accept() 228 | select { 229 | case <-bl.(*mpListener).chClose: 230 | return 231 | default: 232 | } 233 | assert.NoError(t, err) 234 | go func() { 235 | defer conn.Close() 236 | dataLeftToSend := 10 * 100000000 // 10MB 237 | b := make([]byte, 10240) 238 | for { 239 | var n int 240 | var err error 241 | if dataLeftToSend < len(b) { 242 | n, err = conn.Write(b[:dataLeftToSend]) 243 | } else { 244 | n, err = conn.Write(b) 245 | } 246 | if err != nil { 247 | return 248 | } 249 | dataLeftToSend = dataLeftToSend - n 250 | 251 | if dataLeftToSend == 0 { 252 | return 253 | } 254 | } 255 | }() 256 | } 257 | }() 258 | conn, err := bd.DialContext(context.Background()) 259 | if !assert.NoError(t, err) { 260 | return 261 | } 262 | defer conn.Close() 263 | 264 | readBytes := 0 265 | for { 266 | b := make([]byte, 1024) 267 | n, err := conn.Read(b) 268 | if err != nil { 269 | fmt.Printf("Connection closed early at %v (%v)\n", readBytes, err) 270 | t.FailNow() 271 | } 272 | readBytes += n 273 | if readBytes == 10*100000000 { 274 | // pass! 275 | break 276 | } 277 | } 278 | t.Fatalf("aaa %v", readBytes) 279 | } 280 | 281 | func TestE2EEarlyCloseOtherWay(t *testing.T) { 282 | // Reusing the testE2E infra as much as possible 283 | // 284 | // this test is here to ensure that mpConns transfer the full 285 | // set of data transmitted when the connection is closed, to avoid truncation. 286 | listeners := []net.Listener{} 287 | trackers := []StatsTracker{} 288 | dialers := []Dialer{} 289 | for i := 0; i < 3; i++ { 290 | l, err := net.Listen("tcp", ":") 291 | if !assert.NoError(t, err) { 292 | continue 293 | } 294 | defer l.Close() 295 | listeners = append(listeners, newTestListener(l, i)) 296 | trackers = append(trackers, NullTracker{}) 297 | // simulate one or more dialers to each listener 298 | for j := 0; j <= 1; j++ { 299 | dialers = append(dialers, newTestDialer(l.Addr().String(), len(dialers))) 300 | } 301 | } 302 | log.Debugf("Testing with %d listeners and %d dialers", len(listeners), len(dialers)) 303 | bl := NewListener(listeners, trackers) 304 | defer bl.Close() 305 | bd := NewDialer("endpoint", dialers) 306 | 307 | go func() { 308 | lastDebug := "" 309 | for { 310 | frozenTrackingLock.Lock() 311 | newDebug := "Dailers: \n" 312 | for k, v := range dialers { 313 | newDebug += fmt.Sprintf("\t(%d) - %v\n", frozenDailers[k], v.(*testDialer).name) 314 | } 315 | newDebug += "Listeners: \n" 316 | for k, v := range listeners { 317 | newDebug += fmt.Sprintf("\t(%d) - %v\n", frozenListeners[k], v.(*testListener).l.Addr()) 318 | } 319 | if newDebug != lastDebug { 320 | log.Debug(newDebug) 321 | lastDebug = newDebug 322 | } 323 | frozenTrackingLock.Unlock() 324 | time.Sleep(time.Millisecond * 33) 325 | } 326 | }() 327 | 328 | go func() { 329 | for { 330 | conn, err := bl.Accept() 331 | select { 332 | case <-bl.(*mpListener).chClose: 333 | return 334 | default: 335 | } 336 | assert.NoError(t, err) 337 | go func() { 338 | defer conn.Close() 339 | 340 | readBytes := 0 341 | for { 342 | b := make([]byte, 1024) 343 | n, err := conn.Read(b) 344 | if err != nil { 345 | fmt.Printf("Connection closed early at %v (%v)\n", readBytes, err) 346 | t.FailNow() 347 | } 348 | readBytes += n 349 | if readBytes == 10*100000000 { 350 | // pass! 351 | return 352 | } 353 | } 354 | }() 355 | } 356 | }() 357 | conn, err := bd.DialContext(context.Background()) 358 | if !assert.NoError(t, err) { 359 | return 360 | } 361 | 362 | defer conn.Close() 363 | dataLeftToSend := 10 * 100000000 // 10MB 364 | b := make([]byte, 10210) 365 | for { 366 | var n int 367 | var err error 368 | if dataLeftToSend < len(b) { 369 | n, err = conn.Write(b[:dataLeftToSend]) 370 | } else { 371 | n, err = conn.Write(b) 372 | } 373 | if err != nil { 374 | t.Fatalf("Failed to write %v", err) 375 | } 376 | dataLeftToSend = dataLeftToSend - n 377 | 378 | if dataLeftToSend == 0 { 379 | return 380 | } 381 | } 382 | 383 | } 384 | 385 | type testDialer struct { 386 | delayEnforcer 387 | addr string 388 | idx int 389 | } 390 | 391 | func newTestDialer(addr string, idx int) *testDialer { 392 | var lock sync.Mutex 393 | td := &testDialer{ 394 | delayEnforcer{cond: sync.NewCond(&lock)}, addr, idx, 395 | } 396 | td.delayEnforcer.name = td.Label() 397 | return td 398 | } 399 | 400 | func (td *testDialer) DialContext(ctx context.Context) (net.Conn, error) { 401 | var d net.Dialer 402 | conn, err := d.DialContext(ctx, "tcp", td.addr) 403 | if err != nil { 404 | return nil, err 405 | } 406 | return &laggedConn{conn, conn, td.delayEnforcer.sleep}, nil 407 | } 408 | 409 | func (td *testDialer) Label() string { 410 | return fmt.Sprintf("test dialer #%d to %v", td.idx, td.addr) 411 | } 412 | 413 | type testListener struct { 414 | net.Listener 415 | delayEnforcer 416 | l net.Listener 417 | } 418 | 419 | func newTestListener(l net.Listener, idx int) *testListener { 420 | var lock sync.Mutex 421 | tl := &testListener{l, delayEnforcer{cond: sync.NewCond(&lock)}, l} 422 | tl.delayEnforcer.name = fmt.Sprintf("listener %d", idx) 423 | return tl 424 | } 425 | 426 | func (tl *testListener) Accept() (net.Conn, error) { 427 | conn, err := tl.l.Accept() 428 | if err != nil { 429 | return nil, err 430 | } 431 | return &laggedConn{conn, conn, tl.delayEnforcer.sleep}, nil 432 | } 433 | 434 | type laggedConn struct { 435 | net.Conn 436 | conn net.Conn // has to be the same as the net.Conn 437 | sleep func() 438 | } 439 | 440 | func (c *laggedConn) Read(b []byte) (int, error) { 441 | c.sleep() 442 | return c.conn.Read(b) 443 | } 444 | 445 | func TestDelayEnforcer(t *testing.T) { 446 | var lock sync.Mutex 447 | d := delayEnforcer{cond: sync.NewCond(&lock)} 448 | var wg sync.WaitGroup 449 | d.setDelay(time.Hour) 450 | wg.Add(1) 451 | start := time.Now() 452 | go func() { 453 | d.sleep() 454 | wg.Done() 455 | }() 456 | time.Sleep(100 * time.Millisecond) 457 | d.setDelay(0) 458 | wg.Wait() 459 | assert.InDelta(t, time.Since(start), 100*time.Millisecond, float64(10*time.Millisecond)) 460 | } 461 | 462 | type delayEnforcer struct { 463 | name string 464 | delay int64 465 | cond *sync.Cond 466 | } 467 | 468 | func (e *delayEnforcer) setDelay(d time.Duration) { 469 | atomic.StoreInt64(&e.delay, int64(d)) 470 | log.Debugf("%s delay is set to %v", e.name, d) 471 | e.cond.Broadcast() 472 | } 473 | 474 | func (e *delayEnforcer) sleep() { 475 | e.cond.L.Lock() 476 | defer e.cond.L.Unlock() 477 | for { 478 | d := atomic.LoadInt64(&e.delay) 479 | if delay := time.Duration(d); delay > 0 { 480 | log.Debugf("%s sleep for %v", e.name, delay) 481 | time.AfterFunc(delay, func() { 482 | e.cond.Broadcast() 483 | }) 484 | e.cond.Wait() 485 | log.Debugf("%s done sleeping", e.name) 486 | } else { 487 | return 488 | } 489 | } 490 | } 491 | --------------------------------------------------------------------------------