├── .gitignore ├── LICENSE.md ├── MINT-LICENSE.md ├── README.md ├── aead.go ├── aead_test.go ├── bin ├── client │ └── main.go ├── server │ └── main.go ├── shim │ └── main.go └── tester │ └── main.go ├── codec.go ├── codec_test.go ├── common.go ├── common_test.go ├── congestion.go ├── connbuffer.go ├── connection.go ├── connection_test.go ├── crypto.go ├── deploy ├── Dockerfile ├── logserver │ ├── package.json │ └── server.js ├── mk-endpoint.sh ├── mk-localhost.sh ├── run-local.sh └── run-looped.sh ├── errors.go ├── frame.go ├── frame_test.go ├── log.go ├── minq.png ├── minq.svg ├── packet.go ├── packet_test.go ├── record-layer.go ├── server.go ├── server_test.go ├── stream.go ├── stream_test.go ├── timer.go ├── tls.go ├── tracking.go ├── tracking_test.go ├── transport.go ├── transport_parameters.go └── udp_transport.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.test 2 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Eric Rescorla 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /MINT-LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Richard Barnes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WARNING 2 | 3 | **This implementation is not under active development, and has diverged from the QUIC specification.** 4 | 5 | The QUIC WG maintains [a list of active implementations](https://github.com/quicwg/base-drafts/wiki/Implementations). 6 | 7 | ![A mink forming a Q](/minq.png) 8 | 9 | ----- 10 | 11 | minq -- A minimal QUIC stack 12 | ============================ 13 | Minq is a minimal implementation of QUIC, as documented at 14 | https://quicwg.github.io/. Minq partly implements draft-05 15 | (it advertises -04 but it's actually more like the editor's copy) 16 | with TLS 1.3 draft-20 or draft-21. 17 | 18 | Currently it will do: 19 | 20 | - A 1-RTT handshake (with self-generated and unverified certificates) 21 | - Some ACK processing 22 | - Primitive retransmission (manual, no timers) 23 | - 1-RTT application data 24 | - Exchange of stream close (though this doesn't really have much impact) 25 | 26 | Important missing pieces for the first implementation draft include: 27 | 28 | - Handling ACK ranges 29 | - Real timeout and retransmission support 30 | 31 | Other defects include: 32 | 33 | - Doesn't properly clean up state, so things will just grow without bound 34 | - TLS configuration and verification 35 | - A huge other pile of unknown and known defects. 36 | 37 | 38 | ## WARNING 39 | 40 | Minq is absolutely not suitable for any kind of production use and should 41 | only be used for testing. In particular, it explicitly doesn't validate 42 | certificates. 43 | 44 | 45 | 46 | ## Quick Start (untested but should be rightish) 47 | 48 | cd ${GOPATH}/src 49 | go get github.com/ekr/minq 50 | cd github.com/bifurcation/mint 51 | git remote add ekr https://github.com/ekr/mint 52 | git fetch ekr 53 | git checkout ekr/quic_record_layer 54 | cd ../../ekr/minq 55 | go test 56 | 57 | This should produce something like this: 58 | 59 | Result = 010002616263 60 | Result2 = 010002616263 61 | Result = 0102616263 62 | Result2 = 0102616263 63 | {1 2 [97 98 99]} 64 | {1 1 [8 16]} 65 | {3 2 [8 16 24 32]} 66 | Checking client state 67 | Checking server state 68 | Encoded frame ab00deadbeef0000000000000001 69 | Encoded frame bb0100deadbeef00000000000000010e00000001 70 | Result = 820123456789abcdefdeadbeefff000001 71 | Result2 = 820123456789abcdefdeadbeefff000001 72 | PASS 73 | ok github.com/ekr/minq 1.285s 74 | 75 | It's the "ok" at the end that's important. 76 | 77 | There are two test programs that live in ```minq/bin/client``` and 78 | ```minq/bin/server```. The server is an echo server that upcases the 79 | returned data. The client is just a passthrough. 80 | 81 | In ```${GOPATH}/src/github.com/ekr```, doing 82 | 83 | go run minq/bin/server/main.go 84 | go run minq/bin/client/main.go 85 | 86 | In separate windows should have the desired result. 87 | 88 | 89 | ## Logging 90 | 91 | To enable logging, set the ```MINQ_LOG``` environment variable, as 92 | in ```MINQ_LOG=connection go test```. Valid values are: 93 | 94 | // Pre-defined log types 95 | const ( 96 | logTypeAead = "aead" 97 | logTypeCodec = "codec" 98 | logTypeConnBuffer = "connbuffer" 99 | logTypeConnection = "connection" 100 | logTypeAck = "ack" 101 | logTypeFrame = "frame" 102 | logTypeHandshake = "handshake" 103 | logTypeTls = "tls" 104 | logTypeTrace = "trace" 105 | logTypeServer = "server" 106 | logTypeUdp = "udp" 107 | ) 108 | 109 | Multiple log levels can be separated by commas. 110 | 111 | ## Mint 112 | 113 | Minq depends on Mint (https://www.github.com/bifurcation/mint) for TLS. 114 | Right now we are on the following branch: 115 | 116 | https://github.com/ekr/mint/tree/quic_record_layer 117 | 118 | This branch is more experimental than usual. 119 | 120 | -------------------------------------------------------------------------------- /aead.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | ) 7 | 8 | // aeadWrapper contains an existing AEAD object and does the 9 | // QUIC nonce masking. 10 | type aeadWrapper struct { 11 | iv []byte 12 | cipher cipher.AEAD 13 | } 14 | 15 | func (a *aeadWrapper) NonceSize() int { 16 | return a.cipher.NonceSize() 17 | } 18 | func (a *aeadWrapper) Overhead() int { 19 | return a.cipher.Overhead() 20 | } 21 | 22 | func (a *aeadWrapper) fmtNonce(in []byte) []byte { 23 | // The input nonce is actually a packet number. 24 | assert(len(in) == 8) 25 | assert(a.NonceSize() == 12) 26 | assert(len(a.iv) == a.NonceSize()) 27 | 28 | nonce := make([]byte, a.NonceSize()) 29 | copy(nonce[len(nonce)-len(in):], in) 30 | for i, b := range a.iv { 31 | nonce[i] ^= b 32 | } 33 | 34 | logf(logTypeAead, "Nonce=%x", nonce) 35 | return nonce 36 | } 37 | 38 | func (a *aeadWrapper) Seal(dst []byte, nonce []byte, plaintext []byte, aad []byte) []byte { 39 | logf(logTypeAead, "AES protecting aad len=%d, plaintext len=%d", len(aad), len(plaintext)) 40 | logf(logTypeTrace, "AES input AAD=%x P=%x", aad, plaintext) 41 | ret := a.cipher.Seal(dst, a.fmtNonce(nonce), plaintext, aad) 42 | logf(logTypeTrace, "AES output %x", ret) 43 | 44 | return ret 45 | } 46 | 47 | func (a *aeadWrapper) Open(dst []byte, nonce []byte, ciphertext []byte, aad []byte) ([]byte, error) { 48 | logf(logTypeAead, "AES unprotecting aad len=%d, ciphertext len=%d", len(aad), len(ciphertext)) 49 | logf(logTypeTrace, "AES input AAD=%x C=%x", aad, ciphertext) 50 | ret, err := a.cipher.Open(dst, a.fmtNonce(nonce), ciphertext, aad) 51 | if err != nil { 52 | return nil, err 53 | } 54 | logf(logTypeTrace, "AES output %x", ret) 55 | return ret, err 56 | } 57 | 58 | func newWrappedAESGCM(key []byte, iv []byte) (cipher.AEAD, error) { 59 | logf(logTypeAead, "New AES GCM context: key=%x iv=%x", key, iv) 60 | a, err := aes.NewCipher(key) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | aead, err := cipher.NewGCM(a) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | return &aeadWrapper{iv, aead}, nil 71 | } 72 | -------------------------------------------------------------------------------- /aead_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "crypto/cipher" 5 | "testing" 6 | ) 7 | 8 | var kTestKey1 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} 9 | var kTestIV1 = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} 10 | var kTestKey2 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} 11 | var kTestIV2 = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13} 12 | 13 | var ktestAeadHdr1 = []byte{1, 2, 3} 14 | var ktestAeadHdr2 = []byte{1, 2, 4} 15 | var ktestAeadBody1 = []byte{5, 6, 7} 16 | var ktestAeadBody2 = []byte{5, 6, 8} 17 | 18 | var kNonce0 = []byte{0, 0, 0, 0, 0, 0, 0, 0} 19 | var kNonce1 = []byte{0, 0, 0, 0, 0, 0, 0, 1} 20 | 21 | func testAeadSuccess(t *testing.T, aead cipher.AEAD) { 22 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1) 23 | 24 | pt, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr1) 25 | assertNotError(t, err, "Could not unprotect") 26 | 27 | assertByteEquals(t, pt, ktestAeadBody1) 28 | } 29 | 30 | func testAeadWrongPacketNumber(t *testing.T, aead cipher.AEAD) { 31 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1) 32 | 33 | _, err := aead.Open(nil, kNonce1, ct, ktestAeadHdr1) 34 | assertError(t, err, "Shouldn't have unprotected") 35 | } 36 | 37 | func testAeadWrongHeader(t *testing.T, aead cipher.AEAD) { 38 | 39 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1) 40 | 41 | _, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr2) 42 | assertError(t, err, "Shouldn't have unprotected") 43 | } 44 | 45 | func testAeadCorruptCT(t *testing.T, aead cipher.AEAD) { 46 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1) 47 | 48 | ct[0]++ 49 | _, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr1) 50 | assertError(t, err, "Shouldn't have unprotected") 51 | } 52 | 53 | func testAeadCorruptTag(t *testing.T, aead cipher.AEAD) { 54 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1) 55 | ct[len(ct)-1]++ 56 | _, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr1) 57 | assertError(t, err, "Shouldn't have unprotected") 58 | } 59 | 60 | func testAeadWrongAead(t *testing.T, aead cipher.AEAD, aead2 cipher.AEAD) { 61 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1) 62 | _, err := aead2.Open(nil, kNonce0, ct, ktestAeadHdr1) 63 | assertError(t, err, "Shouldn't have unprotected") 64 | } 65 | 66 | func testAeadAll(t *testing.T, aead cipher.AEAD) { 67 | t.Run("Success", func(t *testing.T) { testAeadSuccess(t, aead) }) 68 | t.Run("WrongHeader", func(t *testing.T) { testAeadWrongHeader(t, aead) }) 69 | t.Run("CorruptCT", func(t *testing.T) { testAeadCorruptCT(t, aead) }) 70 | t.Run("CorruptTag", func(t *testing.T) { testAeadCorruptTag(t, aead) }) 71 | } 72 | 73 | func makeWrappedAead(t *testing.T, key []byte, iv []byte) cipher.AEAD { 74 | a, err := newWrappedAESGCM(key, iv) 75 | assertNotError(t, err, "Couldn't make AEAD") 76 | return a 77 | } 78 | 79 | func TestAeadAES128GCM(t *testing.T) { 80 | a1 := makeWrappedAead(t, kTestKey1, kTestIV1) 81 | a2 := makeWrappedAead(t, kTestKey2, kTestIV1) 82 | a3 := makeWrappedAead(t, kTestKey1, kTestIV2) 83 | 84 | testAeadAll(t, a1) 85 | t.Run("WrongKey", func(t *testing.T) { testAeadWrongAead(t, a1, a2) }) 86 | t.Run("WrongIV", func(t *testing.T) { testAeadWrongAead(t, a1, a3) }) 87 | t.Run("WrongPacketNumber", func(t *testing.T) { testAeadWrongPacketNumber(t, a1) }) 88 | } 89 | -------------------------------------------------------------------------------- /bin/client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "github.com/ekr/minq" 7 | "log" 8 | "net" 9 | "os" 10 | "runtime/pprof" 11 | "time" 12 | ) 13 | 14 | var addr string 15 | var serverName string 16 | var doHttp string 17 | var httpCount int 18 | var heartbeat int 19 | var cpuProfile string 20 | var resume bool 21 | var httpLeft int 22 | var zeroRtt bool 23 | 24 | type connHandler struct { 25 | bytesRead int 26 | } 27 | 28 | func (h *connHandler) StateChanged(s minq.State) { 29 | log.Println("State changed to ", s) 30 | } 31 | 32 | func (h *connHandler) NewStream(s minq.Stream) { 33 | } 34 | 35 | func (h *connHandler) NewRecvStream(s minq.RecvStream) { 36 | } 37 | 38 | func (h *connHandler) StreamReadable(s minq.RecvStream) { 39 | for { 40 | b := make([]byte, 1024) 41 | 42 | n, err := s.Read(b) 43 | switch err { 44 | case nil: 45 | break 46 | case minq.ErrorWouldBlock: 47 | return 48 | case minq.ErrorStreamIsClosed, minq.ErrorConnIsClosed: 49 | log.Println("") 50 | httpLeft-- 51 | return 52 | default: 53 | log.Println("Error: ", err) 54 | httpLeft-- 55 | return 56 | } 57 | b = b[:n] 58 | h.bytesRead += n 59 | os.Stdout.Write(b) 60 | os.Stderr.Write([]byte(fmt.Sprintf("Total bytes read = %d\n", h.bytesRead))) 61 | } 62 | } 63 | 64 | func readUDP(s *net.UDPConn) ([]byte, error) { 65 | b := make([]byte, 8192) 66 | 67 | s.SetReadDeadline(time.Now().Add(time.Second)) 68 | n, _, err := s.ReadFromUDP(b) 69 | if err != nil { 70 | e, o := err.(net.Error) 71 | if o && e.Timeout() { 72 | return nil, minq.ErrorWouldBlock 73 | } 74 | log.Println("Error reading from UDP socket: ", err) 75 | return nil, err 76 | } 77 | 78 | if n == len(b) { 79 | log.Println("Underread from UDP socket") 80 | return nil, err 81 | } 82 | b = b[:n] 83 | return b, nil 84 | } 85 | 86 | func makeConnection(config *minq.TlsConfig, uaddr *net.UDPAddr) (*net.UDPConn, *minq.Connection) { 87 | usock, err := net.ListenUDP("udp", nil) 88 | if err != nil { 89 | log.Println("Couldn't create connected UDP socket") 90 | return nil, nil 91 | } 92 | 93 | utrans := minq.NewUdpTransport(usock, uaddr) 94 | 95 | conn := minq.NewConnection(utrans, minq.RoleClient, 96 | config, &connHandler{}) 97 | 98 | log.Printf("Client conn id=%v\n", conn.ClientId()) 99 | 100 | // Start things off. 101 | _, err = conn.CheckTimer() 102 | 103 | return usock, conn 104 | } 105 | 106 | func completeConnection(usock *net.UDPConn, conn *minq.Connection) error { 107 | for conn.GetState() != minq.StateEstablished { 108 | b, err := readUDP(usock) 109 | if err != nil { 110 | if err == minq.ErrorWouldBlock { 111 | _, err = conn.CheckTimer() 112 | if err != nil { 113 | return err 114 | } 115 | continue 116 | } 117 | return err 118 | } 119 | 120 | err = conn.Input(b) 121 | if err != nil { 122 | log.Println("Error", err) 123 | return err 124 | } 125 | } 126 | 127 | log.Printf("Connection established server CID = %v\n", conn.ServerId()) 128 | return nil 129 | } 130 | 131 | func main() { 132 | log.Println("PID=", os.Getpid()) 133 | flag.StringVar(&addr, "addr", "localhost:4433", "[host:port]") 134 | flag.StringVar(&serverName, "server-name", "", "SNI") 135 | flag.StringVar(&doHttp, "http", "", "Do HTTP/0.9 with provided URL") 136 | flag.IntVar(&httpCount, "httpCount", 1, "Number of parallel HTTP requests to start") 137 | flag.IntVar(&heartbeat, "heartbeat", 0, "heartbeat frequency [ms]") 138 | flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to file") 139 | flag.BoolVar(&resume, "resume", false, "Test resumption") 140 | flag.BoolVar(&zeroRtt, "zerortt", false, "Test 0-RTT") 141 | flag.Parse() 142 | 143 | if zeroRtt { 144 | resume = true 145 | if doHttp == "" { 146 | log.Printf("Need HTTP to do 0-RTT") 147 | return 148 | } 149 | } 150 | if cpuProfile != "" { 151 | f, err := os.Create(cpuProfile) 152 | if err != nil { 153 | log.Printf("Could not create CPU profile file %v err=%v\n", cpuProfile, err) 154 | return 155 | } 156 | pprof.StartCPUProfile(f) 157 | log.Println("CPU profiler started") 158 | defer pprof.StopCPUProfile() 159 | } 160 | 161 | // Default to the host component of addr. 162 | if serverName == "" { 163 | host, _, err := net.SplitHostPort(addr) 164 | if err != nil { 165 | log.Println("Couldn't split host/port", err) 166 | } 167 | serverName = host 168 | } 169 | config := minq.NewTlsConfig(serverName) 170 | 171 | inner_main(&config, false) 172 | if resume { 173 | inner_main(&config, true) 174 | } 175 | } 176 | func inner_main(config *minq.TlsConfig, resuming bool) { 177 | 178 | uaddr, err := net.ResolveUDPAddr("udp", addr) 179 | if err != nil { 180 | log.Println("Invalid UDP addr", err) 181 | return 182 | } 183 | 184 | usock, conn := makeConnection(config, uaddr) 185 | if conn == nil { 186 | return 187 | } 188 | 189 | if !resuming || !zeroRtt { 190 | err = completeConnection(usock, conn) 191 | if err != nil { 192 | return 193 | } 194 | } 195 | 196 | // Hopefully reduce the risk of reordering 197 | time.Sleep(100 * time.Millisecond) 198 | 199 | // Make all the streams we need 200 | streams := make([]minq.Stream, httpCount) 201 | for i := 0; i < httpCount; i++ { 202 | streams[i] = conn.CreateStream() 203 | if streams[i] == nil { 204 | log.Println("Couldn't create stream") 205 | return 206 | } 207 | } 208 | httpLeft = httpCount 209 | 210 | udpin := make(chan []byte) 211 | stdin := make(chan []byte) 212 | 213 | // Read from the UDP socket. 214 | go func() { 215 | for { 216 | b, err := readUDP(usock) 217 | if err == minq.ErrorWouldBlock { 218 | udpin <- make([]byte, 0) 219 | continue 220 | } 221 | udpin <- b 222 | if b == nil { 223 | return 224 | } 225 | } 226 | }() 227 | 228 | if heartbeat > 0 && doHttp == "" { 229 | ticker := time.NewTicker(time.Millisecond * time.Duration(heartbeat)) 230 | go func() { 231 | for t := range ticker.C { 232 | stdin <- []byte(fmt.Sprintf("Heartbeat at %v\n", t)) 233 | } 234 | }() 235 | } 236 | 237 | if doHttp != "" { 238 | req := "GET " + doHttp + "\r\n" 239 | for _, str := range streams { 240 | str.Write([]byte(req)) 241 | str.Close() 242 | } 243 | } 244 | 245 | if resuming && zeroRtt { 246 | log.Println("Completing connection after we sent 0-RTT send in 0-RTT") 247 | err = completeConnection(usock, conn) 248 | if err != nil { 249 | return 250 | } 251 | } 252 | 253 | if doHttp == "" { 254 | // Read from stdin. 255 | go func() { 256 | for { 257 | b := make([]byte, 1024) 258 | n, err := os.Stdin.Read(b) 259 | if err != nil { 260 | stdin <- nil 261 | return 262 | } 263 | b = b[:n] 264 | stdin <- b 265 | } 266 | }() 267 | } 268 | for { 269 | select { 270 | case u := <-udpin: 271 | if len(u) == 0 { 272 | _, err = conn.CheckTimer() 273 | } else { 274 | err = conn.Input(u) 275 | } 276 | if err != nil { 277 | log.Println("Error", err) 278 | return 279 | } 280 | if doHttp != "" && httpLeft == 0 { 281 | return 282 | } 283 | case i := <-stdin: 284 | if i == nil { 285 | // TODO(piet@devae.re) close the apropriate stream(s) 286 | } 287 | streams[0].Write(i) 288 | if err != nil { 289 | log.Println("Error", err) 290 | return 291 | } 292 | } 293 | 294 | } 295 | } 296 | -------------------------------------------------------------------------------- /bin/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "crypto" 6 | "crypto/x509" 7 | "flag" 8 | "fmt" 9 | "github.com/cloudflare/cfssl/helpers" 10 | "github.com/ekr/minq" 11 | "io" 12 | "io/ioutil" 13 | "log" 14 | "net" 15 | "os" 16 | "runtime/pprof" 17 | "strconv" 18 | "strings" 19 | "time" 20 | ) 21 | 22 | var addr string 23 | var serverName string 24 | var keyFile string 25 | var certFile string 26 | var logFile string 27 | var logOut *os.File 28 | var doHttp bool 29 | var statelessReset bool 30 | var cpuProfile string 31 | var echo bool 32 | var standalone bool 33 | 34 | // Shared data structures. 35 | type conn struct { 36 | conn *minq.Connection 37 | last time.Time 38 | } 39 | 40 | func (c *conn) checkTimer() { 41 | t := time.Now() 42 | if t.After(c.last.Add(time.Second)) { 43 | c.conn.CheckTimer() 44 | c.last = time.Now() 45 | } 46 | } 47 | 48 | var conns = make(map[string]*conn) 49 | 50 | // An feed through server. 51 | type feedthroughServerHandler struct { 52 | echo bool 53 | } 54 | 55 | func (h *feedthroughServerHandler) NewConnection(c *minq.Connection) { 56 | log.Println("New connection") 57 | c.SetHandler(&feedthroughConnHandler{echo, 0}) 58 | conns[c.ServerId().String()] = &conn{c, time.Now()} 59 | } 60 | 61 | type feedthroughConnHandler struct { 62 | echo bool 63 | bytesRead int 64 | } 65 | 66 | func (h *feedthroughConnHandler) StateChanged(s minq.State) { 67 | log.Println("State changed to ", s) 68 | } 69 | 70 | func (h *feedthroughConnHandler) NewStream(s minq.Stream) { 71 | log.Println("Created new stream id=", s.Id()) 72 | } 73 | func (h *feedthroughConnHandler) NewRecvStream(s minq.RecvStream) { 74 | log.Println("Created new stream id=", s.Id()) 75 | } 76 | 77 | func (h *feedthroughConnHandler) StreamReadable(s minq.RecvStream) { 78 | log.Println("Ready to read for stream id=", s.Id()) 79 | for { 80 | b := make([]byte, 1024) 81 | 82 | n, err := s.Read(b) 83 | switch err { 84 | case nil: 85 | break 86 | case minq.ErrorWouldBlock: 87 | return 88 | case minq.ErrorStreamIsClosed, minq.ErrorConnIsClosed: 89 | log.Println("") 90 | return 91 | default: 92 | log.Println("Error: ", err) 93 | return 94 | } 95 | b = b[:n] 96 | h.bytesRead += n 97 | os.Stdout.Write(b) 98 | log.Printf("Total bytes read = %d\n", h.bytesRead) 99 | 100 | if echo { 101 | // Flip the case so we can distinguish echo 102 | for i := range b { 103 | if b[i] > 0x40 { 104 | b[i] ^= 0x20 105 | } 106 | } 107 | // This isn't really going to work but for now. 108 | s.(minq.SendStream).Write(b) 109 | } 110 | } 111 | } 112 | 113 | // An HTTP 0.9 Handler 114 | type httpServerHandler struct { 115 | } 116 | 117 | func (h *httpServerHandler) NewConnection(c *minq.Connection) { 118 | log.Println("New connection") 119 | c.SetHandler(&httpConnHandler{make(map[uint64]*httpStream, 0)}) 120 | conns[c.ServerId().String()] = &conn{c, time.Now()} 121 | } 122 | 123 | type httpStream struct { 124 | s minq.Stream 125 | buf []byte 126 | closed bool 127 | } 128 | 129 | type httpConnHandler struct { 130 | streams map[uint64]*httpStream 131 | } 132 | 133 | func (h *httpConnHandler) StateChanged(s minq.State) { 134 | log.Println("State changed to ", s) 135 | } 136 | 137 | func (h *httpConnHandler) NewStream(s minq.Stream) { 138 | h.streams[s.Id()] = &httpStream{s, nil, false} 139 | } 140 | 141 | func (h *httpConnHandler) NewRecvStream(s minq.RecvStream) { 142 | log.Println("For some reason some opened a unidirectional stream. Ignoring") 143 | } 144 | 145 | func (h *httpStream) Respond(val []byte) { 146 | h.s.Write(val) 147 | h.s.Close() 148 | h.closed = true 149 | } 150 | 151 | func (h *httpStream) Error(err string) { 152 | h.Respond([]byte(err)) 153 | } 154 | 155 | // We expect the URL to be one of two things: 156 | // 157 | // A number, in which case we respond with that number of 158 | // Xs, up to 10,000 159 | // A non-number, in which case we respond with 10 repetitions 160 | // of that value. 161 | func (h *httpConnHandler) StreamReadable(s minq.RecvStream) { 162 | log.Println("Ready to read for stream id=", s.Id()) 163 | st := h.streams[s.Id()] 164 | if st.closed { 165 | return 166 | } 167 | 168 | b := make([]byte, 1024) 169 | n, err := s.Read(b) 170 | if err != nil && err != minq.ErrorWouldBlock { 171 | log.Println("Error reading") 172 | return 173 | } 174 | b = b[:n] 175 | log.Printf("Read %v bytes from peer %x\n", n, b) 176 | 177 | st.buf = append(st.buf, b...) 178 | 179 | // See if we received a complete LF 180 | str := string(st.buf) 181 | idx := strings.IndexRune(str, '\n') 182 | if idx == -1 { 183 | return 184 | } 185 | str = str[:idx] 186 | 187 | // OK, we have a complete line. 188 | toks := strings.Split(str, " ") 189 | if toks[0] != "GET" { 190 | st.Error(fmt.Sprintf("Bogus method: %v", toks[0])) 191 | return 192 | } 193 | if len(toks) < 2 { 194 | st.Error("No resource") 195 | return 196 | } 197 | 198 | val := strings.TrimSpace(toks[1]) 199 | 200 | if val[0] != '/' { 201 | st.Error(fmt.Sprintf("Bad value: %v", val)) 202 | return 203 | } 204 | val = val[1:] 205 | 206 | count, err := strconv.ParseUint(val, 10, 32) 207 | var rsp []byte 208 | if err == nil { 209 | if count > 10000 { 210 | count = 10000 211 | } 212 | rsp = bytes.Repeat([]byte{'X'}, int(count)) 213 | } else { 214 | rspstr := "" 215 | for i := 0; i < 10; i++ { 216 | rspstr += val 217 | rspstr += "--" 218 | } 219 | rspstr += "\n" 220 | rsp = []byte(rspstr) 221 | } 222 | st.Respond(rsp) 223 | } 224 | 225 | func logFunc(format string, args ...interface{}) { 226 | fmt.Fprintf(logOut, format, args...) 227 | fmt.Fprintf(logOut, "\n") 228 | } 229 | 230 | func main() { 231 | flag.StringVar(&addr, "addr", "localhost:4433", "[host:port]") 232 | flag.StringVar(&serverName, "server-name", "localhost", "[SNI]") 233 | flag.StringVar(&keyFile, "key", "", "Key file") 234 | flag.StringVar(&certFile, "cert", "", "Cert file") 235 | flag.StringVar(&logFile, "log", "", "Log file") 236 | flag.BoolVar(&doHttp, "http", false, "Do HTTP/0.9") 237 | flag.BoolVar(&echo, "echo", false, "Run as an echo server") 238 | flag.BoolVar(&statelessReset, "stateless-reset", false, "Do stateless reset") 239 | flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to file") 240 | flag.BoolVar(&standalone, "standalone", false, "Run standalone") 241 | flag.Parse() 242 | 243 | var key crypto.Signer 244 | var certChain []*x509.Certificate 245 | 246 | if cpuProfile != "" { 247 | f, err := os.Create(cpuProfile) 248 | if err != nil { 249 | log.Printf("Could not create CPU profile file %v err=%v\n", cpuProfile, err) 250 | return 251 | } 252 | pprof.StartCPUProfile(f) 253 | log.Println("CPU profiler started") 254 | defer pprof.StopCPUProfile() 255 | } 256 | 257 | config := minq.NewTlsConfig(serverName) 258 | config.ForceHrr = statelessReset 259 | 260 | if keyFile != "" && certFile == "" { 261 | log.Println("Can't specify -key without -cert") 262 | return 263 | } 264 | 265 | if keyFile == "" && certFile != "" { 266 | log.Println("Can't specify -cert without -key") 267 | return 268 | } 269 | 270 | if keyFile != "" && certFile != "" { 271 | keyPEM, err := ioutil.ReadFile(keyFile) 272 | if err != nil { 273 | log.Printf("Couldn't open keyFile %v err=%v", keyFile, err) 274 | return 275 | } 276 | key, err = helpers.ParsePrivateKeyPEM(keyPEM) 277 | if err != nil { 278 | log.Println("Couldn't parse private key: ", err) 279 | return 280 | } 281 | 282 | certPEM, err := ioutil.ReadFile(certFile) 283 | if err != nil { 284 | log.Printf("Couldn't open certFile %v err=%v", certFile, err) 285 | return 286 | } 287 | certChain, err = helpers.ParseCertificatesPEM(certPEM) 288 | if err != nil { 289 | log.Println("Couldn't parse certificates: ", err) 290 | return 291 | } 292 | config.CertificateChain = certChain 293 | config.Key = key 294 | } 295 | 296 | if logFile != "" { 297 | var err error 298 | logOut, err = os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) 299 | if err != nil { 300 | log.Println("Couldn't open file") 301 | return 302 | } 303 | minq.SetLogOutput(logFunc) 304 | } 305 | uaddr, err := net.ResolveUDPAddr("udp", addr) 306 | if err != nil { 307 | log.Println("Invalid UDP addr: ", err) 308 | return 309 | } 310 | 311 | usock, err := net.ListenUDP("udp", uaddr) 312 | if err != nil { 313 | log.Println("Couldn't listen on UDP: ", err) 314 | return 315 | } 316 | 317 | var handler minq.ServerHandler 318 | if doHttp { 319 | handler = &httpServerHandler{} 320 | } else { 321 | handler = &feedthroughServerHandler{echo} 322 | } 323 | server := minq.NewServer(minq.NewUdpTransportFactory(usock), &config, handler) 324 | 325 | stdin := make(chan []byte) 326 | if !standalone { 327 | go func() { 328 | for { 329 | b := make([]byte, 1024) 330 | n, err := os.Stdin.Read(b) 331 | if err == io.EOF { 332 | log.Println("EOF received") 333 | close(stdin) 334 | return 335 | } else if err != nil { 336 | log.Println("Error reading from stdin") 337 | return 338 | } 339 | b = b[:n] 340 | stdin <- b 341 | } 342 | }() 343 | } 344 | 345 | for { 346 | 347 | select { 348 | case _, open := <-stdin: 349 | if open == false { 350 | log.Println("Shutdown signal received from stdin. Goodnight.") 351 | return 352 | } 353 | default: 354 | } 355 | 356 | b := make([]byte, 8192) 357 | 358 | usock.SetDeadline(time.Now().Add(time.Second)) 359 | n, addr, err := usock.ReadFromUDP(b) 360 | if err != nil { 361 | e, o := err.(net.Error) 362 | if !o || !e.Timeout() { 363 | log.Println("Error reading from UDP socket: ", err) 364 | return 365 | } 366 | n = 0 367 | } 368 | 369 | // If we read data, process it. 370 | if n > 0 { 371 | if n == len(b) { 372 | log.Println("Underread from UDP socket") 373 | return 374 | } 375 | b = b[:n] 376 | 377 | _, err = server.Input(addr, b) 378 | if err != nil { 379 | log.Println("server.Input returned error: ", err) 380 | return 381 | } 382 | } 383 | 384 | // Check the timers. 385 | server.CheckTimer() 386 | } 387 | } 388 | -------------------------------------------------------------------------------- /bin/shim/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "github.com/ekr/minq" 7 | "net" 8 | "time" 9 | ) 10 | 11 | var addr string 12 | var server bool 13 | 14 | type connHandler struct { 15 | } 16 | 17 | func (h *connHandler) StateChanged(s minq.State) { 18 | fmt.Println("State changed to ", s) 19 | } 20 | 21 | func (h *connHandler) NewStream(s *minq.Stream) { 22 | } 23 | 24 | func (h *connHandler) StreamReadable(s *minq.Stream) { 25 | } 26 | 27 | func readUDP(s *net.UDPConn) ([]byte, error) { 28 | b := make([]byte, 8192) 29 | 30 | s.SetReadDeadline(time.Now().Add(time.Second)) 31 | n, _, err := s.ReadFromUDP(b) 32 | if err != nil { 33 | e, o := err.(net.Error) 34 | if o && e.Timeout() { 35 | return nil, minq.ErrorWouldBlock 36 | } 37 | fmt.Println("Error reading from UDP socket: ", err) 38 | return nil, err 39 | } 40 | 41 | if n == len(b) { 42 | fmt.Println("Underread from UDP socket") 43 | return nil, err 44 | } 45 | b = b[:n] 46 | return b, nil 47 | } 48 | 49 | func main() { 50 | flag.StringVar(&addr, "addr", "localhost:4433", "[host:port]") 51 | flag.BoolVar(&server, "server", false, "Run as server]") 52 | flag.Parse() 53 | 54 | uaddr, err := net.ResolveUDPAddr("udp", addr) 55 | if err != nil { 56 | fmt.Println("Invalid UDP addr", err) 57 | return 58 | } 59 | 60 | usock, err := net.ListenUDP("udp", nil) 61 | if err != nil { 62 | fmt.Println("Couldn't create connected UDP socket") 63 | return 64 | } 65 | 66 | role := minq.RoleClient 67 | if server { 68 | _, port, err := net.SplitHostPort(usock.LocalAddr().String()) 69 | if err != nil { 70 | return 71 | } 72 | fmt.Println(port) 73 | role = minq.RoleServer 74 | } 75 | fmt.Printf("Remote addr=%v\n", addr) 76 | utrans := minq.NewUdpTransport(usock, uaddr) 77 | config := minq.NewTlsConfig("localhost") 78 | 79 | conn := minq.NewConnection(utrans, role, &config, nil) 80 | 81 | // Start things off. 82 | fmt.Println("Starting") 83 | _, err = conn.CheckTimer() 84 | 85 | for conn.GetState() != minq.StateEstablished { 86 | b, err := readUDP(usock) 87 | if err != nil { 88 | if err == minq.ErrorWouldBlock { 89 | _, err = conn.CheckTimer() 90 | if err != nil { 91 | return 92 | } 93 | continue 94 | } 95 | return 96 | } 97 | 98 | err = conn.Input(b) 99 | if err != nil { 100 | fmt.Println("Error", err) 101 | return 102 | } 103 | } 104 | 105 | fmt.Println("Connection established") 106 | } 107 | -------------------------------------------------------------------------------- /bin/tester/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/hex" 5 | "flag" 6 | "fmt" 7 | "io/ioutil" 8 | "strings" 9 | 10 | "github.com/ekr/minq" 11 | ) 12 | 13 | var infile string 14 | var serverName string 15 | var dehex bool 16 | 17 | type stdoutTransport struct { 18 | } 19 | 20 | func (t *stdoutTransport) Send(p []byte) error { 21 | fmt.Printf("Output=%v", hex.Dump(p)) 22 | return nil 23 | } 24 | 25 | type connHandler struct { 26 | } 27 | 28 | func (h *connHandler) StateChanged(s minq.State) { 29 | fmt.Println("State changed to ", s) 30 | } 31 | 32 | func (h *connHandler) NewStream(s *minq.Stream) { 33 | fmt.Println("New stream") 34 | } 35 | 36 | func (h *connHandler) StreamReadable(s *minq.Stream) { 37 | fmt.Println("Stream readable") 38 | } 39 | 40 | func main() { 41 | flag.StringVar(&infile, "infile", "input", "input file") 42 | flag.StringVar(&serverName, "server-name", "", "SNI") 43 | flag.BoolVar(&dehex, "hex", false, "file is in hex") 44 | flag.Parse() 45 | 46 | in, err := ioutil.ReadFile(infile) 47 | if err != nil { 48 | fmt.Println("Couldn't read file") 49 | } 50 | 51 | if dehex { 52 | s := string(in) 53 | s = strings.Replace(s, " ", "", -1) 54 | s = strings.Replace(s, "\n", "", -1) 55 | in, err = hex.DecodeString(s) 56 | if err != nil { 57 | fmt.Println("Couldn't hex decode input") 58 | } 59 | 60 | } 61 | 62 | strans := &stdoutTransport{} 63 | config := minq.NewTlsConfig(serverName) 64 | conn := minq.NewConnection(strans, minq.RoleServer, &config, nil) 65 | err = conn.Input(in) 66 | if err != nil { 67 | fmt.Println("Couldn't process input: ", err) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /codec.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "reflect" 8 | "runtime" 9 | "strconv" 10 | "strings" 11 | "unicode" 12 | ) 13 | 14 | const ( 15 | codecDefaultSize = ^uintptr(0) 16 | ) 17 | 18 | func uintEncode(buf *bytes.Buffer, v reflect.Value, encodingSize uintptr) error { 19 | size := v.Type().Size() 20 | if encodingSize != codecDefaultSize { 21 | if encodingSize > size { 22 | return fmt.Errorf("Requested a length longer than the native type") 23 | } 24 | size = encodingSize 25 | } 26 | 27 | uintEncodeInt(buf, v.Uint(), size) 28 | return nil 29 | } 30 | 31 | func uintEncodeInt(buf *bytes.Buffer, val uint64, size uintptr) { 32 | // Now encode the low-order bytes of the value. 33 | for b := size; b > 0; b -= 1 { 34 | buf.WriteByte(byte(val >> ((b - 1) * 8))) 35 | } 36 | } 37 | 38 | // isVarint determines if the field is a varint. This reads the mint/syntax tag 39 | // for the field, but only supports a simple "varint". 40 | func isVarint(f reflect.StructField) bool { 41 | return f.Tag.Get("tls") == "varint" 42 | } 43 | 44 | func varintEncode(buf *bytes.Buffer, v uint64) { 45 | switch { 46 | case v < (uint64(1) << 6): 47 | uintEncodeInt(buf, v, 1) 48 | case v < (uint64(1) << 14): 49 | uintEncodeInt(buf, v|(1<<14), 2) 50 | case v < (uint64(1) << 30): 51 | uintEncodeInt(buf, v|(2<<30), 4) 52 | case v < (uint64(1) << 62): 53 | uintEncodeInt(buf, v|(3<<62), 8) 54 | default: 55 | panic("varint value is too large") 56 | } 57 | } 58 | 59 | func arrayEncode(buf *bytes.Buffer, v reflect.Value) error { 60 | b := v.Bytes() 61 | logf(logTypeCodec, "Encoding array length=%d", len(b)) 62 | buf.Write(b) 63 | 64 | return nil 65 | } 66 | 67 | // Check to see if fields 68 | func ignoreField(name string) bool { 69 | return unicode.IsLower(rune(name[0])) 70 | } 71 | 72 | // Length specifications are of the form: 73 | // 74 | // lengthbits: "B:L1,L2,...LN 75 | // 76 | // where B is the rightmost bit of the length bits and 77 | // L_n are the various lengths (in bytes) indicated by 78 | // the bit values in sequence. N must be a power of 2 79 | // and the right number of bytes is drawn to compute it. 80 | type lengthSpec struct { 81 | rightBit uint 82 | numBits uint 83 | values []int 84 | } 85 | 86 | func parseLengthSpecification(spec string) (*lengthSpec, error) { 87 | spl := strings.Split(spec, ":") 88 | assert(len(spl) == 2) 89 | 90 | // Rightmost bit. 91 | p, err := strconv.ParseUint(spl[0], 10, 8) 92 | if err != nil { 93 | return nil, err 94 | } 95 | bitr := uint(p) 96 | vals := strings.Split(spl[1], ",") 97 | 98 | // Figure out how many bits we need. 99 | nvals := int(1) 100 | var bits int 101 | for bits = 1; bits <= 8; bits++ { 102 | nvals <<= 1 103 | if nvals == len(vals) { 104 | break 105 | } 106 | } 107 | assert(bits < 9) 108 | 109 | // Now compute the values 110 | valArr := make([]int, nvals) 111 | for i, v := range vals { 112 | valArr[i], err = strconv.Atoi(v) 113 | if err != nil { 114 | return nil, err 115 | } 116 | } 117 | 118 | return &lengthSpec{ 119 | bitr, 120 | uint(bits), 121 | valArr, 122 | }, nil 123 | } 124 | 125 | func computeLengthFromSpec(t byte, f reflect.StructField) uintptr { 126 | st := f.Tag.Get("lengthbits") 127 | if st == "" { 128 | return codecDefaultSize 129 | } 130 | 131 | spec, err := parseLengthSpecification(st) 132 | assert(err == nil) 133 | 134 | mask := byte(0) 135 | bit := uint(0) 136 | for ; bit < spec.numBits; bit++ { 137 | mask |= (1 << bit) 138 | } 139 | idx := int(t >> (spec.rightBit - 1) & mask) 140 | 141 | return uintptr(spec.values[idx]) 142 | } 143 | 144 | // Encode all the fields of a struct to a bytestring. 145 | func encode(i interface{}) (ret []byte, err error) { 146 | var buf bytes.Buffer 147 | var res error 148 | reflected := reflect.ValueOf(i).Elem() 149 | fields := reflected.NumField() 150 | 151 | for j := 0; j < fields; j += 1 { 152 | field := reflected.Field(j) 153 | tipe := reflected.Type().Field(j) 154 | 155 | if ignoreField(tipe.Name) { 156 | continue 157 | } 158 | 159 | logf(logTypeCodec, "Type name %s Kind=%v", tipe.Name, field.Kind()) 160 | 161 | switch field.Kind() { 162 | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 163 | // Call the length overrider to tell us if we shoud be using a shorter 164 | // encoding. 165 | encodingSize := uintptr(codecDefaultSize) 166 | lFunc, getLength := reflected.Type().MethodByName(tipe.Name + "__length") 167 | logf(logTypeCodec, "Looking for length overrider for type %v", tipe.Name) 168 | if getLength { 169 | lengthResult := lFunc.Func.Call([]reflect.Value{reflect.ValueOf(i).Elem()}) 170 | encodingSize = uintptr(lengthResult[0].Uint()) 171 | logf(logTypeCodec, "Overriden length to %v", encodingSize) 172 | } 173 | if isVarint(tipe) { 174 | if encodingSize != 0 { 175 | varintEncode(&buf, field.Uint()) 176 | } 177 | res = nil 178 | break 179 | } 180 | 181 | res = uintEncode(&buf, field, encodingSize) 182 | case reflect.Array, reflect.Slice: 183 | res = arrayEncode(&buf, field) 184 | default: 185 | return nil, fmt.Errorf("Unknown type") 186 | } 187 | 188 | if res != nil { 189 | return nil, res 190 | } 191 | } 192 | 193 | ret = buf.Bytes() 194 | logf(logTypeCodec, "Total encoded length = %v", len(ret)) 195 | return ret, nil 196 | } 197 | 198 | func uintDecodeIntBuf(val []byte) uint64 { 199 | tmp := uint64(0) 200 | for b := 0; b < len(val); b++ { 201 | tmp = (tmp << 8) + uint64(val[b]) 202 | } 203 | return tmp 204 | } 205 | 206 | func uintDecodeInt(r io.Reader, size uintptr) (uint64, error) { 207 | val := make([]byte, size) 208 | _, err := io.ReadFull(r, val) 209 | if err != nil { 210 | return 0, err 211 | } 212 | 213 | return uintDecodeIntBuf(val), nil 214 | } 215 | 216 | func uintDecode(r io.Reader, v reflect.Value, encodingSize uintptr) (uintptr, error) { 217 | size := v.Type().Size() 218 | if encodingSize != codecDefaultSize { 219 | if encodingSize > size { 220 | return 0, fmt.Errorf("Requested a length longer than the native type") 221 | } 222 | size = encodingSize 223 | } 224 | 225 | tmp, err := uintDecodeInt(r, size) 226 | if err != nil { 227 | return 0, err 228 | } 229 | 230 | v.SetUint(tmp) 231 | 232 | return size, nil 233 | } 234 | 235 | func varintDecode(r io.Reader, v reflect.Value) (uintptr, error) { 236 | p := make([]byte, 8) 237 | _, err := r.Read(p[:1]) 238 | if err != nil { 239 | return 0, err 240 | } 241 | 242 | value := uint64(p[0] & 0x3f) 243 | extra := uintptr(1<<(p[0]>>6)) - 1 244 | if extra > 0 { 245 | tail, err := uintDecodeInt(r, extra) 246 | if err != nil { 247 | return 0, err 248 | } 249 | value = (value << (8 * extra)) | tail 250 | } 251 | 252 | v.SetUint(value) 253 | return 1 + extra, nil 254 | } 255 | 256 | func encodeArgs(args ...interface{}) []byte { 257 | var buf bytes.Buffer 258 | var res error 259 | 260 | for _, arg := range args { 261 | reflected := reflect.ValueOf(arg) 262 | switch reflected.Kind() { 263 | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 264 | res = uintEncode(&buf, reflected, codecDefaultSize) 265 | case reflect.Array, reflect.Slice: 266 | res = arrayEncode(&buf, reflected) 267 | default: 268 | panic(fmt.Sprintf("Unknown type")) 269 | } 270 | if res != nil { 271 | panic(fmt.Sprintf("Encoding error")) 272 | } 273 | } 274 | 275 | return buf.Bytes() 276 | } 277 | 278 | func arrayDecode(r io.Reader, v reflect.Value, encodingSize uintptr) (uintptr, error) { 279 | logf(logTypeCodec, "encodingSize = %v", encodingSize) 280 | 281 | val := make([]byte, encodingSize) 282 | 283 | logf(logTypeCodec, "Reading array of size %v", encodingSize) 284 | 285 | // Go will return EOF if you try to read 0 bytes off a closed stream. 286 | if encodingSize == 0 { 287 | return 0, nil 288 | } 289 | _, err := io.ReadFull(r, val) 290 | if err != nil { 291 | return 0, err 292 | } 293 | 294 | v.SetBytes(val) 295 | return encodingSize, nil 296 | } 297 | 298 | // Decode all the fields of a struct from a bytestring. Takes 299 | // a pointer to the struct to fill in 300 | func decode(i interface{}, data []byte) (uintptr, error) { 301 | buf := bytes.NewReader(data) 302 | var res error 303 | reflected := reflect.ValueOf(i).Elem() 304 | fields := reflected.NumField() 305 | bytesread := uintptr(0) 306 | 307 | for j := 0; j < fields; j++ { 308 | br := uintptr(0) 309 | field := reflected.Field(j) 310 | tipe := reflected.Type().Field(j) 311 | 312 | if ignoreField(tipe.Name) { 313 | continue 314 | } 315 | 316 | // Call the length overrider to tell us if we should be using a shorter 317 | // encoding. 318 | encodingSize := uintptr(codecDefaultSize) 319 | lFunc, getLength := reflected.Type().MethodByName(tipe.Name + "__length") 320 | if getLength { 321 | lengthResult := lFunc.Func.Call([]reflect.Value{reflect.ValueOf(i).Elem()}) 322 | encodingSize = uintptr(lengthResult[0].Uint()) 323 | logf(logTypeCodec, "Length overrider for %s returns %v", tipe.Name, encodingSize) 324 | } 325 | 326 | switch field.Kind() { 327 | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 328 | if isVarint(tipe) && encodingSize != 0 { 329 | br, res = varintDecode(buf, field) 330 | } else { 331 | br, res = uintDecode(buf, field, encodingSize) 332 | } 333 | case reflect.Array, reflect.Slice: 334 | if encodingSize == codecDefaultSize { 335 | encodingSize = uintptr(buf.Len()) 336 | } 337 | br, res = arrayDecode(buf, field, encodingSize) 338 | default: 339 | return 0, fmt.Errorf("Unknown type") 340 | } 341 | if res != nil { 342 | logf(logTypeCodec, "Error while reading field %v: %v", tipe.Name, res) 343 | return bytesread, res 344 | } 345 | bytesread += br 346 | } 347 | 348 | return bytesread, nil 349 | } 350 | 351 | func backtrace() string { 352 | bt := string("") 353 | for i := 1; ; i++ { 354 | _, file, line, ok := runtime.Caller(i) 355 | if !ok { 356 | break 357 | } 358 | bt = fmt.Sprintf("%v: %d\n", file, line) + bt 359 | } 360 | return bt 361 | } 362 | -------------------------------------------------------------------------------- /codec_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "encoding/hex" 5 | "fmt" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | type Uint8Indirect uint8 11 | 12 | type TestStructDefaultLengths struct { 13 | U8 Uint8Indirect 14 | U16 uint16 15 | B []byte 16 | } 17 | 18 | type TestStructOverrideLengths struct { 19 | U8 uint8 20 | U16 uint16 21 | B []byte 22 | } 23 | 24 | func (t TestStructOverrideLengths) U16__length() uintptr { 25 | return 1 26 | } 27 | 28 | func (t TestStructOverrideLengths) B__length() uintptr { 29 | return 3 30 | } 31 | 32 | func codecEDE(t *testing.T, s interface{}, s2 interface{}, expectedLen uintptr) { 33 | res, err := encode(s) 34 | assertNotError(t, err, "Could not encode") 35 | 36 | fmt.Println("Result = ", hex.EncodeToString(res)) 37 | // TODO(ekr@rtfm.com). What is the type of len(). 38 | assertEquals(t, uintptr(expectedLen), uintptr(len(res))) 39 | 40 | _, err = decode(s2, res) 41 | assertNotError(t, err, "Could not decode") 42 | 43 | res2, err := encode(s2) 44 | assertNotError(t, err, "Could not re-encode") 45 | fmt.Println("Result2 = ", hex.EncodeToString(res2)) 46 | assertByteEquals(t, res, res2) 47 | } 48 | 49 | func TestCodecDefaultEncode(t *testing.T) { 50 | s := TestStructDefaultLengths{1, 2, []byte{'a', 'b', 'c'}} 51 | var s2 TestStructDefaultLengths 52 | 53 | codecEDE(t, &s, &s2, 6) 54 | } 55 | 56 | func TestCodecOverrideEncode(t *testing.T) { 57 | s := TestStructOverrideLengths{1, 2, []byte{'a', 'b', 'c'}} 58 | var s2 TestStructOverrideLengths 59 | 60 | codecEDE(t, &s, &s2, 5) 61 | } 62 | 63 | func TestCodecOverrideDecodeLength(t *testing.T) { 64 | s := TestStructOverrideLengths{1, 2, []byte{'a', 'b', 'c'}} 65 | var s2 TestStructOverrideLengths 66 | 67 | res, err := encode(&s) 68 | assertNotError(t, err, "Could not encode") 69 | 70 | modified := append(res, 'd') 71 | _, err = decode(&s2, modified) 72 | assertNotError(t, err, "Could not decode") 73 | 74 | fmt.Println(s2) 75 | 76 | res2, err := encode(&s2) 77 | assertNotError(t, err, "Could not re-encode") 78 | 79 | assertByteEquals(t, res, res2) 80 | } 81 | 82 | func TestParseLengthSpec(t *testing.T) { 83 | // 1 bit, 2 values 84 | spec, err := parseLengthSpecification("1:8,16") 85 | assertNotError(t, err, "Couldn't parse single bit value") 86 | fmt.Println(*spec) 87 | assertX(t, reflect.DeepEqual(*spec, lengthSpec{1, 1, []int{8, 16}}), 88 | "Spec parsed correctly") 89 | 90 | // 2 bit, 4 values 91 | spec, err = parseLengthSpecification("3:8,16,24,32") 92 | assertNotError(t, err, "Couldn't parse two bit value") 93 | fmt.Println(*spec) 94 | assertX(t, reflect.DeepEqual(*spec, lengthSpec{3, 2, []int{8, 16, 24, 32}}), 95 | "Spec parsed correctly") 96 | 97 | } 98 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "os" 5 | ) 6 | 7 | var ( 8 | debug = checkDebug() 9 | ) 10 | 11 | func checkDebug() bool { 12 | if os.Getenv("MINQ_DEBUG") == "true" { 13 | return true 14 | } 15 | return false 16 | } 17 | 18 | func assert(t bool) { 19 | if !t { 20 | panic("Assert") 21 | } 22 | } 23 | 24 | func dup(b []byte) []byte { 25 | ret := make([]byte, len(b)) 26 | copy(ret, b) 27 | return ret 28 | } 29 | -------------------------------------------------------------------------------- /common_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "fmt" 7 | "runtime" 8 | "testing" 9 | ) 10 | 11 | /* STOLEN FROM MINT. 12 | The MIT License (MIT) 13 | 14 | Copyright (c) 2016 Richard Barnes 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in 24 | all copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 32 | THE SOFTWARE. 33 | */ 34 | 35 | func unhex(h string) []byte { 36 | b, err := hex.DecodeString(h) 37 | if err != nil { 38 | panic(err) 39 | } 40 | return b 41 | } 42 | 43 | func assertX(t *testing.T, test bool, msg string) { 44 | prefix := string("") 45 | for i := 1; ; i++ { 46 | _, file, line, ok := runtime.Caller(i) 47 | if !ok { 48 | break 49 | } 50 | prefix = fmt.Sprintf("%v: %d\n", file, line) + prefix 51 | } 52 | if !test { 53 | t.Fatalf(prefix + msg) 54 | } 55 | } 56 | 57 | func assertError(t *testing.T, err error, msg string) { 58 | assertX(t, err != nil, msg) 59 | } 60 | 61 | func assertNotError(t *testing.T, err error, msg string) { 62 | if err != nil { 63 | msg += ": " + err.Error() 64 | } 65 | assertX(t, err == nil, msg) 66 | } 67 | 68 | func assertNotNil(t *testing.T, x interface{}, msg string) { 69 | assertX(t, x != nil, msg) 70 | } 71 | 72 | func assertEquals(t *testing.T, a, b interface{}) { 73 | assertX(t, a == b, fmt.Sprintf("%+v != %+v", a, b)) 74 | } 75 | 76 | func assertByteEquals(t *testing.T, a, b []byte) { 77 | assertX(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b))) 78 | } 79 | 80 | func assertNotByteEquals(t *testing.T, a, b []byte) { 81 | assertX(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b))) 82 | } 83 | 84 | /* END STOLEN FROM MINT. */ 85 | -------------------------------------------------------------------------------- /congestion.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package minq is a minimal implementation of QUIC, as documented at 3 | https://quicwg.github.io/. Minq partly implements draft-04. 4 | 5 | */ 6 | package minq 7 | 8 | import ( 9 | "math" 10 | "time" 11 | // "fmt" 12 | ) 13 | 14 | // congestion control related constants 15 | const ( 16 | kDefaultMss = 1460 // bytes 17 | kInitalWindow = 10 * kDefaultMss 18 | kMinimumWindow = 2 * kDefaultMss 19 | kMaximumWindow = kInitalWindow 20 | kLossReductionFactor = 0.5 21 | ) 22 | 23 | // loss dectection related constants 24 | const ( 25 | kMaxTLPs = 2 26 | kReorderingThreshold = 3 27 | kTimeReorderingFraction = 0.125 28 | kMinTLPTimeout = 10 * time.Millisecond 29 | kMinRTOTimeout = 200 * time.Millisecond 30 | kDelayedAckTimeout = 25 * time.Millisecond 31 | kDefaultInitialRtt = 100 * time.Millisecond 32 | ) 33 | 34 | type CongestionController interface { 35 | onPacketSent(pn uint64, isAckOnly bool, sentBytes int) 36 | onAckReceived(acks ackRanges, delay time.Duration) 37 | bytesAllowedToSend() int 38 | setLostPacketHandler(handler func(pn uint64)) 39 | rto() time.Duration 40 | } 41 | 42 | /* 43 | * DUMMY congestion controller 44 | */ 45 | 46 | type CongestionControllerDummy struct { 47 | } 48 | 49 | func (cc *CongestionControllerDummy) onPacketSent(pn uint64, isAckOnly bool, sentBytes int) { 50 | } 51 | 52 | func (cc *CongestionControllerDummy) onAckReceived(acks ackRanges, delay time.Duration) { 53 | } 54 | 55 | func (cc *CongestionControllerDummy) bytesAllowedToSend() int { 56 | /* return the the maximum int value */ 57 | return int(^uint(0) >> 1) 58 | } 59 | 60 | func (cc *CongestionControllerDummy) setLostPacketHandler(handler func(pn uint64)) { 61 | } 62 | 63 | func (cc *CongestionControllerDummy) rto() time.Duration { 64 | return kMinRTOTimeout 65 | } 66 | 67 | /* 68 | * draft-ietf-quic-recovery congestion controller 69 | */ 70 | 71 | type CongestionControllerIetf struct { 72 | // Congestion control related 73 | bytesInFlight int 74 | congestionWindow int 75 | endOfRecovery uint64 76 | sstresh int 77 | 78 | // Loss detection related 79 | lossDetectionAlarm int //TODO(ekr@rtfm.com) set this to the right type 80 | handshakeCount int 81 | tlpCount int 82 | rtoCount int 83 | largestSendBeforeRto uint64 84 | timeOfLastSentPacket time.Time 85 | largestSendPacket uint64 86 | largestAckedPacket uint64 87 | maxAckDelay time.Duration 88 | minRtt time.Duration 89 | // largestRtt time.Duration 90 | smoothedRtt time.Duration 91 | rttVar time.Duration 92 | smoothedRttTcp time.Duration 93 | rttVarTcp time.Duration 94 | reorderingThreshold int 95 | timeReorderingFraction float32 96 | lossTime time.Time 97 | sentPackets map[uint64]packetEntry 98 | 99 | // others 100 | lostPacketHandler func(pn uint64) 101 | conn *Connection 102 | } 103 | 104 | type packetEntry struct { 105 | pn uint64 106 | txTime time.Time 107 | bytes int 108 | ackOnly bool 109 | } 110 | 111 | func (cc *CongestionControllerIetf) onPacketSent(pn uint64, isAckOnly bool, sentBytes int) { 112 | cc.timeOfLastSentPacket = time.Now() 113 | cc.largestSendPacket = pn 114 | packetData := packetEntry{pn, time.Now(), 0, isAckOnly} 115 | cc.conn.log(logTypeCongestion, "Packet send pn: %d len:%d ackonly: %v\n", pn, sentBytes, isAckOnly) 116 | if !isAckOnly { 117 | cc.onPacketSentCC(sentBytes) 118 | packetData.bytes = sentBytes 119 | cc.setLossDetectionAlarm() 120 | } 121 | cc.sentPackets[pn] = packetData 122 | } 123 | 124 | // acks is received to be a sorted list, where the largest packet numbers are at the beginning 125 | func (cc *CongestionControllerIetf) onAckReceived(acks ackRanges, ackDelay time.Duration) { 126 | 127 | // keep track of largest packet acked overall 128 | if acks[0].lastPacket > cc.largestAckedPacket { 129 | cc.largestAckedPacket = acks[0].lastPacket 130 | } 131 | 132 | // If the largest acked is newly acked update rtt 133 | lastPacket, present := cc.sentPackets[acks[0].lastPacket] 134 | if present { 135 | latestRtt := time.Since(cc.sentPackets[acks[0].lastPacket].txTime) 136 | cc.conn.log(logTypeCongestion, "latestRtt: %v, ackDelay: %v", latestRtt, ackDelay) 137 | cc.updateRttTcp(latestRtt) 138 | 139 | // Update the minRtt, but ignore ackDelay. 140 | if latestRtt < cc.minRtt { 141 | cc.minRtt = latestRtt 142 | } 143 | 144 | // Now reduce by ackDelay if it doesn't reduce the RTT below the minimum. 145 | if latestRtt-cc.minRtt > ackDelay { 146 | latestRtt -= ackDelay 147 | // And update the maximum observed ACK delay. 148 | if !lastPacket.ackOnly && ackDelay > cc.maxAckDelay { 149 | cc.maxAckDelay = ackDelay 150 | } 151 | } 152 | 153 | cc.updateRtt(latestRtt) 154 | } 155 | 156 | // find and proccess newly acked packets 157 | for _, ackBlock := range acks { 158 | for pn := ackBlock.lastPacket; pn > (ackBlock.lastPacket - ackBlock.count); pn-- { 159 | cc.conn.log(logTypeCongestion, "Ack for pn %d received", pn) 160 | _, present := cc.sentPackets[pn] 161 | if present { 162 | cc.conn.log(logTypeCongestion, "First ack for pn %d received", pn) 163 | cc.onPacketAcked(pn) 164 | } 165 | } 166 | } 167 | 168 | cc.detectLostPackets() 169 | cc.setLossDetectionAlarm() 170 | } 171 | 172 | func (cc *CongestionControllerIetf) setLostPacketHandler(handler func(pn uint64)) { 173 | cc.lostPacketHandler = handler 174 | } 175 | 176 | func (cc *CongestionControllerIetf) updateRtt(latestRtt time.Duration) { 177 | if cc.smoothedRtt == 0 { 178 | cc.smoothedRtt = latestRtt 179 | cc.rttVar = time.Duration(int64(latestRtt) / 2) 180 | } else { 181 | rttDelta := cc.smoothedRtt - latestRtt 182 | if rttDelta < 0 { 183 | rttDelta = -rttDelta 184 | } 185 | cc.rttVar = time.Duration(int64(cc.rttVar)*3/4 + int64(rttDelta)*1/4) 186 | cc.smoothedRtt = time.Duration(int64(cc.smoothedRtt)*7/8 + int64(latestRtt)*1/8) 187 | } 188 | cc.conn.log(logTypeCongestion, "New RTT estimate: %v, variance: %v", cc.smoothedRtt, cc.rttVar) 189 | } 190 | 191 | func (cc *CongestionControllerIetf) updateRttTcp(latestRtt time.Duration) { 192 | if cc.smoothedRttTcp == 0 { 193 | cc.smoothedRttTcp = latestRtt 194 | cc.rttVarTcp = time.Duration(int64(latestRtt) / 2) 195 | } else { 196 | rttDelta := cc.smoothedRttTcp - latestRtt 197 | if rttDelta < 0 { 198 | rttDelta = -rttDelta 199 | } 200 | cc.rttVarTcp = time.Duration(int64(cc.rttVarTcp)*3/4 + int64(rttDelta)*3/4) 201 | cc.smoothedRttTcp = time.Duration(int64(cc.smoothedRttTcp)*7/8 + int64(latestRtt)*1/8) 202 | } 203 | cc.conn.log(logTypeCongestion, "New RTT(TCP) estimate: %v, variance: %v", cc.smoothedRttTcp, cc.rttVarTcp) 204 | } 205 | 206 | func (cc *CongestionControllerIetf) rto() time.Duration { 207 | // max(SRTT + 4*RTTVAR + MaxAckDelay, minRTO) 208 | rto := cc.smoothedRtt + 4*cc.rttVar + cc.maxAckDelay 209 | if rto < kMinRTOTimeout { 210 | return kMinRTOTimeout 211 | } 212 | return rto 213 | } 214 | 215 | func (cc *CongestionControllerIetf) onPacketAcked(pn uint64) { 216 | cc.onPacketAckedCC(pn) 217 | //TODO(ekr@rtfm.com) some RTO stuff here 218 | delete(cc.sentPackets, pn) 219 | } 220 | 221 | func (cc *CongestionControllerIetf) setLossDetectionAlarm() { 222 | //TODO(ekr@rtfm.com) 223 | } 224 | 225 | func (cc *CongestionControllerIetf) onLossDetectionAlarm() { 226 | //TODO(ekr@rtfm.com) 227 | } 228 | 229 | func (cc *CongestionControllerIetf) detectLostPackets() { 230 | var lostPackets []packetEntry 231 | //TODO(ekr@rtfm.com) implement loss detection different from reorderingThreshold 232 | for _, packet := range cc.sentPackets { 233 | if (cc.largestAckedPacket > packet.pn) && 234 | (cc.largestAckedPacket-packet.pn > uint64(cc.reorderingThreshold)) { 235 | lostPackets = append(lostPackets, packet) 236 | } 237 | } 238 | 239 | if len(lostPackets) > 0 { 240 | cc.onPacketsLost(lostPackets) 241 | } 242 | for _, packet := range lostPackets { 243 | delete(cc.sentPackets, packet.pn) 244 | } 245 | } 246 | 247 | func (cc *CongestionControllerIetf) onPacketSentCC(bytes_sent int) { 248 | cc.bytesInFlight += bytes_sent 249 | cc.conn.log(logTypeCongestion, "%d bytes added to bytesInFlight", bytes_sent) 250 | } 251 | 252 | func (cc *CongestionControllerIetf) onPacketAckedCC(pn uint64) { 253 | cc.bytesInFlight -= cc.sentPackets[pn].bytes 254 | cc.conn.log(logTypeCongestion, "%d bytes from packet %d removed from bytesInFlight", cc.sentPackets[pn].bytes, pn) 255 | 256 | if pn < cc.endOfRecovery { 257 | // Do not increase window size during recovery 258 | return 259 | } 260 | if cc.congestionWindow < cc.sstresh { 261 | // Slow start 262 | cc.congestionWindow += cc.sentPackets[pn].bytes 263 | cc.conn.log(logTypeCongestion, "PDV Slow Start: increasing window size with %d bytes to %d", 264 | cc.sentPackets[pn].bytes, cc.congestionWindow) 265 | } else { 266 | 267 | // Congestion avoidance 268 | cc.congestionWindow += kDefaultMss * cc.sentPackets[pn].bytes / cc.congestionWindow 269 | cc.conn.log(logTypeCongestion, "PDV Congestion Avoidance: increasing window size to %d", 270 | cc.congestionWindow) 271 | } 272 | } 273 | 274 | func (cc *CongestionControllerIetf) onPacketsLost(packets []packetEntry) { 275 | var largestLostPn uint64 = 0 276 | for _, packet := range packets { 277 | 278 | // First remove lost packets from bytesInFlight and inform the connection 279 | // of the loss 280 | cc.conn.log(logTypeCongestion, "Packet pn: %d len: %d is lost", packet.pn, packet.bytes) 281 | cc.bytesInFlight -= packet.bytes 282 | if cc.lostPacketHandler != nil { 283 | cc.lostPacketHandler(packet.pn) 284 | } 285 | 286 | // and keep track of the largest lost packet 287 | if packet.pn > largestLostPn { 288 | largestLostPn = packet.pn 289 | } 290 | } 291 | 292 | // Now start a new recovery epoch if the largest lost packet is larger than the 293 | // end of the previous recovery epoch 294 | if cc.endOfRecovery < largestLostPn { 295 | cc.endOfRecovery = cc.largestSendPacket 296 | cc.congestionWindow = int(float32(cc.congestionWindow) * kLossReductionFactor) 297 | if kMinimumWindow > cc.congestionWindow { 298 | cc.congestionWindow = kMinimumWindow 299 | } 300 | cc.sstresh = cc.congestionWindow 301 | cc.conn.log(logTypeCongestion, "PDV Recovery started. Window size: %d, sstresh: %d, endOfRecovery %d", 302 | cc.congestionWindow, cc.sstresh, cc.endOfRecovery) 303 | } 304 | } 305 | 306 | func (cc *CongestionControllerIetf) bytesAllowedToSend() int { 307 | cc.conn.log(logTypeCongestion, "Remaining congestion window size: %d", cc.congestionWindow-cc.bytesInFlight) 308 | return cc.congestionWindow - cc.bytesInFlight 309 | } 310 | 311 | func newCongestionControllerIetf(conn *Connection) *CongestionControllerIetf { 312 | return &CongestionControllerIetf{ 313 | 0, // bytesInFlight 314 | kInitalWindow, // congestionWindow 315 | 0, // endOfRecovery 316 | int(^uint(0) >> 1), // sstresh 317 | 0, // lossDetectionAlarm 318 | 0, // handshakeCount 319 | 0, // tlpCount 320 | 0, // rtoCount 321 | 0, // largestSendBeforeRto 322 | time.Unix(0, 0), // timeOfLastSentPacket 323 | 0, // largestSendPacket 324 | 0, // largestAckedPacket 325 | 0, // maxAckDelay 326 | 100 * time.Second, // minRtt 327 | 0, // smoothedRtt 328 | 0, // rttVar 329 | 0, // smoothedRttTcp 330 | 0, // rttVarTcp 331 | kReorderingThreshold, // reorderingThreshold 332 | math.MaxFloat32, // timeReorderingFraction 333 | time.Unix(0, 0), // lossTime 334 | make(map[uint64]packetEntry), // sentPackets 335 | nil, // lostPacketHandler 336 | conn, // conn 337 | } 338 | } 339 | -------------------------------------------------------------------------------- /connbuffer.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net" 7 | "time" 8 | ) 9 | 10 | type connBuffer struct { 11 | r *bytes.Buffer 12 | w *bytes.Buffer 13 | } 14 | 15 | func (p *connBuffer) Read(data []byte) (n int, err error) { 16 | logf(logTypeConnBuffer, "Reading %v", n) 17 | n, err = p.r.Read(data) 18 | 19 | // Suppress bytes.Buffer's EOF on an empty buffer 20 | if err == io.EOF { 21 | err = nil 22 | } 23 | return 24 | } 25 | 26 | func (p *connBuffer) Write(data []byte) (n int, err error) { 27 | logf(logTypeConnBuffer, "Writing %v", n) 28 | return p.w.Write(data) 29 | } 30 | 31 | func (p *connBuffer) Close() error { 32 | return nil 33 | } 34 | 35 | func (p *connBuffer) LocalAddr() net.Addr { return nil } 36 | func (p *connBuffer) RemoteAddr() net.Addr { return nil } 37 | func (p *connBuffer) SetDeadline(t time.Time) error { return nil } 38 | func (p *connBuffer) SetReadDeadline(t time.Time) error { return nil } 39 | func (p *connBuffer) SetWriteDeadline(t time.Time) error { return nil } 40 | 41 | func newConnBuffer() *connBuffer { 42 | return &connBuffer{ 43 | bytes.NewBuffer(nil), 44 | bytes.NewBuffer(nil), 45 | } 46 | } 47 | 48 | func (p *connBuffer) input(data []byte) error { 49 | logf(logTypeConnBuffer, "input %v", len(data)) 50 | _, err := p.r.Write(data) 51 | return err 52 | } 53 | 54 | func (p *connBuffer) getOutput() []byte { 55 | b := p.w.Bytes() 56 | p.w.Reset() 57 | return b 58 | } 59 | 60 | func (p *connBuffer) OutputLen() int { 61 | return p.w.Len() 62 | } 63 | -------------------------------------------------------------------------------- /crypto.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "crypto/cipher" 5 | "encoding/hex" 6 | "github.com/bifurcation/mint" 7 | ) 8 | 9 | type cryptoState struct { 10 | aead cipher.AEAD 11 | pne pneCipherFactory 12 | } 13 | 14 | func infallibleHexDecode(s string) []byte { 15 | b, err := hex.DecodeString(s) 16 | if err != nil { 17 | panic("didn't hex decode " + s) 18 | } 19 | return b 20 | } 21 | 22 | var kQuicVersionSalt = infallibleHexDecode("9c108f98520a5c5c32968e950e8a2c5fe06d6c38") 23 | 24 | const clientCtSecretLabel = "client in" 25 | const serverCtSecretLabel = "server in" 26 | 27 | const clientPpSecretLabel = "EXPORTER-QUIC client 1rtt" 28 | const serverPpSecretLabel = "EXPORTER-QUIC server 1rtt" 29 | 30 | func newCryptoStateInner(secret []byte, cs *mint.CipherSuiteParams) (*cryptoState, error) { 31 | var st cryptoState 32 | var err error 33 | 34 | k := mint.HkdfExpandLabel(cs.Hash, secret, "key", []byte{}, cs.KeyLen) 35 | iv := mint.HkdfExpandLabel(cs.Hash, secret, "iv", []byte{}, cs.IvLen) 36 | pn := mint.HkdfExpandLabel(cs.Hash, secret, "pn", []byte{}, cs.KeyLen) 37 | logf(logTypeAead, "key=%x iv=%x pn=%x", k, iv, pn) 38 | st.aead, err = newWrappedAESGCM(k, iv) 39 | if err != nil { 40 | return nil, err 41 | } 42 | st.pne = newPneCipherFactoryAES(pn) 43 | 44 | return &st, nil 45 | } 46 | 47 | func generateCleartextKeys(secret []byte, label string, cs *mint.CipherSuiteParams) (*cryptoState, error) { 48 | logf(logTypeTls, "Cleartext keys: cid=%x initial_salt=%x", secret, kQuicVersionSalt) 49 | extracted := mint.HkdfExtract(cs.Hash, kQuicVersionSalt, secret) 50 | inner := mint.HkdfExpandLabel(cs.Hash, extracted, label, []byte{}, cs.Hash.Size()) 51 | logf(logTypeAead, "initial_secret (%s) = %x", label, inner) 52 | return newCryptoStateInner(inner, cs) 53 | } 54 | 55 | func newCryptoStateFromTls(t *tlsConn, label string) (*cryptoState, error) { 56 | panic("TODO") 57 | } 58 | -------------------------------------------------------------------------------- /deploy/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang 2 | 3 | RUN go get github.com/bifurcation/mint 4 | RUN (cd /go/src/github.com/bifurcation/mint; git remote add ekr https://github.com/ekr/mint; git fetch ekr; git checkout ekr/quic_record_layer) 5 | RUN go get github.com/cloudflare/cfssl/helpers 6 | RUN go get github.com/ekr/minq 7 | RUN go install github.com/ekr/minq/bin/server 8 | RUN go install github.com/ekr/minq/bin/client 9 | RUN apt-get update 10 | RUN apt-get install -y tcpdump 11 | RUN curl -sL https://deb.nodesource.com/setup_6.x | bash - 12 | RUN apt-get install -y nodejs 13 | RUN (cd /go/src/github.com/ekr/minq/deploy/logserver; npm install) 14 | 15 | ARG SERVERNAME=localhost 16 | ENV SNAME=$SERVERNAME 17 | ENV MINQ_LOG='connection,handshake,stream,packet' 18 | ENTRYPOINT ["/bin/sh","/go/src/github.com/ekr/minq/deploy/run-looped.sh"] 19 | CMD [$SNAME] 20 | 21 | EXPOSE 4433/udp 22 | EXPOSE 3000/tcp 23 | 24 | -------------------------------------------------------------------------------- /deploy/logserver/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "logserver", 3 | "version": "0.0.1", 4 | "dependencies": { 5 | "express": "" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /deploy/logserver/server.js: -------------------------------------------------------------------------------- 1 | var express = require('express'); 2 | var fs = require('fs'); 3 | var readline = require('readline'); 4 | var connid_regex = /[0-9a-fA-F]+$/; 5 | 6 | var app = express(); 7 | 8 | var port = process.env.PORT || 3000; 9 | 10 | if (process.argv.len < 2) { 11 | console.log("Need to specify log file"); 12 | return; 13 | } 14 | var file = process.argv[2]; 15 | console.log(file); 16 | 17 | app.get('/:connid', function(request, response) { 18 | var connid = request.params.connid; 19 | if (!connid.match(connid_regex)) { 20 | response.status(400).send("Bogus connid (non-hex characters)"); 21 | return; 22 | } 23 | 24 | if(connid.length < 4) { 25 | response.status(400).send("Bogus connid (too short)"); 26 | return; 27 | } 28 | 29 | connid = connid.toLowerCase(); 30 | 31 | var match = 'Conn: ' + connid + "_"; 32 | var data = "
";
33 |     const rl = readline.createInterface({
34 |         input: fs.createReadStream(file),
35 |         terminal: false
36 |     });
37 |     rl.on('line', function(l) {
38 |         if (l.search(match) != -1) {
39 |             data += l;
40 |             data += "\n";
41 |         }
42 |     });
43 |     rl.on('close', function() {
44 |         data += "
"; 45 | response.send(data); 46 | }); 47 | }); 48 | 49 | app.listen(port, function() { 50 | console.log("Listening on " + port); 51 | console.log("Logfile = " + file); 52 | }); 53 | -------------------------------------------------------------------------------- /deploy/mk-endpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | docker build -f deploy/Dockerfile --no-cache -t mozilla/minq --build-arg SERVERNAME=minq.dev.mozaws.net . 3 | docker tag mozilla/minq:latest mozilla/minq:$(git rev-parse HEAD) 4 | 5 | 6 | -------------------------------------------------------------------------------- /deploy/mk-localhost.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | docker build --no-cache -f deploy/Dockerfile -t minq --build-arg SERVERNAME=localhost . 3 | 4 | -------------------------------------------------------------------------------- /deploy/run-local.sh: -------------------------------------------------------------------------------- 1 | docker run --name minq --rm --publish 4433:4433/udp --publish 3000:3000 minq:latest 2 | -------------------------------------------------------------------------------- /deploy/run-looped.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | nodejs /go/src/github.com/ekr/minq/deploy/logserver/server.js /tmp/minq.log & 3 | while true; do 4 | echo -n "Starting server as " 5 | echo ${SNAME} 6 | MINQ_LOG=connection,packet /go/bin/server -addr 0.0.0.0:4433 -server-name ${SNAME} -log /tmp/minq.log -http -standalone 7 | echo "Server crashed" 8 | done 9 | 10 | 11 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Errors which don't necesarily cause connection teardown. 8 | type intError struct { 9 | err string 10 | sub string 11 | fatal bool 12 | } 13 | 14 | func (e intError) Error() string { 15 | return e.err 16 | } 17 | 18 | func fatalError(format string, args ...interface{}) error { 19 | return intError{ 20 | fmt.Sprintf(format, args...), 21 | "", 22 | true, 23 | } 24 | } 25 | 26 | func internalError(format string, args ...interface{}) error { 27 | str := fmt.Sprintf(format, args...) 28 | if debug { 29 | panic("Internal error: " + str) 30 | } 31 | 32 | return intError{ 33 | str, 34 | "", 35 | true, 36 | } 37 | } 38 | 39 | func nonFatalError(format string, args ...interface{}) error { 40 | return intError{ 41 | fmt.Sprintf(format, args...), 42 | "", 43 | false, 44 | } 45 | } 46 | 47 | func err2string(err interface{}) string { 48 | switch e := err.(type) { 49 | case error: 50 | return e.Error() 51 | case string: 52 | return e 53 | default: 54 | panic("Bogus argument to err2string") 55 | } 56 | } 57 | 58 | func wrapE(err interface{}, sub interface{}) error { 59 | return intError{ 60 | err2string(err), 61 | err2string(sub), 62 | isFatalError(err), 63 | } 64 | } 65 | 66 | // An error is fatal if either. 67 | // 68 | // It's a regular error (i.e., not an intError) 69 | // e.fatal is true 70 | func isFatalError(e interface{}) bool { 71 | if e == nil { 72 | return false 73 | } 74 | 75 | i, ok := e.(intError) 76 | if !ok { 77 | return true 78 | } 79 | 80 | return i.fatal 81 | } 82 | 83 | // Return codes. 84 | var ErrorWouldBlock = nonFatalError("Would have blocked (QUIC)") 85 | var ErrorDestroyConnection = fatalError("Terminate connection") 86 | var ErrorReceivedVersionNegotiation = fatalError("Received a version negotiation packet advertising a different version than ours") 87 | var ErrorConnIsClosed = fatalError("Connection is closed") 88 | var ErrorConnIsClosing = nonFatalError("Connection is closing") 89 | var ErrorStreamReset = fatalError("Stream was reset") 90 | var ErrorStreamIsClosed = fatalError("Stream is closed") 91 | var ErrorInvalidPacket = nonFatalError("Invalid packet") 92 | var ErrorConnectionTimedOut = fatalError("Connection timed out") 93 | var ErrorMissingValue = fatalError("Expected value is missing") 94 | var ErrorInvalidEncoding = fatalError("Invalid encoding") 95 | var ErrorProtocolViolation = fatalError("Protocol violation") 96 | var ErrorFrameFormatError = fatalError("Frame format error") 97 | var ErrorFlowControlError = fatalError("Flow control error") 98 | 99 | // Protocol errors 100 | type ErrorCode uint16 101 | 102 | const ( 103 | kQuicErrorNoError = ErrorCode(0x0000) 104 | kQuicErrorProtocolViolation = ErrorCode(0x000A) 105 | ) 106 | -------------------------------------------------------------------------------- /frame.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/bifurcation/mint/syntax" 8 | ) 9 | 10 | type frameType uint8 11 | 12 | type frameNonSyntax interface { 13 | unmarshal(b []byte) (int, error) 14 | } 15 | 16 | const ( 17 | kFrameTypePadding = frameType(0x0) 18 | kFrameTypeRstStream = frameType(0x1) 19 | kFrameTypeConnectionClose = frameType(0x2) 20 | kFrameTypeApplicationClose = frameType(0x3) 21 | kFrameTypeMaxData = frameType(0x4) 22 | kFrameTypeMaxStreamData = frameType(0x5) 23 | kFrameTypeMaxStreamId = frameType(0x6) 24 | kFrameTypePing = frameType(0x7) 25 | kFrameTypeBlocked = frameType(0x8) 26 | kFrameTypeStreamBlocked = frameType(0x9) 27 | kFrameTypeStreamIdBlocked = frameType(0xa) 28 | kFrameTypeNewConnectionId = frameType(0xb) 29 | kFrameTypeStopSending = frameType(0xc) 30 | kFrameTypeAck = frameType(0x1a) 31 | kFrameTypeAckECN = frameType(0x1b) 32 | kFrameTypePathChallenge = frameType(0xe) 33 | kFrameTypePathResponse = frameType(0xf) 34 | kFrameTypeStream = frameType(0x10) 35 | kFrameTypeStreamMax = frameType(0x17) 36 | kFrameTypeCryptoHs = frameType(0x18) 37 | ) 38 | 39 | const ( 40 | kFrameTypeStreamFlagFIN = frameType(0x01) 41 | kFrameTypeStreamFlagLEN = frameType(0x02) 42 | kFrameTypeStreamFlagOFF = frameType(0x04) 43 | ) 44 | 45 | const ( 46 | // Assume maximal sizes for these. 47 | kMaxAckHeaderLength = 33 48 | kMaxAckBlockEntryLength = 16 49 | kMaxAckGap = 255 50 | kMaxAckBlocks = 255 51 | ) 52 | 53 | type innerFrame interface { 54 | getType() frameType 55 | String() string 56 | } 57 | 58 | type frame struct { 59 | stream uint64 60 | f innerFrame 61 | encoded []byte 62 | pns []uint64 63 | lostPns []uint64 64 | time time.Time 65 | needsTransmit bool 66 | } 67 | 68 | func (f frame) String() string { 69 | return f.f.String() 70 | } 71 | 72 | func newFrame(stream uint64, inner innerFrame) *frame { 73 | return &frame{stream, inner, nil, nil, nil, time.Unix(0, 0), true} 74 | } 75 | 76 | // Encode internally if not already encoded. 77 | func (f *frame) encode() error { 78 | if f.encoded != nil { 79 | return nil 80 | } 81 | var err error 82 | f.encoded, err = syntax.Marshal(f.f) 83 | logf(logTypeFrame, "Frame encoded, total length=%v", len(f.encoded)) 84 | return err 85 | } 86 | 87 | func (f *frame) length() (int, error) { 88 | err := f.encode() 89 | if err != nil { 90 | return 0, err 91 | } 92 | return len(f.encoded), nil 93 | } 94 | 95 | // Decode an arbitrary frame. 96 | func decodeFrame(data []byte) (uintptr, *frame, error) { 97 | var inner innerFrame 98 | var n int 99 | var err error 100 | 101 | t := data[0] 102 | logf(logTypeFrame, "Frame type byte %v", t) 103 | switch { 104 | case t == uint8(kFrameTypePadding): 105 | inner = &paddingFrame{} 106 | case t == uint8(kFrameTypeRstStream): 107 | inner = &rstStreamFrame{} 108 | case t == uint8(kFrameTypeConnectionClose): 109 | inner = &connectionCloseFrame{} 110 | case t == uint8(kFrameTypeApplicationClose): 111 | inner = &applicationCloseFrame{} 112 | case t == uint8(kFrameTypeMaxData): 113 | inner = &maxDataFrame{} 114 | case t == uint8(kFrameTypeMaxStreamData): 115 | inner = &maxStreamDataFrame{} 116 | case t == uint8(kFrameTypeMaxStreamId): 117 | inner = &maxStreamIdFrame{} 118 | case t == uint8(kFrameTypePing): 119 | inner = &pingFrame{} 120 | case t == uint8(kFrameTypeBlocked): 121 | inner = &blockedFrame{} 122 | case t == uint8(kFrameTypeStreamBlocked): 123 | inner = &streamBlockedFrame{} 124 | case t == uint8(kFrameTypeStreamIdBlocked): 125 | inner = &streamIdBlockedFrame{} 126 | case t == uint8(kFrameTypeNewConnectionId): 127 | inner = &newConnectionIdFrame{} 128 | case t == uint8(kFrameTypeStopSending): 129 | inner = &stopSendingFrame{} 130 | case t == uint8(kFrameTypeAck): 131 | inner = &ackFrame{} 132 | case t == uint8(kFrameTypePathChallenge): 133 | inner = &pathChallengeFrame{} 134 | case t == uint8(kFrameTypePathResponse): 135 | inner = &pathResponseFrame{} 136 | case t >= uint8(kFrameTypeStream) && t <= uint8(kFrameTypeStreamMax): 137 | inner = &streamFrame{} 138 | case t == uint8(kFrameTypeCryptoHs): 139 | inner = &cryptoHsFrame{} 140 | default: 141 | logf(logTypeConnection, "Unknown frame type %v", t) 142 | return 0, nil, fmt.Errorf("Received unknown frame type: %v", t) 143 | } 144 | 145 | ns, ok := inner.(frameNonSyntax) 146 | if ok { 147 | n, err = ns.unmarshal(data) 148 | 149 | } else { 150 | n, err = syntax.Unmarshal(data, inner) 151 | } 152 | if err != nil { 153 | return 0, nil, err 154 | } 155 | 156 | return uintptr(n), &frame{0, inner, data[:n], nil, nil, time.Now(), false}, nil 157 | } 158 | 159 | // Frame definitions below this point. 160 | 161 | // PADDING 162 | type paddingFrame struct { 163 | Typ frameType 164 | } 165 | 166 | func (f paddingFrame) String() string { 167 | return "P" 168 | } 169 | 170 | func (f paddingFrame) getType() frameType { 171 | return kFrameTypePadding 172 | } 173 | 174 | func newPaddingFrame(stream uint64) *frame { 175 | return newFrame(stream, &paddingFrame{0}) 176 | } 177 | 178 | // RST_STREAM 179 | type rstStreamFrame struct { 180 | Type frameType 181 | StreamId uint64 `tls:"varint"` 182 | ErrorCode uint16 183 | FinalOffset uint64 `tls:"varint"` 184 | } 185 | 186 | func (f rstStreamFrame) String() string { 187 | return fmt.Sprintf("RST_STREAM stream=%x errorCode=%d finalOffset=%x", f.StreamId, f.ErrorCode, f.FinalOffset) 188 | } 189 | 190 | func (f rstStreamFrame) getType() frameType { 191 | return kFrameTypeRstStream 192 | } 193 | 194 | func newRstStreamFrame(streamId uint64, errorCode uint16, finalOffset uint64) *frame { 195 | return newFrame(streamId, &rstStreamFrame{ 196 | kFrameTypeRstStream, 197 | uint64(streamId), 198 | errorCode, 199 | finalOffset}) 200 | } 201 | 202 | // STOP_SENDING 203 | type stopSendingFrame struct { 204 | Type frameType 205 | StreamId uint64 `tls:"varint"` 206 | ErrorCode uint16 207 | } 208 | 209 | func (f stopSendingFrame) String() string { 210 | return fmt.Sprintf("STOP_SENDING stream=%x errorCode=%d", f.StreamId, f.ErrorCode) 211 | } 212 | 213 | func (f stopSendingFrame) getType() frameType { 214 | return kFrameTypeStopSending 215 | } 216 | 217 | func newStopSendingFrame(streamId uint64, errorCode uint16) *frame { 218 | return newFrame(streamId, &stopSendingFrame{ 219 | kFrameTypeStopSending, 220 | uint64(streamId), 221 | errorCode}) 222 | } 223 | 224 | // CONNECTION_CLOSE 225 | type connectionCloseFrame struct { 226 | Type frameType 227 | ErrorCode uint16 228 | ReasonPhrase []byte `tls:"head=varint"` 229 | } 230 | 231 | func (f connectionCloseFrame) String() string { 232 | return fmt.Sprintf("CONNECTION_CLOSE errorCode=%x", f.ErrorCode) 233 | } 234 | 235 | func (f connectionCloseFrame) getType() frameType { 236 | return kFrameTypeConnectionClose 237 | } 238 | 239 | func newConnectionCloseFrame(errcode ErrorCode, reason string) *frame { 240 | return newFrame(0, &connectionCloseFrame{ 241 | kFrameTypeConnectionClose, 242 | uint16(errcode), 243 | []byte(reason), 244 | }) 245 | } 246 | 247 | // APPLICATION_CLOSE 248 | type applicationCloseFrame struct { 249 | Type frameType 250 | ErrorCode uint16 251 | ReasonPhrase []byte `tls:"head=varint"` 252 | } 253 | 254 | func (f applicationCloseFrame) String() string { 255 | return fmt.Sprintf("APPLICATION_CLOSE errorCode=%x", f.ErrorCode) 256 | } 257 | 258 | func (f applicationCloseFrame) getType() frameType { 259 | return kFrameTypeApplicationClose 260 | } 261 | 262 | func newApplicationCloseFrame(errcode uint16, reason string) *frame { 263 | return newFrame(0, &applicationCloseFrame{ 264 | kFrameTypeApplicationClose, 265 | uint16(errcode), 266 | []byte(reason), 267 | }) 268 | } 269 | 270 | // MAX_DATA 271 | type maxDataFrame struct { 272 | Type frameType 273 | MaximumData uint64 `tls:"varint"` 274 | } 275 | 276 | func (f maxDataFrame) String() string { 277 | return fmt.Sprintf("MAX_DATA %d", f.MaximumData) 278 | } 279 | 280 | func (f maxDataFrame) getType() frameType { 281 | return kFrameTypeMaxData 282 | } 283 | 284 | func newMaxData(m uint64) *frame { 285 | return newFrame(0, &maxDataFrame{kFrameTypeMaxData, m}) 286 | } 287 | 288 | // MAX_STREAM_DATA 289 | type maxStreamDataFrame struct { 290 | Type frameType 291 | StreamId uint64 `tls:"varint"` 292 | MaximumStreamData uint64 `tls:"varint"` 293 | } 294 | 295 | func newMaxStreamData(stream uint64, offset uint64) *frame { 296 | return newFrame(stream, 297 | &maxStreamDataFrame{ 298 | kFrameTypeMaxStreamData, 299 | stream, 300 | offset, 301 | }) 302 | } 303 | 304 | func (f maxStreamDataFrame) String() string { 305 | return fmt.Sprintf("MAX_STREAM_DATA stream=%d %d", f.StreamId, f.MaximumStreamData) 306 | } 307 | 308 | func (f maxStreamDataFrame) getType() frameType { 309 | return kFrameTypeMaxStreamData 310 | } 311 | 312 | // MAX_STREAM_ID 313 | type maxStreamIdFrame struct { 314 | Type frameType 315 | MaximumStreamId uint64 `tls:"varint"` 316 | } 317 | 318 | func newMaxStreamId(id uint64) *frame { 319 | return newFrame(0, 320 | &maxStreamIdFrame{ 321 | kFrameTypeMaxStreamId, 322 | id, 323 | }) 324 | } 325 | 326 | func (f maxStreamIdFrame) String() string { 327 | return fmt.Sprintf("MAX_STREAM_ID %d", f.MaximumStreamId) 328 | } 329 | 330 | func (f maxStreamIdFrame) getType() frameType { 331 | return kFrameTypeMaxStreamId 332 | } 333 | 334 | // PING 335 | type pingFrame struct { 336 | Type frameType 337 | } 338 | 339 | func (f pingFrame) String() string { 340 | return "PING" 341 | } 342 | 343 | func (f pingFrame) getType() frameType { 344 | return kFrameTypePing 345 | } 346 | 347 | // BLOCKED 348 | type blockedFrame struct { 349 | Type frameType 350 | Offset uint64 `tls:"varint"` 351 | } 352 | 353 | func (f blockedFrame) String() string { 354 | return "BLOCKED" 355 | } 356 | 357 | func (f blockedFrame) getType() frameType { 358 | return kFrameTypeBlocked 359 | } 360 | 361 | func newBlockedFrame(offset uint64) *frame { 362 | return newFrame(0, &blockedFrame{kFrameTypeBlocked, offset}) 363 | } 364 | 365 | // STREAM_BLOCKED 366 | type streamBlockedFrame struct { 367 | Type frameType 368 | StreamId uint64 `tls:"varint"` 369 | Offset uint64 `tls:"varint"` 370 | } 371 | 372 | func (f streamBlockedFrame) String() string { 373 | return "STREAM_BLOCKED" 374 | } 375 | 376 | func (f streamBlockedFrame) getType() frameType { 377 | return kFrameTypeStreamBlocked 378 | } 379 | 380 | func newStreamBlockedFrame(id uint64, offset uint64) *frame { 381 | return newFrame(0, &streamBlockedFrame{kFrameTypeStreamBlocked, id, offset}) 382 | } 383 | 384 | // STREAM_ID_BLOCKED 385 | type streamIdBlockedFrame struct { 386 | Type frameType 387 | StreamId uint64 `tls:"varint"` 388 | } 389 | 390 | func (f streamIdBlockedFrame) String() string { 391 | return "STREAM_ID_BLOCKED" 392 | } 393 | 394 | func (f streamIdBlockedFrame) getType() frameType { 395 | return kFrameTypeStreamIdBlocked 396 | } 397 | 398 | func newStreamIdBlockedFrame(id uint64) *frame { 399 | return newFrame(0, &streamIdBlockedFrame{ 400 | kFrameTypeStreamIdBlocked, 401 | id}) 402 | } 403 | 404 | // NEW_CONNECTION_ID 405 | type newConnectionIdFrame struct { 406 | Type frameType 407 | Sequence uint16 `tls:"varint"` 408 | ConnectionId ConnectionId 409 | ResetToken [16]byte 410 | } 411 | 412 | func (f newConnectionIdFrame) String() string { 413 | return fmt.Sprintf("NEW_CONNECTION_ID %d=%x", f.Sequence, f.ConnectionId) 414 | } 415 | 416 | func (f newConnectionIdFrame) getType() frameType { 417 | return kFrameTypeNewConnectionId 418 | } 419 | 420 | func newNewConnectionIdFrame(seq uint16, cid ConnectionId, resetToken []byte) *frame { 421 | f := &newConnectionIdFrame{ 422 | Type: kFrameTypeNewConnectionId, 423 | Sequence: seq, 424 | ConnectionId: cid, 425 | } 426 | assert(len(resetToken) == len(f.ResetToken)) 427 | copy(f.ResetToken[:], resetToken) 428 | return newFrame(0, f) 429 | } 430 | 431 | // ACK 432 | type ackBlock struct { 433 | Gap uint64 `tls:"varint"` 434 | Length uint64 `tls:"varint"` 435 | } 436 | 437 | type ackFrameHeader struct { 438 | Type frameType 439 | LargestAcknowledged uint64 `tls:"varint"` 440 | AckDelay uint64 `tls:"varint"` 441 | AckBlockCount uint64 `tls:"varint"` 442 | FirstAckBlock uint64 `tls:"varint"` 443 | } 444 | 445 | type ackFrame struct { 446 | ackFrameHeader 447 | AckBlockSection []*ackBlock `tls:"head=none"` 448 | } 449 | 450 | func (f ackFrame) String() string { 451 | return fmt.Sprintf("ACK numBlocks=%d largestAck=%x", f.AckBlockCount, f.LargestAcknowledged) 452 | } 453 | 454 | func (f ackFrame) getType() frameType { 455 | return kFrameTypeAck 456 | } 457 | 458 | // ACK frames can't presently be decoded with syntax, so we need 459 | // a custom decoder. 460 | func (f *ackFrame) unmarshal(buf []byte) (int, error) { 461 | // First, decode the header 462 | read := int(0) 463 | n, err := syntax.Unmarshal(buf, &f.ackFrameHeader) 464 | if err != nil { 465 | return 0, err 466 | } 467 | buf = buf[n:] 468 | read += n 469 | 470 | // Now decode each block 471 | for i := uint64(0); i < f.AckBlockCount; i++ { 472 | blk := &ackBlock{} 473 | n, err := syntax.Unmarshal(buf, blk) 474 | if err != nil { 475 | return 0, err 476 | } 477 | buf = buf[n:] 478 | read += n 479 | 480 | f.AckBlockSection = append(f.AckBlockSection, blk) 481 | } 482 | 483 | return read, nil 484 | } 485 | 486 | func newAckFrame(recvd *recvdPackets, rs ackRanges, left int) (*frame, int, error) { 487 | if left < kMaxAckHeaderLength { 488 | return nil, 0, nil 489 | } 490 | logf(logTypeFrame, "Making ACK frame %v", rs) 491 | 492 | left -= kMaxAckHeaderLength 493 | 494 | last := rs[0].lastPacket 495 | largestAckData, ok := recvd.packets[last] 496 | // Should always be there. Packets only get removed after being set to ack2, 497 | // which means we should not be acking it again. 498 | assert(ok) 499 | 500 | // FIRST, fill in the basic info of the ACK frame 501 | var f ackFrame 502 | f.Type = kFrameTypeAck 503 | f.LargestAcknowledged = last 504 | delay := time.Since(largestAckData.t).Nanoseconds() 505 | f.AckDelay = uint64(delay) / 1000 >> kTpDefaultAckDelayExponent 506 | f.AckBlockCount = 0 507 | f.FirstAckBlock = rs[0].count - 1 508 | 509 | // ...and account for the first block. 510 | last -= f.FirstAckBlock 511 | addedRanges := 1 512 | 513 | // SECOND, add the remaining ACK blocks that fit and that we have 514 | for (left > 0) && (addedRanges < len(rs)) { 515 | // calculate blocks needed for the next range 516 | gap := last - rs[addedRanges].lastPacket - 1 517 | 518 | gap = last - rs[addedRanges].lastPacket - 1 519 | b := &ackBlock{ 520 | gap, 521 | rs[addedRanges].count - 1, 522 | } 523 | 524 | last = rs[addedRanges].lastPacket - rs[addedRanges].count 525 | 526 | f.AckBlockCount++ 527 | f.AckBlockSection = append(f.AckBlockSection, b) 528 | addedRanges++ 529 | left -= kMaxAckBlockEntryLength // Assume worst-case. 530 | } 531 | 532 | return newFrame(0, &f), addedRanges, nil 533 | } 534 | 535 | // PATH_CHALLENGE 536 | type pathChallengeFrame struct { 537 | Type frameType 538 | Data [8]byte 539 | } 540 | 541 | func (f pathChallengeFrame) String() string { 542 | return "PATH_CHALLENGE" 543 | } 544 | 545 | func (f pathChallengeFrame) getType() frameType { 546 | return kFrameTypePathChallenge 547 | } 548 | 549 | func newPathChallengeFrame(data []byte) *frame { 550 | payload := &pathChallengeFrame{Type: kFrameTypePathChallenge} 551 | assert(len(data) == len(payload.Data)) 552 | copy(payload.Data[:], data) 553 | return newFrame(0, payload) 554 | } 555 | 556 | // PATH_RESPONSE 557 | type pathResponseFrame struct { 558 | Type frameType 559 | Data [8]byte 560 | } 561 | 562 | func (f pathResponseFrame) String() string { 563 | return "PATH_RESPONSE" 564 | } 565 | 566 | func (f pathResponseFrame) getType() frameType { 567 | return kFrameTypePathResponse 568 | } 569 | 570 | func newPathResponseFrame(data []byte) *frame { 571 | payload := &pathResponseFrame{Type: kFrameTypePathResponse} 572 | assert(len(data) == len(payload.Data)) 573 | copy(payload.Data[:], data) 574 | return newFrame(0, payload) 575 | } 576 | 577 | // STREAM 578 | type streamFrame struct { 579 | Typ frameType 580 | StreamId uint64 `tls:"varint"` 581 | Offset uint64 `tls:"varint"` 582 | Data []byte `tls:"head=varint"` 583 | } 584 | 585 | func (f streamFrame) String() string { 586 | return fmt.Sprintf("STREAM stream=%d offset=%d len=%d FIN=%v", f.StreamId, f.Offset, len(f.Data), f.hasFin()) 587 | } 588 | 589 | func (f streamFrame) getType() frameType { 590 | return kFrameTypeStream 591 | } 592 | 593 | func (f streamFrame) hasFin() bool { 594 | if f.Typ&kFrameTypeStreamFlagFIN == 0 { 595 | return false 596 | } 597 | return true 598 | } 599 | 600 | func newStreamFrame(stream uint64, offset uint64, data []byte, last bool) *frame { 601 | logf(logTypeFrame, "Creating stream frame with data length=%d", len(data)) 602 | assert(len(data) <= 65535) 603 | // TODO(ekr@tfm.com): One might want to allow non 604 | // D bit, but not for now. 605 | // Set all of SSOO to 1 606 | typ := kFrameTypeStream | kFrameTypeStreamFlagLEN | kFrameTypeStreamFlagOFF 607 | if last { 608 | typ |= kFrameTypeStreamFlagFIN 609 | } 610 | return newFrame( 611 | stream, 612 | &streamFrame{ 613 | typ, 614 | stream, 615 | offset, 616 | dup(data), 617 | }) 618 | } 619 | 620 | func decodeVarint(buf []byte) (int, uint64, error) { 621 | var vi struct { 622 | Val uint64 `tls:"varint"` 623 | } 624 | 625 | n, err := syntax.Unmarshal(buf, &vi) 626 | if err != nil { 627 | return 0, 0, err 628 | } 629 | 630 | return n, vi.Val, nil 631 | } 632 | 633 | // Stream frames can't presently be decoded with syntax, so we need 634 | // a custom decoder. 635 | func (f *streamFrame) unmarshal(buf []byte) (int, error) { 636 | f.Typ = frameType(buf[0]) 637 | buf = buf[1:] 638 | var read = int(1) 639 | var n int 640 | var err error 641 | 642 | n, f.StreamId, err = decodeVarint(buf) 643 | if err != nil { 644 | return 0, err 645 | } 646 | buf = buf[n:] 647 | read += n 648 | 649 | if f.Typ&kFrameTypeStreamFlagOFF != 0 { 650 | n, f.Offset, err = decodeVarint(buf) 651 | if err != nil { 652 | return 0, err 653 | } 654 | buf = buf[n:] 655 | read += n 656 | } 657 | 658 | if f.Typ&kFrameTypeStreamFlagLEN != 0 { 659 | var l uint64 660 | n, l, err = decodeVarint(buf) 661 | if err != nil { 662 | return 0, err 663 | } 664 | buf = buf[n:] 665 | read += n 666 | 667 | logf(logTypeFrame, "Expecting %v bytes", l) 668 | 669 | if l > uint64(len(buf)) { 670 | return 0, fmt.Errorf("Insufficient bytes left") 671 | } 672 | f.Data = dup(buf[:l]) 673 | read += int(l) 674 | } else { 675 | f.Data = dup(buf) 676 | read += len(buf) 677 | } 678 | 679 | return read, nil 680 | } 681 | 682 | // CRYPTO_HS 683 | type cryptoHsFrame struct { 684 | Typ frameType 685 | Offset uint64 `tls:"varint"` 686 | Data []byte `tls:"head=varint"` 687 | } 688 | 689 | func (f cryptoHsFrame) getType() frameType { 690 | return kFrameTypeCryptoHs 691 | } 692 | 693 | func (f cryptoHsFrame) String() string { 694 | return fmt.Sprintf("CRYPTO_HS len=%d", len(f.Data)) 695 | } 696 | 697 | func newCryptoHsFrame(offset uint64, data []byte) *frame { 698 | logf(logTypeFrame, "Creating crypto_hs frame with data length=%d", len(data)) 699 | 700 | return newFrame( 701 | 0, 702 | &cryptoHsFrame{ 703 | kFrameTypeCryptoHs, 704 | offset, 705 | dup(data), 706 | }, 707 | ) 708 | } 709 | -------------------------------------------------------------------------------- /frame_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | func testEncodeDecodeEncode(t *testing.T, f *frame) { 10 | err := f.encode() 11 | assertNotError(t, err, "Encode failed") 12 | fmt.Printf("Encoded: [%x]\n", f.encoded) 13 | 14 | consumed, f2, err := decodeFrame(f.encoded) 15 | assertNotError(t, err, "Failed to decode frame") 16 | assertEquals(t, len(f.encoded), int(consumed)) 17 | f2.encoded = nil // So we re-encode 18 | 19 | err = f2.encode() 20 | assertNotError(t, err, "Encode failed") 21 | assertByteEquals(t, f.encoded, f2.encoded) 22 | 23 | fmt.Printf("%+v\n", f2) 24 | } 25 | 26 | func TestStreamFrame(t *testing.T) { 27 | s := newStreamFrame(1, 0, 28 | bytes.Repeat([]byte{0xa0}, 100), false) 29 | testEncodeDecodeEncode(t, s) 30 | } 31 | 32 | func TestAckFrameOneRange(t *testing.T) { 33 | ar := []ackRange{{0xdeadbeef, 2}} 34 | 35 | recvd := newRecvdPackets(logf) 36 | recvd.init(ar[0].lastPacket) 37 | recvd.packetSetReceived(ar[0].lastPacket, false, false) 38 | 39 | f, _, err := newAckFrame(recvd, ar, 33) 40 | assertNotError(t, err, "Couldn't make ack frame") 41 | 42 | testEncodeDecodeEncode(t, f) 43 | } 44 | 45 | func TestAckFrameTwoRanges(t *testing.T) { 46 | ar := []ackRange{{0xdeadbeef, 2}, {0xdeadbee0, 1}} 47 | 48 | recvd := newRecvdPackets(logf) 49 | recvd.init(ar[0].lastPacket) 50 | recvd.packetSetReceived(ar[0].lastPacket, false, false) 51 | 52 | f, _, err := newAckFrame(recvd, ar, 49) 53 | assertNotError(t, err, "Couldn't make ack frame") 54 | 55 | testEncodeDecodeEncode(t, f) 56 | } 57 | 58 | func TestFixedSizedData(t *testing.T) { 59 | f := newPathChallengeFrame([]byte{1, 2, 3, 4, 5, 6, 7, 8}) 60 | testEncodeDecodeEncode(t, f) 61 | f = newPathResponseFrame([]byte{10, 9, 8, 7, 6, 5, 4, 3}) 62 | testEncodeDecodeEncode(t, f) 63 | } 64 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | // Lightly modified from Mint 2 | 3 | package minq 4 | 5 | import ( 6 | "fmt" 7 | "log" 8 | "os" 9 | "strings" 10 | ) 11 | 12 | // We use this environment variable to control logging. It should be a 13 | // comma-separated list of log tags (see below) or "*" to enable all logging. 14 | const logConfigVar = "MINQ_LOG" 15 | 16 | // Pre-defined log types 17 | const ( 18 | logTypeAead = "aead" 19 | logTypeCodec = "codec" 20 | logTypeConnBuffer = "connbuffer" 21 | logTypeConnection = "connection" 22 | logTypeAck = "ack" 23 | logTypeFrame = "frame" 24 | logTypeHandshake = "handshake" 25 | logTypeTls = "tls" 26 | logTypeTrace = "trace" 27 | logTypeServer = "server" 28 | logTypeUdp = "udp" 29 | logTypeStream = "stream" 30 | logTypeFlowControl = "flow" 31 | logTypePacket = "packet" // Just send notes on which packets are sent and received 32 | logTypeCongestion = "congestion" 33 | ) 34 | 35 | var ( 36 | logFunction = log.Printf 37 | logAll = false 38 | logSettings = map[string]bool{} 39 | ) 40 | 41 | func init() { 42 | parseLogEnv(os.Environ()) 43 | } 44 | 45 | func parseLogEnv(env []string) { 46 | for _, stmt := range env { 47 | if strings.HasPrefix(stmt, logConfigVar+"=") { 48 | val := stmt[len(logConfigVar)+1:] 49 | 50 | if val == "*" { 51 | logAll = true 52 | } else { 53 | for _, t := range strings.Split(val, ",") { 54 | logSettings[t] = true 55 | } 56 | } 57 | } 58 | } 59 | } 60 | 61 | func logf(tag string, format string, args ...interface{}) { 62 | if logAll || logSettings[tag] { 63 | fullFormat := fmt.Sprintf("[%s] %s", tag, format) 64 | logFunction(fullFormat, args...) 65 | } 66 | } 67 | 68 | type loggingFunction func(string, string, ...interface{}) 69 | 70 | func SetLogOutput(f func(string, ...interface{})) { 71 | logFunction = f 72 | } 73 | 74 | func newConnectionLogger(c *Connection) loggingFunction { 75 | return func(tag string, format string, args ...interface{}) { 76 | if logAll || logSettings[tag] { 77 | logf(tag, c.String()+": "+format, args...) 78 | } 79 | } 80 | } 81 | 82 | func newStreamLogger(id uint64, dir string, f loggingFunction) loggingFunction { 83 | extra := fmt.Sprintf("%s stream %d: ", dir, id) 84 | return func(tag string, format string, args ...interface{}) { 85 | f(tag, extra+format, args...) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /minq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekr/minq/a5bd85261472f54ed1ba18b1a127d929c2d4782a/minq.png -------------------------------------------------------------------------------- /minq.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | image/svg+xml 98 | -------------------------------------------------------------------------------- /packet.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "encoding/hex" 8 | "fmt" 9 | ) 10 | 11 | // Encode a QUIC packet. 12 | /* 13 | Long header 14 | 15 | 0 1 2 3 16 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 17 | +-+-+-+-+-+-+-+-+ 18 | |1| Type (7) | 19 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 20 | | Version (32) | 21 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 22 | |DCIL(4)|SCIL(4)| 23 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 24 | | Destination Connection ID (0/32..144) ... 25 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 26 | | Source Connection ID (0/32..144) ... 27 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 28 | | Payload Length (i) ... 29 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 30 | | Packet Number (8/16/32) | 31 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 32 | | Payload (*) ... 33 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 34 | 35 | // Initial Header: same as long header but with Token 36 | +-+-+-+-+-+-+-+-+ 37 | |1| 0x7f | 38 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 39 | | Version (32) | 40 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 41 | |DCIL(4)|SCIL(4)| 42 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 43 | | Destination Connection ID (0/32..144) ... 44 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 45 | | Source Connection ID (0/32..144) ... 46 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 47 | | Token Length (i) ... 48 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 49 | | Token (*) ... 50 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 51 | | Length (i) ... 52 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 53 | | Packet Number (8/16/32) | 54 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 55 | | Payload (*) ... 56 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 57 | 58 | 0 1 2 3 59 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 60 | +-+-+-+-+-+-+-+-+ 61 | |0|K|1|1|0|R R R| 62 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 63 | | Destination Connection ID (0..144) ... 64 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 65 | | Packet Number (8/16/32) ... 66 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 67 | | Protected Payload (*) ... 68 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 69 | */ 70 | 71 | const ( 72 | packetFlagLongHeader = byte(0x80) 73 | packetFlagK = byte(0x40) 74 | packetFlagShortHeader = byte(0x30) 75 | ) 76 | 77 | // This packet type differs considerably from the spec. It includes both 78 | // long and short headers in the same value space. Long headers are from 79 | // 0-0x7f (inclusive); short headers are always represented as 0xff. 80 | type packetType byte 81 | 82 | const ( 83 | packetTypeInitial = packetType(0x7f) 84 | packetTypeRetry = packetType(0x7e) 85 | packetTypeHandshake = packetType(0x7d) 86 | packetType0RTTProtected = packetType(0x7c) 87 | packetTypeProtectedShort = packetType(0x00) // Not a real type 88 | ) 89 | 90 | func (pt packetType) isLongHeader() bool { 91 | return pt&packetType(packetFlagLongHeader) != 0 92 | } 93 | 94 | func (pt packetType) isProtected() bool { 95 | if !pt.isLongHeader() { 96 | return true 97 | } 98 | 99 | switch pt & 0x7f { 100 | case packetTypeInitial, packetTypeHandshake, packetTypeRetry: 101 | return false 102 | } 103 | return true 104 | } 105 | 106 | func (pt packetType) String() string { 107 | switch pt { 108 | case packetTypeInitial: 109 | return "Initial" 110 | case packetTypeRetry: 111 | return "Retry" 112 | case packetTypeHandshake: 113 | return "Handshake" 114 | case packetType0RTTProtected: 115 | return "0-RTT" 116 | case packetTypeProtectedShort: 117 | return "1-RTT" 118 | default: 119 | return fmt.Sprintf("%x", uint8(pt)) 120 | } 121 | } 122 | 123 | // kCidDefaultLength is the length of connection ID we generate. 124 | // TODO: make this configurable. 125 | const kCidDefaultLength = 5 126 | 127 | // ConnectionId identifies the connection that a packet belongs to. 128 | type ConnectionId []byte 129 | 130 | // String stringifies a connection ID in the natural way. 131 | func (c ConnectionId) String() string { 132 | return hex.EncodeToString(c) 133 | } 134 | 135 | // EncodeLength produces the length encoding used in the long packet header. 136 | func (c ConnectionId) EncodeLength() byte { 137 | if len(c) == 0 { 138 | return 0 139 | } 140 | assert(len(c) >= 4 && len(c) <= 18) 141 | return byte(len(c) - 3) 142 | } 143 | 144 | // The PDU definition for the header. 145 | // These types are capitalized so that |codec| can use them. 146 | type packetHeader struct { 147 | // Type is the on-the-wire form of the packet type. 148 | // Consult getHeaderType if you want a value that corresponds to the 149 | // definition of packetType. 150 | Type packetType 151 | Version VersionNumber 152 | ConnectionIDLengths byte 153 | DestinationConnectionID ConnectionId 154 | SourceConnectionID ConnectionId 155 | TokenLength uint8 156 | Token []byte 157 | PayloadLength uint64 `tls:"varint"` 158 | 159 | // In order to decode a short header, the length of the connection 160 | // ID must be set in |shortCidLength| before decoding. 161 | shortCidLength uintptr 162 | } 163 | 164 | func (p packetHeader) String() string { 165 | ht := "SHORT" 166 | if p.Type.isLongHeader() { 167 | ht = "LONG" 168 | } 169 | return fmt.Sprintf("%s PT=%v", ht, p.getHeaderType()) 170 | } 171 | 172 | func (p *packetHeader) getHeaderType() packetType { 173 | if p.Type.isLongHeader() { 174 | return p.Type & 0x7f 175 | } 176 | return packetTypeProtectedShort 177 | } 178 | 179 | type packet struct { 180 | packetHeader 181 | PacketNumber uint64 // Never more than 32 bits on the wire. 182 | payload []byte 183 | } 184 | 185 | // This reads from p.ConnectionIDLengths. 186 | func (p packetHeader) ConnectionIDLengths__length() uintptr { 187 | if p.Type.isLongHeader() { 188 | return 1 189 | } 190 | return 0 191 | } 192 | 193 | func (p packetHeader) TokenLength__length() uintptr { 194 | if p.getHeaderType() != packetTypeInitial { 195 | assert(len(p.Token) == 0) 196 | return 0 197 | } 198 | return 1 199 | } 200 | 201 | func (p packetHeader) Token__length() uintptr { 202 | if p.getHeaderType() != packetTypeInitial { 203 | assert(len(p.Token) == 0) 204 | return 0 205 | } 206 | return uintptr(p.TokenLength) 207 | } 208 | 209 | func (p packetHeader) DestinationConnectionID__length() uintptr { 210 | if !p.Type.isLongHeader() { 211 | return p.shortCidLength 212 | } 213 | l := p.ConnectionIDLengths >> 4 214 | if l != 0 { 215 | l += 3 216 | } 217 | return uintptr(l) 218 | } 219 | 220 | func (p packetHeader) SourceConnectionID__length() uintptr { 221 | if !p.Type.isLongHeader() { 222 | return 0 223 | } 224 | l := p.ConnectionIDLengths & 0xf 225 | if l != 0 { 226 | l += 3 227 | } 228 | return uintptr(l) 229 | } 230 | 231 | func (p packetHeader) PayloadLength__length() uintptr { 232 | if p.Type.isLongHeader() { 233 | return codecDefaultSize 234 | } 235 | return 0 236 | } 237 | 238 | func (p packetHeader) Version__length() uintptr { 239 | if p.Type.isLongHeader() { 240 | return 4 241 | } 242 | return 0 243 | } 244 | 245 | func newPacket(pt packetType, destCid ConnectionId, srcCid ConnectionId, ver VersionNumber, pn uint64, payload []byte, aeadOverhead int) *packet { 246 | if pt == packetTypeProtectedShort { 247 | // Only support writing the 32-bit packet number. 248 | pt = packetType(0x2 | packetFlagShortHeader) 249 | srcCid = nil 250 | } else { 251 | pt = pt | packetType(packetFlagLongHeader) 252 | } 253 | lengths := (destCid.EncodeLength() << 4) | srcCid.EncodeLength() 254 | return &packet{ 255 | packetHeader: packetHeader{ 256 | Type: pt, 257 | ConnectionIDLengths: lengths, 258 | DestinationConnectionID: destCid, 259 | SourceConnectionID: srcCid, 260 | Version: ver, 261 | PayloadLength: uint64(len(payload) + 4 + aeadOverhead), 262 | }, 263 | PacketNumber: pn, 264 | payload: payload, 265 | } 266 | } 267 | 268 | type versionNegotiationPacket struct { 269 | Versions []byte 270 | } 271 | 272 | func newVersionNegotiationPacket(versions []VersionNumber) *versionNegotiationPacket { 273 | var buf bytes.Buffer 274 | 275 | for _, v := range versions { 276 | buf.Write(encodeArgs(v)) 277 | } 278 | 279 | return &versionNegotiationPacket{buf.Bytes()} 280 | } 281 | 282 | /* 283 | We don't use these. 284 | 285 | func encodePacket(c ConnectionState, aead Aead, p *Packet) ([]byte, error) { 286 | hdr, err := encode(&p.packetHeader) 287 | if err != nil { 288 | return nil, err 289 | } 290 | 291 | b, err := aead.protect(p.packetHeader.PacketNumber, hdr, p.payload) 292 | if err != nil { 293 | return nil, err 294 | } 295 | 296 | return encodeArgs(hdr, b), nil 297 | } 298 | 299 | func decodePacket(c ConnectionState, aead Aead, b []byte) (*Packet, error) { 300 | // Parse the header 301 | var hdr packetHeader 302 | br, err := decode(&hdr, b) 303 | if err != nil { 304 | return nil, err 305 | } 306 | 307 | hdr.PacketNumber = c.expandPacketNumber(hdr.PacketNumber) 308 | pt, err := aead.unprotect(hdr.PacketNumber, b[0:br], b[br:]) 309 | if err != nil { 310 | return nil, err 311 | } 312 | 313 | return &Packet{hdr, pt}, nil 314 | } 315 | */ 316 | 317 | func dumpPacket(payload []byte) string { 318 | first := true 319 | ret := fmt.Sprintf("%d=[", len(payload)) 320 | 321 | for len(payload) > 0 { 322 | if !first { 323 | ret += ", " 324 | } 325 | first = false 326 | n, f, err := decodeFrame(payload) 327 | if err != nil { 328 | ret += fmt.Sprintf("Undecoded: [%x]", payload) 329 | break 330 | } 331 | payload = payload[n:] 332 | // TODO(ekr@rtfm.com): Not sure why %v doesn't work 333 | ret += f.String() 334 | } 335 | ret += "]" 336 | return ret 337 | } 338 | 339 | type pneCipherFactory interface { 340 | create(sample []byte) cipher.Stream 341 | } 342 | 343 | type pneCipherFactoryAES struct { 344 | block cipher.Block 345 | } 346 | 347 | func newPneCipherFactoryAES(key []byte) pneCipherFactory { 348 | inner, err := aes.NewCipher(key) 349 | assert(err == nil) 350 | if err != nil { 351 | return nil 352 | } 353 | return &pneCipherFactoryAES{block: inner} 354 | } 355 | 356 | func (f *pneCipherFactoryAES) create(sample []byte) cipher.Stream { 357 | if len(sample) != 16 { 358 | return nil 359 | } 360 | return cipher.NewCTR(f.block, sample) 361 | } 362 | 363 | func xorPacketNumber(hdr *packetHeader, hdrlen int, pnbuf []byte, p []byte, factory pneCipherFactory) error { 364 | logf(logTypeTrace, "PNE Operation: hdrlen=%v, hdr=%x, payload=%x", hdrlen, p[:hdrlen], p) 365 | 366 | // The packet must be at least long enough to contain 367 | // the header, plus a minimum 1-byte PN, plus the sample. 368 | sample_length := 16 369 | if sample_length > len(p)-(hdrlen+1) { 370 | logf(logTypePacket, "Packet too short") 371 | return nil 372 | } 373 | 374 | // Now compute the offset 375 | sample_offset := hdrlen + 4 376 | if sample_offset+sample_length > len(p) { 377 | sample_offset = len(p) - sample_length 378 | } 379 | 380 | sample := p[sample_offset : sample_offset+sample_length] 381 | logf(logTypeTrace, "PNE sample_offset=%d sample=%x", sample_offset, sample) 382 | stream := factory.create(sample) 383 | stream.XORKeyStream(pnbuf, p[hdrlen:hdrlen+len(pnbuf)]) 384 | 385 | return nil 386 | } 387 | 388 | var pnPatterns = []struct { 389 | prefix byte 390 | mask byte 391 | length int 392 | }{ 393 | { 394 | 0, 0x80, 1, 395 | }, 396 | { 397 | 0x80, 0xc0, 2, 398 | }, 399 | { 400 | 0xc0, 0xc0, 4, 401 | }, 402 | } 403 | 404 | const () 405 | 406 | func encodePacketNumber(pn uint64, l int) []byte { 407 | var buf bytes.Buffer 408 | i := 0 409 | 410 | for i, _ = range pnPatterns { 411 | if pnPatterns[i].length == l { 412 | break 413 | } 414 | } 415 | 416 | uintEncodeInt(&buf, pn, uintptr(l)) 417 | b := buf.Bytes() 418 | b[0] &= ^pnPatterns[i].mask 419 | b[0] |= pnPatterns[i].prefix 420 | 421 | return b 422 | } 423 | 424 | func decodePacketNumber(buf []byte) (uint64, int, error) { 425 | if len(buf) < 1 { 426 | return 0, 0, fmt.Errorf("Zero-length packet number") 427 | } 428 | 429 | i := 0 430 | for i, _ = range pnPatterns { 431 | if pnPatterns[i].mask&buf[0] == pnPatterns[i].prefix { 432 | break 433 | } 434 | } 435 | 436 | pat := &pnPatterns[i] 437 | if len(buf) < pat.length { 438 | return 0, 0, fmt.Errorf("Buffer too short for packet number (%v < %v)", len(buf), pat.length) 439 | } 440 | buf = dup(buf[:pat.length]) 441 | buf[0] &= ^pat.mask 442 | 443 | return uintDecodeIntBuf(buf), pat.length, nil 444 | } 445 | -------------------------------------------------------------------------------- /packet_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "crypto" 5 | "encoding/hex" 6 | "fmt" 7 | "github.com/bifurcation/mint" 8 | "testing" 9 | ) 10 | 11 | var ( 12 | testCid7 = ConnectionId([]byte{7, 7, 7, 7, 7, 7, 7}) 13 | testCid4 = ConnectionId([]byte{4, 4, 4, 4}) 14 | testCid5 = ConnectionId([]byte{5, 5, 5, 5, 5}) 15 | testVersion = VersionNumber(0xdeadbeef) 16 | testPn = uint64(0xff000001) 17 | ) 18 | 19 | // Packet header tests. 20 | func packetHeaderEDE(t *testing.T, p *packetHeader, cidLen uintptr) { 21 | res, err := encode(p) 22 | assertNotError(t, err, "Could not encode") 23 | fmt.Println("Encoded = ", hex.EncodeToString(res)) 24 | 25 | var p2 packetHeader 26 | p2.shortCidLength = cidLen 27 | _, err = decode(&p2, res) 28 | assertNotError(t, err, "Could not decode") 29 | fmt.Println("Decoded = ", p2) 30 | 31 | res2, err := encode(&p2) 32 | assertNotError(t, err, "Could not re-encode") 33 | fmt.Println("Encoded2 =", hex.EncodeToString(res2)) 34 | assertByteEquals(t, res, res2) 35 | } 36 | 37 | func TestLongHeader(t *testing.T) { 38 | p := newPacket(packetTypeInitial, testCid7, testCid4, testVersion, 39 | testPn, make([]byte, 65), 16) 40 | p.Token = []byte{1, 2, 3} 41 | p.TokenLength = uint8(len(p.Token)) 42 | packetHeaderEDE(t, &p.packetHeader, 0) 43 | } 44 | 45 | func TestShortHeader(t *testing.T) { 46 | p := newPacket(packetTypeProtectedShort, testCid7, testCid4, testVersion, 47 | testPn, make([]byte, 65), 16) 48 | 49 | // We have to provide assistance to the decoder for short headers. 50 | // Otherwise, it can't know long the destination connection ID is. 51 | packetHeaderEDE(t, &p.packetHeader, uintptr(len(p.DestinationConnectionID))) 52 | } 53 | 54 | func testPNEDecrypt(t *testing.T, pbytes []byte, pn uint64, pnLen int, pnef pneCipherFactory) { 55 | // Now decode the packet. 56 | hdr2 := packetHeader{shortCidLength: kCidDefaultLength} 57 | 58 | hdrlen2, err := decode(&hdr2, pbytes) 59 | assertNotError(t, err, "Couldn't decode encrypted packet") 60 | 61 | dpn := make([]byte, 4) 62 | err = xorPacketNumber(&hdr2, int(hdrlen2), dpn, pbytes, pnef) 63 | assertNotError(t, err, "Couldn't XOR the packet number") 64 | assertEquals(t, 4, len(dpn)) 65 | 66 | pn2, l2, err := decodePacketNumber(dpn) 67 | assertNotError(t, err, "Couldn't decode packet number") 68 | assertEquals(t, l2, pnLen) 69 | assertEquals(t, pn2, pn) 70 | } 71 | 72 | func DISABLED_TestPNEVector(t *testing.T) { 73 | kPacketHex := "ffff00000d5006b858ec6f80452b0044efa5d8d307c2973fa0d63fd9b03a4e163b990dd778894a9edc8eacfbe4aa6fbf4a22ec7f906b5e8b8ae12e5fcc7924dfeee813842bb2149b805e55895084e8393200bb3fc618af7d08281485d914ce42303f5d772b200508a0c00253e332e36a84f657321ac4c8e2cc8a117e95871f12b1f36be8c4b76fa433dc4d3142e6547f4598bf4b192130aea6fc20da5158b2162b5a899957da05ded5c70907298fd885847f22a1ecb0a814fe0170e23cad20af64f05cc13c74e91824101afdcf5f1532fc2fde936a3a159f76283a26c738f778c76e6ca41fa7f134401d39027fd81de17a8021a9c0aaa9b4478fe5c0647941618f3bee410caf94c248d2a64b5e45845cd77de13a5ed94034d2bc5f457887351993c1ecfa34fd0c658fea3f8086d26808eef976262ecf0ad646b627945511dde83e26609cd5cfd7ed9f6207d76618b44c48bf623bf420dc7c127e5d5f529f083b71a17b17da329bfc38a74bf8cfcf315c7c070b71ebfae3ab351341a767adfdd9e57c738f5de9da53711e886d1472310b917a1c9798e3e9b13c7c74beb8d1b82345bea1349415679a9c64b0433b68c871ae08092a1f6106bc06337cd343866ee8185c03fcf3bb0666453f847905547199414c1e57535747be61cdf6778378f121d68df0181ee9e8d9932c1c593c0f8c0a1af0f5262b86205002dced9ecdaee2d0aa07dd4c14f98571e4bea72f8474f63697043e936ebb2bf9716ed0efbdc13005a75cee3a49babc61b9677764510eb19828df4e10fb38b79a1efbf04cc2d571949d5403f797361743dcc5e3bf3b4396f7ae1a3affbc9f72e540d920363970307e0725fa838d611803251a4a08ccca1983d5b29a583758be63343e88f5591d885b8af695f33adbdd0d941d260287e32ef5a98fd55ac137211021fdc23b5d7a5469f578bf7aff6529117996f9ebab5e6dc7b047b356332fea82fdd620eb86f3c1d3855c8b8075da59a7662f4a11b977d996b8b3c7657ad4a82a20a7f76ce376c0320086ed029dd615399307983113cc0aa973ecba691e7e4cdc80aefa7e8c8347baba050eaca7dc35a21aa854e531dc7758d7d10b8c8e42c1be3bbf266d055ac25c37279ebefa28bbe89a34ad1ab3d23d7a66d1c216a57650e6ec9fc8ba7adfb38e57f20c467166c8fe7944e67f82138160002004812c78ba4b5f0da917da4cc14cf8fc10dba3f533facb11ef06d8b8f178ea9c5e8acbbca7b7f0e1f6b7a70ec2d5108cc41178056295793bed357accbb03c0582dc69bc77a34030f38cce256c5a9cec6e862146e3f0463f10dd5833257d0a0359166a7e2027d98eaf26cf0d5a4a05f6ef8b742f5d314a31deeeabe4ebc3106547e79c6cb933105d907b4c8c60443e97a154694bab5edfc781a438675b9de6ed03c77f51458eab61ca2e80ac02cc8c037d8fb3cf129d7107f618d66032cc02238a211f78bfa44e7c1bbcfcc627771c188d1b3713ce5e75cd2325a0a2ba08268cad13b27d97696ef678b592d0ac80ad1bacb4a1ba75bea8c477f39fc32c2aa20f352bb0da1c49b7d3927bcd9dfaf229237081d5fa08924fefd923ff0ac6baad6864b7c10dc73379a5ebd9e4678a0c26517656e8e51fca2a51a33fb2cdd5d76d12674c240ba9a4893c1af69b8f2c4adf37c4a47551eb2006a732f6b3b2f338c078ede33946dfe4a55bf644d3b98848693ada1fcb6fc16cac339ee65c24dc64b0ae92005354af00ade71e6c5e2efd85c46131d948ff14096b0f06a41d83c8522f30beb4eaaf4a6f908fe2a6ee754c896" 74 | kPacket := unhex(kPacketHex) 75 | kPNLen := 4 76 | kPN := 0 77 | 78 | hdr2 := packetHeader{shortCidLength: kCidDefaultLength} 79 | _, err := decode(&hdr2, kPacket) 80 | assertNotError(t, err, "Couldn't decode encrypted packet") 81 | 82 | params := mint.CipherSuiteParams{ 83 | Suite: mint.TLS_AES_128_GCM_SHA256, 84 | Cipher: nil, 85 | Hash: crypto.SHA256, 86 | KeyLen: 16, 87 | IvLen: 12, 88 | } 89 | 90 | cs, err := generateCleartextKeys(hdr2.DestinationConnectionID, clientCtSecretLabel, 91 | ¶ms) 92 | assertNotError(t, err, "Couldn't generate cleartext keys") 93 | testPNEDecrypt(t, kPacket, uint64(kPN), kPNLen, cs.pne) 94 | } 95 | 96 | func testPNE(t *testing.T, pt packetType) { 97 | key := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} 98 | payload := make([]byte, 65) 99 | p := newPacket(pt, testCid5, testCid4, testVersion, 100 | 0xfe, payload, 16) 101 | 102 | hdr, err := encode(&p.packetHeader) 103 | assertNotError(t, err, "Couldn't encode packet header") 104 | 105 | pnbytes := encodePacketNumber(p.PacketNumber, 2) 106 | 107 | pbytes := append(hdr, pnbytes...) 108 | pbytes = append(pbytes, payload...) 109 | 110 | pnef := newPneCipherFactoryAES(key) 111 | 112 | // Encode the packet in place. 113 | err = xorPacketNumber(&p.packetHeader, len(hdr), pbytes[len(hdr):len(hdr)+len(pnbytes)], pbytes, pnef) 114 | assertNotError(t, err, "Couldn't XOR the packet number") 115 | 116 | // Now decode the packet. 117 | testPNEDecrypt(t, pbytes, p.PacketNumber, len(pnbytes), pnef) 118 | } 119 | 120 | func TestPNE(t *testing.T) { 121 | t.Run("Long", func(t *testing.T) { 122 | testPNE(t, packetTypeInitial) 123 | }) 124 | t.Run("Short", func(t *testing.T) { 125 | testPNE(t, packetTypeProtectedShort) 126 | }) 127 | } 128 | 129 | /* 130 | * TODO(ekr@rtfm.com): Rewrite this code and merge it into 131 | * connection.go 132 | // Mock for connection state 133 | type ConnectionStateMock struct { 134 | aead aeadFNV 135 | } 136 | 137 | func (c *ConnectionStateMock) established() bool { return false } 138 | func (c *ConnectionStateMock) zeroRttAllowed() bool { return false } 139 | func (c *ConnectionStateMock) expandPacketNumber(pn uint64) uint64 { 140 | return pn 141 | } 142 | 143 | func TestEDEPacket(t *testing.T) { 144 | var c ConnectionStateMock 145 | 146 | p := Packet{ 147 | kTestpacketHeader, 148 | []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}, 149 | } 150 | 151 | encoded, err := encodePacket(&c, &c.aead, &p) 152 | assertNotError(t, err, "Could not encode packet") 153 | 154 | p2, err := decodePacket(&c, &c.aead, encoded) 155 | assertNotError(t, err, "Could not decode packet") 156 | 157 | encoded2, err := encodePacket(&c, &c.aead, p2) 158 | assertNotError(t, err, "Could not re-encode packet") 159 | 160 | assertByteEquals(t, encoded, encoded2) 161 | } 162 | */ 163 | 164 | func testPacketNumberED(t *testing.T, pn uint64, l int) { 165 | b := encodePacketNumber(pn, l) 166 | assertEquals(t, l, len(b)) 167 | 168 | pn2, l2, err := decodePacketNumber(b) 169 | assertNotError(t, err, "Error decoding packet number") 170 | assertEquals(t, l2, l) 171 | 172 | mask := uint64(0) 173 | for i := 0; i < l; i++ { 174 | mask <<= 8 175 | mask |= 0xff 176 | } 177 | assertEquals(t, mask&pn, pn2) 178 | } 179 | 180 | func TestPacketNumberED(t *testing.T) { 181 | val := uint64(0x04030201) 182 | 183 | for _, i := range []int{1, 2, 4} { 184 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { 185 | testPacketNumberED(t, val, i) 186 | }) 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /record-layer.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "github.com/bifurcation/mint" 5 | "io" 6 | "sync" 7 | ) 8 | 9 | type RecordLayerImpl struct { 10 | sync.Mutex 11 | conn *Connection 12 | epoch mint.Epoch 13 | dir mint.Direction 14 | buffer []byte 15 | } 16 | 17 | func (r *RecordLayerImpl) SetVersion(v uint16) { 18 | // Do nothing 19 | } 20 | 21 | func (r *RecordLayerImpl) SetLabel(s string) { 22 | // Do nothing 23 | } 24 | 25 | func (r *RecordLayerImpl) Rekey(epoch mint.Epoch, factory mint.AeadFactory, keys *mint.KeySet) error { 26 | logf(logTypeTls, "Rekey epoch=%v", epoch) 27 | // TODO(ekr@rtfm.com): Check to see if it's GCM. 28 | aead, err := newWrappedAESGCM(keys.Key, keys.Iv) 29 | if err != nil { 30 | return mint.AlertInternalError 31 | } 32 | 33 | st := cryptoState{ 34 | aead: aead, 35 | pne: newPneCipherFactoryAES(keys.Pn), 36 | } 37 | 38 | if r.dir == mint.DirectionRead { 39 | r.conn.encryptionLevels[epoch].recvCipher = &st 40 | } else { 41 | r.conn.encryptionLevels[epoch].sendCipher = &st 42 | } 43 | r.epoch = epoch 44 | return nil 45 | } 46 | 47 | func (r *RecordLayerImpl) ResetClear(seq uint64) { 48 | panic("UNIMPLEMENTED") 49 | } 50 | func (r *RecordLayerImpl) DiscardReadKey(epoch mint.Epoch) { 51 | // Do nothing 52 | } 53 | 54 | func (r *RecordLayerImpl) readBytes() ([]byte, error) { 55 | str := &(r.conn.encryptionLevels[r.epoch].recvCryptoStream.(*recvStream).recvStreamBase) 56 | 57 | b := make([]byte, 16384) 58 | n, err := str.read(b) 59 | logf(logTypeStream, "EKR: n=%d err=%v\n", n, err) 60 | if err == ErrorWouldBlock { 61 | return nil, mint.AlertWouldBlock 62 | } 63 | if err != nil { 64 | return nil, mint.AlertInternalError 65 | } 66 | 67 | return b[:n], nil 68 | } 69 | func (r *RecordLayerImpl) PeekRecordType(block bool) (mint.RecordType, error) { 70 | assert(r.buffer == nil) 71 | var err error 72 | r.buffer, err = r.readBytes() 73 | if err != nil { 74 | return 0, err 75 | } 76 | return mint.RecordTypeHandshake, nil 77 | } 78 | 79 | func (r *RecordLayerImpl) ReadRecord() (*mint.TLSPlaintext, error) { 80 | var b []byte 81 | var err error 82 | if r.buffer != nil { 83 | b = r.buffer 84 | r.buffer = nil 85 | } else { 86 | b, err = r.readBytes() 87 | if err != nil { 88 | return nil, err 89 | } 90 | } 91 | return mint.NewTLSPlaintext(mint.RecordTypeHandshake, r.epoch, b), nil 92 | } 93 | 94 | func (r *RecordLayerImpl) WriteRecord(pt *mint.TLSPlaintext) error { 95 | logf(logTypeTls, "WriteRecord(epoch=%v, len=%v)", r.epoch, len(pt.Fragment())) 96 | _, err := r.conn.encryptionLevels[r.epoch].sendCryptoStream.(*sendStream).write(pt.Fragment(), nil) 97 | return err 98 | } 99 | 100 | func (r *RecordLayerImpl) Epoch() mint.Epoch { 101 | return r.epoch 102 | } 103 | 104 | type RecordLayerFactoryImpl struct { 105 | conn *Connection 106 | } 107 | 108 | func newRecordLayerFactory(conn *Connection) mint.RecordLayerFactory { 109 | return &RecordLayerFactoryImpl{conn: conn} 110 | } 111 | 112 | func (f *RecordLayerFactoryImpl) NewLayer(conn io.ReadWriter, dir mint.Direction) mint.RecordLayer { 113 | return &RecordLayerImpl{ 114 | dir: dir, 115 | conn: f.conn, 116 | buffer: nil, 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "net" 5 | ) 6 | 7 | // TransportFactory makes transports bound to a specific remote 8 | // address. 9 | type TransportFactory interface { 10 | // Make a transport object bound to |remote|. 11 | MakeTransport(remote *net.UDPAddr) (Transport, error) 12 | } 13 | 14 | // Server represents a QUIC server. A server can be fed an arbitrary 15 | // number of packets and will create Connections as needed, passing 16 | // each packet to the right connection. 17 | type Server struct { 18 | handler ServerHandler 19 | transFactory TransportFactory 20 | tls *TlsConfig 21 | addrTable map[string]*Connection 22 | idTable map[string]*Connection 23 | } 24 | 25 | // Interface for the handler object which the Server will call 26 | // to notify of events. 27 | type ServerHandler interface { 28 | // A new connection has been created and can be found in |c|. 29 | NewConnection(c *Connection) 30 | } 31 | 32 | // SetHandler sets a handler function. 33 | func (s *Server) SetHandler(h ServerHandler) { 34 | s.handler = h 35 | } 36 | 37 | // Input passes an incoming packet to the Server. 38 | func (s *Server) Input(addr *net.UDPAddr, data []byte) (*Connection, error) { 39 | logf(logTypeServer, "Received packet from %v", addr) 40 | hdr := packetHeader{shortCidLength: kCidDefaultLength} 41 | newConn := false 42 | 43 | _, err := decode(&hdr, data) 44 | if err != nil { 45 | return nil, err 46 | } 47 | 48 | var conn *Connection 49 | 50 | if len(hdr.DestinationConnectionID) > 0 { 51 | logf(logTypeServer, "Received conn id %v", hdr.DestinationConnectionID) 52 | conn = s.idTable[hdr.DestinationConnectionID.String()] 53 | if conn != nil { 54 | logf(logTypeServer, "Found by conn id") 55 | } 56 | } 57 | 58 | if conn == nil { 59 | conn = s.addrTable[addr.String()] 60 | } 61 | 62 | if conn == nil { 63 | logf(logTypeServer, "New server connection from addr %v", addr) 64 | trans, err := s.transFactory.MakeTransport(addr) 65 | if err != nil { 66 | return nil, err 67 | } 68 | conn = NewConnection(trans, RoleServer, s.tls, nil) 69 | newConn = true 70 | } 71 | 72 | err = conn.Input(data) 73 | if isFatalError(err) { 74 | logf(logTypeServer, "Fatal Error %v killing connection %v", err, conn) 75 | return nil, nil 76 | } 77 | 78 | if newConn { 79 | // Wait until handling the first packet before the connection is added 80 | // to the table. Firstly, to avoid having to remove it if there is an 81 | // error, but also because the server-chosen connection ID isn't set 82 | // until after the Initial is handled. 83 | s.idTable[conn.serverConnectionId.String()] = conn 84 | s.addrTable[addr.String()] = conn 85 | if s.handler != nil { 86 | s.handler.NewConnection(conn) 87 | } 88 | } 89 | 90 | return conn, nil 91 | } 92 | 93 | // Check the server timers. 94 | func (s *Server) CheckTimer() error { 95 | for _, conn := range s.idTable { 96 | _, err := conn.CheckTimer() 97 | if isFatalError(err) { 98 | logf(logTypeServer, "Fatal Error %v killing connection %v", err, conn) 99 | delete(s.idTable, conn.serverConnectionId.String()) 100 | // TODO(ekr@rtfm.com): Delete this from the addr table. 101 | } 102 | } 103 | return nil 104 | } 105 | 106 | // How many connections do we have? 107 | func (s *Server) ConnectionCount() int { 108 | return len(s.idTable) 109 | } 110 | 111 | // Create a new QUIC server with the provide TLS config. 112 | func NewServer(factory TransportFactory, tls *TlsConfig, handler ServerHandler) *Server { 113 | s := Server{ 114 | handler, 115 | factory, 116 | tls, 117 | make(map[string]*Connection), 118 | make(map[string]*Connection), 119 | } 120 | s.tls.init() 121 | return &s 122 | } 123 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | // fake TransportFactory that comes populated with 10 | // a set of pre-fab transports keyed by name. 11 | type testTransportFactory struct { 12 | transports map[string]*testTransport 13 | } 14 | 15 | func (f *testTransportFactory) MakeTransport(remote *net.UDPAddr) (Transport, error) { 16 | return f.transports[remote.String()], nil 17 | } 18 | 19 | func (f *testTransportFactory) addTransport(remote *net.UDPAddr, t *testTransport) { 20 | f.transports[remote.String()] = t 21 | } 22 | 23 | func serverInputAll(t *testing.T, trans *testTransport, s *Server, u net.UDPAddr) (*Connection, error) { 24 | var clast *Connection 25 | 26 | for { 27 | p, err := trans.Recv() 28 | if err != nil && err != ErrorWouldBlock { 29 | return nil, err 30 | } 31 | 32 | if p == nil { 33 | return clast, nil 34 | } 35 | 36 | c, err := s.Input(&u, p) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | if clast == nil { 42 | clast = c 43 | } 44 | assertEquals(t, c, clast) 45 | } 46 | } 47 | 48 | func TestServer(t *testing.T) { 49 | // Have the client and server do a handshake. 50 | u, _ := net.ResolveUDPAddr("udp", "127.0.0.1:4443") // Just a fixed address 51 | 52 | cTrans, sTrans := newTestTransportPair(true) 53 | factory := &testTransportFactory{make(map[string]*testTransport)} 54 | factory.addTransport(u, sTrans) 55 | 56 | server := NewServer(factory, testTlsConfig(), nil) 57 | assertNotNil(t, server, "Couldn't make server") 58 | 59 | client := NewConnection(cTrans, RoleClient, testTlsConfig(), nil) 60 | assertNotNil(t, client, "Couldn't make client") 61 | 62 | n, err := client.CheckTimer() 63 | assertEquals(t, 1, n) 64 | assertNotError(t, err, "Couldn't send client initial") 65 | 66 | s1, err := serverInputAll(t, sTrans, server, *u) 67 | assertNotError(t, err, "Couldn't consume client initial") 68 | 69 | err = inputAll(client) 70 | assertNotError(t, err, "Error processing SH") 71 | 72 | s2, err := serverInputAll(t, sTrans, server, *u) 73 | assertNotError(t, err, "Error processing CFIN") 74 | // Make sure we get the same server back. 75 | assertEquals(t, s1, s2) 76 | 77 | // Now make a new client and ensure we get a different server connection 78 | u2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:4444") // Just a fixed address 79 | cTrans2, sTrans2 := newTestTransportPair(true) 80 | factory.addTransport(u2, sTrans2) 81 | client = NewConnection(cTrans2, RoleClient, testTlsConfig(), nil) 82 | assertNotNil(t, client, "Couldn't make client") 83 | 84 | n, err = client.CheckTimer() 85 | assertEquals(t, 1, n) 86 | assertNotError(t, err, "Couldn't send client initial") 87 | 88 | s3, err := serverInputAll(t, sTrans2, server, *u2) 89 | assertNotError(t, err, "Couldn't consume client initial") 90 | 91 | assertX(t, s1 != s3, "Got the same server connection back with a different address") 92 | assertEquals(t, 2, len(server.addrTable)) 93 | } 94 | 95 | func TestServerIdleTimeout(t *testing.T) { 96 | // Have the client and server do a handshake. 97 | u, _ := net.ResolveUDPAddr("udp", "127.0.0.1:4443") // Just a fixed address 98 | 99 | cTrans, sTrans := newTestTransportPair(true) 100 | factory := &testTransportFactory{make(map[string]*testTransport)} 101 | factory.addTransport(u, sTrans) 102 | 103 | server := NewServer(factory, testTlsConfig(), nil) 104 | assertNotNil(t, server, "Couldn't make server") 105 | 106 | client := NewConnection(cTrans, RoleClient, testTlsConfig(), nil) 107 | assertNotNil(t, client, "Couldn't make client") 108 | 109 | n, err := client.CheckTimer() 110 | assertEquals(t, 1, n) 111 | assertNotError(t, err, "Couldn't send client initial") 112 | 113 | sconn, err := serverInputAll(t, sTrans, server, *u) 114 | assertNotError(t, err, "Couldn't consume client initial") 115 | assertNotNil(t, sconn, "no server connection") 116 | 117 | assertEquals(t, 1, server.ConnectionCount()) 118 | 119 | // This pokes into internal state of the server to avoid having to include 120 | // sleep calls in tests. Don't do this at home kids. 121 | // Wind the timer on the connection back to short-circuit the idle timeout. 122 | sconn.lastInput = sconn.lastInput.Add(-1 - sconn.idleTimeout) 123 | server.CheckTimer() 124 | // A second nap to allow for draining period. 125 | sconn.closingEnd = sconn.closingEnd.Add(-1 - time.Second) 126 | server.CheckTimer() 127 | 128 | assertEquals(t, 0, server.ConnectionCount()) 129 | } 130 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "encoding/hex" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | // SendStreamState is the state of a SendStream 10 | type SendStreamState uint8 11 | 12 | // SendStreamState values. Not all of these are tracked 13 | const ( 14 | SendStreamStateOpen = SendStreamState(0) 15 | SendStreamStateSend = SendStreamState(1) 16 | SendStreamStateCloseQueued = SendStreamState(2) // Not in the spec 17 | SendStreamStateDataSent = SendStreamState(3) 18 | SendStreamStateResetSent = SendStreamState(4) 19 | SendStreamStateDataRecvd = SendStreamState(5) // Not tracked 20 | SendStreamStateResetRecvd = SendStreamState(6) // Not tracked 21 | ) 22 | 23 | // String produces a nice string from a SendStreamState. 24 | func (s SendStreamState) String() string { 25 | switch s { 26 | case SendStreamStateOpen: 27 | return "SendStreamStateOpen" 28 | case SendStreamStateSend: 29 | return "SendStreamStateSend" 30 | case SendStreamStateCloseQueued: 31 | return "SendStreamStateCloseQueued" 32 | case SendStreamStateDataSent: 33 | return "SendStreamStateDataSent" 34 | case SendStreamStateResetSent: 35 | return "SendStreamStateResetSent" 36 | case SendStreamStateDataRecvd: 37 | return "SendStreamStateDataRecvd" 38 | case SendStreamStateResetRecvd: 39 | return "SendStreamStateResetRecvd" 40 | default: 41 | panic("Unknown SendStreamState") 42 | } 43 | } 44 | 45 | // RecvStreamState is the state of a RecvStream 46 | type RecvStreamState uint8 47 | 48 | // RecvStreamState values. Not all of these are tracked. 49 | const ( 50 | RecvStreamStateRecv = RecvStreamState(0) 51 | RecvStreamStateSizeKnown = RecvStreamState(1) 52 | RecvStreamStateDataRecvd = RecvStreamState(2) // Not tracked 53 | RecvStreamStateResetRecvd = RecvStreamState(3) 54 | RecvStreamStateDataRead = RecvStreamState(4) 55 | RecvStreamStateResetRead = RecvStreamState(5) 56 | ) 57 | 58 | // String produces a nice string from a RecvStreamState. 59 | func (s RecvStreamState) String() string { 60 | switch s { 61 | case RecvStreamStateRecv: 62 | return "RecvStreamStateRecv" 63 | case RecvStreamStateSizeKnown: 64 | return "RecvStreamStateSizeKnown" 65 | case RecvStreamStateDataRecvd: 66 | return "RecvStreamStateDataRecvd" 67 | case RecvStreamStateResetRecvd: 68 | return "RecvStreamStateResetRecvd" 69 | case RecvStreamStateDataRead: 70 | return "RecvStreamStateDataRead" 71 | case RecvStreamStateResetRead: 72 | return "RecvStreamStateResetRead" 73 | default: 74 | panic("Unknown RecvStreamState") 75 | } 76 | } 77 | 78 | // The structure here is a little convoluted. 79 | // 80 | // There are three primary interfaces: SendStream, RecvStream, and Stream. These 81 | // all implement hasIdentity and one or both (for Stream) of sendStreamMethods 82 | // or recvStreamMethods. 83 | // 84 | // The implementations are layered. 85 | // 86 | // streamCommon is at the bottom, it includes stuff common to sending and receiving. 87 | // 88 | // sendStreamBase and recvStreamBase add sending and receiving functions. These 89 | // know how to send and receive, but don't know about identifiers or 90 | // connections. This allows them to be tested in isolation. 91 | // 92 | // Those types don't know about connections, so sendStream and recvStream add 93 | // that by mixing in streamWithIdentity. The same applies to stream, which mixes 94 | // both sendStream and recvStream. These include the concrete implementations of 95 | // the interfaces. 96 | 97 | type hasIdentity interface { 98 | Id() uint64 99 | } 100 | 101 | type sendStreamMethods interface { 102 | io.WriteCloser 103 | Reset(uint16) error 104 | SendState() SendStreamState 105 | } 106 | 107 | type sendStreamPrivateMethods interface { 108 | setSendState(SendStreamState) 109 | outstandingQueuedBytes() int 110 | processMaxStreamData(uint64) 111 | outputWritable() []streamChunk 112 | flowControl() flowControl 113 | } 114 | 115 | type recvStreamMethods interface { 116 | io.Reader 117 | StopSending(uint16) error 118 | RecvState() RecvStreamState 119 | } 120 | 121 | type recvStreamPrivateMethods interface { 122 | setRecvState(RecvStreamState) 123 | handleReset(offset uint64) error 124 | clearReadable() bool 125 | newFrameData(uint64, bool, []byte, *flowControl) error 126 | updateMaxStreamData(bool) 127 | } 128 | 129 | // SendStream can send. 130 | type SendStream interface { 131 | hasIdentity 132 | sendStreamMethods 133 | } 134 | 135 | type sendStreamPrivate interface { 136 | SendStream 137 | sendStreamPrivateMethods 138 | } 139 | 140 | // RecvStream can receive. 141 | type RecvStream interface { 142 | hasIdentity 143 | recvStreamMethods 144 | } 145 | 146 | type recvStreamPrivate interface { 147 | RecvStream 148 | recvStreamPrivateMethods 149 | } 150 | 151 | // Stream is both a send and receive stream. 152 | type Stream interface { 153 | hasIdentity 154 | sendStreamMethods 155 | recvStreamMethods 156 | } 157 | 158 | type streamPrivate interface { 159 | Stream 160 | sendStreamPrivateMethods 161 | recvStreamPrivateMethods 162 | } 163 | 164 | type streamChunk struct { 165 | offset uint64 166 | last bool 167 | data []byte 168 | } 169 | 170 | func (sc streamChunk) String() string { 171 | return fmt.Sprintf("chunk(offset=%v, len=%v, last=%v)", sc.offset, len(sc.data), sc.last) 172 | } 173 | 174 | type streamCommon struct { 175 | log loggingFunction 176 | chunks []streamChunk 177 | fc flowControl 178 | readOffset uint64 179 | } 180 | 181 | func (s *streamCommon) insertSortedChunk(offset uint64, last bool, payload []byte) { 182 | c := streamChunk{offset, last, dup(payload)} 183 | s.log(logTypeStream, "insert %v, current offset=%v", c, s.fc.used) 184 | s.log(logTypeTrace, "payload %v", hex.EncodeToString(payload)) 185 | if len(payload) == 0 && !last && offset != 0 { 186 | // Empty frame, ignore 187 | return 188 | } 189 | 190 | // First check if we can append the new slice at the end 191 | if nchunks := len(s.chunks); nchunks == 0 || offset > s.chunks[nchunks-1].offset { 192 | s.chunks = append(s.chunks, c) 193 | } else { 194 | // Otherwise find out where it should go 195 | var i int 196 | for i = 0; i < nchunks; i++ { 197 | if offset < s.chunks[i].offset { 198 | break 199 | } 200 | } 201 | 202 | // This may not be the fastest way to do this splice. 203 | tmp := make([]streamChunk, 0, nchunks+1) 204 | tmp = append(tmp, s.chunks[:i]...) 205 | tmp = append(tmp, c) 206 | tmp = append(tmp, s.chunks[i:]...) 207 | s.chunks = tmp 208 | } 209 | s.log(logTypeStream, "Stream now has %v chunks", len(s.chunks)) 210 | } 211 | 212 | type sendStreamBase struct { 213 | streamCommon 214 | state SendStreamState 215 | } 216 | 217 | func (s *sendStreamBase) setSendState(state SendStreamState) { 218 | if state != s.state { 219 | s.log(logTypeStream, "set state %v->%v", s.state, state) 220 | s.state = state 221 | } 222 | } 223 | 224 | // SendState returns the current state of the receive stream. 225 | func (s *sendStreamBase) SendState() SendStreamState { 226 | return s.state 227 | } 228 | 229 | func (s *sendStreamBase) queue(payload []byte, cfc *flowControl) (int, error) { 230 | s.log(logTypeStream, "queueing %v bytes, flow control %v %v", len(payload), &s.fc, cfc) 231 | offset := s.fc.used 232 | allowed := s.fc.take(cfc, uint64(len(payload))) 233 | s.log(logTypeFlowControl, "flow control consumed %v %v", &s.fc, cfc) 234 | if allowed == 0 { 235 | s.log(logTypeFlowControl, "blocked write") 236 | return 0, ErrorWouldBlock 237 | } 238 | payload = payload[:allowed] 239 | s.insertSortedChunk(offset, false, payload) 240 | return int(allowed), nil 241 | } 242 | 243 | func (s *sendStreamBase) write(data []byte, connectionFlowControl *flowControl) (int, error) { 244 | switch s.state { 245 | case SendStreamStateOpen: 246 | s.setSendState(SendStreamStateSend) 247 | // Allow a zero-octet write on a stream that hasn't been opened. 248 | if len(data) == 0 { 249 | return s.queue(data, connectionFlowControl) 250 | } 251 | case SendStreamStateSend: 252 | // OK to send 253 | default: 254 | return 0, ErrorStreamIsClosed 255 | } 256 | written := 0 257 | for len(data) > 0 { 258 | tocpy := 1024 259 | if tocpy > len(data) { 260 | tocpy = len(data) 261 | } 262 | n, err := s.queue(data[:tocpy], connectionFlowControl) 263 | if (err == ErrorWouldBlock) && (written > 0) { 264 | s.log(logTypeFlowControl, "write flow control blocked at offset %d", s.fc.used) 265 | break 266 | } 267 | if err != nil { 268 | return written, err 269 | } 270 | written += n 271 | 272 | data = data[tocpy:] 273 | } 274 | 275 | s.log(logTypeTrace, "wrote %d bytes", written) 276 | return written, nil 277 | } 278 | 279 | func (s *sendStreamBase) outstandingQueuedBytes() int { 280 | n := 0 281 | for _, ch := range s.chunks { 282 | n += len(ch.data) 283 | } 284 | return n 285 | } 286 | 287 | func (s *sendStreamBase) flowControl() flowControl { 288 | return s.fc 289 | } 290 | 291 | // Push out all pending frames. Set the stream state if the end of the stream is available. 292 | func (s *sendStreamBase) outputWritable() []streamChunk { 293 | s.log(logTypeStream, "outputWritable, chunks=%v current max offset=%d)", len(s.chunks), s.fc.max) 294 | for _, ch := range s.chunks { 295 | if ch.last { 296 | s.setSendState(SendStreamStateDataSent) 297 | } 298 | } 299 | 300 | out := s.chunks 301 | s.chunks = nil 302 | return out 303 | } 304 | 305 | func (s *sendStreamBase) processMaxStreamData(offset uint64) { 306 | s.fc.update(offset) 307 | } 308 | 309 | func (s *sendStreamBase) close() { 310 | switch s.state { 311 | case SendStreamStateOpen, SendStreamStateSend: 312 | s.insertSortedChunk(s.fc.used, true, nil) 313 | s.setSendState(SendStreamStateCloseQueued) 314 | default: 315 | // NOOP 316 | } 317 | } 318 | 319 | type recvStreamBase struct { 320 | streamCommon 321 | state RecvStreamState 322 | readable bool 323 | } 324 | 325 | func (s *recvStreamBase) setRecvState(state RecvStreamState) { 326 | if state != s.state { 327 | s.log(logTypeStream, "set state %v->%v", s.state, state) 328 | s.state = state 329 | } 330 | } 331 | 332 | // RecvState returns the current state of the receive stream. 333 | func (s *recvStreamBase) RecvState() RecvStreamState { 334 | return s.state 335 | } 336 | 337 | // clearReadable clears the readable flag and returns true if it was set. 338 | func (s *recvStreamBase) clearReadable() bool { 339 | r := s.readable 340 | s.readable = false 341 | return r 342 | } 343 | 344 | // Add data to a stream. Return true if this is readable now. 345 | func (s *recvStreamBase) newFrameData(offset uint64, last bool, payload []byte, 346 | cfc *flowControl) error { 347 | s.log(logTypeStream, "new data offset=%d, len=%d", offset, len(payload)) 348 | s.log(logTypeFlowControl, "new data flow control %v %v", &s.fc, cfc) 349 | 350 | end := offset + uint64(len(payload)) 351 | if last { 352 | if end < s.fc.used { 353 | // The end can't be less than what we've received already. 354 | return ErrorFlowControlError 355 | } 356 | if s.state == RecvStreamStateRecv { 357 | s.setRecvState(RecvStreamStateSizeKnown) 358 | } 359 | } else if end > s.fc.used { 360 | if s.state != RecvStreamStateRecv { 361 | // We shouldn't be increasing used in any other state. 362 | return ErrorFlowControlError 363 | } 364 | 365 | increase := end - s.fc.used 366 | taken := increase 367 | if !s.fc.unlimited { 368 | taken := s.fc.take(cfc, increase) 369 | s.log(logTypeFlowControl, "taken flow control %d, now %v %v", taken, &s.fc, cfc) 370 | } 371 | if taken < increase { 372 | // We didn't have that much available. 373 | return ErrorFlowControlError 374 | } 375 | } else if end <= s.readOffset { 376 | // No new data here. 377 | return nil 378 | } 379 | if s.state != RecvStreamStateRecv && s.state != RecvStreamStateSizeKnown { 380 | // We shouldn't be receiving in other states. 381 | return nil 382 | } 383 | 384 | s.insertSortedChunk(offset, last, payload) 385 | if s.chunks[0].offset <= s.readOffset { 386 | s.readable = true 387 | } 388 | 389 | return nil 390 | } 391 | 392 | func (s *recvStreamBase) read(b []byte) (int, error) { 393 | s.log(logTypeStream, "Reading len=%v read offset=%v available chunks=%v", 394 | len(b), s.readOffset, len(s.chunks)) 395 | 396 | if s.state == RecvStreamStateResetRecvd { 397 | s.log(logTypeStream, "Reading stopped for RST_STREAM") 398 | s.setRecvState(RecvStreamStateResetRead) 399 | return 0, ErrorStreamReset 400 | } 401 | 402 | read := 0 403 | 404 | for len(b) > 0 { 405 | if len(s.chunks) == 0 { 406 | break 407 | } 408 | 409 | chunk := s.chunks[0] 410 | s.log(logTypeTrace, "next chunk %v", chunk) 411 | // We have a gap. 412 | if chunk.offset > s.readOffset { 413 | break 414 | } 415 | 416 | // Remove leading bytes 417 | remove := s.readOffset - chunk.offset 418 | if remove > uint64(len(chunk.data)) { 419 | // Nothing left. 420 | s.chunks = s.chunks[1:] 421 | continue 422 | } 423 | 424 | chunk.offset += remove 425 | chunk.data = chunk.data[remove:] 426 | 427 | // Now figure out how much we can read 428 | n := copy(b, chunk.data) 429 | s.log(logTypeTrace, "read %v at offset %v", n, s.readOffset) 430 | chunk.data = chunk.data[n:] 431 | chunk.offset += uint64(n) 432 | s.readOffset += uint64(n) 433 | b = b[n:] 434 | read += n 435 | 436 | // This chunk is empty. 437 | if len(chunk.data) == 0 { 438 | s.chunks = s.chunks[1:] 439 | 440 | if chunk.last { 441 | s.setRecvState(RecvStreamStateDataRead) 442 | s.chunks = nil 443 | break 444 | } 445 | } 446 | } 447 | 448 | // If we have read no data, say we would have blocked. 449 | if read == 0 { 450 | switch s.state { 451 | case RecvStreamStateRecv, RecvStreamStateSizeKnown: 452 | return 0, ErrorWouldBlock 453 | default: 454 | if s.chunks == nil { 455 | return 0, io.EOF 456 | } 457 | return 0, ErrorStreamIsClosed 458 | } 459 | } 460 | s.log(logTypeStream, "Returning %v bytes chunks=%v", read, len(s.chunks)) 461 | return read, nil 462 | } 463 | 464 | func (s *recvStreamBase) handleReset(offset uint64) error { 465 | switch s.state { 466 | case RecvStreamStateRecv: 467 | s.fc.used = offset 468 | case RecvStreamStateDataRecvd, RecvStreamStateResetRead: 469 | panic("we don't use this state") 470 | case RecvStreamStateSizeKnown, RecvStreamStateDataRead, RecvStreamStateResetRecvd: 471 | if offset != s.fc.used { 472 | return ErrorProtocolViolation 473 | } 474 | default: 475 | panic(fmt.Sprintf("unknown state %v", s.state)) 476 | } 477 | 478 | s.setRecvState(RecvStreamStateResetRecvd) 479 | s.chunks = nil 480 | return nil 481 | } 482 | 483 | // SendStream is a unidirectional stream for sending. 484 | type sendStream struct { 485 | c *Connection 486 | id uint64 487 | sendStreamBase 488 | } 489 | 490 | // Compile-time interface check. 491 | var _ SendStream = &sendStream{} 492 | 493 | func newSendStream(c *Connection, id uint64, initialMax uint64) sendStreamPrivate { 494 | return &sendStream{ 495 | c: c, id: id, 496 | sendStreamBase: sendStreamBase{ 497 | streamCommon: streamCommon{ 498 | log: newStreamLogger(id, "send", c.log), 499 | fc: newFlowControl(initialMax), 500 | }, 501 | state: SendStreamStateOpen, 502 | }, 503 | } 504 | } 505 | 506 | // Id returns the id. 507 | func (s *sendStream) Id() uint64 { 508 | return s.id 509 | } 510 | 511 | // Write writes data. 512 | func (s *sendStream) Write(data []byte) (int, error) { 513 | s.log(logTypeStream, "Stream %v: writing %v bytes", s.Id(), len(data)) 514 | if s.c.isClosed() { 515 | return 0, ErrorConnIsClosed 516 | } 517 | 518 | n, err := s.write(data, &s.c.sendFlowControl) 519 | if err != nil { 520 | if err == ErrorWouldBlock { 521 | s.c.updateStreamBlocked(s) 522 | s.c.updateBlocked() 523 | } 524 | return n, err 525 | } 526 | 527 | s.c.sendQueued(false) 528 | return n, nil 529 | } 530 | 531 | // Close makes the stream end cleanly. 532 | func (s *sendStream) Close() error { 533 | s.close() 534 | s.c.sendQueued(false) 535 | return nil 536 | } 537 | 538 | // Reset abandons writing on the stream. 539 | func (s *sendStream) Reset(code uint16) error { 540 | s.setSendState(SendStreamStateResetSent) 541 | f := newRstStreamFrame(s.id, code, s.fc.used) 542 | return s.c.sendFrame(f) 543 | } 544 | 545 | // recvStream is the implementation of a unidirectional stream for receiving. 546 | type recvStream struct { 547 | c *Connection 548 | id uint64 549 | recvStreamBase 550 | } 551 | 552 | // Compile-time interface check. 553 | var _ RecvStream = &recvStream{} 554 | 555 | func newRecvStream(c *Connection, id uint64, maxStreamData uint64) recvStreamPrivate { 556 | return &recvStream{ 557 | c: c, id: id, 558 | recvStreamBase: recvStreamBase{ 559 | streamCommon: streamCommon{ 560 | log: newStreamLogger(id, "recv", c.log), 561 | fc: newFlowControl(maxStreamData), 562 | }, 563 | state: RecvStreamStateRecv, 564 | readable: false, 565 | }, 566 | } 567 | } 568 | 569 | // Id returns the id. 570 | func (s *recvStream) Id() uint64 { 571 | return s.id 572 | } 573 | 574 | // updateMaxStreamData checks the current flow control limit and sends 575 | // MAX_STREAM_DATA as necessary. 576 | func (s *recvStream) updateMaxStreamData(force bool) { 577 | s.log(logTypeFlowControl, "credit flow control %v", &s.fc) 578 | if force || s.fc.remaining() < kInitialMaxStreamData/2 { 579 | s.fc.max = s.readOffset + kInitialMaxData 580 | s.log(logTypeFlowControl, "increased flow control to %v", &s.fc) 581 | s.c.issueStreamCredit(s, s.fc.max) 582 | } 583 | } 584 | 585 | // Read implements io.Reader. 586 | func (s *recvStream) Read(b []byte) (int, error) { 587 | if s.c.isClosed() { 588 | return 0, io.EOF 589 | } 590 | 591 | n, err := s.read(b) 592 | if err != nil { 593 | return 0, err 594 | } 595 | s.c.amountRead += uint64(n) 596 | // Now issue credit for stream flow control, ... 597 | s.updateMaxStreamData(false) 598 | // ..., connection flow control, ... 599 | s.c.issueCredit(false) 600 | // ..., and streams. 601 | if s.state == RecvStreamStateDataRead { 602 | s.c.issueStreamIdCredit(streamTypeFromId(s.id, s.c.role)) 603 | } 604 | return n, nil 605 | } 606 | 607 | func (s *recvStream) handleReset(offset uint64) error { 608 | err := s.recvStreamBase.handleReset(offset) 609 | if err != nil { 610 | return err 611 | } 612 | // Pretend that we read this much data. 613 | s.c.amountRead += s.fc.used - s.readOffset 614 | s.readOffset = s.fc.used 615 | s.c.issueCredit(false) 616 | 617 | return nil 618 | } 619 | 620 | // StopSending requests a reset. 621 | func (s *recvStream) StopSending(code uint16) error { 622 | f := newStopSendingFrame(s.id, code) 623 | return s.c.sendFrame(f) 624 | } 625 | 626 | // stream is a bidirectional stream. 627 | type stream struct { 628 | c *Connection 629 | id uint64 630 | 631 | sendStreamPrivate 632 | recvStreamPrivate 633 | } 634 | 635 | // Compile-time interface check. 636 | var _ Stream = &stream{} 637 | 638 | func newStream(c *Connection, id uint64, sendMax uint64, recvMax uint64) streamPrivate { 639 | return &stream{ 640 | sendStreamPrivate: newSendStream(c, id, sendMax), 641 | recvStreamPrivate: newRecvStream(c, id, recvMax), 642 | } 643 | } 644 | 645 | // Id needs to be overwritten so that the ambiguity between send and receive can be resolved. 646 | func (s *stream) Id() uint64 { 647 | return s.sendStreamPrivate.Id() 648 | } 649 | 650 | type streamType uint8 651 | 652 | // These values match the low bits of the stream ID for a client, but the low 653 | // bit is flipped for a server. 654 | const ( 655 | streamTypeBidirectionalLocal = streamType(0) 656 | streamTypeBidirectionalRemote = streamType(1) 657 | streamTypeUnidirectionalLocal = streamType(2) 658 | streamTypeUnidirectionalRemote = streamType(3) 659 | ) 660 | 661 | func streamTypeFromId(id uint64, role Role) streamType { 662 | t := id & 3 663 | if role == RoleServer { 664 | t ^= 1 665 | } 666 | return streamType(t) 667 | } 668 | 669 | func (t streamType) suffix(role Role) uint64 { 670 | suff := uint64(t) 671 | if role == RoleServer { 672 | suff ^= 1 673 | } 674 | return suff 675 | } 676 | 677 | func (t streamType) String() string { 678 | switch t { 679 | case streamTypeBidirectionalLocal: 680 | return "bidirectional local" 681 | case streamTypeBidirectionalRemote: 682 | return "bidirectional remote" 683 | case streamTypeUnidirectionalLocal: 684 | return "unidirectional local" 685 | case streamTypeUnidirectionalRemote: 686 | return "unidirectional remote" 687 | default: 688 | panic("unknown stream type") 689 | } 690 | } 691 | 692 | type streamSet struct { 693 | // t is the type of stream relative to the endpoints role 694 | t streamType 695 | // role is the endpoint's role 696 | role Role 697 | // nstreams is the maximum number of streams (as opposed to the maximum ID) 698 | nstreams int 699 | // typeless array of streams because go doesn't have generics 700 | streams []hasIdentity 701 | } 702 | 703 | func newStreamSet(t streamType, role Role, nstreams int) *streamSet { 704 | return &streamSet{t, role, nstreams, make([]hasIdentity, 0, nstreams)} 705 | } 706 | 707 | func (ss *streamSet) check(id uint64) { 708 | // If sizeof(int) == sizeof(uint64), then we will never overflow int. 709 | assert(^uint64(0) == uint64(^uint(0))) 710 | assert((id & (^uint64(0) >> 2)) == id) // The top bits should be clear. 711 | assert((id & 3) == ss.t.suffix(ss.role)) 712 | } 713 | 714 | func (ss *streamSet) index(id uint64) int { 715 | ss.check(id) 716 | return int(id >> 2) 717 | } 718 | 719 | func (ss *streamSet) id(index int) uint64 { 720 | assert(index >= 0) 721 | return uint64(index<<2) | uint64(ss.t.suffix(ss.role)) 722 | } 723 | 724 | type flowControl struct { 725 | unlimited bool 726 | max uint64 727 | used uint64 728 | } 729 | 730 | func newFlowControl(initialMax uint64) flowControl { 731 | fc := flowControl{false, initialMax, 0} 732 | if initialMax == ^uint64(0) { 733 | fc.unlimited = true 734 | } 735 | return fc 736 | } 737 | 738 | func (fc *flowControl) String() string { 739 | if fc.unlimited { 740 | return ("Unlimited") 741 | } 742 | return fmt.Sprintf("%d/%d", fc.used, fc.max) 743 | } 744 | 745 | func (fc *flowControl) update(max uint64) { 746 | if max > fc.max { 747 | fc.max = max 748 | } 749 | } 750 | 751 | func (fc *flowControl) take(other *flowControl, amount uint64) uint64 { 752 | taken := uint64(0) 753 | if !fc.unlimited { 754 | taken = fc.remaining() 755 | if taken > other.remaining() { 756 | taken = other.remaining() 757 | } 758 | } else { 759 | taken = ^uint64(0) 760 | } 761 | if taken > amount { 762 | taken = amount 763 | } 764 | 765 | fc.used += taken 766 | // TODO(ekr@rtfm.com): Is this still needed. 767 | if other != nil { 768 | other.used += taken 769 | } 770 | return taken 771 | } 772 | 773 | func (fc *flowControl) remaining() uint64 { 774 | return fc.max - fc.used 775 | } 776 | 777 | func (ss *streamSet) updateMax(id uint64) { 778 | ss.nstreams = ss.index(id) + 1 779 | } 780 | 781 | func (ss *streamSet) credit(n int) uint64 { 782 | ss.nstreams += n 783 | return ss.id(ss.nstreams - 1) 784 | } 785 | 786 | func (ss *streamSet) get(id uint64) hasIdentity { 787 | i := ss.index(id) 788 | if i >= len(ss.streams) { 789 | return nil 790 | } 791 | return ss.streams[i] 792 | } 793 | 794 | type streamSetCtor func(id uint64) hasIdentity 795 | 796 | func (ss *streamSet) create(ctor streamSetCtor) hasIdentity { 797 | i := len(ss.streams) 798 | if i >= ss.nstreams { 799 | return nil 800 | } 801 | ss.streams = append(ss.streams, ctor(ss.id(i))) 802 | return ss.streams[i] 803 | } 804 | 805 | func (ss *streamSet) ensure(id uint64, ctor streamSetCtor, 806 | notify func(s hasIdentity)) hasIdentity { 807 | i := ss.index(id) 808 | if i >= ss.nstreams { 809 | return nil 810 | } 811 | if i >= len(ss.streams) { 812 | needed := i - len(ss.streams) + 1 813 | start := len(ss.streams) 814 | ss.streams = append(ss.streams, make([]hasIdentity, needed)...) 815 | for j := start; j < len(ss.streams); j++ { 816 | s := ctor(ss.id(j)) 817 | ss.check(s.Id()) 818 | ss.streams[j] = s 819 | notify(ss.streams[j]) 820 | } 821 | } 822 | return ss.streams[i] 823 | } 824 | 825 | func (ss *streamSet) forEach(f func(hasIdentity)) { 826 | for _, s := range ss.streams { 827 | f(s) 828 | } 829 | } 830 | -------------------------------------------------------------------------------- /stream_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "runtime" 7 | "testing" 8 | ) 9 | 10 | type testStreamFixture struct { 11 | t *testing.T 12 | name string 13 | log loggingFunction 14 | r *recvStreamBase 15 | w *sendStreamBase 16 | b []byte 17 | } 18 | 19 | func (f *testStreamFixture) read() { 20 | assertX(f.t, f.r.readable, "stream should be readable") 21 | f.b = make([]byte, 1024) 22 | n, err := f.r.read(f.b) 23 | assertNotError(f.t, err, "Should be able to read bytes") 24 | f.b = f.b[:n] 25 | assertX(f.t, f.r.clearReadable(), "should have been readable") 26 | } 27 | 28 | func (f *testStreamFixture) readExpectError(exerr error) { 29 | f.b = make([]byte, 1024) 30 | n, err := f.r.read(f.b) 31 | assertError(f.t, err, "Should not be able to read bytes") 32 | assertEquals(f.t, exerr, err) 33 | assertEquals(f.t, 0, n) 34 | } 35 | 36 | var kTestString1 = []byte("abcdef") 37 | var kTestString2 = []byte("ghijkl") 38 | 39 | func newTestStreamFixture(t *testing.T) *testStreamFixture { 40 | pc, _, _, ok := runtime.Caller(1) 41 | name := "unknown" 42 | if ok { 43 | name = runtime.FuncForPC(pc).Name() 44 | } 45 | log := func(tag string, format string, args ...interface{}) { 46 | fullFormat := fmt.Sprintf("%s: %s", name, format) 47 | logf(tag, fullFormat, args...) 48 | } 49 | 50 | fc := flowControl{false, 2048, 0} 51 | return &testStreamFixture{ 52 | t: t, 53 | name: name, 54 | log: log, 55 | r: &recvStreamBase{streamCommon: streamCommon{log: log, fc: fc}}, 56 | w: &sendStreamBase{streamCommon: streamCommon{log: log, fc: fc}}, 57 | b: nil, 58 | } 59 | } 60 | 61 | func TestStreamInputOneChunk(t *testing.T) { 62 | f := newTestStreamFixture(t) 63 | err := f.r.newFrameData(0, false, kTestString1, &flowControl{false, 2048, 0}) 64 | assertNotError(t, err, "Data should be accepted") 65 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1))) 66 | assertEquals(t, RecvStreamStateRecv, f.r.state) 67 | f.read() 68 | assertByteEquals(t, f.b, kTestString1) 69 | } 70 | 71 | func TestStreamInputTwoChunks(t *testing.T) { 72 | f := newTestStreamFixture(t) 73 | err := f.r.newFrameData(0, false, kTestString1, &flowControl{false, 2048, 0}) 74 | assertNotError(t, err, "Data should be accepted") 75 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1))) 76 | f.read() 77 | assertByteEquals(t, f.b, kTestString1) 78 | err = f.r.newFrameData(uint64(len(kTestString1)), false, kTestString2, &flowControl{false, 2048, 0}) 79 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1)+len(kTestString2))) 80 | f.read() 81 | assertByteEquals(t, f.b, kTestString2) 82 | } 83 | 84 | func TestStreamInputCoalesceChunks(t *testing.T) { 85 | f := newTestStreamFixture(t) 86 | err := f.r.newFrameData(0, false, kTestString1[:2], &flowControl{false, 2048, 0}) 87 | assertNotError(t, err, "data should be accepted") 88 | err = f.r.newFrameData(2, false, kTestString1[2:], &flowControl{false, 2048, 0}) 89 | assertNotError(t, err, "data should be accepted") 90 | f.read() 91 | assertByteEquals(t, f.b, kTestString1) 92 | } 93 | 94 | func TestStreamInputChunksOverlap(t *testing.T) { 95 | f := newTestStreamFixture(t) 96 | err := f.r.newFrameData(0, false, kTestString1[:2], &flowControl{false, 2048, 0}) 97 | assertNotError(t, err, "data should be accepted") 98 | err = f.r.newFrameData(0, false, kTestString1, &flowControl{false, 2048, 0}) 99 | assertNotError(t, err, "data should be accepted") 100 | f.read() 101 | assertByteEquals(t, f.b, kTestString1) 102 | } 103 | 104 | func TestStreamInputTwoChunksWrongOrder(t *testing.T) { 105 | f := newTestStreamFixture(t) 106 | err := f.r.newFrameData(2, false, kTestString1[2:], &flowControl{false, 2048, 0}) 107 | assertNotError(t, err, "data should be accepted") 108 | assertX(t, !f.r.readable, "Stream not should be readable") 109 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1))) 110 | f.readExpectError(ErrorWouldBlock) 111 | err = f.r.newFrameData(0, false, kTestString1[:2], &flowControl{false, 2048, 0}) 112 | assertNotError(t, err, "data should be accepted") 113 | f.read() 114 | assertByteEquals(t, f.b, kTestString1) 115 | } 116 | 117 | func TestStreamInputChunk1FinChunk2(t *testing.T) { 118 | f := newTestStreamFixture(t) 119 | err := f.r.newFrameData(0, true, kTestString1, &flowControl{false, 2048, 0}) 120 | assertNotError(t, err, "data should be accepted") 121 | assertEquals(t, RecvStreamStateSizeKnown, f.r.state) 122 | f.read() 123 | assertByteEquals(t, f.b, kTestString1) 124 | assertEquals(t, RecvStreamStateDataRead, f.r.state) 125 | err = f.r.newFrameData(uint64(len(kTestString1)), false, kTestString2, &flowControl{false, 2048, 0}) 126 | assertEquals(t, err, ErrorFlowControlError) 127 | assertX(t, !f.r.readable, "Stream not be readable") 128 | f.readExpectError(io.EOF) 129 | } 130 | 131 | func TestStreamInputShortFinChunkAfterFin(t *testing.T) { 132 | f := newTestStreamFixture(t) 133 | err := f.r.newFrameData(0, true, kTestString1, &flowControl{false, 2048, 0}) 134 | assertNotError(t, err, "data should be accepted") 135 | assertEquals(t, RecvStreamStateSizeKnown, f.r.state) 136 | f.read() 137 | err = f.r.newFrameData(0, true, kTestString1[:2], &flowControl{false, 2048, 0}) 138 | assertNotError(t, err, "overlapping data can be discarded") 139 | } 140 | 141 | func TestStreamReadReset(t *testing.T) { 142 | f := newTestStreamFixture(t) 143 | err := f.r.handleReset(10) 144 | assertNotError(t, err, "should accept the reset") 145 | assertEquals(t, RecvStreamStateResetRecvd, f.r.state) 146 | } 147 | 148 | func TestStreamWriteClose(t *testing.T) { 149 | f := newTestStreamFixture(t) 150 | f.w.close() 151 | assertEquals(t, SendStreamStateCloseQueued, f.w.state) 152 | } 153 | 154 | func TestStreamIncreaseFlowControl(t *testing.T) { 155 | f := newTestStreamFixture(t) 156 | f.w.processMaxStreamData(2050) 157 | f.w.processMaxStreamData(2000) 158 | assertEquals(t, uint64(2050), f.w.fc.max) 159 | } 160 | 161 | func countChunkLens(chunks []streamChunk) int { 162 | ct := 0 163 | for _, ch := range chunks { 164 | ct += len(ch.data) 165 | } 166 | return ct 167 | } 168 | 169 | func TestStreamBlockRelease(t *testing.T) { 170 | f := newTestStreamFixture(t) 171 | b := make([]byte, 5000) 172 | connFc := &flowControl{false, uint64(len(b)), 0} 173 | n, err := f.w.write(b, connFc) 174 | assertEquals(t, nil, err) 175 | chunks := f.w.outputWritable() 176 | assertEquals(t, 2048, countChunkLens(chunks)) 177 | assertEquals(t, 2048, n) 178 | assertEquals(t, uint64(2048), connFc.used) 179 | // Calling output writable again returns 0 chunks 180 | chunks = f.w.outputWritable() 181 | assertEquals(t, 0, countChunkLens(chunks)) 182 | 183 | // Writing again blocks 184 | _, err = f.w.write(b[n:], connFc) 185 | assertEquals(t, ErrorWouldBlock, err) 186 | 187 | // Increasing the limit should let us write. 188 | f.w.processMaxStreamData(8192) 189 | n, err = f.w.write(b[n:], connFc) 190 | assertNotError(t, err, "Writing works") 191 | assertEquals(t, 2952, n) 192 | assertEquals(t, connFc.max, connFc.used) 193 | chunks = f.w.outputWritable() 194 | assertEquals(t, 2952, countChunkLens(chunks)) 195 | } 196 | -------------------------------------------------------------------------------- /timer.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type timerCb func() 8 | 9 | type timer struct { 10 | ts *timerSet 11 | cb timerCb 12 | deadline time.Time 13 | } 14 | 15 | // This is a simple implementation of unsorted timers. 16 | // TODO(ekr@rtfm.com): Need a better data structure. 17 | type timerSet struct { 18 | ts []*timer 19 | } 20 | 21 | func newTimers() *timerSet { 22 | return &timerSet{nil} 23 | } 24 | 25 | func (ts *timerSet) start(cb timerCb, delayMs uint32) *timer { 26 | t := timer{ 27 | ts, 28 | cb, 29 | time.Now().Add(time.Millisecond * time.Duration(delayMs)), 30 | } 31 | 32 | ts.ts = append(ts.ts, &t) 33 | 34 | return &t 35 | } 36 | 37 | func (ts *timerSet) check(now time.Time) { 38 | for i, t := range ts.ts { 39 | if now.After(t.deadline) { 40 | ts.ts = append(ts.ts[:i], ts.ts[:i+1]...) 41 | if t.cb != nil { 42 | t.cb() 43 | } 44 | } 45 | } 46 | } 47 | 48 | func (t *timer) cancel() { 49 | t.cb = nil 50 | } 51 | -------------------------------------------------------------------------------- /tls.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "crypto" 5 | "crypto/x509" 6 | "fmt" 7 | "log" 8 | 9 | "github.com/bifurcation/mint" 10 | ) 11 | 12 | type TlsConfig struct { 13 | ServerName string 14 | CertificateChain []*x509.Certificate 15 | Key crypto.Signer 16 | mintConfig *mint.Config 17 | ForceHrr bool 18 | } 19 | 20 | func (c *TlsConfig) init() { 21 | _ = c.toMint() 22 | } 23 | 24 | func (c *TlsConfig) toMint() *mint.Config { 25 | if c.mintConfig == nil { 26 | // TODO(ekr@rtfm.com): Provide a real config 27 | config := mint.Config{ 28 | ServerName: c.ServerName, 29 | NonBlocking: true, 30 | NextProtos: []string{kQuicALPNToken}, 31 | SendSessionTickets: true, 32 | AllowEarlyData: true, 33 | } 34 | 35 | if c.ForceHrr { 36 | config.RequireCookie = true 37 | } 38 | 39 | config.CookieProtector, _ = mint.NewDefaultCookieProtector() 40 | config.InsecureSkipVerify = true // TODO(ekr@rtfm.com): This is horribly insecure, but Minq is right now for testing 41 | 42 | if c.CertificateChain != nil && c.Key != nil { 43 | config.Certificates = 44 | []*mint.Certificate{ 45 | &mint.Certificate{ 46 | Chain: c.CertificateChain, 47 | PrivateKey: c.Key, 48 | }, 49 | } 50 | } else { 51 | priv, cert, err := mint.MakeNewSelfSignedCert(c.ServerName, mint.ECDSA_P256_SHA256) 52 | if err != nil { 53 | log.Fatalf("Couldn't make self-signed cert %v", err) 54 | } 55 | config.Certificates = []*mint.Certificate{ 56 | { 57 | Chain: []*x509.Certificate{cert}, 58 | PrivateKey: priv, 59 | }, 60 | } 61 | } 62 | config.Init(false) 63 | c.mintConfig = &config 64 | } 65 | return c.mintConfig.Clone() 66 | } 67 | 68 | func NewTlsConfig(serverName string) TlsConfig { 69 | return TlsConfig{ 70 | ServerName: serverName, 71 | } 72 | } 73 | 74 | type tlsConn struct { 75 | config *TlsConfig 76 | conn *Connection 77 | mintConfig *mint.Config 78 | tls *mint.Conn 79 | finished bool 80 | cs *mint.CipherSuiteParams 81 | } 82 | 83 | func newTlsConn(conn *Connection, conf *TlsConfig, role Role) *tlsConn { 84 | isClient := true 85 | if role == RoleServer { 86 | isClient = false 87 | } 88 | 89 | mc := conf.toMint() 90 | mc.RecordLayer = newRecordLayerFactory(conn) 91 | return &tlsConn{ 92 | conf, 93 | conn, 94 | mc, 95 | mint.NewConn(nil, mc, isClient), 96 | false, 97 | nil, 98 | } 99 | } 100 | 101 | func (c *tlsConn) setTransportParametersHandler(h *transportParametersHandler) { 102 | c.mintConfig.ExtensionHandler = h 103 | } 104 | 105 | func (c *tlsConn) handshake() error { 106 | outer: 107 | for { 108 | alert := c.tls.Handshake() 109 | hst := c.tls.GetHsState() 110 | switch alert { 111 | case mint.AlertNoAlert, mint.AlertStatelessRetry: 112 | if hst == mint.StateServerConnected || hst == mint.StateClientConnected { 113 | st := c.tls.ConnectionState() 114 | 115 | logf(logTypeTls, "TLS handshake complete") 116 | logf(logTypeTls, "Negotiated ALPN = %v", st.NextProto) 117 | // TODO(ekr@rtfm.com): Abort on ALPN mismatch when others do. 118 | if st.NextProto != kQuicALPNToken { 119 | logf(logTypeTls, "ALPN mismatch %v != %v", st.NextProto, kQuicALPNToken) 120 | } 121 | cs := st.CipherSuite 122 | c.cs = &cs 123 | c.finished = true 124 | 125 | break outer 126 | } 127 | // Loop 128 | case mint.AlertWouldBlock: 129 | logf(logTypeTls, "TLS would have blocked") 130 | break outer 131 | default: 132 | return fmt.Errorf("TLS sent an alert %v", alert) 133 | } 134 | } 135 | return nil 136 | } 137 | 138 | func (c *tlsConn) postHandshake() error { 139 | b := make([]byte, 1) 140 | 141 | n, err := c.tls.Read(b) 142 | assert(n == 0) // This can't happen 143 | if err == nil || err == mint.AlertWouldBlock { 144 | return nil 145 | } 146 | return ErrorProtocolViolation 147 | } 148 | 149 | func (c *tlsConn) getHsState() string { 150 | return c.tls.GetHsState().String() 151 | } 152 | -------------------------------------------------------------------------------- /tracking.go: -------------------------------------------------------------------------------- 1 | // Internal structure indicating packets we have 2 | // received 3 | package minq 4 | 5 | import ( 6 | "fmt" 7 | "github.com/bifurcation/mint" 8 | "time" 9 | ) 10 | 11 | type packetData struct { 12 | protected bool 13 | nonAcks bool 14 | pn uint64 15 | t time.Time 16 | acked2 bool 17 | } 18 | 19 | type recvdPackets struct { 20 | log loggingFunction 21 | initted bool 22 | minReceived uint64 23 | maxReceived uint64 24 | minNotAcked2 uint64 25 | packets map[uint64]*packetData 26 | unacked bool // Are there packets we haven't generated an ACK for 27 | } 28 | 29 | func newRecvdPackets(log loggingFunction) *recvdPackets { 30 | return &recvdPackets{ 31 | log, // loggingFunction 32 | false, // initted 33 | 0, // minReceived 34 | 0, // maxReceived 35 | 0, // minNotAcked2 36 | make(map[uint64]*packetData, 0), // packets 37 | false, // unacked 38 | } 39 | } 40 | 41 | func (p *recvdPackets) initialized() bool { 42 | return p.initted 43 | } 44 | 45 | func (p *recvdPackets) init(pn uint64) { 46 | p.log(logTypeAck, "Initializing received packet start=%x", pn) 47 | p.initted = true 48 | p.minReceived = pn 49 | p.maxReceived = pn 50 | p.minNotAcked2 = pn 51 | } 52 | 53 | func (p *recvdPackets) packetNotReceived(pn uint64) bool { 54 | if pn < p.minReceived { 55 | return false 56 | } 57 | _, found := p.packets[pn] 58 | return !found 59 | } 60 | 61 | func (p *recvdPackets) packetSetReceived(pn uint64, protected bool, nonAcks bool) { 62 | p.log(logTypeAck, "Setting packet received=%x", pn) 63 | if pn > p.maxReceived { 64 | p.maxReceived = pn 65 | } 66 | if pn < p.minNotAcked2 { 67 | p.minNotAcked2 = pn 68 | } 69 | p.log(logTypeAck, "Setting packet received=%x", pn) 70 | p.packets[pn] = &packetData{ 71 | protected, 72 | nonAcks, 73 | pn, 74 | time.Now(), 75 | false, 76 | } 77 | p.unacked = true 78 | } 79 | 80 | func (p *recvdPackets) packetSetAcked2(pn uint64) { 81 | p.log(logTypeAck, "Setting packet acked2=%v", pn) 82 | if pn >= p.minNotAcked2 { 83 | pk, ok := p.packets[pn] 84 | if ok { 85 | pk.acked2 = true 86 | } 87 | } 88 | } 89 | 90 | func (r *ackRange) String() string { 91 | return fmt.Sprintf("%x(%d)", r.lastPacket, r.count) 92 | } 93 | 94 | func (r *ackRanges) String() string { 95 | rsp := "" 96 | for _, s := range *r { 97 | if rsp != "" { 98 | rsp += ", " 99 | } 100 | rsp += s.String() 101 | } 102 | return rsp 103 | } 104 | 105 | func (p *recvdPackets) needToAck() bool { 106 | return p.unacked 107 | } 108 | 109 | // Prepare a list of the ACK ranges, starting at the highest 110 | func (p *recvdPackets) prepareAckRange(epoch mint.Epoch, allowAckOnly bool) ackRanges { 111 | p.log(logTypeAck, "Prepare ACK range epoch=%d", epoch) 112 | // Don't ACK if there's nothing new to ACK 113 | if !p.unacked { 114 | p.log(logTypeAck, "Nothing new to ACK") 115 | return nil 116 | } 117 | 118 | var last uint64 119 | var pn uint64 120 | inrange := false 121 | nonAcks := false 122 | 123 | ranges := make(ackRanges, 0) 124 | 125 | newMinNotAcked2 := p.maxReceived 126 | 127 | // TODO(ekr@rtfm.com): This is kind of a gross hack in case 128 | // someone sends us a 0 initial packet number. 129 | for pn = p.maxReceived; pn >= p.minNotAcked2 && pn > 0; pn-- { 130 | p.log(logTypeTrace, "Examining packet %x", pn) 131 | pk, ok := p.packets[pn] 132 | needs_ack := false 133 | 134 | // If we don't know about the packet, or if the ack has been 135 | // acked, we don't need to ack it. 136 | if ok && !pk.acked2 { 137 | needs_ack = true 138 | newMinNotAcked2 = pn 139 | } 140 | 141 | if ok && pk.acked2 { 142 | delete(p.packets, pn) 143 | } 144 | 145 | if needs_ack { 146 | p.log(logTypeTrace, "Acking packet %x", pn) 147 | } 148 | if needs_ack && pk.nonAcks { 149 | // Note if this is an ack of anything other than 150 | // acks. 151 | p.log(logTypeTrace, "Packet %x contains non-acks", pn) 152 | nonAcks = true 153 | } 154 | 155 | if inrange != needs_ack { 156 | if inrange { 157 | // This is the end of a range. 158 | ranges = append(ranges, ackRange{last, last - pn}) 159 | } else { 160 | last = pn 161 | } 162 | inrange = needs_ack 163 | } 164 | } 165 | if inrange { 166 | p.log(logTypeTrace, "Appending final range %x-%x", last, pn+1) 167 | ranges = append(ranges, ackRange{last, last - pn}) 168 | } 169 | 170 | p.minNotAcked2 = newMinNotAcked2 171 | 172 | p.log(logTypeAck, "%v ACK ranges to send", len(ranges)) 173 | for i, r := range ranges { 174 | p.log(logTypeAck, " %d = %v", i, r.String()) 175 | } 176 | 177 | if !allowAckOnly && !nonAcks { 178 | p.log(logTypeAck, "No non-ack packets and this ack is not ack-only capable") 179 | return nil 180 | } 181 | 182 | p.unacked = false 183 | return ranges 184 | } 185 | -------------------------------------------------------------------------------- /tracking_test.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "fmt" 5 | "github.com/bifurcation/mint" 6 | "runtime" 7 | "testing" 8 | ) 9 | 10 | type testTrackingFixture struct { 11 | pns []uint64 12 | r *recvdPackets 13 | } 14 | 15 | func newTestTrackingFixture() *testTrackingFixture { 16 | pc, _, _, ok := runtime.Caller(1) 17 | name := "unknown" 18 | if ok { 19 | name = runtime.FuncForPC(pc).Name() 20 | } 21 | log := func(tag string, format string, args ...interface{}) { 22 | fullFormat := fmt.Sprintf("%s: %s", name, format) 23 | logf(tag, fullFormat, args...) 24 | } 25 | 26 | pns := make([]uint64, 10) 27 | for i := uint64(0); i < 10; i++ { 28 | pns[i] = uint64(0xdead0000) + i 29 | } 30 | return &testTrackingFixture{ 31 | pns, 32 | newRecvdPackets(log), 33 | } 34 | } 35 | 36 | func TestTrackingPacketsReceived(t *testing.T) { 37 | f := newTestTrackingFixture() 38 | assertEquals(t, true, f.r.packetNotReceived(f.pns[1])) 39 | f.r.init(f.pns[0]) 40 | assertEquals(t, true, f.r.packetNotReceived(f.pns[0])) 41 | assertEquals(t, true, f.r.packetNotReceived(f.pns[1])) 42 | f.r.packetSetReceived(f.pns[0], false, true) 43 | assertEquals(t, false, f.r.packetNotReceived(f.pns[0])) 44 | assertEquals(t, true, f.r.packetNotReceived(f.pns[1])) 45 | f.r.packetSetReceived(f.pns[1], true, true) 46 | assertEquals(t, false, f.r.packetNotReceived(f.pns[1])) 47 | 48 | // Check that things less than min are received 49 | assertEquals(t, false, f.r.packetNotReceived(f.pns[0]-1)) 50 | 51 | // Now make some ACKs 52 | ar := f.r.prepareAckRange(mint.EpochApplicationData, false) 53 | assertX(t, len(ar) == 1, "Should be one entry in ACK range") 54 | assertEquals(t, ar[0].lastPacket, f.pns[1]) 55 | assertEquals(t, ar[0].count, uint64(2)) 56 | 57 | f.r.packetSetReceived(f.pns[3], true, true) 58 | ar = f.r.prepareAckRange(mint.EpochApplicationData, false) 59 | assertX(t, len(ar) == 2, "Should be two entry in ACK range") 60 | assertEquals(t, ar[0].lastPacket, f.pns[3]) 61 | assertEquals(t, ar[1].lastPacket, f.pns[1]) 62 | assertEquals(t, ar[1].count, uint64(2)) 63 | 64 | // Now ack all the acks, so that we should send nothing. 65 | f.r.packetSetAcked2(f.pns[0]) 66 | f.r.packetSetAcked2(f.pns[1]) 67 | f.r.packetSetAcked2(f.pns[3]) 68 | ar = f.r.prepareAckRange(mint.EpochApplicationData, false) 69 | assertX(t, len(ar) == 0, "Should be no acks") 70 | } 71 | -------------------------------------------------------------------------------- /transport.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import () 4 | 5 | // Interface for an object to send packets. Each Transport 6 | // is bound to some particular remote address (or in testing 7 | // we just use a mock which sends the packet into a queue). 8 | type Transport interface { 9 | Send(p []byte) error 10 | } 11 | -------------------------------------------------------------------------------- /transport_parameters.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "encoding/hex" 7 | "fmt" 8 | 9 | "github.com/bifurcation/mint" 10 | "github.com/bifurcation/mint/syntax" 11 | ) 12 | 13 | const ( 14 | kQuicTransportParamtersXtn = mint.ExtensionType(0xffa5) 15 | ) 16 | 17 | type TransportParameterId uint16 18 | 19 | const ( 20 | kTpIdInitialMaxStreamDataBidiLocal = TransportParameterId(0x0000) 21 | kTpIdInitialMaxData = TransportParameterId(0x0001) 22 | kTpIdInitialMaxBidiStreams = TransportParameterId(0x0002) 23 | kTpIdIdleTimeout = TransportParameterId(0x0003) 24 | kTpPreferredAddress = TransportParameterId(0x0004) 25 | kTpIdMaxPacketSize = TransportParameterId(0x0005) 26 | kTpIdStatelessResetToken = TransportParameterId(0x0006) 27 | kTpIdAckDelayExponent = TransportParameterId(0x0007) 28 | kTpIdInitialMaxUniStreams = TransportParameterId(0x0008) 29 | kTpIdDisableMigration = TransportParameterId(0x0009) 30 | kTpIdInitialMaxStreamDataBidiRemote = TransportParameterId(0x0010) 31 | kTpIdInitialMaxStreamDataUni = TransportParameterId(0x0011) 32 | ) 33 | 34 | const ( 35 | kTpDefaultAckDelayExponent = byte(3) 36 | ) 37 | 38 | type tpDef struct { 39 | parameter TransportParameterId 40 | val uint32 41 | size uintptr 42 | } 43 | 44 | var ( 45 | kInitialMaxData = uint64(65536) 46 | kInitialMaxStreamData = uint64(8192) 47 | kConcurrentStreamsBidi = 16 48 | kConcurrentStreamsUni = 16 49 | kTransportParameterDefaults = []tpDef{ 50 | {kTpIdInitialMaxStreamDataBidiLocal, uint32(kInitialMaxStreamData), 4}, 51 | {kTpIdInitialMaxStreamDataBidiRemote, uint32(kInitialMaxStreamData), 4}, 52 | {kTpIdInitialMaxStreamDataUni, uint32(kInitialMaxStreamData), 4}, 53 | {kTpIdInitialMaxData, uint32(kInitialMaxData), 4}, 54 | {kTpIdInitialMaxBidiStreams, uint32(kConcurrentStreamsBidi), 2}, 55 | {kTpIdIdleTimeout, 5, 2}, 56 | {kTpIdInitialMaxUniStreams, uint32(kConcurrentStreamsUni), 2}, 57 | } 58 | ) 59 | 60 | type transportParameters struct { 61 | maxStreamDataUni uint32 62 | maxStreamDataBidiLocal uint32 63 | maxStreamDataBidiRemote uint32 64 | maxData uint32 65 | maxStreamsBidi int 66 | maxStreamsUni int 67 | idleTimeout uint16 68 | ackDelayExp uint8 69 | } 70 | 71 | type TransportParameterList []transportParameter 72 | 73 | type transportParameter struct { 74 | Parameter TransportParameterId 75 | Value []byte `tls:"head=2"` 76 | } 77 | 78 | type clientHelloTransportParameters struct { 79 | InitialVersion VersionNumber 80 | Parameters TransportParameterList `tls:"head=2"` 81 | } 82 | 83 | type encryptedExtensionsTransportParameters struct { 84 | NegotiatedVersion VersionNumber 85 | SupportedVersions []VersionNumber `tls:"head=1"` 86 | Parameters TransportParameterList `tls:"head=2"` 87 | } 88 | 89 | func (tp *TransportParameterList) addUintParameter(id TransportParameterId, val uint32, size uintptr) error { 90 | var buf bytes.Buffer 91 | uintEncodeInt(&buf, uint64(val), size) 92 | *tp = append(*tp, transportParameter{ 93 | id, 94 | buf.Bytes(), 95 | }) 96 | return nil 97 | } 98 | 99 | func (tp *TransportParameterList) getParameter(id TransportParameterId) []byte { 100 | for _, ex := range *tp { 101 | if ex.Parameter == id { 102 | return ex.Value 103 | } 104 | } 105 | return nil 106 | } 107 | 108 | func (tp *TransportParameterList) getUintParameter(id TransportParameterId, size uintptr) (uint32, error) { 109 | assert(size <= 4) 110 | 111 | b := tp.getParameter(id) 112 | if b == nil { 113 | logf(logTypeHandshake, "Missing transport parameter %v", id) 114 | return 0, ErrorMissingValue 115 | } 116 | 117 | if len(b) != int(size) { 118 | logf(logTypeHandshake, "Bogus transport parameter %v", id) 119 | return 0, ErrorInvalidEncoding 120 | } 121 | 122 | buf := bytes.NewReader(b) 123 | tmp, err := uintDecodeInt(buf, size) 124 | if err != nil { 125 | return 0, err 126 | } 127 | 128 | return uint32(tmp), nil 129 | } 130 | 131 | func (tp *TransportParameterList) getUintParameterOrDefault(id TransportParameterId, size uintptr, def uint32) (uint32, error) { 132 | assert(size <= 4) 133 | 134 | b := tp.getParameter(id) 135 | if b == nil { 136 | logf(logTypeHandshake, "Missing transport parameter %v", id) 137 | return def, nil 138 | } 139 | 140 | if len(b) != int(size) { 141 | logf(logTypeHandshake, "Bogus transport parameter %v", id) 142 | return 0, ErrorInvalidEncoding 143 | } 144 | 145 | buf := bytes.NewReader(b) 146 | tmp, err := uintDecodeInt(buf, size) 147 | if err != nil { 148 | return 0, err 149 | } 150 | 151 | return uint32(tmp), nil 152 | } 153 | 154 | func (tp *TransportParameterList) addOpaqueParameter(id TransportParameterId, b []byte) error { 155 | *tp = append(*tp, transportParameter{ 156 | id, 157 | b, 158 | }) 159 | return nil 160 | } 161 | 162 | func (tp *TransportParameterList) createCommonTransportParameters() error { 163 | for _, p := range kTransportParameterDefaults { 164 | err := tp.addUintParameter(p.parameter, p.val, p.size) 165 | if err != nil { 166 | return err 167 | } 168 | } 169 | 170 | return nil 171 | } 172 | 173 | // Implement mint.AppExtensionHandler. 174 | type transportParametersXtnBody struct { 175 | body []byte 176 | } 177 | 178 | func (t transportParametersXtnBody) Type() mint.ExtensionType { 179 | return kQuicTransportParamtersXtn 180 | } 181 | 182 | func (t transportParametersXtnBody) Marshal() ([]byte, error) { 183 | return t.body, nil 184 | } 185 | 186 | func (t *transportParametersXtnBody) Unmarshal(data []byte) (int, error) { 187 | t.body = data 188 | return len(t.body), nil 189 | } 190 | 191 | type transportParametersHandler struct { 192 | log loggingFunction 193 | role Role 194 | version VersionNumber 195 | peerParams *transportParameters 196 | } 197 | 198 | func newTransportParametersHandler(log loggingFunction, role Role, version VersionNumber) *transportParametersHandler { 199 | return &transportParametersHandler{log, role, version, nil} 200 | } 201 | 202 | func (h *transportParametersHandler) setDummyPeerParams() { 203 | h.peerParams = &transportParameters{ 204 | uint32(kInitialMaxStreamData), 205 | uint32(kInitialMaxStreamData), 206 | uint32(kInitialMaxStreamData), 207 | uint32(kInitialMaxData), 208 | kConcurrentStreamsBidi, 209 | kConcurrentStreamsUni, 210 | 600, 211 | uint8(1), 212 | } 213 | } 214 | 215 | func (h *transportParametersHandler) Send(hs mint.HandshakeType, el *mint.ExtensionList) error { 216 | if h.role == RoleClient { 217 | h.log(logTypeHandshake, "Sending transport parameters") 218 | if hs != mint.HandshakeTypeClientHello { 219 | return nil 220 | } 221 | b, err := h.createClientHelloTransportParameters() 222 | if err != nil { 223 | return err 224 | } 225 | h.log(logTypeTrace, "ClientHelloTransportParameters=%s", hex.EncodeToString(b)) 226 | el.Add(&transportParametersXtnBody{b}) 227 | return nil 228 | } 229 | 230 | if h.peerParams == nil { 231 | return nil 232 | } 233 | 234 | if hs != mint.HandshakeTypeEncryptedExtensions { 235 | return nil 236 | } 237 | 238 | h.log(logTypeHandshake, "Sending transport parameters message") 239 | b, err := h.createEncryptedExtensionsTransportParameters() 240 | if err != nil { 241 | return err 242 | } 243 | el.Add(&transportParametersXtnBody{b}) 244 | return nil 245 | } 246 | 247 | func (h *transportParametersHandler) Receive(hs mint.HandshakeType, el *mint.ExtensionList) error { 248 | h.log(logTypeHandshake, "%p TransportParametersHandler message=%d", h, hs) 249 | // First see if the other side sent the extension. 250 | var body transportParametersXtnBody 251 | found, err := el.Find(&body) 252 | 253 | if err != nil { 254 | return fmt.Errorf("Invalid transport parameters") 255 | } 256 | 257 | if found { 258 | h.log(logTypeTrace, "Retrieved transport parameters len=%d %v", len(body.body), hex.EncodeToString(body.body)) 259 | } 260 | 261 | var params *TransportParameterList 262 | 263 | switch hs { 264 | case mint.HandshakeTypeEncryptedExtensions: 265 | if h.role != RoleClient { 266 | return fmt.Errorf("EncryptedExtensions received but not a client") 267 | } 268 | if !found { 269 | h.log(logTypeHandshake, "Missing transport parameters") 270 | return fmt.Errorf("Missing transport parameters") 271 | } 272 | var eeParams encryptedExtensionsTransportParameters 273 | _, err = syntax.Unmarshal(body.body, &eeParams) 274 | if err != nil { 275 | h.log(logTypeHandshake, "Failed to decode parameters") 276 | return err 277 | } 278 | params = &eeParams.Parameters 279 | // TODO(ekr@rtfm.com): Process version #s 280 | case mint.HandshakeTypeClientHello: 281 | if h.role != RoleServer { 282 | return fmt.Errorf("ClientHello received but not a server") 283 | } 284 | if !found { 285 | h.log(logTypeHandshake, "Missing transport parameters") 286 | return fmt.Errorf("Missing transport parameters") 287 | } 288 | 289 | // TODO(ekr@rtfm.com): Process version #s 290 | var chParams clientHelloTransportParameters 291 | _, err = syntax.Unmarshal(body.body, &chParams) 292 | if err != nil { 293 | h.log(logTypeHandshake, "Couldn't unmarshal %v", err) 294 | return err 295 | } 296 | params = &chParams.Parameters 297 | default: 298 | if found { 299 | return fmt.Errorf("Received quic_transport_parameters in inappropriate message %v", hs) 300 | } 301 | return nil 302 | } 303 | 304 | // Now try to process each param. 305 | // TODO(ekr@rtfm.com): Enforce that each param appears only once. 306 | var tp transportParameters 307 | h.log(logTypeHandshake, "Reading transport parameters values") 308 | 309 | tp.maxStreamDataBidiLocal, err = params.getUintParameterOrDefault(kTpIdInitialMaxStreamDataBidiLocal, 4, 0) 310 | if err != nil { 311 | return err 312 | } 313 | 314 | tp.maxStreamDataBidiRemote, err = params.getUintParameterOrDefault(kTpIdInitialMaxStreamDataBidiRemote, 4, 0) 315 | if err != nil { 316 | return err 317 | } 318 | 319 | tp.maxStreamDataUni, err = params.getUintParameterOrDefault(kTpIdInitialMaxStreamDataUni, 4, 0) 320 | if err != nil { 321 | return err 322 | } 323 | 324 | tp.maxData, err = params.getUintParameterOrDefault(kTpIdInitialMaxData, 4, 0) 325 | if err != nil { 326 | return err 327 | } 328 | 329 | tmp, err := params.getUintParameterOrDefault(kTpIdInitialMaxBidiStreams, 2, 0) 330 | if err != nil { 331 | return err 332 | } 333 | tp.maxStreamsBidi = int(tmp) 334 | 335 | if h.role == RoleClient { 336 | tp.maxStreamsBidi++ // Allow for stream 0. 337 | } 338 | 339 | tmp, err = params.getUintParameterOrDefault(kTpIdInitialMaxUniStreams, 2, 0) 340 | if err != nil { 341 | return err 342 | } 343 | tp.maxStreamsUni = int(tmp) 344 | 345 | tmp, err = params.getUintParameter(kTpIdIdleTimeout, 2) 346 | if err != nil { 347 | return err 348 | } 349 | tp.idleTimeout = uint16(tmp) 350 | 351 | tmp, err = params.getUintParameterOrDefault(kTpIdAckDelayExponent, 1, 0) 352 | if err != nil { 353 | return err 354 | } 355 | 356 | h.peerParams = &tp 357 | 358 | h.log(logTypeHandshake, "Finished reading transport parameters") 359 | return nil 360 | } 361 | 362 | func (h *transportParametersHandler) createClientHelloTransportParameters() ([]byte, error) { 363 | chtp := clientHelloTransportParameters{ 364 | h.version, 365 | nil, 366 | } 367 | 368 | err := chtp.Parameters.createCommonTransportParameters() 369 | if err != nil { 370 | return nil, err 371 | } 372 | 373 | b, err := syntax.Marshal(chtp) 374 | if err != nil { 375 | return nil, err 376 | } 377 | return b, nil 378 | } 379 | 380 | func (h *transportParametersHandler) createEncryptedExtensionsTransportParameters() ([]byte, error) { 381 | eetp := encryptedExtensionsTransportParameters{ 382 | h.version, 383 | []VersionNumber{ 384 | h.version, 385 | }, 386 | nil, 387 | } 388 | 389 | err := eetp.Parameters.createCommonTransportParameters() 390 | if err != nil { 391 | return nil, err 392 | } 393 | 394 | b := make([]byte, 16) 395 | _, err = rand.Read(b) 396 | if err != nil { 397 | return nil, err 398 | } 399 | 400 | eetp.Parameters.addOpaqueParameter(kTpIdStatelessResetToken, b) 401 | 402 | b, err = syntax.Marshal(eetp) 403 | if err != nil { 404 | return nil, err 405 | } 406 | return b, nil 407 | } 408 | -------------------------------------------------------------------------------- /udp_transport.go: -------------------------------------------------------------------------------- 1 | package minq 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | ) 7 | 8 | type UdpTransport struct { 9 | u *net.UDPConn 10 | r *net.UDPAddr 11 | } 12 | 13 | func (t *UdpTransport) Send(p []byte) error { 14 | logf(logTypeUdp, "Sending message of len %v", len(p)) 15 | n, err := t.u.WriteToUDP(p, t.r) 16 | if err != nil { 17 | return err 18 | } 19 | if n != len(p) { 20 | return fmt.Errorf("Incomplete write") 21 | } 22 | 23 | return nil 24 | } 25 | 26 | func NewUdpTransport(u *net.UDPConn, r *net.UDPAddr) *UdpTransport { 27 | return &UdpTransport{u, r} 28 | } 29 | 30 | type UdpTransportFactory struct { 31 | local *net.UDPConn 32 | } 33 | 34 | func (f *UdpTransportFactory) MakeTransport(remote *net.UDPAddr) (Transport, error) { 35 | logf(logTypeUdp, "Making transport with remote addr %v", remote) 36 | return NewUdpTransport(f.local, remote), nil 37 | } 38 | 39 | func NewUdpTransportFactory(sock *net.UDPConn) *UdpTransportFactory { 40 | return &UdpTransportFactory{sock} 41 | } 42 | --------------------------------------------------------------------------------