├── .gitignore
├── minq.png
├── deploy
├── run-local.sh
├── mk-localhost.sh
├── logserver
│ ├── package.json
│ └── server.js
├── mk-endpoint.sh
├── run-looped.sh
└── Dockerfile
├── transport.go
├── common.go
├── timer.go
├── udp_transport.go
├── LICENSE.md
├── MINT-LICENSE.md
├── connbuffer.go
├── bin
├── tester
│ └── main.go
├── shim
│ └── main.go
├── client
│ └── main.go
└── server
│ └── main.go
├── frame_test.go
├── crypto.go
├── aead.go
├── tracking_test.go
├── log.go
├── common_test.go
├── codec_test.go
├── errors.go
├── record-layer.go
├── aead_test.go
├── server.go
├── README.md
├── tls.go
├── server_test.go
├── tracking.go
├── stream_test.go
├── packet_test.go
├── codec.go
├── congestion.go
├── transport_parameters.go
├── packet.go
├── minq.svg
├── frame.go
└── stream.go
/.gitignore:
--------------------------------------------------------------------------------
1 | *.test
2 |
--------------------------------------------------------------------------------
/minq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ekr/minq/HEAD/minq.png
--------------------------------------------------------------------------------
/deploy/run-local.sh:
--------------------------------------------------------------------------------
1 | docker run --name minq --rm --publish 4433:4433/udp --publish 3000:3000 minq:latest
2 |
--------------------------------------------------------------------------------
/deploy/mk-localhost.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | docker build --no-cache -f deploy/Dockerfile -t minq --build-arg SERVERNAME=localhost .
3 |
4 |
--------------------------------------------------------------------------------
/deploy/logserver/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "logserver",
3 | "version": "0.0.1",
4 | "dependencies": {
5 | "express": ""
6 | }
7 | }
8 |
--------------------------------------------------------------------------------
/deploy/mk-endpoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | docker build -f deploy/Dockerfile --no-cache -t mozilla/minq --build-arg SERVERNAME=minq.dev.mozaws.net .
3 | docker tag mozilla/minq:latest mozilla/minq:$(git rev-parse HEAD)
4 |
5 |
6 |
--------------------------------------------------------------------------------
/transport.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import ()
4 |
5 | // Interface for an object to send packets. Each Transport
6 | // is bound to some particular remote address (or in testing
7 | // we just use a mock which sends the packet into a queue).
8 | type Transport interface {
9 | Send(p []byte) error
10 | }
11 |
--------------------------------------------------------------------------------
/deploy/run-looped.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | nodejs /go/src/github.com/ekr/minq/deploy/logserver/server.js /tmp/minq.log &
3 | while true; do
4 | echo -n "Starting server as "
5 | echo ${SNAME}
6 | MINQ_LOG=connection,packet /go/bin/server -addr 0.0.0.0:4433 -server-name ${SNAME} -log /tmp/minq.log -http -standalone
7 | echo "Server crashed"
8 | done
9 |
10 |
11 |
--------------------------------------------------------------------------------
/common.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "os"
5 | )
6 |
7 | var (
8 | debug = checkDebug()
9 | )
10 |
11 | func checkDebug() bool {
12 | if os.Getenv("MINQ_DEBUG") == "true" {
13 | return true
14 | }
15 | return false
16 | }
17 |
18 | func assert(t bool) {
19 | if !t {
20 | panic("Assert")
21 | }
22 | }
23 |
24 | func dup(b []byte) []byte {
25 | ret := make([]byte, len(b))
26 | copy(ret, b)
27 | return ret
28 | }
29 |
--------------------------------------------------------------------------------
/deploy/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang
2 |
3 | RUN go get github.com/bifurcation/mint
4 | RUN (cd /go/src/github.com/bifurcation/mint; git remote add ekr https://github.com/ekr/mint; git fetch ekr; git checkout ekr/quic_record_layer)
5 | RUN go get github.com/cloudflare/cfssl/helpers
6 | RUN go get github.com/ekr/minq
7 | RUN go install github.com/ekr/minq/bin/server
8 | RUN go install github.com/ekr/minq/bin/client
9 | RUN apt-get update
10 | RUN apt-get install -y tcpdump
11 | RUN curl -sL https://deb.nodesource.com/setup_6.x | bash -
12 | RUN apt-get install -y nodejs
13 | RUN (cd /go/src/github.com/ekr/minq/deploy/logserver; npm install)
14 |
15 | ARG SERVERNAME=localhost
16 | ENV SNAME=$SERVERNAME
17 | ENV MINQ_LOG='connection,handshake,stream,packet'
18 | ENTRYPOINT ["/bin/sh","/go/src/github.com/ekr/minq/deploy/run-looped.sh"]
19 | CMD [$SNAME]
20 |
21 | EXPOSE 4433/udp
22 | EXPOSE 3000/tcp
23 |
24 |
--------------------------------------------------------------------------------
/timer.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "time"
5 | )
6 |
7 | type timerCb func()
8 |
9 | type timer struct {
10 | ts *timerSet
11 | cb timerCb
12 | deadline time.Time
13 | }
14 |
15 | // This is a simple implementation of unsorted timers.
16 | // TODO(ekr@rtfm.com): Need a better data structure.
17 | type timerSet struct {
18 | ts []*timer
19 | }
20 |
21 | func newTimers() *timerSet {
22 | return &timerSet{nil}
23 | }
24 |
25 | func (ts *timerSet) start(cb timerCb, delayMs uint32) *timer {
26 | t := timer{
27 | ts,
28 | cb,
29 | time.Now().Add(time.Millisecond * time.Duration(delayMs)),
30 | }
31 |
32 | ts.ts = append(ts.ts, &t)
33 |
34 | return &t
35 | }
36 |
37 | func (ts *timerSet) check(now time.Time) {
38 | for i, t := range ts.ts {
39 | if now.After(t.deadline) {
40 | ts.ts = append(ts.ts[:i], ts.ts[:i+1]...)
41 | if t.cb != nil {
42 | t.cb()
43 | }
44 | }
45 | }
46 | }
47 |
48 | func (t *timer) cancel() {
49 | t.cb = nil
50 | }
51 |
--------------------------------------------------------------------------------
/udp_transport.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | )
7 |
8 | type UdpTransport struct {
9 | u *net.UDPConn
10 | r *net.UDPAddr
11 | }
12 |
13 | func (t *UdpTransport) Send(p []byte) error {
14 | logf(logTypeUdp, "Sending message of len %v", len(p))
15 | n, err := t.u.WriteToUDP(p, t.r)
16 | if err != nil {
17 | return err
18 | }
19 | if n != len(p) {
20 | return fmt.Errorf("Incomplete write")
21 | }
22 |
23 | return nil
24 | }
25 |
26 | func NewUdpTransport(u *net.UDPConn, r *net.UDPAddr) *UdpTransport {
27 | return &UdpTransport{u, r}
28 | }
29 |
30 | type UdpTransportFactory struct {
31 | local *net.UDPConn
32 | }
33 |
34 | func (f *UdpTransportFactory) MakeTransport(remote *net.UDPAddr) (Transport, error) {
35 | logf(logTypeUdp, "Making transport with remote addr %v", remote)
36 | return NewUdpTransport(f.local, remote), nil
37 | }
38 |
39 | func NewUdpTransportFactory(sock *net.UDPConn) *UdpTransportFactory {
40 | return &UdpTransportFactory{sock}
41 | }
42 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Eric Rescorla
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MINT-LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Richard Barnes
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/deploy/logserver/server.js:
--------------------------------------------------------------------------------
1 | var express = require('express');
2 | var fs = require('fs');
3 | var readline = require('readline');
4 | var connid_regex = /[0-9a-fA-F]+$/;
5 |
6 | var app = express();
7 |
8 | var port = process.env.PORT || 3000;
9 |
10 | if (process.argv.len < 2) {
11 | console.log("Need to specify log file");
12 | return;
13 | }
14 | var file = process.argv[2];
15 | console.log(file);
16 |
17 | app.get('/:connid', function(request, response) {
18 | var connid = request.params.connid;
19 | if (!connid.match(connid_regex)) {
20 | response.status(400).send("Bogus connid (non-hex characters)");
21 | return;
22 | }
23 |
24 | if(connid.length < 4) {
25 | response.status(400).send("Bogus connid (too short)");
26 | return;
27 | }
28 |
29 | connid = connid.toLowerCase();
30 |
31 | var match = 'Conn: ' + connid + "_";
32 | var data = "
";
33 | const rl = readline.createInterface({
34 | input: fs.createReadStream(file),
35 | terminal: false
36 | });
37 | rl.on('line', function(l) {
38 | if (l.search(match) != -1) {
39 | data += l;
40 | data += "\n";
41 | }
42 | });
43 | rl.on('close', function() {
44 | data += "";
45 | response.send(data);
46 | });
47 | });
48 |
49 | app.listen(port, function() {
50 | console.log("Listening on " + port);
51 | console.log("Logfile = " + file);
52 | });
53 |
--------------------------------------------------------------------------------
/connbuffer.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "net"
7 | "time"
8 | )
9 |
10 | type connBuffer struct {
11 | r *bytes.Buffer
12 | w *bytes.Buffer
13 | }
14 |
15 | func (p *connBuffer) Read(data []byte) (n int, err error) {
16 | logf(logTypeConnBuffer, "Reading %v", n)
17 | n, err = p.r.Read(data)
18 |
19 | // Suppress bytes.Buffer's EOF on an empty buffer
20 | if err == io.EOF {
21 | err = nil
22 | }
23 | return
24 | }
25 |
26 | func (p *connBuffer) Write(data []byte) (n int, err error) {
27 | logf(logTypeConnBuffer, "Writing %v", n)
28 | return p.w.Write(data)
29 | }
30 |
31 | func (p *connBuffer) Close() error {
32 | return nil
33 | }
34 |
35 | func (p *connBuffer) LocalAddr() net.Addr { return nil }
36 | func (p *connBuffer) RemoteAddr() net.Addr { return nil }
37 | func (p *connBuffer) SetDeadline(t time.Time) error { return nil }
38 | func (p *connBuffer) SetReadDeadline(t time.Time) error { return nil }
39 | func (p *connBuffer) SetWriteDeadline(t time.Time) error { return nil }
40 |
41 | func newConnBuffer() *connBuffer {
42 | return &connBuffer{
43 | bytes.NewBuffer(nil),
44 | bytes.NewBuffer(nil),
45 | }
46 | }
47 |
48 | func (p *connBuffer) input(data []byte) error {
49 | logf(logTypeConnBuffer, "input %v", len(data))
50 | _, err := p.r.Write(data)
51 | return err
52 | }
53 |
54 | func (p *connBuffer) getOutput() []byte {
55 | b := p.w.Bytes()
56 | p.w.Reset()
57 | return b
58 | }
59 |
60 | func (p *connBuffer) OutputLen() int {
61 | return p.w.Len()
62 | }
63 |
--------------------------------------------------------------------------------
/bin/tester/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "encoding/hex"
5 | "flag"
6 | "fmt"
7 | "io/ioutil"
8 | "strings"
9 |
10 | "github.com/ekr/minq"
11 | )
12 |
13 | var infile string
14 | var serverName string
15 | var dehex bool
16 |
17 | type stdoutTransport struct {
18 | }
19 |
20 | func (t *stdoutTransport) Send(p []byte) error {
21 | fmt.Printf("Output=%v", hex.Dump(p))
22 | return nil
23 | }
24 |
25 | type connHandler struct {
26 | }
27 |
28 | func (h *connHandler) StateChanged(s minq.State) {
29 | fmt.Println("State changed to ", s)
30 | }
31 |
32 | func (h *connHandler) NewStream(s *minq.Stream) {
33 | fmt.Println("New stream")
34 | }
35 |
36 | func (h *connHandler) StreamReadable(s *minq.Stream) {
37 | fmt.Println("Stream readable")
38 | }
39 |
40 | func main() {
41 | flag.StringVar(&infile, "infile", "input", "input file")
42 | flag.StringVar(&serverName, "server-name", "", "SNI")
43 | flag.BoolVar(&dehex, "hex", false, "file is in hex")
44 | flag.Parse()
45 |
46 | in, err := ioutil.ReadFile(infile)
47 | if err != nil {
48 | fmt.Println("Couldn't read file")
49 | }
50 |
51 | if dehex {
52 | s := string(in)
53 | s = strings.Replace(s, " ", "", -1)
54 | s = strings.Replace(s, "\n", "", -1)
55 | in, err = hex.DecodeString(s)
56 | if err != nil {
57 | fmt.Println("Couldn't hex decode input")
58 | }
59 |
60 | }
61 |
62 | strans := &stdoutTransport{}
63 | config := minq.NewTlsConfig(serverName)
64 | conn := minq.NewConnection(strans, minq.RoleServer, &config, nil)
65 | err = conn.Input(in)
66 | if err != nil {
67 | fmt.Println("Couldn't process input: ", err)
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/frame_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "testing"
7 | )
8 |
9 | func testEncodeDecodeEncode(t *testing.T, f *frame) {
10 | err := f.encode()
11 | assertNotError(t, err, "Encode failed")
12 | fmt.Printf("Encoded: [%x]\n", f.encoded)
13 |
14 | consumed, f2, err := decodeFrame(f.encoded)
15 | assertNotError(t, err, "Failed to decode frame")
16 | assertEquals(t, len(f.encoded), int(consumed))
17 | f2.encoded = nil // So we re-encode
18 |
19 | err = f2.encode()
20 | assertNotError(t, err, "Encode failed")
21 | assertByteEquals(t, f.encoded, f2.encoded)
22 |
23 | fmt.Printf("%+v\n", f2)
24 | }
25 |
26 | func TestStreamFrame(t *testing.T) {
27 | s := newStreamFrame(1, 0,
28 | bytes.Repeat([]byte{0xa0}, 100), false)
29 | testEncodeDecodeEncode(t, s)
30 | }
31 |
32 | func TestAckFrameOneRange(t *testing.T) {
33 | ar := []ackRange{{0xdeadbeef, 2}}
34 |
35 | recvd := newRecvdPackets(logf)
36 | recvd.init(ar[0].lastPacket)
37 | recvd.packetSetReceived(ar[0].lastPacket, false, false)
38 |
39 | f, _, err := newAckFrame(recvd, ar, 33)
40 | assertNotError(t, err, "Couldn't make ack frame")
41 |
42 | testEncodeDecodeEncode(t, f)
43 | }
44 |
45 | func TestAckFrameTwoRanges(t *testing.T) {
46 | ar := []ackRange{{0xdeadbeef, 2}, {0xdeadbee0, 1}}
47 |
48 | recvd := newRecvdPackets(logf)
49 | recvd.init(ar[0].lastPacket)
50 | recvd.packetSetReceived(ar[0].lastPacket, false, false)
51 |
52 | f, _, err := newAckFrame(recvd, ar, 49)
53 | assertNotError(t, err, "Couldn't make ack frame")
54 |
55 | testEncodeDecodeEncode(t, f)
56 | }
57 |
58 | func TestFixedSizedData(t *testing.T) {
59 | f := newPathChallengeFrame([]byte{1, 2, 3, 4, 5, 6, 7, 8})
60 | testEncodeDecodeEncode(t, f)
61 | f = newPathResponseFrame([]byte{10, 9, 8, 7, 6, 5, 4, 3})
62 | testEncodeDecodeEncode(t, f)
63 | }
64 |
--------------------------------------------------------------------------------
/crypto.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "crypto/cipher"
5 | "encoding/hex"
6 | "github.com/bifurcation/mint"
7 | )
8 |
9 | type cryptoState struct {
10 | aead cipher.AEAD
11 | pne pneCipherFactory
12 | }
13 |
14 | func infallibleHexDecode(s string) []byte {
15 | b, err := hex.DecodeString(s)
16 | if err != nil {
17 | panic("didn't hex decode " + s)
18 | }
19 | return b
20 | }
21 |
22 | var kQuicVersionSalt = infallibleHexDecode("9c108f98520a5c5c32968e950e8a2c5fe06d6c38")
23 |
24 | const clientCtSecretLabel = "client in"
25 | const serverCtSecretLabel = "server in"
26 |
27 | const clientPpSecretLabel = "EXPORTER-QUIC client 1rtt"
28 | const serverPpSecretLabel = "EXPORTER-QUIC server 1rtt"
29 |
30 | func newCryptoStateInner(secret []byte, cs *mint.CipherSuiteParams) (*cryptoState, error) {
31 | var st cryptoState
32 | var err error
33 |
34 | k := mint.HkdfExpandLabel(cs.Hash, secret, "key", []byte{}, cs.KeyLen)
35 | iv := mint.HkdfExpandLabel(cs.Hash, secret, "iv", []byte{}, cs.IvLen)
36 | pn := mint.HkdfExpandLabel(cs.Hash, secret, "pn", []byte{}, cs.KeyLen)
37 | logf(logTypeAead, "key=%x iv=%x pn=%x", k, iv, pn)
38 | st.aead, err = newWrappedAESGCM(k, iv)
39 | if err != nil {
40 | return nil, err
41 | }
42 | st.pne = newPneCipherFactoryAES(pn)
43 |
44 | return &st, nil
45 | }
46 |
47 | func generateCleartextKeys(secret []byte, label string, cs *mint.CipherSuiteParams) (*cryptoState, error) {
48 | logf(logTypeTls, "Cleartext keys: cid=%x initial_salt=%x", secret, kQuicVersionSalt)
49 | extracted := mint.HkdfExtract(cs.Hash, kQuicVersionSalt, secret)
50 | inner := mint.HkdfExpandLabel(cs.Hash, extracted, label, []byte{}, cs.Hash.Size())
51 | logf(logTypeAead, "initial_secret (%s) = %x", label, inner)
52 | return newCryptoStateInner(inner, cs)
53 | }
54 |
55 | func newCryptoStateFromTls(t *tlsConn, label string) (*cryptoState, error) {
56 | panic("TODO")
57 | }
58 |
--------------------------------------------------------------------------------
/aead.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "crypto/aes"
5 | "crypto/cipher"
6 | )
7 |
8 | // aeadWrapper contains an existing AEAD object and does the
9 | // QUIC nonce masking.
10 | type aeadWrapper struct {
11 | iv []byte
12 | cipher cipher.AEAD
13 | }
14 |
15 | func (a *aeadWrapper) NonceSize() int {
16 | return a.cipher.NonceSize()
17 | }
18 | func (a *aeadWrapper) Overhead() int {
19 | return a.cipher.Overhead()
20 | }
21 |
22 | func (a *aeadWrapper) fmtNonce(in []byte) []byte {
23 | // The input nonce is actually a packet number.
24 | assert(len(in) == 8)
25 | assert(a.NonceSize() == 12)
26 | assert(len(a.iv) == a.NonceSize())
27 |
28 | nonce := make([]byte, a.NonceSize())
29 | copy(nonce[len(nonce)-len(in):], in)
30 | for i, b := range a.iv {
31 | nonce[i] ^= b
32 | }
33 |
34 | logf(logTypeAead, "Nonce=%x", nonce)
35 | return nonce
36 | }
37 |
38 | func (a *aeadWrapper) Seal(dst []byte, nonce []byte, plaintext []byte, aad []byte) []byte {
39 | logf(logTypeAead, "AES protecting aad len=%d, plaintext len=%d", len(aad), len(plaintext))
40 | logf(logTypeTrace, "AES input AAD=%x P=%x", aad, plaintext)
41 | ret := a.cipher.Seal(dst, a.fmtNonce(nonce), plaintext, aad)
42 | logf(logTypeTrace, "AES output %x", ret)
43 |
44 | return ret
45 | }
46 |
47 | func (a *aeadWrapper) Open(dst []byte, nonce []byte, ciphertext []byte, aad []byte) ([]byte, error) {
48 | logf(logTypeAead, "AES unprotecting aad len=%d, ciphertext len=%d", len(aad), len(ciphertext))
49 | logf(logTypeTrace, "AES input AAD=%x C=%x", aad, ciphertext)
50 | ret, err := a.cipher.Open(dst, a.fmtNonce(nonce), ciphertext, aad)
51 | if err != nil {
52 | return nil, err
53 | }
54 | logf(logTypeTrace, "AES output %x", ret)
55 | return ret, err
56 | }
57 |
58 | func newWrappedAESGCM(key []byte, iv []byte) (cipher.AEAD, error) {
59 | logf(logTypeAead, "New AES GCM context: key=%x iv=%x", key, iv)
60 | a, err := aes.NewCipher(key)
61 | if err != nil {
62 | return nil, err
63 | }
64 |
65 | aead, err := cipher.NewGCM(a)
66 | if err != nil {
67 | return nil, err
68 | }
69 |
70 | return &aeadWrapper{iv, aead}, nil
71 | }
72 |
--------------------------------------------------------------------------------
/tracking_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "fmt"
5 | "github.com/bifurcation/mint"
6 | "runtime"
7 | "testing"
8 | )
9 |
10 | type testTrackingFixture struct {
11 | pns []uint64
12 | r *recvdPackets
13 | }
14 |
15 | func newTestTrackingFixture() *testTrackingFixture {
16 | pc, _, _, ok := runtime.Caller(1)
17 | name := "unknown"
18 | if ok {
19 | name = runtime.FuncForPC(pc).Name()
20 | }
21 | log := func(tag string, format string, args ...interface{}) {
22 | fullFormat := fmt.Sprintf("%s: %s", name, format)
23 | logf(tag, fullFormat, args...)
24 | }
25 |
26 | pns := make([]uint64, 10)
27 | for i := uint64(0); i < 10; i++ {
28 | pns[i] = uint64(0xdead0000) + i
29 | }
30 | return &testTrackingFixture{
31 | pns,
32 | newRecvdPackets(log),
33 | }
34 | }
35 |
36 | func TestTrackingPacketsReceived(t *testing.T) {
37 | f := newTestTrackingFixture()
38 | assertEquals(t, true, f.r.packetNotReceived(f.pns[1]))
39 | f.r.init(f.pns[0])
40 | assertEquals(t, true, f.r.packetNotReceived(f.pns[0]))
41 | assertEquals(t, true, f.r.packetNotReceived(f.pns[1]))
42 | f.r.packetSetReceived(f.pns[0], false, true)
43 | assertEquals(t, false, f.r.packetNotReceived(f.pns[0]))
44 | assertEquals(t, true, f.r.packetNotReceived(f.pns[1]))
45 | f.r.packetSetReceived(f.pns[1], true, true)
46 | assertEquals(t, false, f.r.packetNotReceived(f.pns[1]))
47 |
48 | // Check that things less than min are received
49 | assertEquals(t, false, f.r.packetNotReceived(f.pns[0]-1))
50 |
51 | // Now make some ACKs
52 | ar := f.r.prepareAckRange(mint.EpochApplicationData, false)
53 | assertX(t, len(ar) == 1, "Should be one entry in ACK range")
54 | assertEquals(t, ar[0].lastPacket, f.pns[1])
55 | assertEquals(t, ar[0].count, uint64(2))
56 |
57 | f.r.packetSetReceived(f.pns[3], true, true)
58 | ar = f.r.prepareAckRange(mint.EpochApplicationData, false)
59 | assertX(t, len(ar) == 2, "Should be two entry in ACK range")
60 | assertEquals(t, ar[0].lastPacket, f.pns[3])
61 | assertEquals(t, ar[1].lastPacket, f.pns[1])
62 | assertEquals(t, ar[1].count, uint64(2))
63 |
64 | // Now ack all the acks, so that we should send nothing.
65 | f.r.packetSetAcked2(f.pns[0])
66 | f.r.packetSetAcked2(f.pns[1])
67 | f.r.packetSetAcked2(f.pns[3])
68 | ar = f.r.prepareAckRange(mint.EpochApplicationData, false)
69 | assertX(t, len(ar) == 0, "Should be no acks")
70 | }
71 |
--------------------------------------------------------------------------------
/log.go:
--------------------------------------------------------------------------------
1 | // Lightly modified from Mint
2 |
3 | package minq
4 |
5 | import (
6 | "fmt"
7 | "log"
8 | "os"
9 | "strings"
10 | )
11 |
12 | // We use this environment variable to control logging. It should be a
13 | // comma-separated list of log tags (see below) or "*" to enable all logging.
14 | const logConfigVar = "MINQ_LOG"
15 |
16 | // Pre-defined log types
17 | const (
18 | logTypeAead = "aead"
19 | logTypeCodec = "codec"
20 | logTypeConnBuffer = "connbuffer"
21 | logTypeConnection = "connection"
22 | logTypeAck = "ack"
23 | logTypeFrame = "frame"
24 | logTypeHandshake = "handshake"
25 | logTypeTls = "tls"
26 | logTypeTrace = "trace"
27 | logTypeServer = "server"
28 | logTypeUdp = "udp"
29 | logTypeStream = "stream"
30 | logTypeFlowControl = "flow"
31 | logTypePacket = "packet" // Just send notes on which packets are sent and received
32 | logTypeCongestion = "congestion"
33 | )
34 |
35 | var (
36 | logFunction = log.Printf
37 | logAll = false
38 | logSettings = map[string]bool{}
39 | )
40 |
41 | func init() {
42 | parseLogEnv(os.Environ())
43 | }
44 |
45 | func parseLogEnv(env []string) {
46 | for _, stmt := range env {
47 | if strings.HasPrefix(stmt, logConfigVar+"=") {
48 | val := stmt[len(logConfigVar)+1:]
49 |
50 | if val == "*" {
51 | logAll = true
52 | } else {
53 | for _, t := range strings.Split(val, ",") {
54 | logSettings[t] = true
55 | }
56 | }
57 | }
58 | }
59 | }
60 |
61 | func logf(tag string, format string, args ...interface{}) {
62 | if logAll || logSettings[tag] {
63 | fullFormat := fmt.Sprintf("[%s] %s", tag, format)
64 | logFunction(fullFormat, args...)
65 | }
66 | }
67 |
68 | type loggingFunction func(string, string, ...interface{})
69 |
70 | func SetLogOutput(f func(string, ...interface{})) {
71 | logFunction = f
72 | }
73 |
74 | func newConnectionLogger(c *Connection) loggingFunction {
75 | return func(tag string, format string, args ...interface{}) {
76 | if logAll || logSettings[tag] {
77 | logf(tag, c.String()+": "+format, args...)
78 | }
79 | }
80 | }
81 |
82 | func newStreamLogger(id uint64, dir string, f loggingFunction) loggingFunction {
83 | extra := fmt.Sprintf("%s stream %d: ", dir, id)
84 | return func(tag string, format string, args ...interface{}) {
85 | f(tag, extra+format, args...)
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/bin/shim/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "github.com/ekr/minq"
7 | "net"
8 | "time"
9 | )
10 |
11 | var addr string
12 | var server bool
13 |
14 | type connHandler struct {
15 | }
16 |
17 | func (h *connHandler) StateChanged(s minq.State) {
18 | fmt.Println("State changed to ", s)
19 | }
20 |
21 | func (h *connHandler) NewStream(s *minq.Stream) {
22 | }
23 |
24 | func (h *connHandler) StreamReadable(s *minq.Stream) {
25 | }
26 |
27 | func readUDP(s *net.UDPConn) ([]byte, error) {
28 | b := make([]byte, 8192)
29 |
30 | s.SetReadDeadline(time.Now().Add(time.Second))
31 | n, _, err := s.ReadFromUDP(b)
32 | if err != nil {
33 | e, o := err.(net.Error)
34 | if o && e.Timeout() {
35 | return nil, minq.ErrorWouldBlock
36 | }
37 | fmt.Println("Error reading from UDP socket: ", err)
38 | return nil, err
39 | }
40 |
41 | if n == len(b) {
42 | fmt.Println("Underread from UDP socket")
43 | return nil, err
44 | }
45 | b = b[:n]
46 | return b, nil
47 | }
48 |
49 | func main() {
50 | flag.StringVar(&addr, "addr", "localhost:4433", "[host:port]")
51 | flag.BoolVar(&server, "server", false, "Run as server]")
52 | flag.Parse()
53 |
54 | uaddr, err := net.ResolveUDPAddr("udp", addr)
55 | if err != nil {
56 | fmt.Println("Invalid UDP addr", err)
57 | return
58 | }
59 |
60 | usock, err := net.ListenUDP("udp", nil)
61 | if err != nil {
62 | fmt.Println("Couldn't create connected UDP socket")
63 | return
64 | }
65 |
66 | role := minq.RoleClient
67 | if server {
68 | _, port, err := net.SplitHostPort(usock.LocalAddr().String())
69 | if err != nil {
70 | return
71 | }
72 | fmt.Println(port)
73 | role = minq.RoleServer
74 | }
75 | fmt.Printf("Remote addr=%v\n", addr)
76 | utrans := minq.NewUdpTransport(usock, uaddr)
77 | config := minq.NewTlsConfig("localhost")
78 |
79 | conn := minq.NewConnection(utrans, role, &config, nil)
80 |
81 | // Start things off.
82 | fmt.Println("Starting")
83 | _, err = conn.CheckTimer()
84 |
85 | for conn.GetState() != minq.StateEstablished {
86 | b, err := readUDP(usock)
87 | if err != nil {
88 | if err == minq.ErrorWouldBlock {
89 | _, err = conn.CheckTimer()
90 | if err != nil {
91 | return
92 | }
93 | continue
94 | }
95 | return
96 | }
97 |
98 | err = conn.Input(b)
99 | if err != nil {
100 | fmt.Println("Error", err)
101 | return
102 | }
103 | }
104 |
105 | fmt.Println("Connection established")
106 | }
107 |
--------------------------------------------------------------------------------
/common_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "bytes"
5 | "encoding/hex"
6 | "fmt"
7 | "runtime"
8 | "testing"
9 | )
10 |
11 | /* STOLEN FROM MINT.
12 | The MIT License (MIT)
13 |
14 | Copyright (c) 2016 Richard Barnes
15 |
16 | Permission is hereby granted, free of charge, to any person obtaining a copy
17 | of this software and associated documentation files (the "Software"), to deal
18 | in the Software without restriction, including without limitation the rights
19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | copies of the Software, and to permit persons to whom the Software is
21 | furnished to do so, subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in
24 | all copies or substantial portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
32 | THE SOFTWARE.
33 | */
34 |
35 | func unhex(h string) []byte {
36 | b, err := hex.DecodeString(h)
37 | if err != nil {
38 | panic(err)
39 | }
40 | return b
41 | }
42 |
43 | func assertX(t *testing.T, test bool, msg string) {
44 | prefix := string("")
45 | for i := 1; ; i++ {
46 | _, file, line, ok := runtime.Caller(i)
47 | if !ok {
48 | break
49 | }
50 | prefix = fmt.Sprintf("%v: %d\n", file, line) + prefix
51 | }
52 | if !test {
53 | t.Fatalf(prefix + msg)
54 | }
55 | }
56 |
57 | func assertError(t *testing.T, err error, msg string) {
58 | assertX(t, err != nil, msg)
59 | }
60 |
61 | func assertNotError(t *testing.T, err error, msg string) {
62 | if err != nil {
63 | msg += ": " + err.Error()
64 | }
65 | assertX(t, err == nil, msg)
66 | }
67 |
68 | func assertNotNil(t *testing.T, x interface{}, msg string) {
69 | assertX(t, x != nil, msg)
70 | }
71 |
72 | func assertEquals(t *testing.T, a, b interface{}) {
73 | assertX(t, a == b, fmt.Sprintf("%+v != %+v", a, b))
74 | }
75 |
76 | func assertByteEquals(t *testing.T, a, b []byte) {
77 | assertX(t, bytes.Equal(a, b), fmt.Sprintf("%+v != %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
78 | }
79 |
80 | func assertNotByteEquals(t *testing.T, a, b []byte) {
81 | assertX(t, !bytes.Equal(a, b), fmt.Sprintf("%+v == %+v", hex.EncodeToString(a), hex.EncodeToString(b)))
82 | }
83 |
84 | /* END STOLEN FROM MINT. */
85 |
--------------------------------------------------------------------------------
/codec_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "encoding/hex"
5 | "fmt"
6 | "reflect"
7 | "testing"
8 | )
9 |
10 | type Uint8Indirect uint8
11 |
12 | type TestStructDefaultLengths struct {
13 | U8 Uint8Indirect
14 | U16 uint16
15 | B []byte
16 | }
17 |
18 | type TestStructOverrideLengths struct {
19 | U8 uint8
20 | U16 uint16
21 | B []byte
22 | }
23 |
24 | func (t TestStructOverrideLengths) U16__length() uintptr {
25 | return 1
26 | }
27 |
28 | func (t TestStructOverrideLengths) B__length() uintptr {
29 | return 3
30 | }
31 |
32 | func codecEDE(t *testing.T, s interface{}, s2 interface{}, expectedLen uintptr) {
33 | res, err := encode(s)
34 | assertNotError(t, err, "Could not encode")
35 |
36 | fmt.Println("Result = ", hex.EncodeToString(res))
37 | // TODO(ekr@rtfm.com). What is the type of len().
38 | assertEquals(t, uintptr(expectedLen), uintptr(len(res)))
39 |
40 | _, err = decode(s2, res)
41 | assertNotError(t, err, "Could not decode")
42 |
43 | res2, err := encode(s2)
44 | assertNotError(t, err, "Could not re-encode")
45 | fmt.Println("Result2 = ", hex.EncodeToString(res2))
46 | assertByteEquals(t, res, res2)
47 | }
48 |
49 | func TestCodecDefaultEncode(t *testing.T) {
50 | s := TestStructDefaultLengths{1, 2, []byte{'a', 'b', 'c'}}
51 | var s2 TestStructDefaultLengths
52 |
53 | codecEDE(t, &s, &s2, 6)
54 | }
55 |
56 | func TestCodecOverrideEncode(t *testing.T) {
57 | s := TestStructOverrideLengths{1, 2, []byte{'a', 'b', 'c'}}
58 | var s2 TestStructOverrideLengths
59 |
60 | codecEDE(t, &s, &s2, 5)
61 | }
62 |
63 | func TestCodecOverrideDecodeLength(t *testing.T) {
64 | s := TestStructOverrideLengths{1, 2, []byte{'a', 'b', 'c'}}
65 | var s2 TestStructOverrideLengths
66 |
67 | res, err := encode(&s)
68 | assertNotError(t, err, "Could not encode")
69 |
70 | modified := append(res, 'd')
71 | _, err = decode(&s2, modified)
72 | assertNotError(t, err, "Could not decode")
73 |
74 | fmt.Println(s2)
75 |
76 | res2, err := encode(&s2)
77 | assertNotError(t, err, "Could not re-encode")
78 |
79 | assertByteEquals(t, res, res2)
80 | }
81 |
82 | func TestParseLengthSpec(t *testing.T) {
83 | // 1 bit, 2 values
84 | spec, err := parseLengthSpecification("1:8,16")
85 | assertNotError(t, err, "Couldn't parse single bit value")
86 | fmt.Println(*spec)
87 | assertX(t, reflect.DeepEqual(*spec, lengthSpec{1, 1, []int{8, 16}}),
88 | "Spec parsed correctly")
89 |
90 | // 2 bit, 4 values
91 | spec, err = parseLengthSpecification("3:8,16,24,32")
92 | assertNotError(t, err, "Couldn't parse two bit value")
93 | fmt.Println(*spec)
94 | assertX(t, reflect.DeepEqual(*spec, lengthSpec{3, 2, []int{8, 16, 24, 32}}),
95 | "Spec parsed correctly")
96 |
97 | }
98 |
--------------------------------------------------------------------------------
/errors.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "fmt"
5 | )
6 |
7 | // Errors which don't necesarily cause connection teardown.
8 | type intError struct {
9 | err string
10 | sub string
11 | fatal bool
12 | }
13 |
14 | func (e intError) Error() string {
15 | return e.err
16 | }
17 |
18 | func fatalError(format string, args ...interface{}) error {
19 | return intError{
20 | fmt.Sprintf(format, args...),
21 | "",
22 | true,
23 | }
24 | }
25 |
26 | func internalError(format string, args ...interface{}) error {
27 | str := fmt.Sprintf(format, args...)
28 | if debug {
29 | panic("Internal error: " + str)
30 | }
31 |
32 | return intError{
33 | str,
34 | "",
35 | true,
36 | }
37 | }
38 |
39 | func nonFatalError(format string, args ...interface{}) error {
40 | return intError{
41 | fmt.Sprintf(format, args...),
42 | "",
43 | false,
44 | }
45 | }
46 |
47 | func err2string(err interface{}) string {
48 | switch e := err.(type) {
49 | case error:
50 | return e.Error()
51 | case string:
52 | return e
53 | default:
54 | panic("Bogus argument to err2string")
55 | }
56 | }
57 |
58 | func wrapE(err interface{}, sub interface{}) error {
59 | return intError{
60 | err2string(err),
61 | err2string(sub),
62 | isFatalError(err),
63 | }
64 | }
65 |
66 | // An error is fatal if either.
67 | //
68 | // It's a regular error (i.e., not an intError)
69 | // e.fatal is true
70 | func isFatalError(e interface{}) bool {
71 | if e == nil {
72 | return false
73 | }
74 |
75 | i, ok := e.(intError)
76 | if !ok {
77 | return true
78 | }
79 |
80 | return i.fatal
81 | }
82 |
83 | // Return codes.
84 | var ErrorWouldBlock = nonFatalError("Would have blocked (QUIC)")
85 | var ErrorDestroyConnection = fatalError("Terminate connection")
86 | var ErrorReceivedVersionNegotiation = fatalError("Received a version negotiation packet advertising a different version than ours")
87 | var ErrorConnIsClosed = fatalError("Connection is closed")
88 | var ErrorConnIsClosing = nonFatalError("Connection is closing")
89 | var ErrorStreamReset = fatalError("Stream was reset")
90 | var ErrorStreamIsClosed = fatalError("Stream is closed")
91 | var ErrorInvalidPacket = nonFatalError("Invalid packet")
92 | var ErrorConnectionTimedOut = fatalError("Connection timed out")
93 | var ErrorMissingValue = fatalError("Expected value is missing")
94 | var ErrorInvalidEncoding = fatalError("Invalid encoding")
95 | var ErrorProtocolViolation = fatalError("Protocol violation")
96 | var ErrorFrameFormatError = fatalError("Frame format error")
97 | var ErrorFlowControlError = fatalError("Flow control error")
98 |
99 | // Protocol errors
100 | type ErrorCode uint16
101 |
102 | const (
103 | kQuicErrorNoError = ErrorCode(0x0000)
104 | kQuicErrorProtocolViolation = ErrorCode(0x000A)
105 | )
106 |
--------------------------------------------------------------------------------
/record-layer.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "github.com/bifurcation/mint"
5 | "io"
6 | "sync"
7 | )
8 |
9 | type RecordLayerImpl struct {
10 | sync.Mutex
11 | conn *Connection
12 | epoch mint.Epoch
13 | dir mint.Direction
14 | buffer []byte
15 | }
16 |
17 | func (r *RecordLayerImpl) SetVersion(v uint16) {
18 | // Do nothing
19 | }
20 |
21 | func (r *RecordLayerImpl) SetLabel(s string) {
22 | // Do nothing
23 | }
24 |
25 | func (r *RecordLayerImpl) Rekey(epoch mint.Epoch, factory mint.AeadFactory, keys *mint.KeySet) error {
26 | logf(logTypeTls, "Rekey epoch=%v", epoch)
27 | // TODO(ekr@rtfm.com): Check to see if it's GCM.
28 | aead, err := newWrappedAESGCM(keys.Key, keys.Iv)
29 | if err != nil {
30 | return mint.AlertInternalError
31 | }
32 |
33 | st := cryptoState{
34 | aead: aead,
35 | pne: newPneCipherFactoryAES(keys.Pn),
36 | }
37 |
38 | if r.dir == mint.DirectionRead {
39 | r.conn.encryptionLevels[epoch].recvCipher = &st
40 | } else {
41 | r.conn.encryptionLevels[epoch].sendCipher = &st
42 | }
43 | r.epoch = epoch
44 | return nil
45 | }
46 |
47 | func (r *RecordLayerImpl) ResetClear(seq uint64) {
48 | panic("UNIMPLEMENTED")
49 | }
50 | func (r *RecordLayerImpl) DiscardReadKey(epoch mint.Epoch) {
51 | // Do nothing
52 | }
53 |
54 | func (r *RecordLayerImpl) readBytes() ([]byte, error) {
55 | str := &(r.conn.encryptionLevels[r.epoch].recvCryptoStream.(*recvStream).recvStreamBase)
56 |
57 | b := make([]byte, 16384)
58 | n, err := str.read(b)
59 | logf(logTypeStream, "EKR: n=%d err=%v\n", n, err)
60 | if err == ErrorWouldBlock {
61 | return nil, mint.AlertWouldBlock
62 | }
63 | if err != nil {
64 | return nil, mint.AlertInternalError
65 | }
66 |
67 | return b[:n], nil
68 | }
69 | func (r *RecordLayerImpl) PeekRecordType(block bool) (mint.RecordType, error) {
70 | assert(r.buffer == nil)
71 | var err error
72 | r.buffer, err = r.readBytes()
73 | if err != nil {
74 | return 0, err
75 | }
76 | return mint.RecordTypeHandshake, nil
77 | }
78 |
79 | func (r *RecordLayerImpl) ReadRecord() (*mint.TLSPlaintext, error) {
80 | var b []byte
81 | var err error
82 | if r.buffer != nil {
83 | b = r.buffer
84 | r.buffer = nil
85 | } else {
86 | b, err = r.readBytes()
87 | if err != nil {
88 | return nil, err
89 | }
90 | }
91 | return mint.NewTLSPlaintext(mint.RecordTypeHandshake, r.epoch, b), nil
92 | }
93 |
94 | func (r *RecordLayerImpl) WriteRecord(pt *mint.TLSPlaintext) error {
95 | logf(logTypeTls, "WriteRecord(epoch=%v, len=%v)", r.epoch, len(pt.Fragment()))
96 | _, err := r.conn.encryptionLevels[r.epoch].sendCryptoStream.(*sendStream).write(pt.Fragment(), nil)
97 | return err
98 | }
99 |
100 | func (r *RecordLayerImpl) Epoch() mint.Epoch {
101 | return r.epoch
102 | }
103 |
104 | type RecordLayerFactoryImpl struct {
105 | conn *Connection
106 | }
107 |
108 | func newRecordLayerFactory(conn *Connection) mint.RecordLayerFactory {
109 | return &RecordLayerFactoryImpl{conn: conn}
110 | }
111 |
112 | func (f *RecordLayerFactoryImpl) NewLayer(conn io.ReadWriter, dir mint.Direction) mint.RecordLayer {
113 | return &RecordLayerImpl{
114 | dir: dir,
115 | conn: f.conn,
116 | buffer: nil,
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/aead_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "crypto/cipher"
5 | "testing"
6 | )
7 |
8 | var kTestKey1 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
9 | var kTestIV1 = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
10 | var kTestKey2 = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
11 | var kTestIV2 = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13}
12 |
13 | var ktestAeadHdr1 = []byte{1, 2, 3}
14 | var ktestAeadHdr2 = []byte{1, 2, 4}
15 | var ktestAeadBody1 = []byte{5, 6, 7}
16 | var ktestAeadBody2 = []byte{5, 6, 8}
17 |
18 | var kNonce0 = []byte{0, 0, 0, 0, 0, 0, 0, 0}
19 | var kNonce1 = []byte{0, 0, 0, 0, 0, 0, 0, 1}
20 |
21 | func testAeadSuccess(t *testing.T, aead cipher.AEAD) {
22 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1)
23 |
24 | pt, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr1)
25 | assertNotError(t, err, "Could not unprotect")
26 |
27 | assertByteEquals(t, pt, ktestAeadBody1)
28 | }
29 |
30 | func testAeadWrongPacketNumber(t *testing.T, aead cipher.AEAD) {
31 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1)
32 |
33 | _, err := aead.Open(nil, kNonce1, ct, ktestAeadHdr1)
34 | assertError(t, err, "Shouldn't have unprotected")
35 | }
36 |
37 | func testAeadWrongHeader(t *testing.T, aead cipher.AEAD) {
38 |
39 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1)
40 |
41 | _, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr2)
42 | assertError(t, err, "Shouldn't have unprotected")
43 | }
44 |
45 | func testAeadCorruptCT(t *testing.T, aead cipher.AEAD) {
46 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1)
47 |
48 | ct[0]++
49 | _, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr1)
50 | assertError(t, err, "Shouldn't have unprotected")
51 | }
52 |
53 | func testAeadCorruptTag(t *testing.T, aead cipher.AEAD) {
54 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1)
55 | ct[len(ct)-1]++
56 | _, err := aead.Open(nil, kNonce0, ct, ktestAeadHdr1)
57 | assertError(t, err, "Shouldn't have unprotected")
58 | }
59 |
60 | func testAeadWrongAead(t *testing.T, aead cipher.AEAD, aead2 cipher.AEAD) {
61 | ct := aead.Seal(nil, kNonce0, ktestAeadBody1, ktestAeadHdr1)
62 | _, err := aead2.Open(nil, kNonce0, ct, ktestAeadHdr1)
63 | assertError(t, err, "Shouldn't have unprotected")
64 | }
65 |
66 | func testAeadAll(t *testing.T, aead cipher.AEAD) {
67 | t.Run("Success", func(t *testing.T) { testAeadSuccess(t, aead) })
68 | t.Run("WrongHeader", func(t *testing.T) { testAeadWrongHeader(t, aead) })
69 | t.Run("CorruptCT", func(t *testing.T) { testAeadCorruptCT(t, aead) })
70 | t.Run("CorruptTag", func(t *testing.T) { testAeadCorruptTag(t, aead) })
71 | }
72 |
73 | func makeWrappedAead(t *testing.T, key []byte, iv []byte) cipher.AEAD {
74 | a, err := newWrappedAESGCM(key, iv)
75 | assertNotError(t, err, "Couldn't make AEAD")
76 | return a
77 | }
78 |
79 | func TestAeadAES128GCM(t *testing.T) {
80 | a1 := makeWrappedAead(t, kTestKey1, kTestIV1)
81 | a2 := makeWrappedAead(t, kTestKey2, kTestIV1)
82 | a3 := makeWrappedAead(t, kTestKey1, kTestIV2)
83 |
84 | testAeadAll(t, a1)
85 | t.Run("WrongKey", func(t *testing.T) { testAeadWrongAead(t, a1, a2) })
86 | t.Run("WrongIV", func(t *testing.T) { testAeadWrongAead(t, a1, a3) })
87 | t.Run("WrongPacketNumber", func(t *testing.T) { testAeadWrongPacketNumber(t, a1) })
88 | }
89 |
--------------------------------------------------------------------------------
/server.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "net"
5 | )
6 |
7 | // TransportFactory makes transports bound to a specific remote
8 | // address.
9 | type TransportFactory interface {
10 | // Make a transport object bound to |remote|.
11 | MakeTransport(remote *net.UDPAddr) (Transport, error)
12 | }
13 |
14 | // Server represents a QUIC server. A server can be fed an arbitrary
15 | // number of packets and will create Connections as needed, passing
16 | // each packet to the right connection.
17 | type Server struct {
18 | handler ServerHandler
19 | transFactory TransportFactory
20 | tls *TlsConfig
21 | addrTable map[string]*Connection
22 | idTable map[string]*Connection
23 | }
24 |
25 | // Interface for the handler object which the Server will call
26 | // to notify of events.
27 | type ServerHandler interface {
28 | // A new connection has been created and can be found in |c|.
29 | NewConnection(c *Connection)
30 | }
31 |
32 | // SetHandler sets a handler function.
33 | func (s *Server) SetHandler(h ServerHandler) {
34 | s.handler = h
35 | }
36 |
37 | // Input passes an incoming packet to the Server.
38 | func (s *Server) Input(addr *net.UDPAddr, data []byte) (*Connection, error) {
39 | logf(logTypeServer, "Received packet from %v", addr)
40 | hdr := packetHeader{shortCidLength: kCidDefaultLength}
41 | newConn := false
42 |
43 | _, err := decode(&hdr, data)
44 | if err != nil {
45 | return nil, err
46 | }
47 |
48 | var conn *Connection
49 |
50 | if len(hdr.DestinationConnectionID) > 0 {
51 | logf(logTypeServer, "Received conn id %v", hdr.DestinationConnectionID)
52 | conn = s.idTable[hdr.DestinationConnectionID.String()]
53 | if conn != nil {
54 | logf(logTypeServer, "Found by conn id")
55 | }
56 | }
57 |
58 | if conn == nil {
59 | conn = s.addrTable[addr.String()]
60 | }
61 |
62 | if conn == nil {
63 | logf(logTypeServer, "New server connection from addr %v", addr)
64 | trans, err := s.transFactory.MakeTransport(addr)
65 | if err != nil {
66 | return nil, err
67 | }
68 | conn = NewConnection(trans, RoleServer, s.tls, nil)
69 | newConn = true
70 | }
71 |
72 | err = conn.Input(data)
73 | if isFatalError(err) {
74 | logf(logTypeServer, "Fatal Error %v killing connection %v", err, conn)
75 | return nil, nil
76 | }
77 |
78 | if newConn {
79 | // Wait until handling the first packet before the connection is added
80 | // to the table. Firstly, to avoid having to remove it if there is an
81 | // error, but also because the server-chosen connection ID isn't set
82 | // until after the Initial is handled.
83 | s.idTable[conn.serverConnectionId.String()] = conn
84 | s.addrTable[addr.String()] = conn
85 | if s.handler != nil {
86 | s.handler.NewConnection(conn)
87 | }
88 | }
89 |
90 | return conn, nil
91 | }
92 |
93 | // Check the server timers.
94 | func (s *Server) CheckTimer() error {
95 | for _, conn := range s.idTable {
96 | _, err := conn.CheckTimer()
97 | if isFatalError(err) {
98 | logf(logTypeServer, "Fatal Error %v killing connection %v", err, conn)
99 | delete(s.idTable, conn.serverConnectionId.String())
100 | // TODO(ekr@rtfm.com): Delete this from the addr table.
101 | }
102 | }
103 | return nil
104 | }
105 |
106 | // How many connections do we have?
107 | func (s *Server) ConnectionCount() int {
108 | return len(s.idTable)
109 | }
110 |
111 | // Create a new QUIC server with the provide TLS config.
112 | func NewServer(factory TransportFactory, tls *TlsConfig, handler ServerHandler) *Server {
113 | s := Server{
114 | handler,
115 | factory,
116 | tls,
117 | make(map[string]*Connection),
118 | make(map[string]*Connection),
119 | }
120 | s.tls.init()
121 | return &s
122 | }
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WARNING
2 |
3 | **This implementation is not under active development, and has diverged from the QUIC specification.**
4 |
5 | The QUIC WG maintains [a list of active implementations](https://github.com/quicwg/base-drafts/wiki/Implementations).
6 |
7 | 
8 |
9 | -----
10 |
11 | minq -- A minimal QUIC stack
12 | ============================
13 | Minq is a minimal implementation of QUIC, as documented at
14 | https://quicwg.github.io/. Minq partly implements draft-05
15 | (it advertises -04 but it's actually more like the editor's copy)
16 | with TLS 1.3 draft-20 or draft-21.
17 |
18 | Currently it will do:
19 |
20 | - A 1-RTT handshake (with self-generated and unverified certificates)
21 | - Some ACK processing
22 | - Primitive retransmission (manual, no timers)
23 | - 1-RTT application data
24 | - Exchange of stream close (though this doesn't really have much impact)
25 |
26 | Important missing pieces for the first implementation draft include:
27 |
28 | - Handling ACK ranges
29 | - Real timeout and retransmission support
30 |
31 | Other defects include:
32 |
33 | - Doesn't properly clean up state, so things will just grow without bound
34 | - TLS configuration and verification
35 | - A huge other pile of unknown and known defects.
36 |
37 |
38 | ## WARNING
39 |
40 | Minq is absolutely not suitable for any kind of production use and should
41 | only be used for testing. In particular, it explicitly doesn't validate
42 | certificates.
43 |
44 |
45 |
46 | ## Quick Start (untested but should be rightish)
47 |
48 | cd ${GOPATH}/src
49 | go get github.com/ekr/minq
50 | cd github.com/bifurcation/mint
51 | git remote add ekr https://github.com/ekr/mint
52 | git fetch ekr
53 | git checkout ekr/quic_record_layer
54 | cd ../../ekr/minq
55 | go test
56 |
57 | This should produce something like this:
58 |
59 | Result = 010002616263
60 | Result2 = 010002616263
61 | Result = 0102616263
62 | Result2 = 0102616263
63 | {1 2 [97 98 99]}
64 | {1 1 [8 16]}
65 | {3 2 [8 16 24 32]}
66 | Checking client state
67 | Checking server state
68 | Encoded frame ab00deadbeef0000000000000001
69 | Encoded frame bb0100deadbeef00000000000000010e00000001
70 | Result = 820123456789abcdefdeadbeefff000001
71 | Result2 = 820123456789abcdefdeadbeefff000001
72 | PASS
73 | ok github.com/ekr/minq 1.285s
74 |
75 | It's the "ok" at the end that's important.
76 |
77 | There are two test programs that live in ```minq/bin/client``` and
78 | ```minq/bin/server```. The server is an echo server that upcases the
79 | returned data. The client is just a passthrough.
80 |
81 | In ```${GOPATH}/src/github.com/ekr```, doing
82 |
83 | go run minq/bin/server/main.go
84 | go run minq/bin/client/main.go
85 |
86 | In separate windows should have the desired result.
87 |
88 |
89 | ## Logging
90 |
91 | To enable logging, set the ```MINQ_LOG``` environment variable, as
92 | in ```MINQ_LOG=connection go test```. Valid values are:
93 |
94 | // Pre-defined log types
95 | const (
96 | logTypeAead = "aead"
97 | logTypeCodec = "codec"
98 | logTypeConnBuffer = "connbuffer"
99 | logTypeConnection = "connection"
100 | logTypeAck = "ack"
101 | logTypeFrame = "frame"
102 | logTypeHandshake = "handshake"
103 | logTypeTls = "tls"
104 | logTypeTrace = "trace"
105 | logTypeServer = "server"
106 | logTypeUdp = "udp"
107 | )
108 |
109 | Multiple log levels can be separated by commas.
110 |
111 | ## Mint
112 |
113 | Minq depends on Mint (https://www.github.com/bifurcation/mint) for TLS.
114 | Right now we are on the following branch:
115 |
116 | https://github.com/ekr/mint/tree/quic_record_layer
117 |
118 | This branch is more experimental than usual.
119 |
120 |
--------------------------------------------------------------------------------
/tls.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "crypto"
5 | "crypto/x509"
6 | "fmt"
7 | "log"
8 |
9 | "github.com/bifurcation/mint"
10 | )
11 |
12 | type TlsConfig struct {
13 | ServerName string
14 | CertificateChain []*x509.Certificate
15 | Key crypto.Signer
16 | mintConfig *mint.Config
17 | ForceHrr bool
18 | }
19 |
20 | func (c *TlsConfig) init() {
21 | _ = c.toMint()
22 | }
23 |
24 | func (c *TlsConfig) toMint() *mint.Config {
25 | if c.mintConfig == nil {
26 | // TODO(ekr@rtfm.com): Provide a real config
27 | config := mint.Config{
28 | ServerName: c.ServerName,
29 | NonBlocking: true,
30 | NextProtos: []string{kQuicALPNToken},
31 | SendSessionTickets: true,
32 | AllowEarlyData: true,
33 | }
34 |
35 | if c.ForceHrr {
36 | config.RequireCookie = true
37 | }
38 |
39 | config.CookieProtector, _ = mint.NewDefaultCookieProtector()
40 | config.InsecureSkipVerify = true // TODO(ekr@rtfm.com): This is horribly insecure, but Minq is right now for testing
41 |
42 | if c.CertificateChain != nil && c.Key != nil {
43 | config.Certificates =
44 | []*mint.Certificate{
45 | &mint.Certificate{
46 | Chain: c.CertificateChain,
47 | PrivateKey: c.Key,
48 | },
49 | }
50 | } else {
51 | priv, cert, err := mint.MakeNewSelfSignedCert(c.ServerName, mint.ECDSA_P256_SHA256)
52 | if err != nil {
53 | log.Fatalf("Couldn't make self-signed cert %v", err)
54 | }
55 | config.Certificates = []*mint.Certificate{
56 | {
57 | Chain: []*x509.Certificate{cert},
58 | PrivateKey: priv,
59 | },
60 | }
61 | }
62 | config.Init(false)
63 | c.mintConfig = &config
64 | }
65 | return c.mintConfig.Clone()
66 | }
67 |
68 | func NewTlsConfig(serverName string) TlsConfig {
69 | return TlsConfig{
70 | ServerName: serverName,
71 | }
72 | }
73 |
74 | type tlsConn struct {
75 | config *TlsConfig
76 | conn *Connection
77 | mintConfig *mint.Config
78 | tls *mint.Conn
79 | finished bool
80 | cs *mint.CipherSuiteParams
81 | }
82 |
83 | func newTlsConn(conn *Connection, conf *TlsConfig, role Role) *tlsConn {
84 | isClient := true
85 | if role == RoleServer {
86 | isClient = false
87 | }
88 |
89 | mc := conf.toMint()
90 | mc.RecordLayer = newRecordLayerFactory(conn)
91 | return &tlsConn{
92 | conf,
93 | conn,
94 | mc,
95 | mint.NewConn(nil, mc, isClient),
96 | false,
97 | nil,
98 | }
99 | }
100 |
101 | func (c *tlsConn) setTransportParametersHandler(h *transportParametersHandler) {
102 | c.mintConfig.ExtensionHandler = h
103 | }
104 |
105 | func (c *tlsConn) handshake() error {
106 | outer:
107 | for {
108 | alert := c.tls.Handshake()
109 | hst := c.tls.GetHsState()
110 | switch alert {
111 | case mint.AlertNoAlert, mint.AlertStatelessRetry:
112 | if hst == mint.StateServerConnected || hst == mint.StateClientConnected {
113 | st := c.tls.ConnectionState()
114 |
115 | logf(logTypeTls, "TLS handshake complete")
116 | logf(logTypeTls, "Negotiated ALPN = %v", st.NextProto)
117 | // TODO(ekr@rtfm.com): Abort on ALPN mismatch when others do.
118 | if st.NextProto != kQuicALPNToken {
119 | logf(logTypeTls, "ALPN mismatch %v != %v", st.NextProto, kQuicALPNToken)
120 | }
121 | cs := st.CipherSuite
122 | c.cs = &cs
123 | c.finished = true
124 |
125 | break outer
126 | }
127 | // Loop
128 | case mint.AlertWouldBlock:
129 | logf(logTypeTls, "TLS would have blocked")
130 | break outer
131 | default:
132 | return fmt.Errorf("TLS sent an alert %v", alert)
133 | }
134 | }
135 | return nil
136 | }
137 |
138 | func (c *tlsConn) postHandshake() error {
139 | b := make([]byte, 1)
140 |
141 | n, err := c.tls.Read(b)
142 | assert(n == 0) // This can't happen
143 | if err == nil || err == mint.AlertWouldBlock {
144 | return nil
145 | }
146 | return ErrorProtocolViolation
147 | }
148 |
149 | func (c *tlsConn) getHsState() string {
150 | return c.tls.GetHsState().String()
151 | }
152 |
--------------------------------------------------------------------------------
/server_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "net"
5 | "testing"
6 | "time"
7 | )
8 |
9 | // fake TransportFactory that comes populated with
10 | // a set of pre-fab transports keyed by name.
11 | type testTransportFactory struct {
12 | transports map[string]*testTransport
13 | }
14 |
15 | func (f *testTransportFactory) MakeTransport(remote *net.UDPAddr) (Transport, error) {
16 | return f.transports[remote.String()], nil
17 | }
18 |
19 | func (f *testTransportFactory) addTransport(remote *net.UDPAddr, t *testTransport) {
20 | f.transports[remote.String()] = t
21 | }
22 |
23 | func serverInputAll(t *testing.T, trans *testTransport, s *Server, u net.UDPAddr) (*Connection, error) {
24 | var clast *Connection
25 |
26 | for {
27 | p, err := trans.Recv()
28 | if err != nil && err != ErrorWouldBlock {
29 | return nil, err
30 | }
31 |
32 | if p == nil {
33 | return clast, nil
34 | }
35 |
36 | c, err := s.Input(&u, p)
37 | if err != nil {
38 | return nil, err
39 | }
40 |
41 | if clast == nil {
42 | clast = c
43 | }
44 | assertEquals(t, c, clast)
45 | }
46 | }
47 |
48 | func TestServer(t *testing.T) {
49 | // Have the client and server do a handshake.
50 | u, _ := net.ResolveUDPAddr("udp", "127.0.0.1:4443") // Just a fixed address
51 |
52 | cTrans, sTrans := newTestTransportPair(true)
53 | factory := &testTransportFactory{make(map[string]*testTransport)}
54 | factory.addTransport(u, sTrans)
55 |
56 | server := NewServer(factory, testTlsConfig(), nil)
57 | assertNotNil(t, server, "Couldn't make server")
58 |
59 | client := NewConnection(cTrans, RoleClient, testTlsConfig(), nil)
60 | assertNotNil(t, client, "Couldn't make client")
61 |
62 | n, err := client.CheckTimer()
63 | assertEquals(t, 1, n)
64 | assertNotError(t, err, "Couldn't send client initial")
65 |
66 | s1, err := serverInputAll(t, sTrans, server, *u)
67 | assertNotError(t, err, "Couldn't consume client initial")
68 |
69 | err = inputAll(client)
70 | assertNotError(t, err, "Error processing SH")
71 |
72 | s2, err := serverInputAll(t, sTrans, server, *u)
73 | assertNotError(t, err, "Error processing CFIN")
74 | // Make sure we get the same server back.
75 | assertEquals(t, s1, s2)
76 |
77 | // Now make a new client and ensure we get a different server connection
78 | u2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:4444") // Just a fixed address
79 | cTrans2, sTrans2 := newTestTransportPair(true)
80 | factory.addTransport(u2, sTrans2)
81 | client = NewConnection(cTrans2, RoleClient, testTlsConfig(), nil)
82 | assertNotNil(t, client, "Couldn't make client")
83 |
84 | n, err = client.CheckTimer()
85 | assertEquals(t, 1, n)
86 | assertNotError(t, err, "Couldn't send client initial")
87 |
88 | s3, err := serverInputAll(t, sTrans2, server, *u2)
89 | assertNotError(t, err, "Couldn't consume client initial")
90 |
91 | assertX(t, s1 != s3, "Got the same server connection back with a different address")
92 | assertEquals(t, 2, len(server.addrTable))
93 | }
94 |
95 | func TestServerIdleTimeout(t *testing.T) {
96 | // Have the client and server do a handshake.
97 | u, _ := net.ResolveUDPAddr("udp", "127.0.0.1:4443") // Just a fixed address
98 |
99 | cTrans, sTrans := newTestTransportPair(true)
100 | factory := &testTransportFactory{make(map[string]*testTransport)}
101 | factory.addTransport(u, sTrans)
102 |
103 | server := NewServer(factory, testTlsConfig(), nil)
104 | assertNotNil(t, server, "Couldn't make server")
105 |
106 | client := NewConnection(cTrans, RoleClient, testTlsConfig(), nil)
107 | assertNotNil(t, client, "Couldn't make client")
108 |
109 | n, err := client.CheckTimer()
110 | assertEquals(t, 1, n)
111 | assertNotError(t, err, "Couldn't send client initial")
112 |
113 | sconn, err := serverInputAll(t, sTrans, server, *u)
114 | assertNotError(t, err, "Couldn't consume client initial")
115 | assertNotNil(t, sconn, "no server connection")
116 |
117 | assertEquals(t, 1, server.ConnectionCount())
118 |
119 | // This pokes into internal state of the server to avoid having to include
120 | // sleep calls in tests. Don't do this at home kids.
121 | // Wind the timer on the connection back to short-circuit the idle timeout.
122 | sconn.lastInput = sconn.lastInput.Add(-1 - sconn.idleTimeout)
123 | server.CheckTimer()
124 | // A second nap to allow for draining period.
125 | sconn.closingEnd = sconn.closingEnd.Add(-1 - time.Second)
126 | server.CheckTimer()
127 |
128 | assertEquals(t, 0, server.ConnectionCount())
129 | }
130 |
--------------------------------------------------------------------------------
/tracking.go:
--------------------------------------------------------------------------------
1 | // Internal structure indicating packets we have
2 | // received
3 | package minq
4 |
5 | import (
6 | "fmt"
7 | "github.com/bifurcation/mint"
8 | "time"
9 | )
10 |
11 | type packetData struct {
12 | protected bool
13 | nonAcks bool
14 | pn uint64
15 | t time.Time
16 | acked2 bool
17 | }
18 |
19 | type recvdPackets struct {
20 | log loggingFunction
21 | initted bool
22 | minReceived uint64
23 | maxReceived uint64
24 | minNotAcked2 uint64
25 | packets map[uint64]*packetData
26 | unacked bool // Are there packets we haven't generated an ACK for
27 | }
28 |
29 | func newRecvdPackets(log loggingFunction) *recvdPackets {
30 | return &recvdPackets{
31 | log, // loggingFunction
32 | false, // initted
33 | 0, // minReceived
34 | 0, // maxReceived
35 | 0, // minNotAcked2
36 | make(map[uint64]*packetData, 0), // packets
37 | false, // unacked
38 | }
39 | }
40 |
41 | func (p *recvdPackets) initialized() bool {
42 | return p.initted
43 | }
44 |
45 | func (p *recvdPackets) init(pn uint64) {
46 | p.log(logTypeAck, "Initializing received packet start=%x", pn)
47 | p.initted = true
48 | p.minReceived = pn
49 | p.maxReceived = pn
50 | p.minNotAcked2 = pn
51 | }
52 |
53 | func (p *recvdPackets) packetNotReceived(pn uint64) bool {
54 | if pn < p.minReceived {
55 | return false
56 | }
57 | _, found := p.packets[pn]
58 | return !found
59 | }
60 |
61 | func (p *recvdPackets) packetSetReceived(pn uint64, protected bool, nonAcks bool) {
62 | p.log(logTypeAck, "Setting packet received=%x", pn)
63 | if pn > p.maxReceived {
64 | p.maxReceived = pn
65 | }
66 | if pn < p.minNotAcked2 {
67 | p.minNotAcked2 = pn
68 | }
69 | p.log(logTypeAck, "Setting packet received=%x", pn)
70 | p.packets[pn] = &packetData{
71 | protected,
72 | nonAcks,
73 | pn,
74 | time.Now(),
75 | false,
76 | }
77 | p.unacked = true
78 | }
79 |
80 | func (p *recvdPackets) packetSetAcked2(pn uint64) {
81 | p.log(logTypeAck, "Setting packet acked2=%v", pn)
82 | if pn >= p.minNotAcked2 {
83 | pk, ok := p.packets[pn]
84 | if ok {
85 | pk.acked2 = true
86 | }
87 | }
88 | }
89 |
90 | func (r *ackRange) String() string {
91 | return fmt.Sprintf("%x(%d)", r.lastPacket, r.count)
92 | }
93 |
94 | func (r *ackRanges) String() string {
95 | rsp := ""
96 | for _, s := range *r {
97 | if rsp != "" {
98 | rsp += ", "
99 | }
100 | rsp += s.String()
101 | }
102 | return rsp
103 | }
104 |
105 | func (p *recvdPackets) needToAck() bool {
106 | return p.unacked
107 | }
108 |
109 | // Prepare a list of the ACK ranges, starting at the highest
110 | func (p *recvdPackets) prepareAckRange(epoch mint.Epoch, allowAckOnly bool) ackRanges {
111 | p.log(logTypeAck, "Prepare ACK range epoch=%d", epoch)
112 | // Don't ACK if there's nothing new to ACK
113 | if !p.unacked {
114 | p.log(logTypeAck, "Nothing new to ACK")
115 | return nil
116 | }
117 |
118 | var last uint64
119 | var pn uint64
120 | inrange := false
121 | nonAcks := false
122 |
123 | ranges := make(ackRanges, 0)
124 |
125 | newMinNotAcked2 := p.maxReceived
126 |
127 | // TODO(ekr@rtfm.com): This is kind of a gross hack in case
128 | // someone sends us a 0 initial packet number.
129 | for pn = p.maxReceived; pn >= p.minNotAcked2 && pn > 0; pn-- {
130 | p.log(logTypeTrace, "Examining packet %x", pn)
131 | pk, ok := p.packets[pn]
132 | needs_ack := false
133 |
134 | // If we don't know about the packet, or if the ack has been
135 | // acked, we don't need to ack it.
136 | if ok && !pk.acked2 {
137 | needs_ack = true
138 | newMinNotAcked2 = pn
139 | }
140 |
141 | if ok && pk.acked2 {
142 | delete(p.packets, pn)
143 | }
144 |
145 | if needs_ack {
146 | p.log(logTypeTrace, "Acking packet %x", pn)
147 | }
148 | if needs_ack && pk.nonAcks {
149 | // Note if this is an ack of anything other than
150 | // acks.
151 | p.log(logTypeTrace, "Packet %x contains non-acks", pn)
152 | nonAcks = true
153 | }
154 |
155 | if inrange != needs_ack {
156 | if inrange {
157 | // This is the end of a range.
158 | ranges = append(ranges, ackRange{last, last - pn})
159 | } else {
160 | last = pn
161 | }
162 | inrange = needs_ack
163 | }
164 | }
165 | if inrange {
166 | p.log(logTypeTrace, "Appending final range %x-%x", last, pn+1)
167 | ranges = append(ranges, ackRange{last, last - pn})
168 | }
169 |
170 | p.minNotAcked2 = newMinNotAcked2
171 |
172 | p.log(logTypeAck, "%v ACK ranges to send", len(ranges))
173 | for i, r := range ranges {
174 | p.log(logTypeAck, " %d = %v", i, r.String())
175 | }
176 |
177 | if !allowAckOnly && !nonAcks {
178 | p.log(logTypeAck, "No non-ack packets and this ack is not ack-only capable")
179 | return nil
180 | }
181 |
182 | p.unacked = false
183 | return ranges
184 | }
185 |
--------------------------------------------------------------------------------
/stream_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "runtime"
7 | "testing"
8 | )
9 |
10 | type testStreamFixture struct {
11 | t *testing.T
12 | name string
13 | log loggingFunction
14 | r *recvStreamBase
15 | w *sendStreamBase
16 | b []byte
17 | }
18 |
19 | func (f *testStreamFixture) read() {
20 | assertX(f.t, f.r.readable, "stream should be readable")
21 | f.b = make([]byte, 1024)
22 | n, err := f.r.read(f.b)
23 | assertNotError(f.t, err, "Should be able to read bytes")
24 | f.b = f.b[:n]
25 | assertX(f.t, f.r.clearReadable(), "should have been readable")
26 | }
27 |
28 | func (f *testStreamFixture) readExpectError(exerr error) {
29 | f.b = make([]byte, 1024)
30 | n, err := f.r.read(f.b)
31 | assertError(f.t, err, "Should not be able to read bytes")
32 | assertEquals(f.t, exerr, err)
33 | assertEquals(f.t, 0, n)
34 | }
35 |
36 | var kTestString1 = []byte("abcdef")
37 | var kTestString2 = []byte("ghijkl")
38 |
39 | func newTestStreamFixture(t *testing.T) *testStreamFixture {
40 | pc, _, _, ok := runtime.Caller(1)
41 | name := "unknown"
42 | if ok {
43 | name = runtime.FuncForPC(pc).Name()
44 | }
45 | log := func(tag string, format string, args ...interface{}) {
46 | fullFormat := fmt.Sprintf("%s: %s", name, format)
47 | logf(tag, fullFormat, args...)
48 | }
49 |
50 | fc := flowControl{false, 2048, 0}
51 | return &testStreamFixture{
52 | t: t,
53 | name: name,
54 | log: log,
55 | r: &recvStreamBase{streamCommon: streamCommon{log: log, fc: fc}},
56 | w: &sendStreamBase{streamCommon: streamCommon{log: log, fc: fc}},
57 | b: nil,
58 | }
59 | }
60 |
61 | func TestStreamInputOneChunk(t *testing.T) {
62 | f := newTestStreamFixture(t)
63 | err := f.r.newFrameData(0, false, kTestString1, &flowControl{false, 2048, 0})
64 | assertNotError(t, err, "Data should be accepted")
65 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1)))
66 | assertEquals(t, RecvStreamStateRecv, f.r.state)
67 | f.read()
68 | assertByteEquals(t, f.b, kTestString1)
69 | }
70 |
71 | func TestStreamInputTwoChunks(t *testing.T) {
72 | f := newTestStreamFixture(t)
73 | err := f.r.newFrameData(0, false, kTestString1, &flowControl{false, 2048, 0})
74 | assertNotError(t, err, "Data should be accepted")
75 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1)))
76 | f.read()
77 | assertByteEquals(t, f.b, kTestString1)
78 | err = f.r.newFrameData(uint64(len(kTestString1)), false, kTestString2, &flowControl{false, 2048, 0})
79 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1)+len(kTestString2)))
80 | f.read()
81 | assertByteEquals(t, f.b, kTestString2)
82 | }
83 |
84 | func TestStreamInputCoalesceChunks(t *testing.T) {
85 | f := newTestStreamFixture(t)
86 | err := f.r.newFrameData(0, false, kTestString1[:2], &flowControl{false, 2048, 0})
87 | assertNotError(t, err, "data should be accepted")
88 | err = f.r.newFrameData(2, false, kTestString1[2:], &flowControl{false, 2048, 0})
89 | assertNotError(t, err, "data should be accepted")
90 | f.read()
91 | assertByteEquals(t, f.b, kTestString1)
92 | }
93 |
94 | func TestStreamInputChunksOverlap(t *testing.T) {
95 | f := newTestStreamFixture(t)
96 | err := f.r.newFrameData(0, false, kTestString1[:2], &flowControl{false, 2048, 0})
97 | assertNotError(t, err, "data should be accepted")
98 | err = f.r.newFrameData(0, false, kTestString1, &flowControl{false, 2048, 0})
99 | assertNotError(t, err, "data should be accepted")
100 | f.read()
101 | assertByteEquals(t, f.b, kTestString1)
102 | }
103 |
104 | func TestStreamInputTwoChunksWrongOrder(t *testing.T) {
105 | f := newTestStreamFixture(t)
106 | err := f.r.newFrameData(2, false, kTestString1[2:], &flowControl{false, 2048, 0})
107 | assertNotError(t, err, "data should be accepted")
108 | assertX(t, !f.r.readable, "Stream not should be readable")
109 | assertEquals(t, f.r.fc.used, uint64(len(kTestString1)))
110 | f.readExpectError(ErrorWouldBlock)
111 | err = f.r.newFrameData(0, false, kTestString1[:2], &flowControl{false, 2048, 0})
112 | assertNotError(t, err, "data should be accepted")
113 | f.read()
114 | assertByteEquals(t, f.b, kTestString1)
115 | }
116 |
117 | func TestStreamInputChunk1FinChunk2(t *testing.T) {
118 | f := newTestStreamFixture(t)
119 | err := f.r.newFrameData(0, true, kTestString1, &flowControl{false, 2048, 0})
120 | assertNotError(t, err, "data should be accepted")
121 | assertEquals(t, RecvStreamStateSizeKnown, f.r.state)
122 | f.read()
123 | assertByteEquals(t, f.b, kTestString1)
124 | assertEquals(t, RecvStreamStateDataRead, f.r.state)
125 | err = f.r.newFrameData(uint64(len(kTestString1)), false, kTestString2, &flowControl{false, 2048, 0})
126 | assertEquals(t, err, ErrorFlowControlError)
127 | assertX(t, !f.r.readable, "Stream not be readable")
128 | f.readExpectError(io.EOF)
129 | }
130 |
131 | func TestStreamInputShortFinChunkAfterFin(t *testing.T) {
132 | f := newTestStreamFixture(t)
133 | err := f.r.newFrameData(0, true, kTestString1, &flowControl{false, 2048, 0})
134 | assertNotError(t, err, "data should be accepted")
135 | assertEquals(t, RecvStreamStateSizeKnown, f.r.state)
136 | f.read()
137 | err = f.r.newFrameData(0, true, kTestString1[:2], &flowControl{false, 2048, 0})
138 | assertNotError(t, err, "overlapping data can be discarded")
139 | }
140 |
141 | func TestStreamReadReset(t *testing.T) {
142 | f := newTestStreamFixture(t)
143 | err := f.r.handleReset(10)
144 | assertNotError(t, err, "should accept the reset")
145 | assertEquals(t, RecvStreamStateResetRecvd, f.r.state)
146 | }
147 |
148 | func TestStreamWriteClose(t *testing.T) {
149 | f := newTestStreamFixture(t)
150 | f.w.close()
151 | assertEquals(t, SendStreamStateCloseQueued, f.w.state)
152 | }
153 |
154 | func TestStreamIncreaseFlowControl(t *testing.T) {
155 | f := newTestStreamFixture(t)
156 | f.w.processMaxStreamData(2050)
157 | f.w.processMaxStreamData(2000)
158 | assertEquals(t, uint64(2050), f.w.fc.max)
159 | }
160 |
161 | func countChunkLens(chunks []streamChunk) int {
162 | ct := 0
163 | for _, ch := range chunks {
164 | ct += len(ch.data)
165 | }
166 | return ct
167 | }
168 |
169 | func TestStreamBlockRelease(t *testing.T) {
170 | f := newTestStreamFixture(t)
171 | b := make([]byte, 5000)
172 | connFc := &flowControl{false, uint64(len(b)), 0}
173 | n, err := f.w.write(b, connFc)
174 | assertEquals(t, nil, err)
175 | chunks := f.w.outputWritable()
176 | assertEquals(t, 2048, countChunkLens(chunks))
177 | assertEquals(t, 2048, n)
178 | assertEquals(t, uint64(2048), connFc.used)
179 | // Calling output writable again returns 0 chunks
180 | chunks = f.w.outputWritable()
181 | assertEquals(t, 0, countChunkLens(chunks))
182 |
183 | // Writing again blocks
184 | _, err = f.w.write(b[n:], connFc)
185 | assertEquals(t, ErrorWouldBlock, err)
186 |
187 | // Increasing the limit should let us write.
188 | f.w.processMaxStreamData(8192)
189 | n, err = f.w.write(b[n:], connFc)
190 | assertNotError(t, err, "Writing works")
191 | assertEquals(t, 2952, n)
192 | assertEquals(t, connFc.max, connFc.used)
193 | chunks = f.w.outputWritable()
194 | assertEquals(t, 2952, countChunkLens(chunks))
195 | }
196 |
--------------------------------------------------------------------------------
/bin/client/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "github.com/ekr/minq"
7 | "log"
8 | "net"
9 | "os"
10 | "runtime/pprof"
11 | "time"
12 | )
13 |
14 | var addr string
15 | var serverName string
16 | var doHttp string
17 | var httpCount int
18 | var heartbeat int
19 | var cpuProfile string
20 | var resume bool
21 | var httpLeft int
22 | var zeroRtt bool
23 |
24 | type connHandler struct {
25 | bytesRead int
26 | }
27 |
28 | func (h *connHandler) StateChanged(s minq.State) {
29 | log.Println("State changed to ", s)
30 | }
31 |
32 | func (h *connHandler) NewStream(s minq.Stream) {
33 | }
34 |
35 | func (h *connHandler) NewRecvStream(s minq.RecvStream) {
36 | }
37 |
38 | func (h *connHandler) StreamReadable(s minq.RecvStream) {
39 | for {
40 | b := make([]byte, 1024)
41 |
42 | n, err := s.Read(b)
43 | switch err {
44 | case nil:
45 | break
46 | case minq.ErrorWouldBlock:
47 | return
48 | case minq.ErrorStreamIsClosed, minq.ErrorConnIsClosed:
49 | log.Println("")
50 | httpLeft--
51 | return
52 | default:
53 | log.Println("Error: ", err)
54 | httpLeft--
55 | return
56 | }
57 | b = b[:n]
58 | h.bytesRead += n
59 | os.Stdout.Write(b)
60 | os.Stderr.Write([]byte(fmt.Sprintf("Total bytes read = %d\n", h.bytesRead)))
61 | }
62 | }
63 |
64 | func readUDP(s *net.UDPConn) ([]byte, error) {
65 | b := make([]byte, 8192)
66 |
67 | s.SetReadDeadline(time.Now().Add(time.Second))
68 | n, _, err := s.ReadFromUDP(b)
69 | if err != nil {
70 | e, o := err.(net.Error)
71 | if o && e.Timeout() {
72 | return nil, minq.ErrorWouldBlock
73 | }
74 | log.Println("Error reading from UDP socket: ", err)
75 | return nil, err
76 | }
77 |
78 | if n == len(b) {
79 | log.Println("Underread from UDP socket")
80 | return nil, err
81 | }
82 | b = b[:n]
83 | return b, nil
84 | }
85 |
86 | func makeConnection(config *minq.TlsConfig, uaddr *net.UDPAddr) (*net.UDPConn, *minq.Connection) {
87 | usock, err := net.ListenUDP("udp", nil)
88 | if err != nil {
89 | log.Println("Couldn't create connected UDP socket")
90 | return nil, nil
91 | }
92 |
93 | utrans := minq.NewUdpTransport(usock, uaddr)
94 |
95 | conn := minq.NewConnection(utrans, minq.RoleClient,
96 | config, &connHandler{})
97 |
98 | log.Printf("Client conn id=%v\n", conn.ClientId())
99 |
100 | // Start things off.
101 | _, err = conn.CheckTimer()
102 |
103 | return usock, conn
104 | }
105 |
106 | func completeConnection(usock *net.UDPConn, conn *minq.Connection) error {
107 | for conn.GetState() != minq.StateEstablished {
108 | b, err := readUDP(usock)
109 | if err != nil {
110 | if err == minq.ErrorWouldBlock {
111 | _, err = conn.CheckTimer()
112 | if err != nil {
113 | return err
114 | }
115 | continue
116 | }
117 | return err
118 | }
119 |
120 | err = conn.Input(b)
121 | if err != nil {
122 | log.Println("Error", err)
123 | return err
124 | }
125 | }
126 |
127 | log.Printf("Connection established server CID = %v\n", conn.ServerId())
128 | return nil
129 | }
130 |
131 | func main() {
132 | log.Println("PID=", os.Getpid())
133 | flag.StringVar(&addr, "addr", "localhost:4433", "[host:port]")
134 | flag.StringVar(&serverName, "server-name", "", "SNI")
135 | flag.StringVar(&doHttp, "http", "", "Do HTTP/0.9 with provided URL")
136 | flag.IntVar(&httpCount, "httpCount", 1, "Number of parallel HTTP requests to start")
137 | flag.IntVar(&heartbeat, "heartbeat", 0, "heartbeat frequency [ms]")
138 | flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to file")
139 | flag.BoolVar(&resume, "resume", false, "Test resumption")
140 | flag.BoolVar(&zeroRtt, "zerortt", false, "Test 0-RTT")
141 | flag.Parse()
142 |
143 | if zeroRtt {
144 | resume = true
145 | if doHttp == "" {
146 | log.Printf("Need HTTP to do 0-RTT")
147 | return
148 | }
149 | }
150 | if cpuProfile != "" {
151 | f, err := os.Create(cpuProfile)
152 | if err != nil {
153 | log.Printf("Could not create CPU profile file %v err=%v\n", cpuProfile, err)
154 | return
155 | }
156 | pprof.StartCPUProfile(f)
157 | log.Println("CPU profiler started")
158 | defer pprof.StopCPUProfile()
159 | }
160 |
161 | // Default to the host component of addr.
162 | if serverName == "" {
163 | host, _, err := net.SplitHostPort(addr)
164 | if err != nil {
165 | log.Println("Couldn't split host/port", err)
166 | }
167 | serverName = host
168 | }
169 | config := minq.NewTlsConfig(serverName)
170 |
171 | inner_main(&config, false)
172 | if resume {
173 | inner_main(&config, true)
174 | }
175 | }
176 | func inner_main(config *minq.TlsConfig, resuming bool) {
177 |
178 | uaddr, err := net.ResolveUDPAddr("udp", addr)
179 | if err != nil {
180 | log.Println("Invalid UDP addr", err)
181 | return
182 | }
183 |
184 | usock, conn := makeConnection(config, uaddr)
185 | if conn == nil {
186 | return
187 | }
188 |
189 | if !resuming || !zeroRtt {
190 | err = completeConnection(usock, conn)
191 | if err != nil {
192 | return
193 | }
194 | }
195 |
196 | // Hopefully reduce the risk of reordering
197 | time.Sleep(100 * time.Millisecond)
198 |
199 | // Make all the streams we need
200 | streams := make([]minq.Stream, httpCount)
201 | for i := 0; i < httpCount; i++ {
202 | streams[i] = conn.CreateStream()
203 | if streams[i] == nil {
204 | log.Println("Couldn't create stream")
205 | return
206 | }
207 | }
208 | httpLeft = httpCount
209 |
210 | udpin := make(chan []byte)
211 | stdin := make(chan []byte)
212 |
213 | // Read from the UDP socket.
214 | go func() {
215 | for {
216 | b, err := readUDP(usock)
217 | if err == minq.ErrorWouldBlock {
218 | udpin <- make([]byte, 0)
219 | continue
220 | }
221 | udpin <- b
222 | if b == nil {
223 | return
224 | }
225 | }
226 | }()
227 |
228 | if heartbeat > 0 && doHttp == "" {
229 | ticker := time.NewTicker(time.Millisecond * time.Duration(heartbeat))
230 | go func() {
231 | for t := range ticker.C {
232 | stdin <- []byte(fmt.Sprintf("Heartbeat at %v\n", t))
233 | }
234 | }()
235 | }
236 |
237 | if doHttp != "" {
238 | req := "GET " + doHttp + "\r\n"
239 | for _, str := range streams {
240 | str.Write([]byte(req))
241 | str.Close()
242 | }
243 | }
244 |
245 | if resuming && zeroRtt {
246 | log.Println("Completing connection after we sent 0-RTT send in 0-RTT")
247 | err = completeConnection(usock, conn)
248 | if err != nil {
249 | return
250 | }
251 | }
252 |
253 | if doHttp == "" {
254 | // Read from stdin.
255 | go func() {
256 | for {
257 | b := make([]byte, 1024)
258 | n, err := os.Stdin.Read(b)
259 | if err != nil {
260 | stdin <- nil
261 | return
262 | }
263 | b = b[:n]
264 | stdin <- b
265 | }
266 | }()
267 | }
268 | for {
269 | select {
270 | case u := <-udpin:
271 | if len(u) == 0 {
272 | _, err = conn.CheckTimer()
273 | } else {
274 | err = conn.Input(u)
275 | }
276 | if err != nil {
277 | log.Println("Error", err)
278 | return
279 | }
280 | if doHttp != "" && httpLeft == 0 {
281 | return
282 | }
283 | case i := <-stdin:
284 | if i == nil {
285 | // TODO(piet@devae.re) close the apropriate stream(s)
286 | }
287 | streams[0].Write(i)
288 | if err != nil {
289 | log.Println("Error", err)
290 | return
291 | }
292 | }
293 |
294 | }
295 | }
296 |
--------------------------------------------------------------------------------
/packet_test.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "crypto"
5 | "encoding/hex"
6 | "fmt"
7 | "github.com/bifurcation/mint"
8 | "testing"
9 | )
10 |
11 | var (
12 | testCid7 = ConnectionId([]byte{7, 7, 7, 7, 7, 7, 7})
13 | testCid4 = ConnectionId([]byte{4, 4, 4, 4})
14 | testCid5 = ConnectionId([]byte{5, 5, 5, 5, 5})
15 | testVersion = VersionNumber(0xdeadbeef)
16 | testPn = uint64(0xff000001)
17 | )
18 |
19 | // Packet header tests.
20 | func packetHeaderEDE(t *testing.T, p *packetHeader, cidLen uintptr) {
21 | res, err := encode(p)
22 | assertNotError(t, err, "Could not encode")
23 | fmt.Println("Encoded = ", hex.EncodeToString(res))
24 |
25 | var p2 packetHeader
26 | p2.shortCidLength = cidLen
27 | _, err = decode(&p2, res)
28 | assertNotError(t, err, "Could not decode")
29 | fmt.Println("Decoded = ", p2)
30 |
31 | res2, err := encode(&p2)
32 | assertNotError(t, err, "Could not re-encode")
33 | fmt.Println("Encoded2 =", hex.EncodeToString(res2))
34 | assertByteEquals(t, res, res2)
35 | }
36 |
37 | func TestLongHeader(t *testing.T) {
38 | p := newPacket(packetTypeInitial, testCid7, testCid4, testVersion,
39 | testPn, make([]byte, 65), 16)
40 | p.Token = []byte{1, 2, 3}
41 | p.TokenLength = uint8(len(p.Token))
42 | packetHeaderEDE(t, &p.packetHeader, 0)
43 | }
44 |
45 | func TestShortHeader(t *testing.T) {
46 | p := newPacket(packetTypeProtectedShort, testCid7, testCid4, testVersion,
47 | testPn, make([]byte, 65), 16)
48 |
49 | // We have to provide assistance to the decoder for short headers.
50 | // Otherwise, it can't know long the destination connection ID is.
51 | packetHeaderEDE(t, &p.packetHeader, uintptr(len(p.DestinationConnectionID)))
52 | }
53 |
54 | func testPNEDecrypt(t *testing.T, pbytes []byte, pn uint64, pnLen int, pnef pneCipherFactory) {
55 | // Now decode the packet.
56 | hdr2 := packetHeader{shortCidLength: kCidDefaultLength}
57 |
58 | hdrlen2, err := decode(&hdr2, pbytes)
59 | assertNotError(t, err, "Couldn't decode encrypted packet")
60 |
61 | dpn := make([]byte, 4)
62 | err = xorPacketNumber(&hdr2, int(hdrlen2), dpn, pbytes, pnef)
63 | assertNotError(t, err, "Couldn't XOR the packet number")
64 | assertEquals(t, 4, len(dpn))
65 |
66 | pn2, l2, err := decodePacketNumber(dpn)
67 | assertNotError(t, err, "Couldn't decode packet number")
68 | assertEquals(t, l2, pnLen)
69 | assertEquals(t, pn2, pn)
70 | }
71 |
72 | func DISABLED_TestPNEVector(t *testing.T) {
73 | kPacketHex := "ffff00000d5006b858ec6f80452b0044efa5d8d307c2973fa0d63fd9b03a4e163b990dd778894a9edc8eacfbe4aa6fbf4a22ec7f906b5e8b8ae12e5fcc7924dfeee813842bb2149b805e55895084e8393200bb3fc618af7d08281485d914ce42303f5d772b200508a0c00253e332e36a84f657321ac4c8e2cc8a117e95871f12b1f36be8c4b76fa433dc4d3142e6547f4598bf4b192130aea6fc20da5158b2162b5a899957da05ded5c70907298fd885847f22a1ecb0a814fe0170e23cad20af64f05cc13c74e91824101afdcf5f1532fc2fde936a3a159f76283a26c738f778c76e6ca41fa7f134401d39027fd81de17a8021a9c0aaa9b4478fe5c0647941618f3bee410caf94c248d2a64b5e45845cd77de13a5ed94034d2bc5f457887351993c1ecfa34fd0c658fea3f8086d26808eef976262ecf0ad646b627945511dde83e26609cd5cfd7ed9f6207d76618b44c48bf623bf420dc7c127e5d5f529f083b71a17b17da329bfc38a74bf8cfcf315c7c070b71ebfae3ab351341a767adfdd9e57c738f5de9da53711e886d1472310b917a1c9798e3e9b13c7c74beb8d1b82345bea1349415679a9c64b0433b68c871ae08092a1f6106bc06337cd343866ee8185c03fcf3bb0666453f847905547199414c1e57535747be61cdf6778378f121d68df0181ee9e8d9932c1c593c0f8c0a1af0f5262b86205002dced9ecdaee2d0aa07dd4c14f98571e4bea72f8474f63697043e936ebb2bf9716ed0efbdc13005a75cee3a49babc61b9677764510eb19828df4e10fb38b79a1efbf04cc2d571949d5403f797361743dcc5e3bf3b4396f7ae1a3affbc9f72e540d920363970307e0725fa838d611803251a4a08ccca1983d5b29a583758be63343e88f5591d885b8af695f33adbdd0d941d260287e32ef5a98fd55ac137211021fdc23b5d7a5469f578bf7aff6529117996f9ebab5e6dc7b047b356332fea82fdd620eb86f3c1d3855c8b8075da59a7662f4a11b977d996b8b3c7657ad4a82a20a7f76ce376c0320086ed029dd615399307983113cc0aa973ecba691e7e4cdc80aefa7e8c8347baba050eaca7dc35a21aa854e531dc7758d7d10b8c8e42c1be3bbf266d055ac25c37279ebefa28bbe89a34ad1ab3d23d7a66d1c216a57650e6ec9fc8ba7adfb38e57f20c467166c8fe7944e67f82138160002004812c78ba4b5f0da917da4cc14cf8fc10dba3f533facb11ef06d8b8f178ea9c5e8acbbca7b7f0e1f6b7a70ec2d5108cc41178056295793bed357accbb03c0582dc69bc77a34030f38cce256c5a9cec6e862146e3f0463f10dd5833257d0a0359166a7e2027d98eaf26cf0d5a4a05f6ef8b742f5d314a31deeeabe4ebc3106547e79c6cb933105d907b4c8c60443e97a154694bab5edfc781a438675b9de6ed03c77f51458eab61ca2e80ac02cc8c037d8fb3cf129d7107f618d66032cc02238a211f78bfa44e7c1bbcfcc627771c188d1b3713ce5e75cd2325a0a2ba08268cad13b27d97696ef678b592d0ac80ad1bacb4a1ba75bea8c477f39fc32c2aa20f352bb0da1c49b7d3927bcd9dfaf229237081d5fa08924fefd923ff0ac6baad6864b7c10dc73379a5ebd9e4678a0c26517656e8e51fca2a51a33fb2cdd5d76d12674c240ba9a4893c1af69b8f2c4adf37c4a47551eb2006a732f6b3b2f338c078ede33946dfe4a55bf644d3b98848693ada1fcb6fc16cac339ee65c24dc64b0ae92005354af00ade71e6c5e2efd85c46131d948ff14096b0f06a41d83c8522f30beb4eaaf4a6f908fe2a6ee754c896"
74 | kPacket := unhex(kPacketHex)
75 | kPNLen := 4
76 | kPN := 0
77 |
78 | hdr2 := packetHeader{shortCidLength: kCidDefaultLength}
79 | _, err := decode(&hdr2, kPacket)
80 | assertNotError(t, err, "Couldn't decode encrypted packet")
81 |
82 | params := mint.CipherSuiteParams{
83 | Suite: mint.TLS_AES_128_GCM_SHA256,
84 | Cipher: nil,
85 | Hash: crypto.SHA256,
86 | KeyLen: 16,
87 | IvLen: 12,
88 | }
89 |
90 | cs, err := generateCleartextKeys(hdr2.DestinationConnectionID, clientCtSecretLabel,
91 | ¶ms)
92 | assertNotError(t, err, "Couldn't generate cleartext keys")
93 | testPNEDecrypt(t, kPacket, uint64(kPN), kPNLen, cs.pne)
94 | }
95 |
96 | func testPNE(t *testing.T, pt packetType) {
97 | key := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
98 | payload := make([]byte, 65)
99 | p := newPacket(pt, testCid5, testCid4, testVersion,
100 | 0xfe, payload, 16)
101 |
102 | hdr, err := encode(&p.packetHeader)
103 | assertNotError(t, err, "Couldn't encode packet header")
104 |
105 | pnbytes := encodePacketNumber(p.PacketNumber, 2)
106 |
107 | pbytes := append(hdr, pnbytes...)
108 | pbytes = append(pbytes, payload...)
109 |
110 | pnef := newPneCipherFactoryAES(key)
111 |
112 | // Encode the packet in place.
113 | err = xorPacketNumber(&p.packetHeader, len(hdr), pbytes[len(hdr):len(hdr)+len(pnbytes)], pbytes, pnef)
114 | assertNotError(t, err, "Couldn't XOR the packet number")
115 |
116 | // Now decode the packet.
117 | testPNEDecrypt(t, pbytes, p.PacketNumber, len(pnbytes), pnef)
118 | }
119 |
120 | func TestPNE(t *testing.T) {
121 | t.Run("Long", func(t *testing.T) {
122 | testPNE(t, packetTypeInitial)
123 | })
124 | t.Run("Short", func(t *testing.T) {
125 | testPNE(t, packetTypeProtectedShort)
126 | })
127 | }
128 |
129 | /*
130 | * TODO(ekr@rtfm.com): Rewrite this code and merge it into
131 | * connection.go
132 | // Mock for connection state
133 | type ConnectionStateMock struct {
134 | aead aeadFNV
135 | }
136 |
137 | func (c *ConnectionStateMock) established() bool { return false }
138 | func (c *ConnectionStateMock) zeroRttAllowed() bool { return false }
139 | func (c *ConnectionStateMock) expandPacketNumber(pn uint64) uint64 {
140 | return pn
141 | }
142 |
143 | func TestEDEPacket(t *testing.T) {
144 | var c ConnectionStateMock
145 |
146 | p := Packet{
147 | kTestpacketHeader,
148 | []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'},
149 | }
150 |
151 | encoded, err := encodePacket(&c, &c.aead, &p)
152 | assertNotError(t, err, "Could not encode packet")
153 |
154 | p2, err := decodePacket(&c, &c.aead, encoded)
155 | assertNotError(t, err, "Could not decode packet")
156 |
157 | encoded2, err := encodePacket(&c, &c.aead, p2)
158 | assertNotError(t, err, "Could not re-encode packet")
159 |
160 | assertByteEquals(t, encoded, encoded2)
161 | }
162 | */
163 |
164 | func testPacketNumberED(t *testing.T, pn uint64, l int) {
165 | b := encodePacketNumber(pn, l)
166 | assertEquals(t, l, len(b))
167 |
168 | pn2, l2, err := decodePacketNumber(b)
169 | assertNotError(t, err, "Error decoding packet number")
170 | assertEquals(t, l2, l)
171 |
172 | mask := uint64(0)
173 | for i := 0; i < l; i++ {
174 | mask <<= 8
175 | mask |= 0xff
176 | }
177 | assertEquals(t, mask&pn, pn2)
178 | }
179 |
180 | func TestPacketNumberED(t *testing.T) {
181 | val := uint64(0x04030201)
182 |
183 | for _, i := range []int{1, 2, 4} {
184 | t.Run(fmt.Sprintf("%v", i), func(t *testing.T) {
185 | testPacketNumberED(t, val, i)
186 | })
187 | }
188 | }
189 |
--------------------------------------------------------------------------------
/codec.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "reflect"
8 | "runtime"
9 | "strconv"
10 | "strings"
11 | "unicode"
12 | )
13 |
14 | const (
15 | codecDefaultSize = ^uintptr(0)
16 | )
17 |
18 | func uintEncode(buf *bytes.Buffer, v reflect.Value, encodingSize uintptr) error {
19 | size := v.Type().Size()
20 | if encodingSize != codecDefaultSize {
21 | if encodingSize > size {
22 | return fmt.Errorf("Requested a length longer than the native type")
23 | }
24 | size = encodingSize
25 | }
26 |
27 | uintEncodeInt(buf, v.Uint(), size)
28 | return nil
29 | }
30 |
31 | func uintEncodeInt(buf *bytes.Buffer, val uint64, size uintptr) {
32 | // Now encode the low-order bytes of the value.
33 | for b := size; b > 0; b -= 1 {
34 | buf.WriteByte(byte(val >> ((b - 1) * 8)))
35 | }
36 | }
37 |
38 | // isVarint determines if the field is a varint. This reads the mint/syntax tag
39 | // for the field, but only supports a simple "varint".
40 | func isVarint(f reflect.StructField) bool {
41 | return f.Tag.Get("tls") == "varint"
42 | }
43 |
44 | func varintEncode(buf *bytes.Buffer, v uint64) {
45 | switch {
46 | case v < (uint64(1) << 6):
47 | uintEncodeInt(buf, v, 1)
48 | case v < (uint64(1) << 14):
49 | uintEncodeInt(buf, v|(1<<14), 2)
50 | case v < (uint64(1) << 30):
51 | uintEncodeInt(buf, v|(2<<30), 4)
52 | case v < (uint64(1) << 62):
53 | uintEncodeInt(buf, v|(3<<62), 8)
54 | default:
55 | panic("varint value is too large")
56 | }
57 | }
58 |
59 | func arrayEncode(buf *bytes.Buffer, v reflect.Value) error {
60 | b := v.Bytes()
61 | logf(logTypeCodec, "Encoding array length=%d", len(b))
62 | buf.Write(b)
63 |
64 | return nil
65 | }
66 |
67 | // Check to see if fields
68 | func ignoreField(name string) bool {
69 | return unicode.IsLower(rune(name[0]))
70 | }
71 |
72 | // Length specifications are of the form:
73 | //
74 | // lengthbits: "B:L1,L2,...LN
75 | //
76 | // where B is the rightmost bit of the length bits and
77 | // L_n are the various lengths (in bytes) indicated by
78 | // the bit values in sequence. N must be a power of 2
79 | // and the right number of bytes is drawn to compute it.
80 | type lengthSpec struct {
81 | rightBit uint
82 | numBits uint
83 | values []int
84 | }
85 |
86 | func parseLengthSpecification(spec string) (*lengthSpec, error) {
87 | spl := strings.Split(spec, ":")
88 | assert(len(spl) == 2)
89 |
90 | // Rightmost bit.
91 | p, err := strconv.ParseUint(spl[0], 10, 8)
92 | if err != nil {
93 | return nil, err
94 | }
95 | bitr := uint(p)
96 | vals := strings.Split(spl[1], ",")
97 |
98 | // Figure out how many bits we need.
99 | nvals := int(1)
100 | var bits int
101 | for bits = 1; bits <= 8; bits++ {
102 | nvals <<= 1
103 | if nvals == len(vals) {
104 | break
105 | }
106 | }
107 | assert(bits < 9)
108 |
109 | // Now compute the values
110 | valArr := make([]int, nvals)
111 | for i, v := range vals {
112 | valArr[i], err = strconv.Atoi(v)
113 | if err != nil {
114 | return nil, err
115 | }
116 | }
117 |
118 | return &lengthSpec{
119 | bitr,
120 | uint(bits),
121 | valArr,
122 | }, nil
123 | }
124 |
125 | func computeLengthFromSpec(t byte, f reflect.StructField) uintptr {
126 | st := f.Tag.Get("lengthbits")
127 | if st == "" {
128 | return codecDefaultSize
129 | }
130 |
131 | spec, err := parseLengthSpecification(st)
132 | assert(err == nil)
133 |
134 | mask := byte(0)
135 | bit := uint(0)
136 | for ; bit < spec.numBits; bit++ {
137 | mask |= (1 << bit)
138 | }
139 | idx := int(t >> (spec.rightBit - 1) & mask)
140 |
141 | return uintptr(spec.values[idx])
142 | }
143 |
144 | // Encode all the fields of a struct to a bytestring.
145 | func encode(i interface{}) (ret []byte, err error) {
146 | var buf bytes.Buffer
147 | var res error
148 | reflected := reflect.ValueOf(i).Elem()
149 | fields := reflected.NumField()
150 |
151 | for j := 0; j < fields; j += 1 {
152 | field := reflected.Field(j)
153 | tipe := reflected.Type().Field(j)
154 |
155 | if ignoreField(tipe.Name) {
156 | continue
157 | }
158 |
159 | logf(logTypeCodec, "Type name %s Kind=%v", tipe.Name, field.Kind())
160 |
161 | switch field.Kind() {
162 | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
163 | // Call the length overrider to tell us if we shoud be using a shorter
164 | // encoding.
165 | encodingSize := uintptr(codecDefaultSize)
166 | lFunc, getLength := reflected.Type().MethodByName(tipe.Name + "__length")
167 | logf(logTypeCodec, "Looking for length overrider for type %v", tipe.Name)
168 | if getLength {
169 | lengthResult := lFunc.Func.Call([]reflect.Value{reflect.ValueOf(i).Elem()})
170 | encodingSize = uintptr(lengthResult[0].Uint())
171 | logf(logTypeCodec, "Overriden length to %v", encodingSize)
172 | }
173 | if isVarint(tipe) {
174 | if encodingSize != 0 {
175 | varintEncode(&buf, field.Uint())
176 | }
177 | res = nil
178 | break
179 | }
180 |
181 | res = uintEncode(&buf, field, encodingSize)
182 | case reflect.Array, reflect.Slice:
183 | res = arrayEncode(&buf, field)
184 | default:
185 | return nil, fmt.Errorf("Unknown type")
186 | }
187 |
188 | if res != nil {
189 | return nil, res
190 | }
191 | }
192 |
193 | ret = buf.Bytes()
194 | logf(logTypeCodec, "Total encoded length = %v", len(ret))
195 | return ret, nil
196 | }
197 |
198 | func uintDecodeIntBuf(val []byte) uint64 {
199 | tmp := uint64(0)
200 | for b := 0; b < len(val); b++ {
201 | tmp = (tmp << 8) + uint64(val[b])
202 | }
203 | return tmp
204 | }
205 |
206 | func uintDecodeInt(r io.Reader, size uintptr) (uint64, error) {
207 | val := make([]byte, size)
208 | _, err := io.ReadFull(r, val)
209 | if err != nil {
210 | return 0, err
211 | }
212 |
213 | return uintDecodeIntBuf(val), nil
214 | }
215 |
216 | func uintDecode(r io.Reader, v reflect.Value, encodingSize uintptr) (uintptr, error) {
217 | size := v.Type().Size()
218 | if encodingSize != codecDefaultSize {
219 | if encodingSize > size {
220 | return 0, fmt.Errorf("Requested a length longer than the native type")
221 | }
222 | size = encodingSize
223 | }
224 |
225 | tmp, err := uintDecodeInt(r, size)
226 | if err != nil {
227 | return 0, err
228 | }
229 |
230 | v.SetUint(tmp)
231 |
232 | return size, nil
233 | }
234 |
235 | func varintDecode(r io.Reader, v reflect.Value) (uintptr, error) {
236 | p := make([]byte, 8)
237 | _, err := r.Read(p[:1])
238 | if err != nil {
239 | return 0, err
240 | }
241 |
242 | value := uint64(p[0] & 0x3f)
243 | extra := uintptr(1<<(p[0]>>6)) - 1
244 | if extra > 0 {
245 | tail, err := uintDecodeInt(r, extra)
246 | if err != nil {
247 | return 0, err
248 | }
249 | value = (value << (8 * extra)) | tail
250 | }
251 |
252 | v.SetUint(value)
253 | return 1 + extra, nil
254 | }
255 |
256 | func encodeArgs(args ...interface{}) []byte {
257 | var buf bytes.Buffer
258 | var res error
259 |
260 | for _, arg := range args {
261 | reflected := reflect.ValueOf(arg)
262 | switch reflected.Kind() {
263 | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
264 | res = uintEncode(&buf, reflected, codecDefaultSize)
265 | case reflect.Array, reflect.Slice:
266 | res = arrayEncode(&buf, reflected)
267 | default:
268 | panic(fmt.Sprintf("Unknown type"))
269 | }
270 | if res != nil {
271 | panic(fmt.Sprintf("Encoding error"))
272 | }
273 | }
274 |
275 | return buf.Bytes()
276 | }
277 |
278 | func arrayDecode(r io.Reader, v reflect.Value, encodingSize uintptr) (uintptr, error) {
279 | logf(logTypeCodec, "encodingSize = %v", encodingSize)
280 |
281 | val := make([]byte, encodingSize)
282 |
283 | logf(logTypeCodec, "Reading array of size %v", encodingSize)
284 |
285 | // Go will return EOF if you try to read 0 bytes off a closed stream.
286 | if encodingSize == 0 {
287 | return 0, nil
288 | }
289 | _, err := io.ReadFull(r, val)
290 | if err != nil {
291 | return 0, err
292 | }
293 |
294 | v.SetBytes(val)
295 | return encodingSize, nil
296 | }
297 |
298 | // Decode all the fields of a struct from a bytestring. Takes
299 | // a pointer to the struct to fill in
300 | func decode(i interface{}, data []byte) (uintptr, error) {
301 | buf := bytes.NewReader(data)
302 | var res error
303 | reflected := reflect.ValueOf(i).Elem()
304 | fields := reflected.NumField()
305 | bytesread := uintptr(0)
306 |
307 | for j := 0; j < fields; j++ {
308 | br := uintptr(0)
309 | field := reflected.Field(j)
310 | tipe := reflected.Type().Field(j)
311 |
312 | if ignoreField(tipe.Name) {
313 | continue
314 | }
315 |
316 | // Call the length overrider to tell us if we should be using a shorter
317 | // encoding.
318 | encodingSize := uintptr(codecDefaultSize)
319 | lFunc, getLength := reflected.Type().MethodByName(tipe.Name + "__length")
320 | if getLength {
321 | lengthResult := lFunc.Func.Call([]reflect.Value{reflect.ValueOf(i).Elem()})
322 | encodingSize = uintptr(lengthResult[0].Uint())
323 | logf(logTypeCodec, "Length overrider for %s returns %v", tipe.Name, encodingSize)
324 | }
325 |
326 | switch field.Kind() {
327 | case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
328 | if isVarint(tipe) && encodingSize != 0 {
329 | br, res = varintDecode(buf, field)
330 | } else {
331 | br, res = uintDecode(buf, field, encodingSize)
332 | }
333 | case reflect.Array, reflect.Slice:
334 | if encodingSize == codecDefaultSize {
335 | encodingSize = uintptr(buf.Len())
336 | }
337 | br, res = arrayDecode(buf, field, encodingSize)
338 | default:
339 | return 0, fmt.Errorf("Unknown type")
340 | }
341 | if res != nil {
342 | logf(logTypeCodec, "Error while reading field %v: %v", tipe.Name, res)
343 | return bytesread, res
344 | }
345 | bytesread += br
346 | }
347 |
348 | return bytesread, nil
349 | }
350 |
351 | func backtrace() string {
352 | bt := string("")
353 | for i := 1; ; i++ {
354 | _, file, line, ok := runtime.Caller(i)
355 | if !ok {
356 | break
357 | }
358 | bt = fmt.Sprintf("%v: %d\n", file, line) + bt
359 | }
360 | return bt
361 | }
362 |
--------------------------------------------------------------------------------
/bin/server/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bytes"
5 | "crypto"
6 | "crypto/x509"
7 | "flag"
8 | "fmt"
9 | "github.com/cloudflare/cfssl/helpers"
10 | "github.com/ekr/minq"
11 | "io"
12 | "io/ioutil"
13 | "log"
14 | "net"
15 | "os"
16 | "runtime/pprof"
17 | "strconv"
18 | "strings"
19 | "time"
20 | )
21 |
22 | var addr string
23 | var serverName string
24 | var keyFile string
25 | var certFile string
26 | var logFile string
27 | var logOut *os.File
28 | var doHttp bool
29 | var statelessReset bool
30 | var cpuProfile string
31 | var echo bool
32 | var standalone bool
33 |
34 | // Shared data structures.
35 | type conn struct {
36 | conn *minq.Connection
37 | last time.Time
38 | }
39 |
40 | func (c *conn) checkTimer() {
41 | t := time.Now()
42 | if t.After(c.last.Add(time.Second)) {
43 | c.conn.CheckTimer()
44 | c.last = time.Now()
45 | }
46 | }
47 |
48 | var conns = make(map[string]*conn)
49 |
50 | // An feed through server.
51 | type feedthroughServerHandler struct {
52 | echo bool
53 | }
54 |
55 | func (h *feedthroughServerHandler) NewConnection(c *minq.Connection) {
56 | log.Println("New connection")
57 | c.SetHandler(&feedthroughConnHandler{echo, 0})
58 | conns[c.ServerId().String()] = &conn{c, time.Now()}
59 | }
60 |
61 | type feedthroughConnHandler struct {
62 | echo bool
63 | bytesRead int
64 | }
65 |
66 | func (h *feedthroughConnHandler) StateChanged(s minq.State) {
67 | log.Println("State changed to ", s)
68 | }
69 |
70 | func (h *feedthroughConnHandler) NewStream(s minq.Stream) {
71 | log.Println("Created new stream id=", s.Id())
72 | }
73 | func (h *feedthroughConnHandler) NewRecvStream(s minq.RecvStream) {
74 | log.Println("Created new stream id=", s.Id())
75 | }
76 |
77 | func (h *feedthroughConnHandler) StreamReadable(s minq.RecvStream) {
78 | log.Println("Ready to read for stream id=", s.Id())
79 | for {
80 | b := make([]byte, 1024)
81 |
82 | n, err := s.Read(b)
83 | switch err {
84 | case nil:
85 | break
86 | case minq.ErrorWouldBlock:
87 | return
88 | case minq.ErrorStreamIsClosed, minq.ErrorConnIsClosed:
89 | log.Println("")
90 | return
91 | default:
92 | log.Println("Error: ", err)
93 | return
94 | }
95 | b = b[:n]
96 | h.bytesRead += n
97 | os.Stdout.Write(b)
98 | log.Printf("Total bytes read = %d\n", h.bytesRead)
99 |
100 | if echo {
101 | // Flip the case so we can distinguish echo
102 | for i := range b {
103 | if b[i] > 0x40 {
104 | b[i] ^= 0x20
105 | }
106 | }
107 | // This isn't really going to work but for now.
108 | s.(minq.SendStream).Write(b)
109 | }
110 | }
111 | }
112 |
113 | // An HTTP 0.9 Handler
114 | type httpServerHandler struct {
115 | }
116 |
117 | func (h *httpServerHandler) NewConnection(c *minq.Connection) {
118 | log.Println("New connection")
119 | c.SetHandler(&httpConnHandler{make(map[uint64]*httpStream, 0)})
120 | conns[c.ServerId().String()] = &conn{c, time.Now()}
121 | }
122 |
123 | type httpStream struct {
124 | s minq.Stream
125 | buf []byte
126 | closed bool
127 | }
128 |
129 | type httpConnHandler struct {
130 | streams map[uint64]*httpStream
131 | }
132 |
133 | func (h *httpConnHandler) StateChanged(s minq.State) {
134 | log.Println("State changed to ", s)
135 | }
136 |
137 | func (h *httpConnHandler) NewStream(s minq.Stream) {
138 | h.streams[s.Id()] = &httpStream{s, nil, false}
139 | }
140 |
141 | func (h *httpConnHandler) NewRecvStream(s minq.RecvStream) {
142 | log.Println("For some reason some opened a unidirectional stream. Ignoring")
143 | }
144 |
145 | func (h *httpStream) Respond(val []byte) {
146 | h.s.Write(val)
147 | h.s.Close()
148 | h.closed = true
149 | }
150 |
151 | func (h *httpStream) Error(err string) {
152 | h.Respond([]byte(err))
153 | }
154 |
155 | // We expect the URL to be one of two things:
156 | //
157 | // A number, in which case we respond with that number of
158 | // Xs, up to 10,000
159 | // A non-number, in which case we respond with 10 repetitions
160 | // of that value.
161 | func (h *httpConnHandler) StreamReadable(s minq.RecvStream) {
162 | log.Println("Ready to read for stream id=", s.Id())
163 | st := h.streams[s.Id()]
164 | if st.closed {
165 | return
166 | }
167 |
168 | b := make([]byte, 1024)
169 | n, err := s.Read(b)
170 | if err != nil && err != minq.ErrorWouldBlock {
171 | log.Println("Error reading")
172 | return
173 | }
174 | b = b[:n]
175 | log.Printf("Read %v bytes from peer %x\n", n, b)
176 |
177 | st.buf = append(st.buf, b...)
178 |
179 | // See if we received a complete LF
180 | str := string(st.buf)
181 | idx := strings.IndexRune(str, '\n')
182 | if idx == -1 {
183 | return
184 | }
185 | str = str[:idx]
186 |
187 | // OK, we have a complete line.
188 | toks := strings.Split(str, " ")
189 | if toks[0] != "GET" {
190 | st.Error(fmt.Sprintf("Bogus method: %v", toks[0]))
191 | return
192 | }
193 | if len(toks) < 2 {
194 | st.Error("No resource")
195 | return
196 | }
197 |
198 | val := strings.TrimSpace(toks[1])
199 |
200 | if val[0] != '/' {
201 | st.Error(fmt.Sprintf("Bad value: %v", val))
202 | return
203 | }
204 | val = val[1:]
205 |
206 | count, err := strconv.ParseUint(val, 10, 32)
207 | var rsp []byte
208 | if err == nil {
209 | if count > 10000 {
210 | count = 10000
211 | }
212 | rsp = bytes.Repeat([]byte{'X'}, int(count))
213 | } else {
214 | rspstr := ""
215 | for i := 0; i < 10; i++ {
216 | rspstr += val
217 | rspstr += "--"
218 | }
219 | rspstr += "\n"
220 | rsp = []byte(rspstr)
221 | }
222 | st.Respond(rsp)
223 | }
224 |
225 | func logFunc(format string, args ...interface{}) {
226 | fmt.Fprintf(logOut, format, args...)
227 | fmt.Fprintf(logOut, "\n")
228 | }
229 |
230 | func main() {
231 | flag.StringVar(&addr, "addr", "localhost:4433", "[host:port]")
232 | flag.StringVar(&serverName, "server-name", "localhost", "[SNI]")
233 | flag.StringVar(&keyFile, "key", "", "Key file")
234 | flag.StringVar(&certFile, "cert", "", "Cert file")
235 | flag.StringVar(&logFile, "log", "", "Log file")
236 | flag.BoolVar(&doHttp, "http", false, "Do HTTP/0.9")
237 | flag.BoolVar(&echo, "echo", false, "Run as an echo server")
238 | flag.BoolVar(&statelessReset, "stateless-reset", false, "Do stateless reset")
239 | flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to file")
240 | flag.BoolVar(&standalone, "standalone", false, "Run standalone")
241 | flag.Parse()
242 |
243 | var key crypto.Signer
244 | var certChain []*x509.Certificate
245 |
246 | if cpuProfile != "" {
247 | f, err := os.Create(cpuProfile)
248 | if err != nil {
249 | log.Printf("Could not create CPU profile file %v err=%v\n", cpuProfile, err)
250 | return
251 | }
252 | pprof.StartCPUProfile(f)
253 | log.Println("CPU profiler started")
254 | defer pprof.StopCPUProfile()
255 | }
256 |
257 | config := minq.NewTlsConfig(serverName)
258 | config.ForceHrr = statelessReset
259 |
260 | if keyFile != "" && certFile == "" {
261 | log.Println("Can't specify -key without -cert")
262 | return
263 | }
264 |
265 | if keyFile == "" && certFile != "" {
266 | log.Println("Can't specify -cert without -key")
267 | return
268 | }
269 |
270 | if keyFile != "" && certFile != "" {
271 | keyPEM, err := ioutil.ReadFile(keyFile)
272 | if err != nil {
273 | log.Printf("Couldn't open keyFile %v err=%v", keyFile, err)
274 | return
275 | }
276 | key, err = helpers.ParsePrivateKeyPEM(keyPEM)
277 | if err != nil {
278 | log.Println("Couldn't parse private key: ", err)
279 | return
280 | }
281 |
282 | certPEM, err := ioutil.ReadFile(certFile)
283 | if err != nil {
284 | log.Printf("Couldn't open certFile %v err=%v", certFile, err)
285 | return
286 | }
287 | certChain, err = helpers.ParseCertificatesPEM(certPEM)
288 | if err != nil {
289 | log.Println("Couldn't parse certificates: ", err)
290 | return
291 | }
292 | config.CertificateChain = certChain
293 | config.Key = key
294 | }
295 |
296 | if logFile != "" {
297 | var err error
298 | logOut, err = os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
299 | if err != nil {
300 | log.Println("Couldn't open file")
301 | return
302 | }
303 | minq.SetLogOutput(logFunc)
304 | }
305 | uaddr, err := net.ResolveUDPAddr("udp", addr)
306 | if err != nil {
307 | log.Println("Invalid UDP addr: ", err)
308 | return
309 | }
310 |
311 | usock, err := net.ListenUDP("udp", uaddr)
312 | if err != nil {
313 | log.Println("Couldn't listen on UDP: ", err)
314 | return
315 | }
316 |
317 | var handler minq.ServerHandler
318 | if doHttp {
319 | handler = &httpServerHandler{}
320 | } else {
321 | handler = &feedthroughServerHandler{echo}
322 | }
323 | server := minq.NewServer(minq.NewUdpTransportFactory(usock), &config, handler)
324 |
325 | stdin := make(chan []byte)
326 | if !standalone {
327 | go func() {
328 | for {
329 | b := make([]byte, 1024)
330 | n, err := os.Stdin.Read(b)
331 | if err == io.EOF {
332 | log.Println("EOF received")
333 | close(stdin)
334 | return
335 | } else if err != nil {
336 | log.Println("Error reading from stdin")
337 | return
338 | }
339 | b = b[:n]
340 | stdin <- b
341 | }
342 | }()
343 | }
344 |
345 | for {
346 |
347 | select {
348 | case _, open := <-stdin:
349 | if open == false {
350 | log.Println("Shutdown signal received from stdin. Goodnight.")
351 | return
352 | }
353 | default:
354 | }
355 |
356 | b := make([]byte, 8192)
357 |
358 | usock.SetDeadline(time.Now().Add(time.Second))
359 | n, addr, err := usock.ReadFromUDP(b)
360 | if err != nil {
361 | e, o := err.(net.Error)
362 | if !o || !e.Timeout() {
363 | log.Println("Error reading from UDP socket: ", err)
364 | return
365 | }
366 | n = 0
367 | }
368 |
369 | // If we read data, process it.
370 | if n > 0 {
371 | if n == len(b) {
372 | log.Println("Underread from UDP socket")
373 | return
374 | }
375 | b = b[:n]
376 |
377 | _, err = server.Input(addr, b)
378 | if err != nil {
379 | log.Println("server.Input returned error: ", err)
380 | return
381 | }
382 | }
383 |
384 | // Check the timers.
385 | server.CheckTimer()
386 | }
387 | }
388 |
--------------------------------------------------------------------------------
/congestion.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package minq is a minimal implementation of QUIC, as documented at
3 | https://quicwg.github.io/. Minq partly implements draft-04.
4 |
5 | */
6 | package minq
7 |
8 | import (
9 | "math"
10 | "time"
11 | // "fmt"
12 | )
13 |
14 | // congestion control related constants
15 | const (
16 | kDefaultMss = 1460 // bytes
17 | kInitalWindow = 10 * kDefaultMss
18 | kMinimumWindow = 2 * kDefaultMss
19 | kMaximumWindow = kInitalWindow
20 | kLossReductionFactor = 0.5
21 | )
22 |
23 | // loss dectection related constants
24 | const (
25 | kMaxTLPs = 2
26 | kReorderingThreshold = 3
27 | kTimeReorderingFraction = 0.125
28 | kMinTLPTimeout = 10 * time.Millisecond
29 | kMinRTOTimeout = 200 * time.Millisecond
30 | kDelayedAckTimeout = 25 * time.Millisecond
31 | kDefaultInitialRtt = 100 * time.Millisecond
32 | )
33 |
34 | type CongestionController interface {
35 | onPacketSent(pn uint64, isAckOnly bool, sentBytes int)
36 | onAckReceived(acks ackRanges, delay time.Duration)
37 | bytesAllowedToSend() int
38 | setLostPacketHandler(handler func(pn uint64))
39 | rto() time.Duration
40 | }
41 |
42 | /*
43 | * DUMMY congestion controller
44 | */
45 |
46 | type CongestionControllerDummy struct {
47 | }
48 |
49 | func (cc *CongestionControllerDummy) onPacketSent(pn uint64, isAckOnly bool, sentBytes int) {
50 | }
51 |
52 | func (cc *CongestionControllerDummy) onAckReceived(acks ackRanges, delay time.Duration) {
53 | }
54 |
55 | func (cc *CongestionControllerDummy) bytesAllowedToSend() int {
56 | /* return the the maximum int value */
57 | return int(^uint(0) >> 1)
58 | }
59 |
60 | func (cc *CongestionControllerDummy) setLostPacketHandler(handler func(pn uint64)) {
61 | }
62 |
63 | func (cc *CongestionControllerDummy) rto() time.Duration {
64 | return kMinRTOTimeout
65 | }
66 |
67 | /*
68 | * draft-ietf-quic-recovery congestion controller
69 | */
70 |
71 | type CongestionControllerIetf struct {
72 | // Congestion control related
73 | bytesInFlight int
74 | congestionWindow int
75 | endOfRecovery uint64
76 | sstresh int
77 |
78 | // Loss detection related
79 | lossDetectionAlarm int //TODO(ekr@rtfm.com) set this to the right type
80 | handshakeCount int
81 | tlpCount int
82 | rtoCount int
83 | largestSendBeforeRto uint64
84 | timeOfLastSentPacket time.Time
85 | largestSendPacket uint64
86 | largestAckedPacket uint64
87 | maxAckDelay time.Duration
88 | minRtt time.Duration
89 | // largestRtt time.Duration
90 | smoothedRtt time.Duration
91 | rttVar time.Duration
92 | smoothedRttTcp time.Duration
93 | rttVarTcp time.Duration
94 | reorderingThreshold int
95 | timeReorderingFraction float32
96 | lossTime time.Time
97 | sentPackets map[uint64]packetEntry
98 |
99 | // others
100 | lostPacketHandler func(pn uint64)
101 | conn *Connection
102 | }
103 |
104 | type packetEntry struct {
105 | pn uint64
106 | txTime time.Time
107 | bytes int
108 | ackOnly bool
109 | }
110 |
111 | func (cc *CongestionControllerIetf) onPacketSent(pn uint64, isAckOnly bool, sentBytes int) {
112 | cc.timeOfLastSentPacket = time.Now()
113 | cc.largestSendPacket = pn
114 | packetData := packetEntry{pn, time.Now(), 0, isAckOnly}
115 | cc.conn.log(logTypeCongestion, "Packet send pn: %d len:%d ackonly: %v\n", pn, sentBytes, isAckOnly)
116 | if !isAckOnly {
117 | cc.onPacketSentCC(sentBytes)
118 | packetData.bytes = sentBytes
119 | cc.setLossDetectionAlarm()
120 | }
121 | cc.sentPackets[pn] = packetData
122 | }
123 |
124 | // acks is received to be a sorted list, where the largest packet numbers are at the beginning
125 | func (cc *CongestionControllerIetf) onAckReceived(acks ackRanges, ackDelay time.Duration) {
126 |
127 | // keep track of largest packet acked overall
128 | if acks[0].lastPacket > cc.largestAckedPacket {
129 | cc.largestAckedPacket = acks[0].lastPacket
130 | }
131 |
132 | // If the largest acked is newly acked update rtt
133 | lastPacket, present := cc.sentPackets[acks[0].lastPacket]
134 | if present {
135 | latestRtt := time.Since(cc.sentPackets[acks[0].lastPacket].txTime)
136 | cc.conn.log(logTypeCongestion, "latestRtt: %v, ackDelay: %v", latestRtt, ackDelay)
137 | cc.updateRttTcp(latestRtt)
138 |
139 | // Update the minRtt, but ignore ackDelay.
140 | if latestRtt < cc.minRtt {
141 | cc.minRtt = latestRtt
142 | }
143 |
144 | // Now reduce by ackDelay if it doesn't reduce the RTT below the minimum.
145 | if latestRtt-cc.minRtt > ackDelay {
146 | latestRtt -= ackDelay
147 | // And update the maximum observed ACK delay.
148 | if !lastPacket.ackOnly && ackDelay > cc.maxAckDelay {
149 | cc.maxAckDelay = ackDelay
150 | }
151 | }
152 |
153 | cc.updateRtt(latestRtt)
154 | }
155 |
156 | // find and proccess newly acked packets
157 | for _, ackBlock := range acks {
158 | for pn := ackBlock.lastPacket; pn > (ackBlock.lastPacket - ackBlock.count); pn-- {
159 | cc.conn.log(logTypeCongestion, "Ack for pn %d received", pn)
160 | _, present := cc.sentPackets[pn]
161 | if present {
162 | cc.conn.log(logTypeCongestion, "First ack for pn %d received", pn)
163 | cc.onPacketAcked(pn)
164 | }
165 | }
166 | }
167 |
168 | cc.detectLostPackets()
169 | cc.setLossDetectionAlarm()
170 | }
171 |
172 | func (cc *CongestionControllerIetf) setLostPacketHandler(handler func(pn uint64)) {
173 | cc.lostPacketHandler = handler
174 | }
175 |
176 | func (cc *CongestionControllerIetf) updateRtt(latestRtt time.Duration) {
177 | if cc.smoothedRtt == 0 {
178 | cc.smoothedRtt = latestRtt
179 | cc.rttVar = time.Duration(int64(latestRtt) / 2)
180 | } else {
181 | rttDelta := cc.smoothedRtt - latestRtt
182 | if rttDelta < 0 {
183 | rttDelta = -rttDelta
184 | }
185 | cc.rttVar = time.Duration(int64(cc.rttVar)*3/4 + int64(rttDelta)*1/4)
186 | cc.smoothedRtt = time.Duration(int64(cc.smoothedRtt)*7/8 + int64(latestRtt)*1/8)
187 | }
188 | cc.conn.log(logTypeCongestion, "New RTT estimate: %v, variance: %v", cc.smoothedRtt, cc.rttVar)
189 | }
190 |
191 | func (cc *CongestionControllerIetf) updateRttTcp(latestRtt time.Duration) {
192 | if cc.smoothedRttTcp == 0 {
193 | cc.smoothedRttTcp = latestRtt
194 | cc.rttVarTcp = time.Duration(int64(latestRtt) / 2)
195 | } else {
196 | rttDelta := cc.smoothedRttTcp - latestRtt
197 | if rttDelta < 0 {
198 | rttDelta = -rttDelta
199 | }
200 | cc.rttVarTcp = time.Duration(int64(cc.rttVarTcp)*3/4 + int64(rttDelta)*3/4)
201 | cc.smoothedRttTcp = time.Duration(int64(cc.smoothedRttTcp)*7/8 + int64(latestRtt)*1/8)
202 | }
203 | cc.conn.log(logTypeCongestion, "New RTT(TCP) estimate: %v, variance: %v", cc.smoothedRttTcp, cc.rttVarTcp)
204 | }
205 |
206 | func (cc *CongestionControllerIetf) rto() time.Duration {
207 | // max(SRTT + 4*RTTVAR + MaxAckDelay, minRTO)
208 | rto := cc.smoothedRtt + 4*cc.rttVar + cc.maxAckDelay
209 | if rto < kMinRTOTimeout {
210 | return kMinRTOTimeout
211 | }
212 | return rto
213 | }
214 |
215 | func (cc *CongestionControllerIetf) onPacketAcked(pn uint64) {
216 | cc.onPacketAckedCC(pn)
217 | //TODO(ekr@rtfm.com) some RTO stuff here
218 | delete(cc.sentPackets, pn)
219 | }
220 |
221 | func (cc *CongestionControllerIetf) setLossDetectionAlarm() {
222 | //TODO(ekr@rtfm.com)
223 | }
224 |
225 | func (cc *CongestionControllerIetf) onLossDetectionAlarm() {
226 | //TODO(ekr@rtfm.com)
227 | }
228 |
229 | func (cc *CongestionControllerIetf) detectLostPackets() {
230 | var lostPackets []packetEntry
231 | //TODO(ekr@rtfm.com) implement loss detection different from reorderingThreshold
232 | for _, packet := range cc.sentPackets {
233 | if (cc.largestAckedPacket > packet.pn) &&
234 | (cc.largestAckedPacket-packet.pn > uint64(cc.reorderingThreshold)) {
235 | lostPackets = append(lostPackets, packet)
236 | }
237 | }
238 |
239 | if len(lostPackets) > 0 {
240 | cc.onPacketsLost(lostPackets)
241 | }
242 | for _, packet := range lostPackets {
243 | delete(cc.sentPackets, packet.pn)
244 | }
245 | }
246 |
247 | func (cc *CongestionControllerIetf) onPacketSentCC(bytes_sent int) {
248 | cc.bytesInFlight += bytes_sent
249 | cc.conn.log(logTypeCongestion, "%d bytes added to bytesInFlight", bytes_sent)
250 | }
251 |
252 | func (cc *CongestionControllerIetf) onPacketAckedCC(pn uint64) {
253 | cc.bytesInFlight -= cc.sentPackets[pn].bytes
254 | cc.conn.log(logTypeCongestion, "%d bytes from packet %d removed from bytesInFlight", cc.sentPackets[pn].bytes, pn)
255 |
256 | if pn < cc.endOfRecovery {
257 | // Do not increase window size during recovery
258 | return
259 | }
260 | if cc.congestionWindow < cc.sstresh {
261 | // Slow start
262 | cc.congestionWindow += cc.sentPackets[pn].bytes
263 | cc.conn.log(logTypeCongestion, "PDV Slow Start: increasing window size with %d bytes to %d",
264 | cc.sentPackets[pn].bytes, cc.congestionWindow)
265 | } else {
266 |
267 | // Congestion avoidance
268 | cc.congestionWindow += kDefaultMss * cc.sentPackets[pn].bytes / cc.congestionWindow
269 | cc.conn.log(logTypeCongestion, "PDV Congestion Avoidance: increasing window size to %d",
270 | cc.congestionWindow)
271 | }
272 | }
273 |
274 | func (cc *CongestionControllerIetf) onPacketsLost(packets []packetEntry) {
275 | var largestLostPn uint64 = 0
276 | for _, packet := range packets {
277 |
278 | // First remove lost packets from bytesInFlight and inform the connection
279 | // of the loss
280 | cc.conn.log(logTypeCongestion, "Packet pn: %d len: %d is lost", packet.pn, packet.bytes)
281 | cc.bytesInFlight -= packet.bytes
282 | if cc.lostPacketHandler != nil {
283 | cc.lostPacketHandler(packet.pn)
284 | }
285 |
286 | // and keep track of the largest lost packet
287 | if packet.pn > largestLostPn {
288 | largestLostPn = packet.pn
289 | }
290 | }
291 |
292 | // Now start a new recovery epoch if the largest lost packet is larger than the
293 | // end of the previous recovery epoch
294 | if cc.endOfRecovery < largestLostPn {
295 | cc.endOfRecovery = cc.largestSendPacket
296 | cc.congestionWindow = int(float32(cc.congestionWindow) * kLossReductionFactor)
297 | if kMinimumWindow > cc.congestionWindow {
298 | cc.congestionWindow = kMinimumWindow
299 | }
300 | cc.sstresh = cc.congestionWindow
301 | cc.conn.log(logTypeCongestion, "PDV Recovery started. Window size: %d, sstresh: %d, endOfRecovery %d",
302 | cc.congestionWindow, cc.sstresh, cc.endOfRecovery)
303 | }
304 | }
305 |
306 | func (cc *CongestionControllerIetf) bytesAllowedToSend() int {
307 | cc.conn.log(logTypeCongestion, "Remaining congestion window size: %d", cc.congestionWindow-cc.bytesInFlight)
308 | return cc.congestionWindow - cc.bytesInFlight
309 | }
310 |
311 | func newCongestionControllerIetf(conn *Connection) *CongestionControllerIetf {
312 | return &CongestionControllerIetf{
313 | 0, // bytesInFlight
314 | kInitalWindow, // congestionWindow
315 | 0, // endOfRecovery
316 | int(^uint(0) >> 1), // sstresh
317 | 0, // lossDetectionAlarm
318 | 0, // handshakeCount
319 | 0, // tlpCount
320 | 0, // rtoCount
321 | 0, // largestSendBeforeRto
322 | time.Unix(0, 0), // timeOfLastSentPacket
323 | 0, // largestSendPacket
324 | 0, // largestAckedPacket
325 | 0, // maxAckDelay
326 | 100 * time.Second, // minRtt
327 | 0, // smoothedRtt
328 | 0, // rttVar
329 | 0, // smoothedRttTcp
330 | 0, // rttVarTcp
331 | kReorderingThreshold, // reorderingThreshold
332 | math.MaxFloat32, // timeReorderingFraction
333 | time.Unix(0, 0), // lossTime
334 | make(map[uint64]packetEntry), // sentPackets
335 | nil, // lostPacketHandler
336 | conn, // conn
337 | }
338 | }
339 |
--------------------------------------------------------------------------------
/transport_parameters.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "bytes"
5 | "crypto/rand"
6 | "encoding/hex"
7 | "fmt"
8 |
9 | "github.com/bifurcation/mint"
10 | "github.com/bifurcation/mint/syntax"
11 | )
12 |
13 | const (
14 | kQuicTransportParamtersXtn = mint.ExtensionType(0xffa5)
15 | )
16 |
17 | type TransportParameterId uint16
18 |
19 | const (
20 | kTpIdInitialMaxStreamDataBidiLocal = TransportParameterId(0x0000)
21 | kTpIdInitialMaxData = TransportParameterId(0x0001)
22 | kTpIdInitialMaxBidiStreams = TransportParameterId(0x0002)
23 | kTpIdIdleTimeout = TransportParameterId(0x0003)
24 | kTpPreferredAddress = TransportParameterId(0x0004)
25 | kTpIdMaxPacketSize = TransportParameterId(0x0005)
26 | kTpIdStatelessResetToken = TransportParameterId(0x0006)
27 | kTpIdAckDelayExponent = TransportParameterId(0x0007)
28 | kTpIdInitialMaxUniStreams = TransportParameterId(0x0008)
29 | kTpIdDisableMigration = TransportParameterId(0x0009)
30 | kTpIdInitialMaxStreamDataBidiRemote = TransportParameterId(0x0010)
31 | kTpIdInitialMaxStreamDataUni = TransportParameterId(0x0011)
32 | )
33 |
34 | const (
35 | kTpDefaultAckDelayExponent = byte(3)
36 | )
37 |
38 | type tpDef struct {
39 | parameter TransportParameterId
40 | val uint32
41 | size uintptr
42 | }
43 |
44 | var (
45 | kInitialMaxData = uint64(65536)
46 | kInitialMaxStreamData = uint64(8192)
47 | kConcurrentStreamsBidi = 16
48 | kConcurrentStreamsUni = 16
49 | kTransportParameterDefaults = []tpDef{
50 | {kTpIdInitialMaxStreamDataBidiLocal, uint32(kInitialMaxStreamData), 4},
51 | {kTpIdInitialMaxStreamDataBidiRemote, uint32(kInitialMaxStreamData), 4},
52 | {kTpIdInitialMaxStreamDataUni, uint32(kInitialMaxStreamData), 4},
53 | {kTpIdInitialMaxData, uint32(kInitialMaxData), 4},
54 | {kTpIdInitialMaxBidiStreams, uint32(kConcurrentStreamsBidi), 2},
55 | {kTpIdIdleTimeout, 5, 2},
56 | {kTpIdInitialMaxUniStreams, uint32(kConcurrentStreamsUni), 2},
57 | }
58 | )
59 |
60 | type transportParameters struct {
61 | maxStreamDataUni uint32
62 | maxStreamDataBidiLocal uint32
63 | maxStreamDataBidiRemote uint32
64 | maxData uint32
65 | maxStreamsBidi int
66 | maxStreamsUni int
67 | idleTimeout uint16
68 | ackDelayExp uint8
69 | }
70 |
71 | type TransportParameterList []transportParameter
72 |
73 | type transportParameter struct {
74 | Parameter TransportParameterId
75 | Value []byte `tls:"head=2"`
76 | }
77 |
78 | type clientHelloTransportParameters struct {
79 | InitialVersion VersionNumber
80 | Parameters TransportParameterList `tls:"head=2"`
81 | }
82 |
83 | type encryptedExtensionsTransportParameters struct {
84 | NegotiatedVersion VersionNumber
85 | SupportedVersions []VersionNumber `tls:"head=1"`
86 | Parameters TransportParameterList `tls:"head=2"`
87 | }
88 |
89 | func (tp *TransportParameterList) addUintParameter(id TransportParameterId, val uint32, size uintptr) error {
90 | var buf bytes.Buffer
91 | uintEncodeInt(&buf, uint64(val), size)
92 | *tp = append(*tp, transportParameter{
93 | id,
94 | buf.Bytes(),
95 | })
96 | return nil
97 | }
98 |
99 | func (tp *TransportParameterList) getParameter(id TransportParameterId) []byte {
100 | for _, ex := range *tp {
101 | if ex.Parameter == id {
102 | return ex.Value
103 | }
104 | }
105 | return nil
106 | }
107 |
108 | func (tp *TransportParameterList) getUintParameter(id TransportParameterId, size uintptr) (uint32, error) {
109 | assert(size <= 4)
110 |
111 | b := tp.getParameter(id)
112 | if b == nil {
113 | logf(logTypeHandshake, "Missing transport parameter %v", id)
114 | return 0, ErrorMissingValue
115 | }
116 |
117 | if len(b) != int(size) {
118 | logf(logTypeHandshake, "Bogus transport parameter %v", id)
119 | return 0, ErrorInvalidEncoding
120 | }
121 |
122 | buf := bytes.NewReader(b)
123 | tmp, err := uintDecodeInt(buf, size)
124 | if err != nil {
125 | return 0, err
126 | }
127 |
128 | return uint32(tmp), nil
129 | }
130 |
131 | func (tp *TransportParameterList) getUintParameterOrDefault(id TransportParameterId, size uintptr, def uint32) (uint32, error) {
132 | assert(size <= 4)
133 |
134 | b := tp.getParameter(id)
135 | if b == nil {
136 | logf(logTypeHandshake, "Missing transport parameter %v", id)
137 | return def, nil
138 | }
139 |
140 | if len(b) != int(size) {
141 | logf(logTypeHandshake, "Bogus transport parameter %v", id)
142 | return 0, ErrorInvalidEncoding
143 | }
144 |
145 | buf := bytes.NewReader(b)
146 | tmp, err := uintDecodeInt(buf, size)
147 | if err != nil {
148 | return 0, err
149 | }
150 |
151 | return uint32(tmp), nil
152 | }
153 |
154 | func (tp *TransportParameterList) addOpaqueParameter(id TransportParameterId, b []byte) error {
155 | *tp = append(*tp, transportParameter{
156 | id,
157 | b,
158 | })
159 | return nil
160 | }
161 |
162 | func (tp *TransportParameterList) createCommonTransportParameters() error {
163 | for _, p := range kTransportParameterDefaults {
164 | err := tp.addUintParameter(p.parameter, p.val, p.size)
165 | if err != nil {
166 | return err
167 | }
168 | }
169 |
170 | return nil
171 | }
172 |
173 | // Implement mint.AppExtensionHandler.
174 | type transportParametersXtnBody struct {
175 | body []byte
176 | }
177 |
178 | func (t transportParametersXtnBody) Type() mint.ExtensionType {
179 | return kQuicTransportParamtersXtn
180 | }
181 |
182 | func (t transportParametersXtnBody) Marshal() ([]byte, error) {
183 | return t.body, nil
184 | }
185 |
186 | func (t *transportParametersXtnBody) Unmarshal(data []byte) (int, error) {
187 | t.body = data
188 | return len(t.body), nil
189 | }
190 |
191 | type transportParametersHandler struct {
192 | log loggingFunction
193 | role Role
194 | version VersionNumber
195 | peerParams *transportParameters
196 | }
197 |
198 | func newTransportParametersHandler(log loggingFunction, role Role, version VersionNumber) *transportParametersHandler {
199 | return &transportParametersHandler{log, role, version, nil}
200 | }
201 |
202 | func (h *transportParametersHandler) setDummyPeerParams() {
203 | h.peerParams = &transportParameters{
204 | uint32(kInitialMaxStreamData),
205 | uint32(kInitialMaxStreamData),
206 | uint32(kInitialMaxStreamData),
207 | uint32(kInitialMaxData),
208 | kConcurrentStreamsBidi,
209 | kConcurrentStreamsUni,
210 | 600,
211 | uint8(1),
212 | }
213 | }
214 |
215 | func (h *transportParametersHandler) Send(hs mint.HandshakeType, el *mint.ExtensionList) error {
216 | if h.role == RoleClient {
217 | h.log(logTypeHandshake, "Sending transport parameters")
218 | if hs != mint.HandshakeTypeClientHello {
219 | return nil
220 | }
221 | b, err := h.createClientHelloTransportParameters()
222 | if err != nil {
223 | return err
224 | }
225 | h.log(logTypeTrace, "ClientHelloTransportParameters=%s", hex.EncodeToString(b))
226 | el.Add(&transportParametersXtnBody{b})
227 | return nil
228 | }
229 |
230 | if h.peerParams == nil {
231 | return nil
232 | }
233 |
234 | if hs != mint.HandshakeTypeEncryptedExtensions {
235 | return nil
236 | }
237 |
238 | h.log(logTypeHandshake, "Sending transport parameters message")
239 | b, err := h.createEncryptedExtensionsTransportParameters()
240 | if err != nil {
241 | return err
242 | }
243 | el.Add(&transportParametersXtnBody{b})
244 | return nil
245 | }
246 |
247 | func (h *transportParametersHandler) Receive(hs mint.HandshakeType, el *mint.ExtensionList) error {
248 | h.log(logTypeHandshake, "%p TransportParametersHandler message=%d", h, hs)
249 | // First see if the other side sent the extension.
250 | var body transportParametersXtnBody
251 | found, err := el.Find(&body)
252 |
253 | if err != nil {
254 | return fmt.Errorf("Invalid transport parameters")
255 | }
256 |
257 | if found {
258 | h.log(logTypeTrace, "Retrieved transport parameters len=%d %v", len(body.body), hex.EncodeToString(body.body))
259 | }
260 |
261 | var params *TransportParameterList
262 |
263 | switch hs {
264 | case mint.HandshakeTypeEncryptedExtensions:
265 | if h.role != RoleClient {
266 | return fmt.Errorf("EncryptedExtensions received but not a client")
267 | }
268 | if !found {
269 | h.log(logTypeHandshake, "Missing transport parameters")
270 | return fmt.Errorf("Missing transport parameters")
271 | }
272 | var eeParams encryptedExtensionsTransportParameters
273 | _, err = syntax.Unmarshal(body.body, &eeParams)
274 | if err != nil {
275 | h.log(logTypeHandshake, "Failed to decode parameters")
276 | return err
277 | }
278 | params = &eeParams.Parameters
279 | // TODO(ekr@rtfm.com): Process version #s
280 | case mint.HandshakeTypeClientHello:
281 | if h.role != RoleServer {
282 | return fmt.Errorf("ClientHello received but not a server")
283 | }
284 | if !found {
285 | h.log(logTypeHandshake, "Missing transport parameters")
286 | return fmt.Errorf("Missing transport parameters")
287 | }
288 |
289 | // TODO(ekr@rtfm.com): Process version #s
290 | var chParams clientHelloTransportParameters
291 | _, err = syntax.Unmarshal(body.body, &chParams)
292 | if err != nil {
293 | h.log(logTypeHandshake, "Couldn't unmarshal %v", err)
294 | return err
295 | }
296 | params = &chParams.Parameters
297 | default:
298 | if found {
299 | return fmt.Errorf("Received quic_transport_parameters in inappropriate message %v", hs)
300 | }
301 | return nil
302 | }
303 |
304 | // Now try to process each param.
305 | // TODO(ekr@rtfm.com): Enforce that each param appears only once.
306 | var tp transportParameters
307 | h.log(logTypeHandshake, "Reading transport parameters values")
308 |
309 | tp.maxStreamDataBidiLocal, err = params.getUintParameterOrDefault(kTpIdInitialMaxStreamDataBidiLocal, 4, 0)
310 | if err != nil {
311 | return err
312 | }
313 |
314 | tp.maxStreamDataBidiRemote, err = params.getUintParameterOrDefault(kTpIdInitialMaxStreamDataBidiRemote, 4, 0)
315 | if err != nil {
316 | return err
317 | }
318 |
319 | tp.maxStreamDataUni, err = params.getUintParameterOrDefault(kTpIdInitialMaxStreamDataUni, 4, 0)
320 | if err != nil {
321 | return err
322 | }
323 |
324 | tp.maxData, err = params.getUintParameterOrDefault(kTpIdInitialMaxData, 4, 0)
325 | if err != nil {
326 | return err
327 | }
328 |
329 | tmp, err := params.getUintParameterOrDefault(kTpIdInitialMaxBidiStreams, 2, 0)
330 | if err != nil {
331 | return err
332 | }
333 | tp.maxStreamsBidi = int(tmp)
334 |
335 | if h.role == RoleClient {
336 | tp.maxStreamsBidi++ // Allow for stream 0.
337 | }
338 |
339 | tmp, err = params.getUintParameterOrDefault(kTpIdInitialMaxUniStreams, 2, 0)
340 | if err != nil {
341 | return err
342 | }
343 | tp.maxStreamsUni = int(tmp)
344 |
345 | tmp, err = params.getUintParameter(kTpIdIdleTimeout, 2)
346 | if err != nil {
347 | return err
348 | }
349 | tp.idleTimeout = uint16(tmp)
350 |
351 | tmp, err = params.getUintParameterOrDefault(kTpIdAckDelayExponent, 1, 0)
352 | if err != nil {
353 | return err
354 | }
355 |
356 | h.peerParams = &tp
357 |
358 | h.log(logTypeHandshake, "Finished reading transport parameters")
359 | return nil
360 | }
361 |
362 | func (h *transportParametersHandler) createClientHelloTransportParameters() ([]byte, error) {
363 | chtp := clientHelloTransportParameters{
364 | h.version,
365 | nil,
366 | }
367 |
368 | err := chtp.Parameters.createCommonTransportParameters()
369 | if err != nil {
370 | return nil, err
371 | }
372 |
373 | b, err := syntax.Marshal(chtp)
374 | if err != nil {
375 | return nil, err
376 | }
377 | return b, nil
378 | }
379 |
380 | func (h *transportParametersHandler) createEncryptedExtensionsTransportParameters() ([]byte, error) {
381 | eetp := encryptedExtensionsTransportParameters{
382 | h.version,
383 | []VersionNumber{
384 | h.version,
385 | },
386 | nil,
387 | }
388 |
389 | err := eetp.Parameters.createCommonTransportParameters()
390 | if err != nil {
391 | return nil, err
392 | }
393 |
394 | b := make([]byte, 16)
395 | _, err = rand.Read(b)
396 | if err != nil {
397 | return nil, err
398 | }
399 |
400 | eetp.Parameters.addOpaqueParameter(kTpIdStatelessResetToken, b)
401 |
402 | b, err = syntax.Marshal(eetp)
403 | if err != nil {
404 | return nil, err
405 | }
406 | return b, nil
407 | }
408 |
--------------------------------------------------------------------------------
/packet.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "bytes"
5 | "crypto/aes"
6 | "crypto/cipher"
7 | "encoding/hex"
8 | "fmt"
9 | )
10 |
11 | // Encode a QUIC packet.
12 | /*
13 | Long header
14 |
15 | 0 1 2 3
16 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
17 | +-+-+-+-+-+-+-+-+
18 | |1| Type (7) |
19 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
20 | | Version (32) |
21 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
22 | |DCIL(4)|SCIL(4)|
23 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
24 | | Destination Connection ID (0/32..144) ...
25 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
26 | | Source Connection ID (0/32..144) ...
27 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
28 | | Payload Length (i) ...
29 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
30 | | Packet Number (8/16/32) |
31 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
32 | | Payload (*) ...
33 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
34 |
35 | // Initial Header: same as long header but with Token
36 | +-+-+-+-+-+-+-+-+
37 | |1| 0x7f |
38 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
39 | | Version (32) |
40 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
41 | |DCIL(4)|SCIL(4)|
42 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
43 | | Destination Connection ID (0/32..144) ...
44 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
45 | | Source Connection ID (0/32..144) ...
46 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
47 | | Token Length (i) ...
48 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
49 | | Token (*) ...
50 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
51 | | Length (i) ...
52 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
53 | | Packet Number (8/16/32) |
54 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
55 | | Payload (*) ...
56 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
57 |
58 | 0 1 2 3
59 | 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
60 | +-+-+-+-+-+-+-+-+
61 | |0|K|1|1|0|R R R|
62 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
63 | | Destination Connection ID (0..144) ...
64 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
65 | | Packet Number (8/16/32) ...
66 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
67 | | Protected Payload (*) ...
68 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
69 | */
70 |
71 | const (
72 | packetFlagLongHeader = byte(0x80)
73 | packetFlagK = byte(0x40)
74 | packetFlagShortHeader = byte(0x30)
75 | )
76 |
77 | // This packet type differs considerably from the spec. It includes both
78 | // long and short headers in the same value space. Long headers are from
79 | // 0-0x7f (inclusive); short headers are always represented as 0xff.
80 | type packetType byte
81 |
82 | const (
83 | packetTypeInitial = packetType(0x7f)
84 | packetTypeRetry = packetType(0x7e)
85 | packetTypeHandshake = packetType(0x7d)
86 | packetType0RTTProtected = packetType(0x7c)
87 | packetTypeProtectedShort = packetType(0x00) // Not a real type
88 | )
89 |
90 | func (pt packetType) isLongHeader() bool {
91 | return pt&packetType(packetFlagLongHeader) != 0
92 | }
93 |
94 | func (pt packetType) isProtected() bool {
95 | if !pt.isLongHeader() {
96 | return true
97 | }
98 |
99 | switch pt & 0x7f {
100 | case packetTypeInitial, packetTypeHandshake, packetTypeRetry:
101 | return false
102 | }
103 | return true
104 | }
105 |
106 | func (pt packetType) String() string {
107 | switch pt {
108 | case packetTypeInitial:
109 | return "Initial"
110 | case packetTypeRetry:
111 | return "Retry"
112 | case packetTypeHandshake:
113 | return "Handshake"
114 | case packetType0RTTProtected:
115 | return "0-RTT"
116 | case packetTypeProtectedShort:
117 | return "1-RTT"
118 | default:
119 | return fmt.Sprintf("%x", uint8(pt))
120 | }
121 | }
122 |
123 | // kCidDefaultLength is the length of connection ID we generate.
124 | // TODO: make this configurable.
125 | const kCidDefaultLength = 5
126 |
127 | // ConnectionId identifies the connection that a packet belongs to.
128 | type ConnectionId []byte
129 |
130 | // String stringifies a connection ID in the natural way.
131 | func (c ConnectionId) String() string {
132 | return hex.EncodeToString(c)
133 | }
134 |
135 | // EncodeLength produces the length encoding used in the long packet header.
136 | func (c ConnectionId) EncodeLength() byte {
137 | if len(c) == 0 {
138 | return 0
139 | }
140 | assert(len(c) >= 4 && len(c) <= 18)
141 | return byte(len(c) - 3)
142 | }
143 |
144 | // The PDU definition for the header.
145 | // These types are capitalized so that |codec| can use them.
146 | type packetHeader struct {
147 | // Type is the on-the-wire form of the packet type.
148 | // Consult getHeaderType if you want a value that corresponds to the
149 | // definition of packetType.
150 | Type packetType
151 | Version VersionNumber
152 | ConnectionIDLengths byte
153 | DestinationConnectionID ConnectionId
154 | SourceConnectionID ConnectionId
155 | TokenLength uint8
156 | Token []byte
157 | PayloadLength uint64 `tls:"varint"`
158 |
159 | // In order to decode a short header, the length of the connection
160 | // ID must be set in |shortCidLength| before decoding.
161 | shortCidLength uintptr
162 | }
163 |
164 | func (p packetHeader) String() string {
165 | ht := "SHORT"
166 | if p.Type.isLongHeader() {
167 | ht = "LONG"
168 | }
169 | return fmt.Sprintf("%s PT=%v", ht, p.getHeaderType())
170 | }
171 |
172 | func (p *packetHeader) getHeaderType() packetType {
173 | if p.Type.isLongHeader() {
174 | return p.Type & 0x7f
175 | }
176 | return packetTypeProtectedShort
177 | }
178 |
179 | type packet struct {
180 | packetHeader
181 | PacketNumber uint64 // Never more than 32 bits on the wire.
182 | payload []byte
183 | }
184 |
185 | // This reads from p.ConnectionIDLengths.
186 | func (p packetHeader) ConnectionIDLengths__length() uintptr {
187 | if p.Type.isLongHeader() {
188 | return 1
189 | }
190 | return 0
191 | }
192 |
193 | func (p packetHeader) TokenLength__length() uintptr {
194 | if p.getHeaderType() != packetTypeInitial {
195 | assert(len(p.Token) == 0)
196 | return 0
197 | }
198 | return 1
199 | }
200 |
201 | func (p packetHeader) Token__length() uintptr {
202 | if p.getHeaderType() != packetTypeInitial {
203 | assert(len(p.Token) == 0)
204 | return 0
205 | }
206 | return uintptr(p.TokenLength)
207 | }
208 |
209 | func (p packetHeader) DestinationConnectionID__length() uintptr {
210 | if !p.Type.isLongHeader() {
211 | return p.shortCidLength
212 | }
213 | l := p.ConnectionIDLengths >> 4
214 | if l != 0 {
215 | l += 3
216 | }
217 | return uintptr(l)
218 | }
219 |
220 | func (p packetHeader) SourceConnectionID__length() uintptr {
221 | if !p.Type.isLongHeader() {
222 | return 0
223 | }
224 | l := p.ConnectionIDLengths & 0xf
225 | if l != 0 {
226 | l += 3
227 | }
228 | return uintptr(l)
229 | }
230 |
231 | func (p packetHeader) PayloadLength__length() uintptr {
232 | if p.Type.isLongHeader() {
233 | return codecDefaultSize
234 | }
235 | return 0
236 | }
237 |
238 | func (p packetHeader) Version__length() uintptr {
239 | if p.Type.isLongHeader() {
240 | return 4
241 | }
242 | return 0
243 | }
244 |
245 | func newPacket(pt packetType, destCid ConnectionId, srcCid ConnectionId, ver VersionNumber, pn uint64, payload []byte, aeadOverhead int) *packet {
246 | if pt == packetTypeProtectedShort {
247 | // Only support writing the 32-bit packet number.
248 | pt = packetType(0x2 | packetFlagShortHeader)
249 | srcCid = nil
250 | } else {
251 | pt = pt | packetType(packetFlagLongHeader)
252 | }
253 | lengths := (destCid.EncodeLength() << 4) | srcCid.EncodeLength()
254 | return &packet{
255 | packetHeader: packetHeader{
256 | Type: pt,
257 | ConnectionIDLengths: lengths,
258 | DestinationConnectionID: destCid,
259 | SourceConnectionID: srcCid,
260 | Version: ver,
261 | PayloadLength: uint64(len(payload) + 4 + aeadOverhead),
262 | },
263 | PacketNumber: pn,
264 | payload: payload,
265 | }
266 | }
267 |
268 | type versionNegotiationPacket struct {
269 | Versions []byte
270 | }
271 |
272 | func newVersionNegotiationPacket(versions []VersionNumber) *versionNegotiationPacket {
273 | var buf bytes.Buffer
274 |
275 | for _, v := range versions {
276 | buf.Write(encodeArgs(v))
277 | }
278 |
279 | return &versionNegotiationPacket{buf.Bytes()}
280 | }
281 |
282 | /*
283 | We don't use these.
284 |
285 | func encodePacket(c ConnectionState, aead Aead, p *Packet) ([]byte, error) {
286 | hdr, err := encode(&p.packetHeader)
287 | if err != nil {
288 | return nil, err
289 | }
290 |
291 | b, err := aead.protect(p.packetHeader.PacketNumber, hdr, p.payload)
292 | if err != nil {
293 | return nil, err
294 | }
295 |
296 | return encodeArgs(hdr, b), nil
297 | }
298 |
299 | func decodePacket(c ConnectionState, aead Aead, b []byte) (*Packet, error) {
300 | // Parse the header
301 | var hdr packetHeader
302 | br, err := decode(&hdr, b)
303 | if err != nil {
304 | return nil, err
305 | }
306 |
307 | hdr.PacketNumber = c.expandPacketNumber(hdr.PacketNumber)
308 | pt, err := aead.unprotect(hdr.PacketNumber, b[0:br], b[br:])
309 | if err != nil {
310 | return nil, err
311 | }
312 |
313 | return &Packet{hdr, pt}, nil
314 | }
315 | */
316 |
317 | func dumpPacket(payload []byte) string {
318 | first := true
319 | ret := fmt.Sprintf("%d=[", len(payload))
320 |
321 | for len(payload) > 0 {
322 | if !first {
323 | ret += ", "
324 | }
325 | first = false
326 | n, f, err := decodeFrame(payload)
327 | if err != nil {
328 | ret += fmt.Sprintf("Undecoded: [%x]", payload)
329 | break
330 | }
331 | payload = payload[n:]
332 | // TODO(ekr@rtfm.com): Not sure why %v doesn't work
333 | ret += f.String()
334 | }
335 | ret += "]"
336 | return ret
337 | }
338 |
339 | type pneCipherFactory interface {
340 | create(sample []byte) cipher.Stream
341 | }
342 |
343 | type pneCipherFactoryAES struct {
344 | block cipher.Block
345 | }
346 |
347 | func newPneCipherFactoryAES(key []byte) pneCipherFactory {
348 | inner, err := aes.NewCipher(key)
349 | assert(err == nil)
350 | if err != nil {
351 | return nil
352 | }
353 | return &pneCipherFactoryAES{block: inner}
354 | }
355 |
356 | func (f *pneCipherFactoryAES) create(sample []byte) cipher.Stream {
357 | if len(sample) != 16 {
358 | return nil
359 | }
360 | return cipher.NewCTR(f.block, sample)
361 | }
362 |
363 | func xorPacketNumber(hdr *packetHeader, hdrlen int, pnbuf []byte, p []byte, factory pneCipherFactory) error {
364 | logf(logTypeTrace, "PNE Operation: hdrlen=%v, hdr=%x, payload=%x", hdrlen, p[:hdrlen], p)
365 |
366 | // The packet must be at least long enough to contain
367 | // the header, plus a minimum 1-byte PN, plus the sample.
368 | sample_length := 16
369 | if sample_length > len(p)-(hdrlen+1) {
370 | logf(logTypePacket, "Packet too short")
371 | return nil
372 | }
373 |
374 | // Now compute the offset
375 | sample_offset := hdrlen + 4
376 | if sample_offset+sample_length > len(p) {
377 | sample_offset = len(p) - sample_length
378 | }
379 |
380 | sample := p[sample_offset : sample_offset+sample_length]
381 | logf(logTypeTrace, "PNE sample_offset=%d sample=%x", sample_offset, sample)
382 | stream := factory.create(sample)
383 | stream.XORKeyStream(pnbuf, p[hdrlen:hdrlen+len(pnbuf)])
384 |
385 | return nil
386 | }
387 |
388 | var pnPatterns = []struct {
389 | prefix byte
390 | mask byte
391 | length int
392 | }{
393 | {
394 | 0, 0x80, 1,
395 | },
396 | {
397 | 0x80, 0xc0, 2,
398 | },
399 | {
400 | 0xc0, 0xc0, 4,
401 | },
402 | }
403 |
404 | const ()
405 |
406 | func encodePacketNumber(pn uint64, l int) []byte {
407 | var buf bytes.Buffer
408 | i := 0
409 |
410 | for i, _ = range pnPatterns {
411 | if pnPatterns[i].length == l {
412 | break
413 | }
414 | }
415 |
416 | uintEncodeInt(&buf, pn, uintptr(l))
417 | b := buf.Bytes()
418 | b[0] &= ^pnPatterns[i].mask
419 | b[0] |= pnPatterns[i].prefix
420 |
421 | return b
422 | }
423 |
424 | func decodePacketNumber(buf []byte) (uint64, int, error) {
425 | if len(buf) < 1 {
426 | return 0, 0, fmt.Errorf("Zero-length packet number")
427 | }
428 |
429 | i := 0
430 | for i, _ = range pnPatterns {
431 | if pnPatterns[i].mask&buf[0] == pnPatterns[i].prefix {
432 | break
433 | }
434 | }
435 |
436 | pat := &pnPatterns[i]
437 | if len(buf) < pat.length {
438 | return 0, 0, fmt.Errorf("Buffer too short for packet number (%v < %v)", len(buf), pat.length)
439 | }
440 | buf = dup(buf[:pat.length])
441 | buf[0] &= ^pat.mask
442 |
443 | return uintDecodeIntBuf(buf), pat.length, nil
444 | }
445 |
--------------------------------------------------------------------------------
/minq.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/frame.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | "github.com/bifurcation/mint/syntax"
8 | )
9 |
10 | type frameType uint8
11 |
12 | type frameNonSyntax interface {
13 | unmarshal(b []byte) (int, error)
14 | }
15 |
16 | const (
17 | kFrameTypePadding = frameType(0x0)
18 | kFrameTypeRstStream = frameType(0x1)
19 | kFrameTypeConnectionClose = frameType(0x2)
20 | kFrameTypeApplicationClose = frameType(0x3)
21 | kFrameTypeMaxData = frameType(0x4)
22 | kFrameTypeMaxStreamData = frameType(0x5)
23 | kFrameTypeMaxStreamId = frameType(0x6)
24 | kFrameTypePing = frameType(0x7)
25 | kFrameTypeBlocked = frameType(0x8)
26 | kFrameTypeStreamBlocked = frameType(0x9)
27 | kFrameTypeStreamIdBlocked = frameType(0xa)
28 | kFrameTypeNewConnectionId = frameType(0xb)
29 | kFrameTypeStopSending = frameType(0xc)
30 | kFrameTypeAck = frameType(0x1a)
31 | kFrameTypeAckECN = frameType(0x1b)
32 | kFrameTypePathChallenge = frameType(0xe)
33 | kFrameTypePathResponse = frameType(0xf)
34 | kFrameTypeStream = frameType(0x10)
35 | kFrameTypeStreamMax = frameType(0x17)
36 | kFrameTypeCryptoHs = frameType(0x18)
37 | )
38 |
39 | const (
40 | kFrameTypeStreamFlagFIN = frameType(0x01)
41 | kFrameTypeStreamFlagLEN = frameType(0x02)
42 | kFrameTypeStreamFlagOFF = frameType(0x04)
43 | )
44 |
45 | const (
46 | // Assume maximal sizes for these.
47 | kMaxAckHeaderLength = 33
48 | kMaxAckBlockEntryLength = 16
49 | kMaxAckGap = 255
50 | kMaxAckBlocks = 255
51 | )
52 |
53 | type innerFrame interface {
54 | getType() frameType
55 | String() string
56 | }
57 |
58 | type frame struct {
59 | stream uint64
60 | f innerFrame
61 | encoded []byte
62 | pns []uint64
63 | lostPns []uint64
64 | time time.Time
65 | needsTransmit bool
66 | }
67 |
68 | func (f frame) String() string {
69 | return f.f.String()
70 | }
71 |
72 | func newFrame(stream uint64, inner innerFrame) *frame {
73 | return &frame{stream, inner, nil, nil, nil, time.Unix(0, 0), true}
74 | }
75 |
76 | // Encode internally if not already encoded.
77 | func (f *frame) encode() error {
78 | if f.encoded != nil {
79 | return nil
80 | }
81 | var err error
82 | f.encoded, err = syntax.Marshal(f.f)
83 | logf(logTypeFrame, "Frame encoded, total length=%v", len(f.encoded))
84 | return err
85 | }
86 |
87 | func (f *frame) length() (int, error) {
88 | err := f.encode()
89 | if err != nil {
90 | return 0, err
91 | }
92 | return len(f.encoded), nil
93 | }
94 |
95 | // Decode an arbitrary frame.
96 | func decodeFrame(data []byte) (uintptr, *frame, error) {
97 | var inner innerFrame
98 | var n int
99 | var err error
100 |
101 | t := data[0]
102 | logf(logTypeFrame, "Frame type byte %v", t)
103 | switch {
104 | case t == uint8(kFrameTypePadding):
105 | inner = &paddingFrame{}
106 | case t == uint8(kFrameTypeRstStream):
107 | inner = &rstStreamFrame{}
108 | case t == uint8(kFrameTypeConnectionClose):
109 | inner = &connectionCloseFrame{}
110 | case t == uint8(kFrameTypeApplicationClose):
111 | inner = &applicationCloseFrame{}
112 | case t == uint8(kFrameTypeMaxData):
113 | inner = &maxDataFrame{}
114 | case t == uint8(kFrameTypeMaxStreamData):
115 | inner = &maxStreamDataFrame{}
116 | case t == uint8(kFrameTypeMaxStreamId):
117 | inner = &maxStreamIdFrame{}
118 | case t == uint8(kFrameTypePing):
119 | inner = &pingFrame{}
120 | case t == uint8(kFrameTypeBlocked):
121 | inner = &blockedFrame{}
122 | case t == uint8(kFrameTypeStreamBlocked):
123 | inner = &streamBlockedFrame{}
124 | case t == uint8(kFrameTypeStreamIdBlocked):
125 | inner = &streamIdBlockedFrame{}
126 | case t == uint8(kFrameTypeNewConnectionId):
127 | inner = &newConnectionIdFrame{}
128 | case t == uint8(kFrameTypeStopSending):
129 | inner = &stopSendingFrame{}
130 | case t == uint8(kFrameTypeAck):
131 | inner = &ackFrame{}
132 | case t == uint8(kFrameTypePathChallenge):
133 | inner = &pathChallengeFrame{}
134 | case t == uint8(kFrameTypePathResponse):
135 | inner = &pathResponseFrame{}
136 | case t >= uint8(kFrameTypeStream) && t <= uint8(kFrameTypeStreamMax):
137 | inner = &streamFrame{}
138 | case t == uint8(kFrameTypeCryptoHs):
139 | inner = &cryptoHsFrame{}
140 | default:
141 | logf(logTypeConnection, "Unknown frame type %v", t)
142 | return 0, nil, fmt.Errorf("Received unknown frame type: %v", t)
143 | }
144 |
145 | ns, ok := inner.(frameNonSyntax)
146 | if ok {
147 | n, err = ns.unmarshal(data)
148 |
149 | } else {
150 | n, err = syntax.Unmarshal(data, inner)
151 | }
152 | if err != nil {
153 | return 0, nil, err
154 | }
155 |
156 | return uintptr(n), &frame{0, inner, data[:n], nil, nil, time.Now(), false}, nil
157 | }
158 |
159 | // Frame definitions below this point.
160 |
161 | // PADDING
162 | type paddingFrame struct {
163 | Typ frameType
164 | }
165 |
166 | func (f paddingFrame) String() string {
167 | return "P"
168 | }
169 |
170 | func (f paddingFrame) getType() frameType {
171 | return kFrameTypePadding
172 | }
173 |
174 | func newPaddingFrame(stream uint64) *frame {
175 | return newFrame(stream, &paddingFrame{0})
176 | }
177 |
178 | // RST_STREAM
179 | type rstStreamFrame struct {
180 | Type frameType
181 | StreamId uint64 `tls:"varint"`
182 | ErrorCode uint16
183 | FinalOffset uint64 `tls:"varint"`
184 | }
185 |
186 | func (f rstStreamFrame) String() string {
187 | return fmt.Sprintf("RST_STREAM stream=%x errorCode=%d finalOffset=%x", f.StreamId, f.ErrorCode, f.FinalOffset)
188 | }
189 |
190 | func (f rstStreamFrame) getType() frameType {
191 | return kFrameTypeRstStream
192 | }
193 |
194 | func newRstStreamFrame(streamId uint64, errorCode uint16, finalOffset uint64) *frame {
195 | return newFrame(streamId, &rstStreamFrame{
196 | kFrameTypeRstStream,
197 | uint64(streamId),
198 | errorCode,
199 | finalOffset})
200 | }
201 |
202 | // STOP_SENDING
203 | type stopSendingFrame struct {
204 | Type frameType
205 | StreamId uint64 `tls:"varint"`
206 | ErrorCode uint16
207 | }
208 |
209 | func (f stopSendingFrame) String() string {
210 | return fmt.Sprintf("STOP_SENDING stream=%x errorCode=%d", f.StreamId, f.ErrorCode)
211 | }
212 |
213 | func (f stopSendingFrame) getType() frameType {
214 | return kFrameTypeStopSending
215 | }
216 |
217 | func newStopSendingFrame(streamId uint64, errorCode uint16) *frame {
218 | return newFrame(streamId, &stopSendingFrame{
219 | kFrameTypeStopSending,
220 | uint64(streamId),
221 | errorCode})
222 | }
223 |
224 | // CONNECTION_CLOSE
225 | type connectionCloseFrame struct {
226 | Type frameType
227 | ErrorCode uint16
228 | ReasonPhrase []byte `tls:"head=varint"`
229 | }
230 |
231 | func (f connectionCloseFrame) String() string {
232 | return fmt.Sprintf("CONNECTION_CLOSE errorCode=%x", f.ErrorCode)
233 | }
234 |
235 | func (f connectionCloseFrame) getType() frameType {
236 | return kFrameTypeConnectionClose
237 | }
238 |
239 | func newConnectionCloseFrame(errcode ErrorCode, reason string) *frame {
240 | return newFrame(0, &connectionCloseFrame{
241 | kFrameTypeConnectionClose,
242 | uint16(errcode),
243 | []byte(reason),
244 | })
245 | }
246 |
247 | // APPLICATION_CLOSE
248 | type applicationCloseFrame struct {
249 | Type frameType
250 | ErrorCode uint16
251 | ReasonPhrase []byte `tls:"head=varint"`
252 | }
253 |
254 | func (f applicationCloseFrame) String() string {
255 | return fmt.Sprintf("APPLICATION_CLOSE errorCode=%x", f.ErrorCode)
256 | }
257 |
258 | func (f applicationCloseFrame) getType() frameType {
259 | return kFrameTypeApplicationClose
260 | }
261 |
262 | func newApplicationCloseFrame(errcode uint16, reason string) *frame {
263 | return newFrame(0, &applicationCloseFrame{
264 | kFrameTypeApplicationClose,
265 | uint16(errcode),
266 | []byte(reason),
267 | })
268 | }
269 |
270 | // MAX_DATA
271 | type maxDataFrame struct {
272 | Type frameType
273 | MaximumData uint64 `tls:"varint"`
274 | }
275 |
276 | func (f maxDataFrame) String() string {
277 | return fmt.Sprintf("MAX_DATA %d", f.MaximumData)
278 | }
279 |
280 | func (f maxDataFrame) getType() frameType {
281 | return kFrameTypeMaxData
282 | }
283 |
284 | func newMaxData(m uint64) *frame {
285 | return newFrame(0, &maxDataFrame{kFrameTypeMaxData, m})
286 | }
287 |
288 | // MAX_STREAM_DATA
289 | type maxStreamDataFrame struct {
290 | Type frameType
291 | StreamId uint64 `tls:"varint"`
292 | MaximumStreamData uint64 `tls:"varint"`
293 | }
294 |
295 | func newMaxStreamData(stream uint64, offset uint64) *frame {
296 | return newFrame(stream,
297 | &maxStreamDataFrame{
298 | kFrameTypeMaxStreamData,
299 | stream,
300 | offset,
301 | })
302 | }
303 |
304 | func (f maxStreamDataFrame) String() string {
305 | return fmt.Sprintf("MAX_STREAM_DATA stream=%d %d", f.StreamId, f.MaximumStreamData)
306 | }
307 |
308 | func (f maxStreamDataFrame) getType() frameType {
309 | return kFrameTypeMaxStreamData
310 | }
311 |
312 | // MAX_STREAM_ID
313 | type maxStreamIdFrame struct {
314 | Type frameType
315 | MaximumStreamId uint64 `tls:"varint"`
316 | }
317 |
318 | func newMaxStreamId(id uint64) *frame {
319 | return newFrame(0,
320 | &maxStreamIdFrame{
321 | kFrameTypeMaxStreamId,
322 | id,
323 | })
324 | }
325 |
326 | func (f maxStreamIdFrame) String() string {
327 | return fmt.Sprintf("MAX_STREAM_ID %d", f.MaximumStreamId)
328 | }
329 |
330 | func (f maxStreamIdFrame) getType() frameType {
331 | return kFrameTypeMaxStreamId
332 | }
333 |
334 | // PING
335 | type pingFrame struct {
336 | Type frameType
337 | }
338 |
339 | func (f pingFrame) String() string {
340 | return "PING"
341 | }
342 |
343 | func (f pingFrame) getType() frameType {
344 | return kFrameTypePing
345 | }
346 |
347 | // BLOCKED
348 | type blockedFrame struct {
349 | Type frameType
350 | Offset uint64 `tls:"varint"`
351 | }
352 |
353 | func (f blockedFrame) String() string {
354 | return "BLOCKED"
355 | }
356 |
357 | func (f blockedFrame) getType() frameType {
358 | return kFrameTypeBlocked
359 | }
360 |
361 | func newBlockedFrame(offset uint64) *frame {
362 | return newFrame(0, &blockedFrame{kFrameTypeBlocked, offset})
363 | }
364 |
365 | // STREAM_BLOCKED
366 | type streamBlockedFrame struct {
367 | Type frameType
368 | StreamId uint64 `tls:"varint"`
369 | Offset uint64 `tls:"varint"`
370 | }
371 |
372 | func (f streamBlockedFrame) String() string {
373 | return "STREAM_BLOCKED"
374 | }
375 |
376 | func (f streamBlockedFrame) getType() frameType {
377 | return kFrameTypeStreamBlocked
378 | }
379 |
380 | func newStreamBlockedFrame(id uint64, offset uint64) *frame {
381 | return newFrame(0, &streamBlockedFrame{kFrameTypeStreamBlocked, id, offset})
382 | }
383 |
384 | // STREAM_ID_BLOCKED
385 | type streamIdBlockedFrame struct {
386 | Type frameType
387 | StreamId uint64 `tls:"varint"`
388 | }
389 |
390 | func (f streamIdBlockedFrame) String() string {
391 | return "STREAM_ID_BLOCKED"
392 | }
393 |
394 | func (f streamIdBlockedFrame) getType() frameType {
395 | return kFrameTypeStreamIdBlocked
396 | }
397 |
398 | func newStreamIdBlockedFrame(id uint64) *frame {
399 | return newFrame(0, &streamIdBlockedFrame{
400 | kFrameTypeStreamIdBlocked,
401 | id})
402 | }
403 |
404 | // NEW_CONNECTION_ID
405 | type newConnectionIdFrame struct {
406 | Type frameType
407 | Sequence uint16 `tls:"varint"`
408 | ConnectionId ConnectionId
409 | ResetToken [16]byte
410 | }
411 |
412 | func (f newConnectionIdFrame) String() string {
413 | return fmt.Sprintf("NEW_CONNECTION_ID %d=%x", f.Sequence, f.ConnectionId)
414 | }
415 |
416 | func (f newConnectionIdFrame) getType() frameType {
417 | return kFrameTypeNewConnectionId
418 | }
419 |
420 | func newNewConnectionIdFrame(seq uint16, cid ConnectionId, resetToken []byte) *frame {
421 | f := &newConnectionIdFrame{
422 | Type: kFrameTypeNewConnectionId,
423 | Sequence: seq,
424 | ConnectionId: cid,
425 | }
426 | assert(len(resetToken) == len(f.ResetToken))
427 | copy(f.ResetToken[:], resetToken)
428 | return newFrame(0, f)
429 | }
430 |
431 | // ACK
432 | type ackBlock struct {
433 | Gap uint64 `tls:"varint"`
434 | Length uint64 `tls:"varint"`
435 | }
436 |
437 | type ackFrameHeader struct {
438 | Type frameType
439 | LargestAcknowledged uint64 `tls:"varint"`
440 | AckDelay uint64 `tls:"varint"`
441 | AckBlockCount uint64 `tls:"varint"`
442 | FirstAckBlock uint64 `tls:"varint"`
443 | }
444 |
445 | type ackFrame struct {
446 | ackFrameHeader
447 | AckBlockSection []*ackBlock `tls:"head=none"`
448 | }
449 |
450 | func (f ackFrame) String() string {
451 | return fmt.Sprintf("ACK numBlocks=%d largestAck=%x", f.AckBlockCount, f.LargestAcknowledged)
452 | }
453 |
454 | func (f ackFrame) getType() frameType {
455 | return kFrameTypeAck
456 | }
457 |
458 | // ACK frames can't presently be decoded with syntax, so we need
459 | // a custom decoder.
460 | func (f *ackFrame) unmarshal(buf []byte) (int, error) {
461 | // First, decode the header
462 | read := int(0)
463 | n, err := syntax.Unmarshal(buf, &f.ackFrameHeader)
464 | if err != nil {
465 | return 0, err
466 | }
467 | buf = buf[n:]
468 | read += n
469 |
470 | // Now decode each block
471 | for i := uint64(0); i < f.AckBlockCount; i++ {
472 | blk := &ackBlock{}
473 | n, err := syntax.Unmarshal(buf, blk)
474 | if err != nil {
475 | return 0, err
476 | }
477 | buf = buf[n:]
478 | read += n
479 |
480 | f.AckBlockSection = append(f.AckBlockSection, blk)
481 | }
482 |
483 | return read, nil
484 | }
485 |
486 | func newAckFrame(recvd *recvdPackets, rs ackRanges, left int) (*frame, int, error) {
487 | if left < kMaxAckHeaderLength {
488 | return nil, 0, nil
489 | }
490 | logf(logTypeFrame, "Making ACK frame %v", rs)
491 |
492 | left -= kMaxAckHeaderLength
493 |
494 | last := rs[0].lastPacket
495 | largestAckData, ok := recvd.packets[last]
496 | // Should always be there. Packets only get removed after being set to ack2,
497 | // which means we should not be acking it again.
498 | assert(ok)
499 |
500 | // FIRST, fill in the basic info of the ACK frame
501 | var f ackFrame
502 | f.Type = kFrameTypeAck
503 | f.LargestAcknowledged = last
504 | delay := time.Since(largestAckData.t).Nanoseconds()
505 | f.AckDelay = uint64(delay) / 1000 >> kTpDefaultAckDelayExponent
506 | f.AckBlockCount = 0
507 | f.FirstAckBlock = rs[0].count - 1
508 |
509 | // ...and account for the first block.
510 | last -= f.FirstAckBlock
511 | addedRanges := 1
512 |
513 | // SECOND, add the remaining ACK blocks that fit and that we have
514 | for (left > 0) && (addedRanges < len(rs)) {
515 | // calculate blocks needed for the next range
516 | gap := last - rs[addedRanges].lastPacket - 1
517 |
518 | gap = last - rs[addedRanges].lastPacket - 1
519 | b := &ackBlock{
520 | gap,
521 | rs[addedRanges].count - 1,
522 | }
523 |
524 | last = rs[addedRanges].lastPacket - rs[addedRanges].count
525 |
526 | f.AckBlockCount++
527 | f.AckBlockSection = append(f.AckBlockSection, b)
528 | addedRanges++
529 | left -= kMaxAckBlockEntryLength // Assume worst-case.
530 | }
531 |
532 | return newFrame(0, &f), addedRanges, nil
533 | }
534 |
535 | // PATH_CHALLENGE
536 | type pathChallengeFrame struct {
537 | Type frameType
538 | Data [8]byte
539 | }
540 |
541 | func (f pathChallengeFrame) String() string {
542 | return "PATH_CHALLENGE"
543 | }
544 |
545 | func (f pathChallengeFrame) getType() frameType {
546 | return kFrameTypePathChallenge
547 | }
548 |
549 | func newPathChallengeFrame(data []byte) *frame {
550 | payload := &pathChallengeFrame{Type: kFrameTypePathChallenge}
551 | assert(len(data) == len(payload.Data))
552 | copy(payload.Data[:], data)
553 | return newFrame(0, payload)
554 | }
555 |
556 | // PATH_RESPONSE
557 | type pathResponseFrame struct {
558 | Type frameType
559 | Data [8]byte
560 | }
561 |
562 | func (f pathResponseFrame) String() string {
563 | return "PATH_RESPONSE"
564 | }
565 |
566 | func (f pathResponseFrame) getType() frameType {
567 | return kFrameTypePathResponse
568 | }
569 |
570 | func newPathResponseFrame(data []byte) *frame {
571 | payload := &pathResponseFrame{Type: kFrameTypePathResponse}
572 | assert(len(data) == len(payload.Data))
573 | copy(payload.Data[:], data)
574 | return newFrame(0, payload)
575 | }
576 |
577 | // STREAM
578 | type streamFrame struct {
579 | Typ frameType
580 | StreamId uint64 `tls:"varint"`
581 | Offset uint64 `tls:"varint"`
582 | Data []byte `tls:"head=varint"`
583 | }
584 |
585 | func (f streamFrame) String() string {
586 | return fmt.Sprintf("STREAM stream=%d offset=%d len=%d FIN=%v", f.StreamId, f.Offset, len(f.Data), f.hasFin())
587 | }
588 |
589 | func (f streamFrame) getType() frameType {
590 | return kFrameTypeStream
591 | }
592 |
593 | func (f streamFrame) hasFin() bool {
594 | if f.Typ&kFrameTypeStreamFlagFIN == 0 {
595 | return false
596 | }
597 | return true
598 | }
599 |
600 | func newStreamFrame(stream uint64, offset uint64, data []byte, last bool) *frame {
601 | logf(logTypeFrame, "Creating stream frame with data length=%d", len(data))
602 | assert(len(data) <= 65535)
603 | // TODO(ekr@tfm.com): One might want to allow non
604 | // D bit, but not for now.
605 | // Set all of SSOO to 1
606 | typ := kFrameTypeStream | kFrameTypeStreamFlagLEN | kFrameTypeStreamFlagOFF
607 | if last {
608 | typ |= kFrameTypeStreamFlagFIN
609 | }
610 | return newFrame(
611 | stream,
612 | &streamFrame{
613 | typ,
614 | stream,
615 | offset,
616 | dup(data),
617 | })
618 | }
619 |
620 | func decodeVarint(buf []byte) (int, uint64, error) {
621 | var vi struct {
622 | Val uint64 `tls:"varint"`
623 | }
624 |
625 | n, err := syntax.Unmarshal(buf, &vi)
626 | if err != nil {
627 | return 0, 0, err
628 | }
629 |
630 | return n, vi.Val, nil
631 | }
632 |
633 | // Stream frames can't presently be decoded with syntax, so we need
634 | // a custom decoder.
635 | func (f *streamFrame) unmarshal(buf []byte) (int, error) {
636 | f.Typ = frameType(buf[0])
637 | buf = buf[1:]
638 | var read = int(1)
639 | var n int
640 | var err error
641 |
642 | n, f.StreamId, err = decodeVarint(buf)
643 | if err != nil {
644 | return 0, err
645 | }
646 | buf = buf[n:]
647 | read += n
648 |
649 | if f.Typ&kFrameTypeStreamFlagOFF != 0 {
650 | n, f.Offset, err = decodeVarint(buf)
651 | if err != nil {
652 | return 0, err
653 | }
654 | buf = buf[n:]
655 | read += n
656 | }
657 |
658 | if f.Typ&kFrameTypeStreamFlagLEN != 0 {
659 | var l uint64
660 | n, l, err = decodeVarint(buf)
661 | if err != nil {
662 | return 0, err
663 | }
664 | buf = buf[n:]
665 | read += n
666 |
667 | logf(logTypeFrame, "Expecting %v bytes", l)
668 |
669 | if l > uint64(len(buf)) {
670 | return 0, fmt.Errorf("Insufficient bytes left")
671 | }
672 | f.Data = dup(buf[:l])
673 | read += int(l)
674 | } else {
675 | f.Data = dup(buf)
676 | read += len(buf)
677 | }
678 |
679 | return read, nil
680 | }
681 |
682 | // CRYPTO_HS
683 | type cryptoHsFrame struct {
684 | Typ frameType
685 | Offset uint64 `tls:"varint"`
686 | Data []byte `tls:"head=varint"`
687 | }
688 |
689 | func (f cryptoHsFrame) getType() frameType {
690 | return kFrameTypeCryptoHs
691 | }
692 |
693 | func (f cryptoHsFrame) String() string {
694 | return fmt.Sprintf("CRYPTO_HS len=%d", len(f.Data))
695 | }
696 |
697 | func newCryptoHsFrame(offset uint64, data []byte) *frame {
698 | logf(logTypeFrame, "Creating crypto_hs frame with data length=%d", len(data))
699 |
700 | return newFrame(
701 | 0,
702 | &cryptoHsFrame{
703 | kFrameTypeCryptoHs,
704 | offset,
705 | dup(data),
706 | },
707 | )
708 | }
709 |
--------------------------------------------------------------------------------
/stream.go:
--------------------------------------------------------------------------------
1 | package minq
2 |
3 | import (
4 | "encoding/hex"
5 | "fmt"
6 | "io"
7 | )
8 |
9 | // SendStreamState is the state of a SendStream
10 | type SendStreamState uint8
11 |
12 | // SendStreamState values. Not all of these are tracked
13 | const (
14 | SendStreamStateOpen = SendStreamState(0)
15 | SendStreamStateSend = SendStreamState(1)
16 | SendStreamStateCloseQueued = SendStreamState(2) // Not in the spec
17 | SendStreamStateDataSent = SendStreamState(3)
18 | SendStreamStateResetSent = SendStreamState(4)
19 | SendStreamStateDataRecvd = SendStreamState(5) // Not tracked
20 | SendStreamStateResetRecvd = SendStreamState(6) // Not tracked
21 | )
22 |
23 | // String produces a nice string from a SendStreamState.
24 | func (s SendStreamState) String() string {
25 | switch s {
26 | case SendStreamStateOpen:
27 | return "SendStreamStateOpen"
28 | case SendStreamStateSend:
29 | return "SendStreamStateSend"
30 | case SendStreamStateCloseQueued:
31 | return "SendStreamStateCloseQueued"
32 | case SendStreamStateDataSent:
33 | return "SendStreamStateDataSent"
34 | case SendStreamStateResetSent:
35 | return "SendStreamStateResetSent"
36 | case SendStreamStateDataRecvd:
37 | return "SendStreamStateDataRecvd"
38 | case SendStreamStateResetRecvd:
39 | return "SendStreamStateResetRecvd"
40 | default:
41 | panic("Unknown SendStreamState")
42 | }
43 | }
44 |
45 | // RecvStreamState is the state of a RecvStream
46 | type RecvStreamState uint8
47 |
48 | // RecvStreamState values. Not all of these are tracked.
49 | const (
50 | RecvStreamStateRecv = RecvStreamState(0)
51 | RecvStreamStateSizeKnown = RecvStreamState(1)
52 | RecvStreamStateDataRecvd = RecvStreamState(2) // Not tracked
53 | RecvStreamStateResetRecvd = RecvStreamState(3)
54 | RecvStreamStateDataRead = RecvStreamState(4)
55 | RecvStreamStateResetRead = RecvStreamState(5)
56 | )
57 |
58 | // String produces a nice string from a RecvStreamState.
59 | func (s RecvStreamState) String() string {
60 | switch s {
61 | case RecvStreamStateRecv:
62 | return "RecvStreamStateRecv"
63 | case RecvStreamStateSizeKnown:
64 | return "RecvStreamStateSizeKnown"
65 | case RecvStreamStateDataRecvd:
66 | return "RecvStreamStateDataRecvd"
67 | case RecvStreamStateResetRecvd:
68 | return "RecvStreamStateResetRecvd"
69 | case RecvStreamStateDataRead:
70 | return "RecvStreamStateDataRead"
71 | case RecvStreamStateResetRead:
72 | return "RecvStreamStateResetRead"
73 | default:
74 | panic("Unknown RecvStreamState")
75 | }
76 | }
77 |
78 | // The structure here is a little convoluted.
79 | //
80 | // There are three primary interfaces: SendStream, RecvStream, and Stream. These
81 | // all implement hasIdentity and one or both (for Stream) of sendStreamMethods
82 | // or recvStreamMethods.
83 | //
84 | // The implementations are layered.
85 | //
86 | // streamCommon is at the bottom, it includes stuff common to sending and receiving.
87 | //
88 | // sendStreamBase and recvStreamBase add sending and receiving functions. These
89 | // know how to send and receive, but don't know about identifiers or
90 | // connections. This allows them to be tested in isolation.
91 | //
92 | // Those types don't know about connections, so sendStream and recvStream add
93 | // that by mixing in streamWithIdentity. The same applies to stream, which mixes
94 | // both sendStream and recvStream. These include the concrete implementations of
95 | // the interfaces.
96 |
97 | type hasIdentity interface {
98 | Id() uint64
99 | }
100 |
101 | type sendStreamMethods interface {
102 | io.WriteCloser
103 | Reset(uint16) error
104 | SendState() SendStreamState
105 | }
106 |
107 | type sendStreamPrivateMethods interface {
108 | setSendState(SendStreamState)
109 | outstandingQueuedBytes() int
110 | processMaxStreamData(uint64)
111 | outputWritable() []streamChunk
112 | flowControl() flowControl
113 | }
114 |
115 | type recvStreamMethods interface {
116 | io.Reader
117 | StopSending(uint16) error
118 | RecvState() RecvStreamState
119 | }
120 |
121 | type recvStreamPrivateMethods interface {
122 | setRecvState(RecvStreamState)
123 | handleReset(offset uint64) error
124 | clearReadable() bool
125 | newFrameData(uint64, bool, []byte, *flowControl) error
126 | updateMaxStreamData(bool)
127 | }
128 |
129 | // SendStream can send.
130 | type SendStream interface {
131 | hasIdentity
132 | sendStreamMethods
133 | }
134 |
135 | type sendStreamPrivate interface {
136 | SendStream
137 | sendStreamPrivateMethods
138 | }
139 |
140 | // RecvStream can receive.
141 | type RecvStream interface {
142 | hasIdentity
143 | recvStreamMethods
144 | }
145 |
146 | type recvStreamPrivate interface {
147 | RecvStream
148 | recvStreamPrivateMethods
149 | }
150 |
151 | // Stream is both a send and receive stream.
152 | type Stream interface {
153 | hasIdentity
154 | sendStreamMethods
155 | recvStreamMethods
156 | }
157 |
158 | type streamPrivate interface {
159 | Stream
160 | sendStreamPrivateMethods
161 | recvStreamPrivateMethods
162 | }
163 |
164 | type streamChunk struct {
165 | offset uint64
166 | last bool
167 | data []byte
168 | }
169 |
170 | func (sc streamChunk) String() string {
171 | return fmt.Sprintf("chunk(offset=%v, len=%v, last=%v)", sc.offset, len(sc.data), sc.last)
172 | }
173 |
174 | type streamCommon struct {
175 | log loggingFunction
176 | chunks []streamChunk
177 | fc flowControl
178 | readOffset uint64
179 | }
180 |
181 | func (s *streamCommon) insertSortedChunk(offset uint64, last bool, payload []byte) {
182 | c := streamChunk{offset, last, dup(payload)}
183 | s.log(logTypeStream, "insert %v, current offset=%v", c, s.fc.used)
184 | s.log(logTypeTrace, "payload %v", hex.EncodeToString(payload))
185 | if len(payload) == 0 && !last && offset != 0 {
186 | // Empty frame, ignore
187 | return
188 | }
189 |
190 | // First check if we can append the new slice at the end
191 | if nchunks := len(s.chunks); nchunks == 0 || offset > s.chunks[nchunks-1].offset {
192 | s.chunks = append(s.chunks, c)
193 | } else {
194 | // Otherwise find out where it should go
195 | var i int
196 | for i = 0; i < nchunks; i++ {
197 | if offset < s.chunks[i].offset {
198 | break
199 | }
200 | }
201 |
202 | // This may not be the fastest way to do this splice.
203 | tmp := make([]streamChunk, 0, nchunks+1)
204 | tmp = append(tmp, s.chunks[:i]...)
205 | tmp = append(tmp, c)
206 | tmp = append(tmp, s.chunks[i:]...)
207 | s.chunks = tmp
208 | }
209 | s.log(logTypeStream, "Stream now has %v chunks", len(s.chunks))
210 | }
211 |
212 | type sendStreamBase struct {
213 | streamCommon
214 | state SendStreamState
215 | }
216 |
217 | func (s *sendStreamBase) setSendState(state SendStreamState) {
218 | if state != s.state {
219 | s.log(logTypeStream, "set state %v->%v", s.state, state)
220 | s.state = state
221 | }
222 | }
223 |
224 | // SendState returns the current state of the receive stream.
225 | func (s *sendStreamBase) SendState() SendStreamState {
226 | return s.state
227 | }
228 |
229 | func (s *sendStreamBase) queue(payload []byte, cfc *flowControl) (int, error) {
230 | s.log(logTypeStream, "queueing %v bytes, flow control %v %v", len(payload), &s.fc, cfc)
231 | offset := s.fc.used
232 | allowed := s.fc.take(cfc, uint64(len(payload)))
233 | s.log(logTypeFlowControl, "flow control consumed %v %v", &s.fc, cfc)
234 | if allowed == 0 {
235 | s.log(logTypeFlowControl, "blocked write")
236 | return 0, ErrorWouldBlock
237 | }
238 | payload = payload[:allowed]
239 | s.insertSortedChunk(offset, false, payload)
240 | return int(allowed), nil
241 | }
242 |
243 | func (s *sendStreamBase) write(data []byte, connectionFlowControl *flowControl) (int, error) {
244 | switch s.state {
245 | case SendStreamStateOpen:
246 | s.setSendState(SendStreamStateSend)
247 | // Allow a zero-octet write on a stream that hasn't been opened.
248 | if len(data) == 0 {
249 | return s.queue(data, connectionFlowControl)
250 | }
251 | case SendStreamStateSend:
252 | // OK to send
253 | default:
254 | return 0, ErrorStreamIsClosed
255 | }
256 | written := 0
257 | for len(data) > 0 {
258 | tocpy := 1024
259 | if tocpy > len(data) {
260 | tocpy = len(data)
261 | }
262 | n, err := s.queue(data[:tocpy], connectionFlowControl)
263 | if (err == ErrorWouldBlock) && (written > 0) {
264 | s.log(logTypeFlowControl, "write flow control blocked at offset %d", s.fc.used)
265 | break
266 | }
267 | if err != nil {
268 | return written, err
269 | }
270 | written += n
271 |
272 | data = data[tocpy:]
273 | }
274 |
275 | s.log(logTypeTrace, "wrote %d bytes", written)
276 | return written, nil
277 | }
278 |
279 | func (s *sendStreamBase) outstandingQueuedBytes() int {
280 | n := 0
281 | for _, ch := range s.chunks {
282 | n += len(ch.data)
283 | }
284 | return n
285 | }
286 |
287 | func (s *sendStreamBase) flowControl() flowControl {
288 | return s.fc
289 | }
290 |
291 | // Push out all pending frames. Set the stream state if the end of the stream is available.
292 | func (s *sendStreamBase) outputWritable() []streamChunk {
293 | s.log(logTypeStream, "outputWritable, chunks=%v current max offset=%d)", len(s.chunks), s.fc.max)
294 | for _, ch := range s.chunks {
295 | if ch.last {
296 | s.setSendState(SendStreamStateDataSent)
297 | }
298 | }
299 |
300 | out := s.chunks
301 | s.chunks = nil
302 | return out
303 | }
304 |
305 | func (s *sendStreamBase) processMaxStreamData(offset uint64) {
306 | s.fc.update(offset)
307 | }
308 |
309 | func (s *sendStreamBase) close() {
310 | switch s.state {
311 | case SendStreamStateOpen, SendStreamStateSend:
312 | s.insertSortedChunk(s.fc.used, true, nil)
313 | s.setSendState(SendStreamStateCloseQueued)
314 | default:
315 | // NOOP
316 | }
317 | }
318 |
319 | type recvStreamBase struct {
320 | streamCommon
321 | state RecvStreamState
322 | readable bool
323 | }
324 |
325 | func (s *recvStreamBase) setRecvState(state RecvStreamState) {
326 | if state != s.state {
327 | s.log(logTypeStream, "set state %v->%v", s.state, state)
328 | s.state = state
329 | }
330 | }
331 |
332 | // RecvState returns the current state of the receive stream.
333 | func (s *recvStreamBase) RecvState() RecvStreamState {
334 | return s.state
335 | }
336 |
337 | // clearReadable clears the readable flag and returns true if it was set.
338 | func (s *recvStreamBase) clearReadable() bool {
339 | r := s.readable
340 | s.readable = false
341 | return r
342 | }
343 |
344 | // Add data to a stream. Return true if this is readable now.
345 | func (s *recvStreamBase) newFrameData(offset uint64, last bool, payload []byte,
346 | cfc *flowControl) error {
347 | s.log(logTypeStream, "new data offset=%d, len=%d", offset, len(payload))
348 | s.log(logTypeFlowControl, "new data flow control %v %v", &s.fc, cfc)
349 |
350 | end := offset + uint64(len(payload))
351 | if last {
352 | if end < s.fc.used {
353 | // The end can't be less than what we've received already.
354 | return ErrorFlowControlError
355 | }
356 | if s.state == RecvStreamStateRecv {
357 | s.setRecvState(RecvStreamStateSizeKnown)
358 | }
359 | } else if end > s.fc.used {
360 | if s.state != RecvStreamStateRecv {
361 | // We shouldn't be increasing used in any other state.
362 | return ErrorFlowControlError
363 | }
364 |
365 | increase := end - s.fc.used
366 | taken := increase
367 | if !s.fc.unlimited {
368 | taken := s.fc.take(cfc, increase)
369 | s.log(logTypeFlowControl, "taken flow control %d, now %v %v", taken, &s.fc, cfc)
370 | }
371 | if taken < increase {
372 | // We didn't have that much available.
373 | return ErrorFlowControlError
374 | }
375 | } else if end <= s.readOffset {
376 | // No new data here.
377 | return nil
378 | }
379 | if s.state != RecvStreamStateRecv && s.state != RecvStreamStateSizeKnown {
380 | // We shouldn't be receiving in other states.
381 | return nil
382 | }
383 |
384 | s.insertSortedChunk(offset, last, payload)
385 | if s.chunks[0].offset <= s.readOffset {
386 | s.readable = true
387 | }
388 |
389 | return nil
390 | }
391 |
392 | func (s *recvStreamBase) read(b []byte) (int, error) {
393 | s.log(logTypeStream, "Reading len=%v read offset=%v available chunks=%v",
394 | len(b), s.readOffset, len(s.chunks))
395 |
396 | if s.state == RecvStreamStateResetRecvd {
397 | s.log(logTypeStream, "Reading stopped for RST_STREAM")
398 | s.setRecvState(RecvStreamStateResetRead)
399 | return 0, ErrorStreamReset
400 | }
401 |
402 | read := 0
403 |
404 | for len(b) > 0 {
405 | if len(s.chunks) == 0 {
406 | break
407 | }
408 |
409 | chunk := s.chunks[0]
410 | s.log(logTypeTrace, "next chunk %v", chunk)
411 | // We have a gap.
412 | if chunk.offset > s.readOffset {
413 | break
414 | }
415 |
416 | // Remove leading bytes
417 | remove := s.readOffset - chunk.offset
418 | if remove > uint64(len(chunk.data)) {
419 | // Nothing left.
420 | s.chunks = s.chunks[1:]
421 | continue
422 | }
423 |
424 | chunk.offset += remove
425 | chunk.data = chunk.data[remove:]
426 |
427 | // Now figure out how much we can read
428 | n := copy(b, chunk.data)
429 | s.log(logTypeTrace, "read %v at offset %v", n, s.readOffset)
430 | chunk.data = chunk.data[n:]
431 | chunk.offset += uint64(n)
432 | s.readOffset += uint64(n)
433 | b = b[n:]
434 | read += n
435 |
436 | // This chunk is empty.
437 | if len(chunk.data) == 0 {
438 | s.chunks = s.chunks[1:]
439 |
440 | if chunk.last {
441 | s.setRecvState(RecvStreamStateDataRead)
442 | s.chunks = nil
443 | break
444 | }
445 | }
446 | }
447 |
448 | // If we have read no data, say we would have blocked.
449 | if read == 0 {
450 | switch s.state {
451 | case RecvStreamStateRecv, RecvStreamStateSizeKnown:
452 | return 0, ErrorWouldBlock
453 | default:
454 | if s.chunks == nil {
455 | return 0, io.EOF
456 | }
457 | return 0, ErrorStreamIsClosed
458 | }
459 | }
460 | s.log(logTypeStream, "Returning %v bytes chunks=%v", read, len(s.chunks))
461 | return read, nil
462 | }
463 |
464 | func (s *recvStreamBase) handleReset(offset uint64) error {
465 | switch s.state {
466 | case RecvStreamStateRecv:
467 | s.fc.used = offset
468 | case RecvStreamStateDataRecvd, RecvStreamStateResetRead:
469 | panic("we don't use this state")
470 | case RecvStreamStateSizeKnown, RecvStreamStateDataRead, RecvStreamStateResetRecvd:
471 | if offset != s.fc.used {
472 | return ErrorProtocolViolation
473 | }
474 | default:
475 | panic(fmt.Sprintf("unknown state %v", s.state))
476 | }
477 |
478 | s.setRecvState(RecvStreamStateResetRecvd)
479 | s.chunks = nil
480 | return nil
481 | }
482 |
483 | // SendStream is a unidirectional stream for sending.
484 | type sendStream struct {
485 | c *Connection
486 | id uint64
487 | sendStreamBase
488 | }
489 |
490 | // Compile-time interface check.
491 | var _ SendStream = &sendStream{}
492 |
493 | func newSendStream(c *Connection, id uint64, initialMax uint64) sendStreamPrivate {
494 | return &sendStream{
495 | c: c, id: id,
496 | sendStreamBase: sendStreamBase{
497 | streamCommon: streamCommon{
498 | log: newStreamLogger(id, "send", c.log),
499 | fc: newFlowControl(initialMax),
500 | },
501 | state: SendStreamStateOpen,
502 | },
503 | }
504 | }
505 |
506 | // Id returns the id.
507 | func (s *sendStream) Id() uint64 {
508 | return s.id
509 | }
510 |
511 | // Write writes data.
512 | func (s *sendStream) Write(data []byte) (int, error) {
513 | s.log(logTypeStream, "Stream %v: writing %v bytes", s.Id(), len(data))
514 | if s.c.isClosed() {
515 | return 0, ErrorConnIsClosed
516 | }
517 |
518 | n, err := s.write(data, &s.c.sendFlowControl)
519 | if err != nil {
520 | if err == ErrorWouldBlock {
521 | s.c.updateStreamBlocked(s)
522 | s.c.updateBlocked()
523 | }
524 | return n, err
525 | }
526 |
527 | s.c.sendQueued(false)
528 | return n, nil
529 | }
530 |
531 | // Close makes the stream end cleanly.
532 | func (s *sendStream) Close() error {
533 | s.close()
534 | s.c.sendQueued(false)
535 | return nil
536 | }
537 |
538 | // Reset abandons writing on the stream.
539 | func (s *sendStream) Reset(code uint16) error {
540 | s.setSendState(SendStreamStateResetSent)
541 | f := newRstStreamFrame(s.id, code, s.fc.used)
542 | return s.c.sendFrame(f)
543 | }
544 |
545 | // recvStream is the implementation of a unidirectional stream for receiving.
546 | type recvStream struct {
547 | c *Connection
548 | id uint64
549 | recvStreamBase
550 | }
551 |
552 | // Compile-time interface check.
553 | var _ RecvStream = &recvStream{}
554 |
555 | func newRecvStream(c *Connection, id uint64, maxStreamData uint64) recvStreamPrivate {
556 | return &recvStream{
557 | c: c, id: id,
558 | recvStreamBase: recvStreamBase{
559 | streamCommon: streamCommon{
560 | log: newStreamLogger(id, "recv", c.log),
561 | fc: newFlowControl(maxStreamData),
562 | },
563 | state: RecvStreamStateRecv,
564 | readable: false,
565 | },
566 | }
567 | }
568 |
569 | // Id returns the id.
570 | func (s *recvStream) Id() uint64 {
571 | return s.id
572 | }
573 |
574 | // updateMaxStreamData checks the current flow control limit and sends
575 | // MAX_STREAM_DATA as necessary.
576 | func (s *recvStream) updateMaxStreamData(force bool) {
577 | s.log(logTypeFlowControl, "credit flow control %v", &s.fc)
578 | if force || s.fc.remaining() < kInitialMaxStreamData/2 {
579 | s.fc.max = s.readOffset + kInitialMaxData
580 | s.log(logTypeFlowControl, "increased flow control to %v", &s.fc)
581 | s.c.issueStreamCredit(s, s.fc.max)
582 | }
583 | }
584 |
585 | // Read implements io.Reader.
586 | func (s *recvStream) Read(b []byte) (int, error) {
587 | if s.c.isClosed() {
588 | return 0, io.EOF
589 | }
590 |
591 | n, err := s.read(b)
592 | if err != nil {
593 | return 0, err
594 | }
595 | s.c.amountRead += uint64(n)
596 | // Now issue credit for stream flow control, ...
597 | s.updateMaxStreamData(false)
598 | // ..., connection flow control, ...
599 | s.c.issueCredit(false)
600 | // ..., and streams.
601 | if s.state == RecvStreamStateDataRead {
602 | s.c.issueStreamIdCredit(streamTypeFromId(s.id, s.c.role))
603 | }
604 | return n, nil
605 | }
606 |
607 | func (s *recvStream) handleReset(offset uint64) error {
608 | err := s.recvStreamBase.handleReset(offset)
609 | if err != nil {
610 | return err
611 | }
612 | // Pretend that we read this much data.
613 | s.c.amountRead += s.fc.used - s.readOffset
614 | s.readOffset = s.fc.used
615 | s.c.issueCredit(false)
616 |
617 | return nil
618 | }
619 |
620 | // StopSending requests a reset.
621 | func (s *recvStream) StopSending(code uint16) error {
622 | f := newStopSendingFrame(s.id, code)
623 | return s.c.sendFrame(f)
624 | }
625 |
626 | // stream is a bidirectional stream.
627 | type stream struct {
628 | c *Connection
629 | id uint64
630 |
631 | sendStreamPrivate
632 | recvStreamPrivate
633 | }
634 |
635 | // Compile-time interface check.
636 | var _ Stream = &stream{}
637 |
638 | func newStream(c *Connection, id uint64, sendMax uint64, recvMax uint64) streamPrivate {
639 | return &stream{
640 | sendStreamPrivate: newSendStream(c, id, sendMax),
641 | recvStreamPrivate: newRecvStream(c, id, recvMax),
642 | }
643 | }
644 |
645 | // Id needs to be overwritten so that the ambiguity between send and receive can be resolved.
646 | func (s *stream) Id() uint64 {
647 | return s.sendStreamPrivate.Id()
648 | }
649 |
650 | type streamType uint8
651 |
652 | // These values match the low bits of the stream ID for a client, but the low
653 | // bit is flipped for a server.
654 | const (
655 | streamTypeBidirectionalLocal = streamType(0)
656 | streamTypeBidirectionalRemote = streamType(1)
657 | streamTypeUnidirectionalLocal = streamType(2)
658 | streamTypeUnidirectionalRemote = streamType(3)
659 | )
660 |
661 | func streamTypeFromId(id uint64, role Role) streamType {
662 | t := id & 3
663 | if role == RoleServer {
664 | t ^= 1
665 | }
666 | return streamType(t)
667 | }
668 |
669 | func (t streamType) suffix(role Role) uint64 {
670 | suff := uint64(t)
671 | if role == RoleServer {
672 | suff ^= 1
673 | }
674 | return suff
675 | }
676 |
677 | func (t streamType) String() string {
678 | switch t {
679 | case streamTypeBidirectionalLocal:
680 | return "bidirectional local"
681 | case streamTypeBidirectionalRemote:
682 | return "bidirectional remote"
683 | case streamTypeUnidirectionalLocal:
684 | return "unidirectional local"
685 | case streamTypeUnidirectionalRemote:
686 | return "unidirectional remote"
687 | default:
688 | panic("unknown stream type")
689 | }
690 | }
691 |
692 | type streamSet struct {
693 | // t is the type of stream relative to the endpoints role
694 | t streamType
695 | // role is the endpoint's role
696 | role Role
697 | // nstreams is the maximum number of streams (as opposed to the maximum ID)
698 | nstreams int
699 | // typeless array of streams because go doesn't have generics
700 | streams []hasIdentity
701 | }
702 |
703 | func newStreamSet(t streamType, role Role, nstreams int) *streamSet {
704 | return &streamSet{t, role, nstreams, make([]hasIdentity, 0, nstreams)}
705 | }
706 |
707 | func (ss *streamSet) check(id uint64) {
708 | // If sizeof(int) == sizeof(uint64), then we will never overflow int.
709 | assert(^uint64(0) == uint64(^uint(0)))
710 | assert((id & (^uint64(0) >> 2)) == id) // The top bits should be clear.
711 | assert((id & 3) == ss.t.suffix(ss.role))
712 | }
713 |
714 | func (ss *streamSet) index(id uint64) int {
715 | ss.check(id)
716 | return int(id >> 2)
717 | }
718 |
719 | func (ss *streamSet) id(index int) uint64 {
720 | assert(index >= 0)
721 | return uint64(index<<2) | uint64(ss.t.suffix(ss.role))
722 | }
723 |
724 | type flowControl struct {
725 | unlimited bool
726 | max uint64
727 | used uint64
728 | }
729 |
730 | func newFlowControl(initialMax uint64) flowControl {
731 | fc := flowControl{false, initialMax, 0}
732 | if initialMax == ^uint64(0) {
733 | fc.unlimited = true
734 | }
735 | return fc
736 | }
737 |
738 | func (fc *flowControl) String() string {
739 | if fc.unlimited {
740 | return ("Unlimited")
741 | }
742 | return fmt.Sprintf("%d/%d", fc.used, fc.max)
743 | }
744 |
745 | func (fc *flowControl) update(max uint64) {
746 | if max > fc.max {
747 | fc.max = max
748 | }
749 | }
750 |
751 | func (fc *flowControl) take(other *flowControl, amount uint64) uint64 {
752 | taken := uint64(0)
753 | if !fc.unlimited {
754 | taken = fc.remaining()
755 | if taken > other.remaining() {
756 | taken = other.remaining()
757 | }
758 | } else {
759 | taken = ^uint64(0)
760 | }
761 | if taken > amount {
762 | taken = amount
763 | }
764 |
765 | fc.used += taken
766 | // TODO(ekr@rtfm.com): Is this still needed.
767 | if other != nil {
768 | other.used += taken
769 | }
770 | return taken
771 | }
772 |
773 | func (fc *flowControl) remaining() uint64 {
774 | return fc.max - fc.used
775 | }
776 |
777 | func (ss *streamSet) updateMax(id uint64) {
778 | ss.nstreams = ss.index(id) + 1
779 | }
780 |
781 | func (ss *streamSet) credit(n int) uint64 {
782 | ss.nstreams += n
783 | return ss.id(ss.nstreams - 1)
784 | }
785 |
786 | func (ss *streamSet) get(id uint64) hasIdentity {
787 | i := ss.index(id)
788 | if i >= len(ss.streams) {
789 | return nil
790 | }
791 | return ss.streams[i]
792 | }
793 |
794 | type streamSetCtor func(id uint64) hasIdentity
795 |
796 | func (ss *streamSet) create(ctor streamSetCtor) hasIdentity {
797 | i := len(ss.streams)
798 | if i >= ss.nstreams {
799 | return nil
800 | }
801 | ss.streams = append(ss.streams, ctor(ss.id(i)))
802 | return ss.streams[i]
803 | }
804 |
805 | func (ss *streamSet) ensure(id uint64, ctor streamSetCtor,
806 | notify func(s hasIdentity)) hasIdentity {
807 | i := ss.index(id)
808 | if i >= ss.nstreams {
809 | return nil
810 | }
811 | if i >= len(ss.streams) {
812 | needed := i - len(ss.streams) + 1
813 | start := len(ss.streams)
814 | ss.streams = append(ss.streams, make([]hasIdentity, needed)...)
815 | for j := start; j < len(ss.streams); j++ {
816 | s := ctor(ss.id(j))
817 | ss.check(s.Id())
818 | ss.streams[j] = s
819 | notify(ss.streams[j])
820 | }
821 | }
822 | return ss.streams[i]
823 | }
824 |
825 | func (ss *streamSet) forEach(f func(hasIdentity)) {
826 | for _, s := range ss.streams {
827 | f(s)
828 | }
829 | }
830 |
--------------------------------------------------------------------------------