├── .gitignore ├── h2quic ├── doc.go ├── h2quic_suite_test.go ├── request_body.go ├── gzipreader.go ├── upgrade.go ├── request_body_test.go ├── LICENSE ├── request.go ├── response_writer.go ├── response.go ├── request_test.go ├── request_writer_test.go ├── roundtrip.go ├── response_writer_test.go ├── request_writer.go ├── roundtrip_test.go ├── client.go ├── server.go ├── server_test.go └── client_test.go ├── doc.go ├── proxy.go ├── LICENSE ├── internal ├── testdata │ ├── cert.go │ ├── privkey.pem │ └── fullchain.pem ├── atomic │ └── atomic.go ├── lru │ ├── lru.go │ └── lru_test.go └── socks │ └── socks.go ├── cmd ├── quictun_client │ └── client.go └── quictun_server │ └── server.go ├── response.go ├── README.md ├── server.go ├── request_writer.go └── client.go /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | cmd/quictun_client/quictun_client 3 | cmd/quictun_server/quictun_server 4 | -------------------------------------------------------------------------------- /h2quic/doc.go: -------------------------------------------------------------------------------- 1 | // Package h2quic is a drop-in replacement for quic-go's h2quic package with 2 | // integrated quictun support. 3 | package h2quic 4 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package quictun is an implementation of quictun. The package can be used to 2 | // implement quictun clients and servers. 3 | package quictun 4 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | package quictun 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | func proxy(dst io.WriteCloser, src io.Reader) { 8 | io.Copy(dst, src) 9 | //src.Close() 10 | dst.Close() 11 | //fmt.Println("done proxying") 12 | } 13 | -------------------------------------------------------------------------------- /h2quic/h2quic_suite_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | . "github.com/onsi/ginkgo" 5 | . "github.com/onsi/gomega" 6 | 7 | "testing" 8 | ) 9 | 10 | func TestH2quic(t *testing.T) { 11 | RegisterFailHandler(Fail) 12 | RunSpecs(t, "H2quic Suite") 13 | } 14 | -------------------------------------------------------------------------------- /h2quic/request_body.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "io" 5 | 6 | quic "github.com/lucas-clemente/quic-go" 7 | ) 8 | 9 | type requestBody struct { 10 | requestRead bool 11 | dataStream quic.Stream 12 | } 13 | 14 | // make sure the requestBody can be used as a http.Request.Body 15 | var _ io.ReadCloser = &requestBody{} 16 | 17 | func newRequestBody(stream quic.Stream) *requestBody { 18 | return &requestBody{dataStream: stream} 19 | } 20 | 21 | func (b *requestBody) Read(p []byte) (int, error) { 22 | b.requestRead = true 23 | return b.dataStream.Read(p) 24 | } 25 | 26 | func (b *requestBody) Close() error { 27 | // stream's Close() closes the write side, not the read side 28 | return nil 29 | } 30 | -------------------------------------------------------------------------------- /h2quic/gzipreader.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | // copied from net/transport.go 4 | 5 | // gzipReader wraps a response body so it can lazily 6 | // call gzip.NewReader on the first call to Read 7 | import ( 8 | "compress/gzip" 9 | "io" 10 | ) 11 | 12 | // call gzip.NewReader on the first call to Read 13 | type gzipReader struct { 14 | body io.ReadCloser // underlying Response.Body 15 | zr *gzip.Reader // lazily-initialized gzip reader 16 | zerr error // sticky error 17 | } 18 | 19 | func (gz *gzipReader) Read(p []byte) (n int, err error) { 20 | if gz.zerr != nil { 21 | return 0, gz.zerr 22 | } 23 | if gz.zr == nil { 24 | gz.zr, err = gzip.NewReader(gz.body) 25 | if err != nil { 26 | gz.zerr = err 27 | return 0, err 28 | } 29 | } 30 | return gz.zr.Read(p) 31 | } 32 | 33 | func (gz *gzipReader) Close() error { 34 | return gz.body.Close() 35 | } 36 | -------------------------------------------------------------------------------- /h2quic/upgrade.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "errors" 5 | 6 | quic "github.com/lucas-clemente/quic-go" 7 | ) 8 | 9 | var noKnownUpgradeProtocol = errors.New("no known upgrade protocol") 10 | 11 | // connectionUpgrade indicates that the connection has been upgraded to the 12 | // protocol set within. 13 | type connectionUpgrade struct { 14 | protocol string 15 | } 16 | 17 | func (c *connectionUpgrade) Error() string { 18 | return "connection has been upgraded to " + c.protocol 19 | } 20 | 21 | // UpgradeHandler is a function which can perform an upgrade to another protocol 22 | // by modifying a given QUIC session. 23 | type UpgradeHandler func(quic.Session) 24 | 25 | // map of registered UpgradeHandlers 26 | var upgradeHandlers = map[string]UpgradeHandler{} 27 | 28 | // RegisterUpgradeHandler registers a handler function for the given protocol 29 | // identifier, such as "PROT/1.2". 30 | func RegisterUpgradeHandler(protocol string, handler UpgradeHandler) { 31 | upgradeHandlers[protocol] = handler 32 | } 33 | -------------------------------------------------------------------------------- /h2quic/request_body_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | . "github.com/onsi/ginkgo" 5 | . "github.com/onsi/gomega" 6 | ) 7 | 8 | var _ = Describe("Request body", func() { 9 | var ( 10 | stream *mockStream 11 | rb *requestBody 12 | ) 13 | 14 | BeforeEach(func() { 15 | stream = &mockStream{} 16 | stream.dataToRead.Write([]byte("foobar")) // provides data to be read 17 | rb = newRequestBody(stream) 18 | }) 19 | 20 | It("reads from the stream", func() { 21 | b := make([]byte, 10) 22 | n, _ := stream.Read(b) 23 | Expect(n).To(Equal(6)) 24 | Expect(b[0:6]).To(Equal([]byte("foobar"))) 25 | }) 26 | 27 | It("saves if the stream was read from", func() { 28 | Expect(rb.requestRead).To(BeFalse()) 29 | rb.Read(make([]byte, 1)) 30 | Expect(rb.requestRead).To(BeTrue()) 31 | }) 32 | 33 | It("doesn't close the stream when closing the request body", func() { 34 | Expect(stream.closed).To(BeFalse()) 35 | err := rb.Close() 36 | Expect(err).ToNot(HaveOccurred()) 37 | Expect(stream.closed).To(BeFalse()) 38 | }) 39 | }) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Julien Schmidt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /h2quic/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 the quic-go authors & Google, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /internal/testdata/cert.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | import ( 4 | "crypto/tls" 5 | "path" 6 | "runtime" 7 | ) 8 | 9 | // copied from github.com/lucas-clemente/quic-go/internal/testdata/cert.go 10 | 11 | var certPath string 12 | 13 | func init() { 14 | _, filename, _, ok := runtime.Caller(0) 15 | if !ok { 16 | panic("Failed to get current frame") 17 | } 18 | 19 | certPath = path.Dir(filename) 20 | } 21 | 22 | // GetCertificatePaths returns the paths to 'fullchain.pem' and 'privkey.pem' for the 23 | // quic.clemente.io cert. 24 | func GetCertificatePaths() (string, string) { 25 | return path.Join(certPath, "fullchain.pem"), path.Join(certPath, "privkey.pem") 26 | } 27 | 28 | // GetTLSConfig returns a tls config for quic.clemente.io 29 | func GetTLSConfig() *tls.Config { 30 | cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) 31 | if err != nil { 32 | panic(err) 33 | } 34 | return &tls.Config{ 35 | Certificates: []tls.Certificate{cert}, 36 | } 37 | } 38 | 39 | // GetCertificate returns a certificate for quic.clemente.io 40 | func GetCertificate() tls.Certificate { 41 | cert, err := tls.LoadX509KeyPair(GetCertificatePaths()) 42 | if err != nil { 43 | panic(err) 44 | } 45 | return cert 46 | } 47 | -------------------------------------------------------------------------------- /internal/atomic/atomic.go: -------------------------------------------------------------------------------- 1 | package atomic 2 | 3 | import ( 4 | "sync/atomic" 5 | ) 6 | 7 | // noCopy may be embedded into structs which must not be copied 8 | // after the first use. 9 | // 10 | // See https://github.com/golang/go/issues/8005#issuecomment-190753527 11 | // for details. 12 | type noCopy struct{} 13 | 14 | // Lock is a no-op used by -copylocks checker from `go vet`. 15 | func (*noCopy) Lock() {} 16 | 17 | // Bool is a wrapper around uint32 for usage as a boolean value with 18 | // atomic access. 19 | type Bool struct { 20 | _noCopy noCopy 21 | value uint32 22 | } 23 | 24 | // IsSet returns whether the current boolean value is true 25 | func (b *Bool) IsSet() bool { 26 | return atomic.LoadUint32(&b.value) > 0 27 | } 28 | 29 | // Set sets the value of the bool regardless of the previous value 30 | func (b *Bool) Set(value bool) { 31 | if value { 32 | atomic.StoreUint32(&b.value, 1) 33 | } else { 34 | atomic.StoreUint32(&b.value, 0) 35 | } 36 | } 37 | 38 | // TrySet sets the value of the bool and returns whether the value has changed 39 | func (b *Bool) TrySet(value bool) bool { 40 | if value { 41 | return atomic.SwapUint32(&b.value, 1) == 0 42 | } 43 | return atomic.SwapUint32(&b.value, 0) > 0 44 | } 45 | -------------------------------------------------------------------------------- /cmd/quictun_client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "time" 10 | 11 | "github.com/julienschmidt/quictun" 12 | ) 13 | 14 | const ( 15 | // the User-Agent string is not observable, but should have the same length as a regular browser UA, e.g. that of Chrome 16 | userAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.108 X-quictun/0.1" 17 | 18 | // timeout for establishing connections to quictun server (in seconds) 19 | dialTimeout = 30 20 | ) 21 | 22 | func main() { 23 | // command-line flags and args 24 | listenFlag := flag.String("l", "localhost:1080", "local SOCKS listen address") 25 | insecureFlag := flag.Bool("invalidCerts", false, "accept all invalid certs (insecure)") 26 | flag.Usage = func() { 27 | fmt.Printf("Usage: %s [OPTIONS] QUICTUN_URL\n", os.Args[0]) 28 | flag.PrintDefaults() 29 | } 30 | flag.Parse() 31 | args := flag.Args() 32 | if len(args) != 1 { 33 | flag.Usage() 34 | return 35 | } 36 | tunnelAddr := args[0] 37 | 38 | // configure and run quictun client 39 | client := quictun.Client{ 40 | ListenAddr: *listenFlag, 41 | TunnelAddr: tunnelAddr, 42 | UserAgent: userAgent, 43 | DialTimeout: dialTimeout * time.Second, 44 | TlsCfg: &tls.Config{InsecureSkipVerify: *insecureFlag}, 45 | } 46 | log.Fatal(client.Run()) 47 | } 48 | -------------------------------------------------------------------------------- /internal/testdata/privkey.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC4xixIS9iBFy/k 3 | N8AxF1AKFQqK06aWQ4Sjku677az6on2q8gDCIiup7uaAiyUzmSgPRZRrcBESApw7 4 | w15XN/K17ZQb4Bw7Xp0O7rzKhtwH8ugz+Qs8ceK4ayKTCT/PoPmKnr+sL9Soo0LJ 5 | 8XFvNdY/v3cq6eScnqztOGJV3RSAd34e/EBVkiSOIeXgu8rjtGOxSuuWS52n6NBx 6 | uZrSjP/JClqV26I1NyH+psfoO+zZ8/cJUoC1clalsDhDfX33rPEHJ4KOvB6rAAh9 7 | VcFQtTfuOTIySzPuPMfarhq14lbt+pRjvEMru1uaQYL+t/OWrX/tLc/990G/rVmv 8 | Xfed9/yDAgMBAAECggEAN48PHaYAscBBHERPO/Ogk4eEJf5CJwiiR3UU59ktnCdj 9 | 1hTyeW1A59X35UrxorQ4wW7QlAWcfGfghm/WXC9sgZuwXzliA9ANNcI/bj5ixtkZ 10 | TRdjc4di/sToHoI3d70Vi8L0K1guf46ntIUu8JulkoGF2Zd+sEFeCe5cUyko0v+Y 11 | SGd1WHmMm526Iw07FkS642Vqdx9DcQZNeSTi8girFGOoLEWkP5JgrUqoLvokBPDW 12 | c/sa4zLsscFCYwafG2olQZLrHlyXlMv6BRKWa/5T62DF1+uVatYOgQgH9/RIu9Jh 13 | jmMgxnA+enXOyGdj/b0lYm0uV/IEUyDtPMSeYDAgcQKBgQDgGX2WQANbxHvvOwO+ 14 | 0radIIsxvQBsO/fyc0NaKNXVgGBfoXomg4iLb2qkDHQkVQS4xiOwfZ6PZxLTa1eT 15 | wZPL749qqPPEwuIJYiJ+6fNNhqLb/qC4UbIQpG7yb+G7O6yjPwnLmHoX9qfcF4He 16 | NojNHB9ltQIe3SgGG96CG3maZwKBgQDTE5s4xyH1OkaPnm0dCgEqDHHpcNLqVNH+ 17 | gSwf/eBBDmclv26Y6HGa9GrPeVSk9kGxF4n5AXm57tnTeMFc8nNzCqcgvaTwRahq 18 | 6ur79PJmKRDTX6/RWK8Srh7Sox5L7jr1Dn/qyOT2579B9eovZRCthZjw21XJGjW2 19 | i5nrdQ3zhQKBgQCjqToUns9VF5vDTQAhPlXrTrcZLgS/BtS/lfocQDJaaBT6Aj3p 20 | Hqp72nSxNf8kAYsfPmUWIcfIxufyyzP8TqUXjO7aYGUWz5SwcaDruwPbHHaX38+U 21 | jOVUTiJQn/DlAmHEHueSbtrL4XEZxXksxfsGgIFVj+nqjG0MeRH5RwN6BQKBgBdv 22 | GdCX6yE6sxLG1/5dWfu9Hfh42jHB8P58gNWcbgVLABCkzDaVt+coM6ONKOSXonty 23 | zZKjo0wNRInB4lXbZQ3kpOFxrJowYZ5dLnGCpFbLQF73RKHNYsKEKk/gZECx1kHW 24 | tkTuwNzYpddA4hsY8V0SdARplYCaNFRr80681CuxAoGASqAbFhJk6YhvBcAEDZit 25 | h92qsH1GCYdqsnbvnJtZIDYch+uyURNnO7sjomm8NkAqjvlC+YfvPdEXUliXfdg9 26 | MkwkZJrCwH/RlHg429agimrUnsFdO26rxjzLFBaUOrfAkYru3YvpHbpjVTTt7O0C 27 | 7QNPy6556B/57VNoAToQOhM= 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /cmd/quictun_server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/julienschmidt/quictun" 10 | "github.com/julienschmidt/quictun/h2quic" 11 | "github.com/julienschmidt/quictun/internal/lru" 12 | "github.com/julienschmidt/quictun/internal/testdata" 13 | ) 14 | 15 | const ( 16 | dialTimeout = 30 17 | ) 18 | 19 | func main() { 20 | // command-line args 21 | listenFlag := flag.String("l", "localhost:6121", "QUIC listen address") 22 | flag.Parse() 23 | args := flag.Args() 24 | if len(args) > 0 { 25 | flag.Usage() 26 | return 27 | } 28 | listenAddr := *listenFlag 29 | 30 | quictunServer := quictun.Server{ 31 | DialTimeout: dialTimeout * time.Second, 32 | SequenceCache: lru.New(10), 33 | } 34 | 35 | // Register the upgrade handler for the quictun protocol 36 | h2quic.RegisterUpgradeHandler("QTP/0.1", quictunServer.Upgrade) 37 | 38 | http.HandleFunc("/secret", func(w http.ResponseWriter, r *http.Request) { 39 | // replay protection 40 | if !quictunServer.CheckSequenceNumber(r.Header.Get("QTP")) { 41 | w.Header().Set("Connection", "close") 42 | w.WriteHeader(http.StatusBadRequest) 43 | r.Close = true 44 | return 45 | } 46 | 47 | // switch to quictun protocol (version 0.1) 48 | w.Header().Set("Connection", "Upgrade") 49 | w.Header().Set("Upgrade", "QTP/0.1") 50 | w.WriteHeader(http.StatusSwitchingProtocols) 51 | }) 52 | 53 | // HTTP server 54 | // Implementations for production usage should be embedded in an existing web server instead. 55 | server := h2quic.Server{ 56 | Server: &http.Server{Addr: listenAddr}, 57 | } 58 | certFile, keyFile := testdata.GetCertificatePaths() 59 | fmt.Printf("Start listening on %s...\n", listenAddr) 60 | err := server.ListenAndServeTLS(certFile, keyFile) 61 | if err != nil { 62 | fmt.Println(err) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /h2quic/request.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "net/http" 7 | "net/url" 8 | "strconv" 9 | "strings" 10 | 11 | "golang.org/x/net/http2/hpack" 12 | ) 13 | 14 | func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { 15 | var path, authority, method, contentLengthStr string 16 | httpHeaders := http.Header{} 17 | 18 | for _, h := range headers { 19 | switch h.Name { 20 | case ":path": 21 | path = h.Value 22 | case ":method": 23 | method = h.Value 24 | case ":authority": 25 | authority = h.Value 26 | case "content-length": 27 | contentLengthStr = h.Value 28 | default: 29 | if !h.IsPseudo() { 30 | httpHeaders.Add(h.Name, h.Value) 31 | } 32 | } 33 | } 34 | 35 | // concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4 36 | if len(httpHeaders["Cookie"]) > 0 { 37 | httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; ")) 38 | } 39 | 40 | if len(path) == 0 || len(authority) == 0 || len(method) == 0 { 41 | return nil, errors.New(":path, :authority and :method must not be empty") 42 | } 43 | 44 | u, err := url.Parse(path) 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | var contentLength int64 50 | if len(contentLengthStr) > 0 { 51 | contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) 52 | if err != nil { 53 | return nil, err 54 | } 55 | } 56 | 57 | return &http.Request{ 58 | Method: method, 59 | URL: u, 60 | Proto: "HTTP/2.0", 61 | ProtoMajor: 2, 62 | ProtoMinor: 0, 63 | Header: httpHeaders, 64 | Body: nil, 65 | ContentLength: contentLength, 66 | Host: authority, 67 | RequestURI: path, 68 | TLS: &tls.ConnectionState{}, 69 | }, nil 70 | } 71 | 72 | func hostnameFromRequest(req *http.Request) string { 73 | if len(req.Host) > 0 { 74 | return req.Host 75 | } 76 | if req.URL != nil { 77 | return req.URL.Host 78 | } 79 | return "" 80 | } 81 | -------------------------------------------------------------------------------- /internal/lru/lru.go: -------------------------------------------------------------------------------- 1 | package lru 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type entry struct { 8 | key uint64 9 | value uint32 10 | next *entry 11 | } 12 | 13 | // LRU is an LRU cache. 14 | // Concurrent access is synchronized. 15 | // 16 | // A map is used as the index. 17 | // The LRU order is tracked in a linked list. 18 | type LRU struct { 19 | capacity int // max number of entries 20 | head *entry // first entry in LRU order 21 | cache map[uint64]*entry // mapping of all key-value pairs 22 | lock sync.Mutex // guards the whole struct 23 | } 24 | 25 | // NewLRU creates a new LRU cache with the given capacity 26 | func New(capacity int) *LRU { 27 | if capacity < 2 { 28 | panic("capacity must be at least 2") 29 | } 30 | return &LRU{ 31 | capacity: capacity, 32 | cache: make(map[uint64]*entry, capacity), 33 | } 34 | } 35 | 36 | // Set sets the value for the given key. If an entry for the given key already 37 | // exists, it is overwritten. 38 | func (l *LRU) Set(key uint64, value uint32) (old uint32) { 39 | l.lock.Lock() 40 | if ep, ok := l.cache[key]; ok { 41 | old = ep.value 42 | ep.value = value 43 | l.moveToFront(ep) 44 | l.lock.Unlock() 45 | return 46 | } 47 | 48 | // insert new entry for key 49 | ep := new(entry) 50 | ep.key = key 51 | ep.value = value 52 | ep.next = l.head 53 | l.head = ep 54 | l.cache[key] = ep 55 | 56 | if len(l.cache) > l.capacity { 57 | l.removeLast() 58 | } 59 | l.lock.Unlock() 60 | return 61 | } 62 | 63 | // Get returns the current value for the given key. 64 | // If no value for the given key exists, 0 is returned. 65 | func (l *LRU) Get(key uint64) (value uint32) { 66 | l.lock.Lock() 67 | if ep, ok := l.cache[key]; ok { 68 | value = ep.value 69 | l.moveToFront(ep) 70 | } 71 | l.lock.Unlock() 72 | return value 73 | } 74 | 75 | func (l *LRU) moveToFront(ep *entry) { 76 | // move entry to front 77 | if l.head != ep { 78 | after := ep.next 79 | cur := l.head 80 | ep.next = cur 81 | for cur.next != ep { 82 | cur = cur.next 83 | } 84 | cur.next = after 85 | l.head = ep 86 | } 87 | } 88 | 89 | func (l *LRU) removeLast() { 90 | // remove last entry from list 91 | prev := l.head 92 | last := prev.next 93 | for last.next != nil { 94 | prev = last 95 | last = last.next 96 | } 97 | prev.next = nil 98 | 99 | // remove from cache 100 | delete(l.cache, last.key) 101 | } 102 | -------------------------------------------------------------------------------- /internal/lru/lru_test.go: -------------------------------------------------------------------------------- 1 | package lru 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func printList(l *LRU) { 9 | cur := l.head 10 | i := 0 11 | for cur != nil { 12 | fmt.Println(i, cur.key, cur.value) 13 | cur = cur.next 14 | i++ 15 | } 16 | } 17 | 18 | func TestLRU(t *testing.T) { 19 | lru := NewLRU(2) 20 | 21 | if n := len(lru.cache); n != 0 { 22 | t.Fatalf("cache should be empty, actually has %d elements", n) 23 | } 24 | 25 | // insert first value 26 | old := lru.Set(1337, 42) 27 | if old != 0 { 28 | t.Fatalf("old value for new key is %d, should be 0", old) 29 | } 30 | 31 | if n := len(lru.cache); n != 1 { 32 | t.Fatalf("cache should have 1 element, actually has %d elements", n) 33 | } 34 | 35 | // overwrite existing value 36 | old = lru.Set(1337, 43) 37 | if old != 42 { 38 | t.Fatalf("old value for existing key is %d, should be 42", old) 39 | } 40 | 41 | if n := len(lru.cache); n != 1 { 42 | t.Fatalf("cache should have 1 element, actually has %d elements", n) 43 | } 44 | 45 | // insert second value 46 | old = lru.Set(1338, 42) 47 | if old != 0 { 48 | t.Fatalf("old value for new key is %d, should be 0", old) 49 | } 50 | 51 | if n := len(lru.cache); n != 2 { 52 | t.Fatalf("cache should have 2 elements, actually has %d elements", n) 53 | } 54 | 55 | if head := lru.head.key; head != 1338 { 56 | t.Fatalf("newly inserted element is not head, key of head is %d", head) 57 | } 58 | 59 | // access the older value 60 | if v1 := lru.Get(1337); v1 != 43 { 61 | t.Fatalf("value of the first entry changed, should be 43, is %d", v1) 62 | } 63 | 64 | if head := lru.head.key; head != 1337 { 65 | t.Fatalf("accessed element is not head, key of head is %d", head) 66 | } 67 | 68 | // overwrite existing value 69 | old = lru.Set(1337, 42) 70 | if old != 43 { 71 | t.Fatalf("old value for existing key is %d, should be 43", old) 72 | } 73 | 74 | if n := len(lru.cache); n != 2 { 75 | t.Fatalf("cache should have 2 elements, actually has %d elements", n) 76 | } 77 | 78 | //printList(lru) 79 | 80 | // insert third value, removing the second 81 | old = lru.Set(1339, 7) 82 | if old != 0 { 83 | t.Fatalf("old value for new key is %d, should be 0", old) 84 | } 85 | 86 | if n := len(lru.cache); n != 2 { 87 | t.Fatalf("cache should have 2 elements, actually has %d elements", n) 88 | } 89 | 90 | if head := lru.head.key; head != 1339 { 91 | t.Fatalf("newly inserted element is not head, key of head is %d", head) 92 | } 93 | 94 | //printList(lru) 95 | } 96 | -------------------------------------------------------------------------------- /response.go: -------------------------------------------------------------------------------- 1 | package quictun 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "io/ioutil" 8 | "net/http" 9 | "net/textproto" 10 | "strconv" 11 | "strings" 12 | 13 | "golang.org/x/net/http2" 14 | ) 15 | 16 | // copied from net/http2/transport.go 17 | 18 | var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") 19 | var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) 20 | 21 | // from the handleResponse function 22 | func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { 23 | if f.Truncated { 24 | return nil, errResponseHeaderListSize 25 | } 26 | 27 | status := f.PseudoValue("status") 28 | if status == "" { 29 | return nil, errors.New("missing status pseudo header") 30 | } 31 | statusCode, err := strconv.Atoi(status) 32 | if err != nil { 33 | return nil, errors.New("malformed non-numeric status pseudo header") 34 | } 35 | 36 | header := make(http.Header) 37 | res := &http.Response{ 38 | Proto: "HTTP/2.0", 39 | ProtoMajor: 2, 40 | Header: header, 41 | StatusCode: statusCode, 42 | Status: status + " " + http.StatusText(statusCode), 43 | } 44 | for _, hf := range f.RegularFields() { 45 | key := http.CanonicalHeaderKey(hf.Name) 46 | if key == "Trailer" { 47 | t := res.Trailer 48 | if t == nil { 49 | t = make(http.Header) 50 | res.Trailer = t 51 | } 52 | foreachHeaderElement(hf.Value, func(v string) { 53 | t[http.CanonicalHeaderKey(v)] = nil 54 | }) 55 | } else { 56 | header[key] = append(header[key], hf.Value) 57 | } 58 | } 59 | 60 | return res, nil 61 | } 62 | 63 | // continuation of the handleResponse function 64 | func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { 65 | if !streamEnded || isHead { 66 | res.ContentLength = -1 67 | if clens := res.Header["Content-Length"]; len(clens) == 1 { 68 | if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { 69 | res.ContentLength = clen64 70 | } 71 | } 72 | } 73 | return res 74 | } 75 | 76 | // copied from net/http/server.go 77 | 78 | // foreachHeaderElement splits v according to the "#rule" construction 79 | // in RFC 2616 section 2.1 and calls fn for each non-empty element. 80 | func foreachHeaderElement(v string, fn func(string)) { 81 | v = textproto.TrimString(v) 82 | if v == "" { 83 | return 84 | } 85 | if !strings.Contains(v, ",") { 86 | fn(v) 87 | return 88 | } 89 | for _, f := range strings.Split(v, ",") { 90 | if f = textproto.TrimString(f); f != "" { 91 | fn(f) 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # quictun [![GoDoc](https://godoc.org/github.com/julienschmidt/quictun?status.svg)](http://godoc.org/github.com/julienschmidt/quictun) 2 | 3 | quictun is a simple hidden tunnel based on the QUIC protocol. 4 | 5 | This repository contains a proof-of-concept implementation of [quictun](https://github.com/julienschmidt/quictun-thesis). 6 | Its purpose is to demonstrate that quictun clients and servers can be implemented with minimal effort on top of an existing QUIC and HTTP/2 over QUIC implementation. 7 | The implementation uses the [quic-go](https://github.com/lucas-clemente/quic-go) QUIC implementation as a basis. 8 | 9 | Note that while quictun is meant to be implemented on top of [IETF QUIC](https://datatracker.ietf.org/wg/quic/about/), this proof-of-concept implementation uses Google QUIC instead, as at the time of development no usable implementation of the (still work-in-progress) IETF version exists. Due to the limitations of the underlying QUIC implementation, this quictun implementation is neither meant for production usage, nor for performance evaluation of the approach. 10 | 11 | 12 | ## Overview 13 | 14 | `h2quic` is a fork of [github.com/lucas-clemente/quic-go/h2quic](https://github.com/lucas-clemente/quic-go/tree/master/h2quic). It adds the upgrade mechanism to the HTTP/2 over QUIC (h2quic) implementation. The fork can be used as a drop-in replacement for the upstream package to add support for quictun. 15 | 16 | `cmd/quictun_client` contains a very minimal client example. Actual clients MUST take care to be indistinguishable from an legitimate HTTP/2 over QUIC client, which a censor is unwilling to block, at the wire level. This could be achieved e.g. by reusing the net stack of a QUIC-capable web browser. 17 | 18 | `cmd/quictun_server` likewise contains a minimal server example. Note that this example server is easily fingerprintable and thus blockable. 19 | 20 | 21 | ## Installation 22 | 23 | ```sh 24 | go get -u github.com/julienschmidt/quictun 25 | ``` 26 | 27 | 28 | ## Usage 29 | 30 | Clients should use the [`quictun.Client` struct](https://godoc.org/github.com/julienschmidt/quictun#Client). An example client can be found in `cmd/quictun_client`. 31 | 32 | Servers should either use the [`quictun.Server` struct](https://godoc.org/github.com/julienschmidt/quictun#Server) directly and manually implement the upgrade mechanism in the web server, or use the [`h2quic`](https://godoc.org/github.com/julienschmidt/quictun/h2quic) sub-package. 33 | 34 | A valid certificate is required to operate a server, which can e.g. be acquired from [Let's Encrypt](https://letsencrypt.org/). For testing purposes, the client may be insecurely configured to allow any, possible invalid, certificate instead. The example client provides a `-invalidCerts` flag for that purpose. -------------------------------------------------------------------------------- /h2quic/response_writer.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "strconv" 8 | "strings" 9 | "sync" 10 | 11 | quic "github.com/lucas-clemente/quic-go" 12 | "golang.org/x/net/http2" 13 | "golang.org/x/net/http2/hpack" 14 | ) 15 | 16 | type responseWriter struct { 17 | dataStreamID quic.StreamID 18 | dataStream quic.Stream 19 | 20 | headerStream quic.Stream 21 | headerStreamMutex *sync.Mutex 22 | 23 | header http.Header 24 | status int // status code passed to WriteHeader 25 | headerWritten bool 26 | } 27 | 28 | func newResponseWriter(headerStream quic.Stream, headerStreamMutex *sync.Mutex, dataStream quic.Stream, dataStreamID quic.StreamID) *responseWriter { 29 | return &responseWriter{ 30 | header: http.Header{}, 31 | headerStream: headerStream, 32 | headerStreamMutex: headerStreamMutex, 33 | dataStream: dataStream, 34 | dataStreamID: dataStreamID, 35 | } 36 | } 37 | 38 | func (w *responseWriter) Header() http.Header { 39 | return w.header 40 | } 41 | 42 | func (w *responseWriter) WriteHeader(status int) { 43 | if w.headerWritten { 44 | return 45 | } 46 | w.headerWritten = true 47 | w.status = status 48 | 49 | var headers bytes.Buffer 50 | enc := hpack.NewEncoder(&headers) 51 | enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)}) 52 | 53 | for k, v := range w.header { 54 | for index := range v { 55 | enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]}) 56 | } 57 | } 58 | 59 | fmt.Printf("Responding with %d\n", status) 60 | w.headerStreamMutex.Lock() 61 | defer w.headerStreamMutex.Unlock() 62 | h2framer := http2.NewFramer(w.headerStream, nil) 63 | err := h2framer.WriteHeaders(http2.HeadersFrameParam{ 64 | StreamID: uint32(w.dataStreamID), 65 | EndHeaders: true, 66 | BlockFragment: headers.Bytes(), 67 | }) 68 | if err != nil { 69 | fmt.Printf("could not write h2 header: %s\n", err.Error()) 70 | } 71 | } 72 | 73 | func (w *responseWriter) Write(p []byte) (int, error) { 74 | if !w.headerWritten { 75 | w.WriteHeader(200) 76 | } 77 | if !bodyAllowedForStatus(w.status) { 78 | return 0, http.ErrBodyNotAllowed 79 | } 80 | return w.dataStream.Write(p) 81 | } 82 | 83 | func (w *responseWriter) Flush() {} 84 | 85 | // This is a NOP. Use http.Request.Context 86 | func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) } 87 | 88 | // test that we implement http.Flusher 89 | var _ http.Flusher = &responseWriter{} 90 | 91 | // test that we implement http.CloseNotifier 92 | var _ http.CloseNotifier = &responseWriter{} 93 | 94 | // copied from http2/http2.go 95 | // bodyAllowedForStatus reports whether a given response status code 96 | // permits a body. See RFC 2616, section 4.4. 97 | func bodyAllowedForStatus(status int) bool { 98 | switch { 99 | case status >= 100 && status <= 199: 100 | return false 101 | case status == 204: 102 | return false 103 | case status == 304: 104 | return false 105 | } 106 | return true 107 | } 108 | -------------------------------------------------------------------------------- /h2quic/response.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "io/ioutil" 8 | "net/http" 9 | "net/textproto" 10 | "strconv" 11 | "strings" 12 | 13 | "golang.org/x/net/http2" 14 | ) 15 | 16 | // copied from net/http2/transport.go 17 | 18 | var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") 19 | var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) 20 | 21 | // from the handleResponse function 22 | func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) { 23 | if f.Truncated { 24 | return nil, errResponseHeaderListSize 25 | } 26 | 27 | status := f.PseudoValue("status") 28 | if status == "" { 29 | return nil, errors.New("missing status pseudo header") 30 | } 31 | statusCode, err := strconv.Atoi(status) 32 | if err != nil { 33 | return nil, errors.New("malformed non-numeric status pseudo header") 34 | } 35 | 36 | if statusCode == 100 { 37 | // TODO: handle this 38 | 39 | // traceGot100Continue(cs.trace) 40 | // if cs.on100 != nil { 41 | // cs.on100() // forces any write delay timer to fire 42 | // } 43 | // cs.pastHeaders = false // do it all again 44 | // return nil, nil 45 | } 46 | 47 | header := make(http.Header) 48 | res := &http.Response{ 49 | Proto: "HTTP/2.0", 50 | ProtoMajor: 2, 51 | Header: header, 52 | StatusCode: statusCode, 53 | Status: status + " " + http.StatusText(statusCode), 54 | } 55 | for _, hf := range f.RegularFields() { 56 | key := http.CanonicalHeaderKey(hf.Name) 57 | if key == "Trailer" { 58 | t := res.Trailer 59 | if t == nil { 60 | t = make(http.Header) 61 | res.Trailer = t 62 | } 63 | foreachHeaderElement(hf.Value, func(v string) { 64 | t[http.CanonicalHeaderKey(v)] = nil 65 | }) 66 | } else { 67 | header[key] = append(header[key], hf.Value) 68 | } 69 | } 70 | 71 | return res, nil 72 | } 73 | 74 | // continuation of the handleResponse function 75 | func setLength(res *http.Response, isHead, streamEnded bool) *http.Response { 76 | if !streamEnded || isHead { 77 | res.ContentLength = -1 78 | if clens := res.Header["Content-Length"]; len(clens) == 1 { 79 | if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { 80 | res.ContentLength = clen64 81 | } else { 82 | // TODO: care? unlike http/1, it won't mess up our framing, so it's 83 | // more safe smuggling-wise to ignore. 84 | } 85 | } else if len(clens) > 1 { 86 | // TODO: care? unlike http/1, it won't mess up our framing, so it's 87 | // more safe smuggling-wise to ignore. 88 | } 89 | } 90 | return res 91 | } 92 | 93 | // copied from net/http/server.go 94 | 95 | // foreachHeaderElement splits v according to the "#rule" construction 96 | // in RFC 2616 section 2.1 and calls fn for each non-empty element. 97 | func foreachHeaderElement(v string, fn func(string)) { 98 | v = textproto.TrimString(v) 99 | if v == "" { 100 | return 101 | } 102 | if !strings.Contains(v, ",") { 103 | fn(v) 104 | return 105 | } 106 | for _, f := range strings.Split(v, ",") { 107 | if f = textproto.TrimString(f); f != "" { 108 | fn(f) 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package quictun 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "net" 7 | "strconv" 8 | "time" 9 | 10 | "github.com/julienschmidt/quictun/internal/socks" 11 | quic "github.com/lucas-clemente/quic-go" 12 | ) 13 | 14 | // SequenceCache is a cache for client sequence numbers. 15 | // Implementations should limit the number of cached key-value pairs using a 16 | // strategy like least recently used (LRU). 17 | type SequenceCache interface { 18 | Set(key uint64, value uint32) (old uint32) 19 | Get(key uint64) (value uint32) 20 | } 21 | 22 | // Server is a quictun server which handles QUIC sessions upgraded to the 23 | // quictun protocol. 24 | type Server struct { 25 | DialTimeout time.Duration 26 | SequenceCache SequenceCache 27 | } 28 | 29 | // CheckSequenceNumber checks and caches the sequence number sent by a client 30 | func (s *Server) CheckSequenceNumber(header string) bool { 31 | // parse clientID and sequenceNumber from header value 32 | if len(header) != 24 { 33 | return false 34 | } 35 | clientID, err := strconv.ParseUint(header[:16], 16, 64) 36 | if err != nil { 37 | return false 38 | } 39 | sequenceNumber, err := strconv.ParseUint(header[16:], 16, 32) 40 | if err != nil { 41 | return false 42 | } 43 | 44 | // the new sequence number must be larger than any previously seen number 45 | return s.SequenceCache.Set(clientID, uint32(sequenceNumber)) < uint32(sequenceNumber) 46 | } 47 | 48 | // Upgrade starts using a given QUIC session with the quictun protocol. 49 | // The quictun server immediately starts accepting new QUIC streams and assumes 50 | // them to speak the quictun protocol (QTP). 51 | // The actual protocol upgrade (via a HTTP/2 request-response) is handled 52 | // entirely by the web server. 53 | func (s *Server) Upgrade(session quic.Session) { 54 | for { 55 | fmt.Println("Waiting for stream...") 56 | stream, err := session.AcceptStream() 57 | if err != nil { 58 | fmt.Println("accept stream:", err) 59 | session.Close(err) 60 | return 61 | } 62 | 63 | go s.handleQuictunStream(stream) 64 | } 65 | } 66 | 67 | func (s *Server) handleQuictunStream(stream quic.Stream) { 68 | streamID := stream.StreamID() 69 | fmt.Println("got stream", streamID) 70 | 71 | streamRd := bufio.NewReader(stream) 72 | req, err := socks.PeekRequest(streamRd) 73 | if err != nil { 74 | stream.Reset(err) 75 | stream.Close() 76 | fmt.Println("stream", streamID, ":", err) 77 | return 78 | } 79 | 80 | switch req.Cmd() { 81 | case socks.CmdConnect: 82 | remote, err := net.DialTimeout("tcp", req.Dest().String(), s.DialTimeout) 83 | if err != nil { 84 | fmt.Printf("stream %d: %#v\n", streamID, err) 85 | stream.Reset(nil) 86 | stream.Close() 87 | return 88 | } 89 | // remove request header from buffer 90 | if _, err = streamRd.Discard(len(req)); err != nil { 91 | stream.Reset(nil) 92 | stream.Close() 93 | remote.Close() 94 | fmt.Println("stream", streamID, ":", err) 95 | return 96 | } 97 | 98 | fmt.Println("Start proxying...") 99 | go proxy(stream, remote) // recv from remote and send to stream 100 | proxy(remote, streamRd) // recv from stream and send to remote 101 | default: 102 | socks.SendReply(stream, socks.StatusCmdNotSupported, nil) 103 | stream.Reset(nil) 104 | stream.Close() 105 | return 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /internal/testdata/fullchain.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIFAzCCA+ugAwIBAgISAwDXm1WX6waPiT39WxQ11nS6MA0GCSqGSIb3DQEBCwUA 3 | MEoxCzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MSMwIQYDVQQD 4 | ExpMZXQncyBFbmNyeXB0IEF1dGhvcml0eSBYMzAeFw0xNzA5MjgwNTU2MDBaFw0x 5 | NzEyMjcwNTU2MDBaMBsxGTAXBgNVBAMTEHF1aWMuY2xlbWVudGUuaW8wggEiMA0G 6 | CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4xixIS9iBFy/kN8AxF1AKFQqK06aW 7 | Q4Sjku677az6on2q8gDCIiup7uaAiyUzmSgPRZRrcBESApw7w15XN/K17ZQb4Bw7 8 | Xp0O7rzKhtwH8ugz+Qs8ceK4ayKTCT/PoPmKnr+sL9Soo0LJ8XFvNdY/v3cq6eSc 9 | nqztOGJV3RSAd34e/EBVkiSOIeXgu8rjtGOxSuuWS52n6NBxuZrSjP/JClqV26I1 10 | NyH+psfoO+zZ8/cJUoC1clalsDhDfX33rPEHJ4KOvB6rAAh9VcFQtTfuOTIySzPu 11 | PMfarhq14lbt+pRjvEMru1uaQYL+t/OWrX/tLc/990G/rVmvXfed9/yDAgMBAAGj 12 | ggIQMIICDDAOBgNVHQ8BAf8EBAMCBaAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG 13 | AQUFBwMCMAwGA1UdEwEB/wQCMAAwHQYDVR0OBBYEFHYR4Sg0neNoAywvJB2+5/jP 14 | eon8MB8GA1UdIwQYMBaAFKhKamMEfd265tE5t6ZFZe/zqOyhMG8GCCsGAQUFBwEB 15 | BGMwYTAuBggrBgEFBQcwAYYiaHR0cDovL29jc3AuaW50LXgzLmxldHNlbmNyeXB0 16 | Lm9yZzAvBggrBgEFBQcwAoYjaHR0cDovL2NlcnQuaW50LXgzLmxldHNlbmNyeXB0 17 | Lm9yZy8wGwYDVR0RBBQwEoIQcXVpYy5jbGVtZW50ZS5pbzCB/gYDVR0gBIH2MIHz 18 | MAgGBmeBDAECATCB5gYLKwYBBAGC3xMBAQEwgdYwJgYIKwYBBQUHAgEWGmh0dHA6 19 | Ly9jcHMubGV0c2VuY3J5cHQub3JnMIGrBggrBgEFBQcCAjCBngyBm1RoaXMgQ2Vy 20 | dGlmaWNhdGUgbWF5IG9ubHkgYmUgcmVsaWVkIHVwb24gYnkgUmVseWluZyBQYXJ0 21 | aWVzIGFuZCBvbmx5IGluIGFjY29yZGFuY2Ugd2l0aCB0aGUgQ2VydGlmaWNhdGUg 22 | UG9saWN5IGZvdW5kIGF0IGh0dHBzOi8vbGV0c2VuY3J5cHQub3JnL3JlcG9zaXRv 23 | cnkvMA0GCSqGSIb3DQEBCwUAA4IBAQBpKQvttZCPi+NKj2iQYvdajk3JVFmSkBo6 24 | PiatcL4VnitR+ld/e8L9cqm0T1MbVrbFmsxR5yCqzYVTCm4s4HuUU2MGhgo2G5UW 25 | ZFSDsOJWHbFb1Q2sA7V0vdf9EMuhcHv/6lVx0wrBRt8deJFQlVxOthriBbNuXU51 26 | U+t6MsnQdCk+az/82uBYXK1DAthD42EeR3rS3uFhvAlRx/FZBt7yq+QuFQfDRqdG 27 | hRgPzEYR1yCBaIQuTYL1EJAYeR1HqaY0NeybIvieXSsAlWOTc6aMmNMOSVkeiYwp 28 | UUN+0WQQ7Wj20ZNay1hDpQrvwa/bq2n12cJjpjSo1SrJo3Ph6VPt 29 | -----END CERTIFICATE----- 30 | -----BEGIN CERTIFICATE----- 31 | MIIEkjCCA3qgAwIBAgIQCgFBQgAAAVOFc2oLheynCDANBgkqhkiG9w0BAQsFADA/ 32 | MSQwIgYDVQQKExtEaWdpdGFsIFNpZ25hdHVyZSBUcnVzdCBDby4xFzAVBgNVBAMT 33 | DkRTVCBSb290IENBIFgzMB4XDTE2MDMxNzE2NDA0NloXDTIxMDMxNzE2NDA0Nlow 34 | SjELMAkGA1UEBhMCVVMxFjAUBgNVBAoTDUxldCdzIEVuY3J5cHQxIzAhBgNVBAMT 35 | GkxldCdzIEVuY3J5cHQgQXV0aG9yaXR5IFgzMIIBIjANBgkqhkiG9w0BAQEFAAOC 36 | AQ8AMIIBCgKCAQEAnNMM8FrlLke3cl03g7NoYzDq1zUmGSXhvb418XCSL7e4S0EF 37 | q6meNQhY7LEqxGiHC6PjdeTm86dicbp5gWAf15Gan/PQeGdxyGkOlZHP/uaZ6WA8 38 | SMx+yk13EiSdRxta67nsHjcAHJyse6cF6s5K671B5TaYucv9bTyWaN8jKkKQDIZ0 39 | Z8h/pZq4UmEUEz9l6YKHy9v6Dlb2honzhT+Xhq+w3Brvaw2VFn3EK6BlspkENnWA 40 | a6xK8xuQSXgvopZPKiAlKQTGdMDQMc2PMTiVFrqoM7hD8bEfwzB/onkxEz0tNvjj 41 | /PIzark5McWvxI0NHWQWM6r6hCm21AvA2H3DkwIDAQABo4IBfTCCAXkwEgYDVR0T 42 | AQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAYYwfwYIKwYBBQUHAQEEczBxMDIG 43 | CCsGAQUFBzABhiZodHRwOi8vaXNyZy50cnVzdGlkLm9jc3AuaWRlbnRydXN0LmNv 44 | bTA7BggrBgEFBQcwAoYvaHR0cDovL2FwcHMuaWRlbnRydXN0LmNvbS9yb290cy9k 45 | c3Ryb290Y2F4My5wN2MwHwYDVR0jBBgwFoAUxKexpHsscfrb4UuQdf/EFWCFiRAw 46 | VAYDVR0gBE0wSzAIBgZngQwBAgEwPwYLKwYBBAGC3xMBAQEwMDAuBggrBgEFBQcC 47 | ARYiaHR0cDovL2Nwcy5yb290LXgxLmxldHNlbmNyeXB0Lm9yZzA8BgNVHR8ENTAz 48 | MDGgL6AthitodHRwOi8vY3JsLmlkZW50cnVzdC5jb20vRFNUUk9PVENBWDNDUkwu 49 | Y3JsMB0GA1UdDgQWBBSoSmpjBH3duubRObemRWXv86jsoTANBgkqhkiG9w0BAQsF 50 | AAOCAQEA3TPXEfNjWDjdGBX7CVW+dla5cEilaUcne8IkCJLxWh9KEik3JHRRHGJo 51 | uM2VcGfl96S8TihRzZvoroed6ti6WqEBmtzw3Wodatg+VyOeph4EYpr/1wXKtx8/ 52 | wApIvJSwtmVi4MFU5aMqrSDE6ea73Mj2tcMyo5jMd6jmeWUHK8so/joWUoHOUgwu 53 | X4Po1QYz+3dszkDqMp4fklxBwXRsW10KXzPMTZ+sOPAveyxindmjkW8lGy+QsRlG 54 | PfZ+G6Z6h7mjem0Y+iWlkYcV4PIWL1iwBi8saCbGS5jN2p8M+X+Q7UNKEkROb3N6 55 | KOqkqm57TH2H3eDJAkSnh6/DNFu0Qg== 56 | -----END CERTIFICATE----- 57 | -------------------------------------------------------------------------------- /h2quic/request_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | 7 | "golang.org/x/net/http2/hpack" 8 | 9 | . "github.com/onsi/ginkgo" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | var _ = Describe("Request", func() { 14 | It("populates request", func() { 15 | headers := []hpack.HeaderField{ 16 | {Name: ":path", Value: "/foo"}, 17 | {Name: ":authority", Value: "quic.clemente.io"}, 18 | {Name: ":method", Value: "GET"}, 19 | {Name: "content-length", Value: "42"}, 20 | } 21 | req, err := requestFromHeaders(headers) 22 | Expect(err).NotTo(HaveOccurred()) 23 | Expect(req.Method).To(Equal("GET")) 24 | Expect(req.URL.Path).To(Equal("/foo")) 25 | Expect(req.Proto).To(Equal("HTTP/2.0")) 26 | Expect(req.ProtoMajor).To(Equal(2)) 27 | Expect(req.ProtoMinor).To(Equal(0)) 28 | Expect(req.ContentLength).To(Equal(int64(42))) 29 | Expect(req.Header).To(BeEmpty()) 30 | Expect(req.Body).To(BeNil()) 31 | Expect(req.Host).To(Equal("quic.clemente.io")) 32 | Expect(req.RequestURI).To(Equal("/foo")) 33 | Expect(req.TLS).ToNot(BeNil()) 34 | }) 35 | 36 | It("concatenates the cookie headers", func() { 37 | headers := []hpack.HeaderField{ 38 | {Name: ":path", Value: "/foo"}, 39 | {Name: ":authority", Value: "quic.clemente.io"}, 40 | {Name: ":method", Value: "GET"}, 41 | {Name: "cookie", Value: "cookie1=foobar1"}, 42 | {Name: "cookie", Value: "cookie2=foobar2"}, 43 | } 44 | req, err := requestFromHeaders(headers) 45 | Expect(err).NotTo(HaveOccurred()) 46 | Expect(req.Header).To(Equal(http.Header{ 47 | "Cookie": []string{"cookie1=foobar1; cookie2=foobar2"}, 48 | })) 49 | }) 50 | 51 | It("handles other headers", func() { 52 | headers := []hpack.HeaderField{ 53 | {Name: ":path", Value: "/foo"}, 54 | {Name: ":authority", Value: "quic.clemente.io"}, 55 | {Name: ":method", Value: "GET"}, 56 | {Name: "cache-control", Value: "max-age=0"}, 57 | {Name: "duplicate-header", Value: "1"}, 58 | {Name: "duplicate-header", Value: "2"}, 59 | } 60 | req, err := requestFromHeaders(headers) 61 | Expect(err).NotTo(HaveOccurred()) 62 | Expect(req.Header).To(Equal(http.Header{ 63 | "Cache-Control": []string{"max-age=0"}, 64 | "Duplicate-Header": []string{"1", "2"}, 65 | })) 66 | }) 67 | 68 | It("errors with missing path", func() { 69 | headers := []hpack.HeaderField{ 70 | {Name: ":authority", Value: "quic.clemente.io"}, 71 | {Name: ":method", Value: "GET"}, 72 | } 73 | _, err := requestFromHeaders(headers) 74 | Expect(err).To(MatchError(":path, :authority and :method must not be empty")) 75 | }) 76 | 77 | It("errors with missing method", func() { 78 | headers := []hpack.HeaderField{ 79 | {Name: ":path", Value: "/foo"}, 80 | {Name: ":authority", Value: "quic.clemente.io"}, 81 | } 82 | _, err := requestFromHeaders(headers) 83 | Expect(err).To(MatchError(":path, :authority and :method must not be empty")) 84 | }) 85 | 86 | It("errors with missing authority", func() { 87 | headers := []hpack.HeaderField{ 88 | {Name: ":path", Value: "/foo"}, 89 | {Name: ":method", Value: "GET"}, 90 | } 91 | _, err := requestFromHeaders(headers) 92 | Expect(err).To(MatchError(":path, :authority and :method must not be empty")) 93 | }) 94 | 95 | Context("extracting the hostname from a request", func() { 96 | var url *url.URL 97 | 98 | BeforeEach(func() { 99 | var err error 100 | url, err = url.Parse("https://quic.clemente.io:1337") 101 | Expect(err).ToNot(HaveOccurred()) 102 | }) 103 | 104 | It("uses req.Host if available", func() { 105 | req := &http.Request{ 106 | Host: "www.example.org", 107 | URL: url, 108 | } 109 | Expect(hostnameFromRequest(req)).To(Equal("www.example.org")) 110 | }) 111 | 112 | It("uses req.URL.Host if req.Host is not set", func() { 113 | req := &http.Request{URL: url} 114 | Expect(hostnameFromRequest(req)).To(Equal("quic.clemente.io:1337")) 115 | }) 116 | 117 | It("returns an empty hostname if nothing is set", func() { 118 | Expect(hostnameFromRequest(&http.Request{})).To(BeEmpty()) 119 | }) 120 | }) 121 | }) 122 | -------------------------------------------------------------------------------- /h2quic/request_writer_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/url" 7 | "strconv" 8 | "strings" 9 | 10 | "golang.org/x/net/http2" 11 | "golang.org/x/net/http2/hpack" 12 | 13 | . "github.com/onsi/ginkgo" 14 | . "github.com/onsi/gomega" 15 | ) 16 | 17 | var _ = Describe("Request", func() { 18 | var ( 19 | rw *requestWriter 20 | headerStream *mockStream 21 | decoder *hpack.Decoder 22 | ) 23 | 24 | BeforeEach(func() { 25 | headerStream = &mockStream{} 26 | rw = newRequestWriter(headerStream) 27 | decoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) 28 | }) 29 | 30 | decode := func(p []byte) (*http2.HeadersFrame, map[string] /* HeaderField.Name */ string /* HeaderField.Value */) { 31 | framer := http2.NewFramer(nil, bytes.NewReader(p)) 32 | frame, err := framer.ReadFrame() 33 | Expect(err).ToNot(HaveOccurred()) 34 | headerFrame := frame.(*http2.HeadersFrame) 35 | fields, err := decoder.DecodeFull(headerFrame.HeaderBlockFragment()) 36 | Expect(err).ToNot(HaveOccurred()) 37 | values := make(map[string]string) 38 | for _, headerField := range fields { 39 | values[headerField.Name] = headerField.Value 40 | } 41 | return headerFrame, values 42 | } 43 | 44 | It("writes a GET request", func() { 45 | req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil) 46 | Expect(err).ToNot(HaveOccurred()) 47 | rw.WriteRequest(req, 1337, true, false) 48 | headerFrame, headerFields := decode(headerStream.dataWritten.Bytes()) 49 | Expect(headerFrame.StreamID).To(Equal(uint32(1337))) 50 | Expect(headerFrame.HasPriority()).To(BeTrue()) 51 | Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) 52 | Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) 53 | Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar")) 54 | Expect(headerFields).To(HaveKeyWithValue(":scheme", "https")) 55 | Expect(headerFields).ToNot(HaveKey("accept-encoding")) 56 | }) 57 | 58 | It("sets the EndStream header", func() { 59 | req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil) 60 | Expect(err).ToNot(HaveOccurred()) 61 | rw.WriteRequest(req, 1337, true, false) 62 | headerFrame, _ := decode(headerStream.dataWritten.Bytes()) 63 | Expect(headerFrame.StreamEnded()).To(BeTrue()) 64 | }) 65 | 66 | It("doesn't set the EndStream header, if requested", func() { 67 | req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil) 68 | Expect(err).ToNot(HaveOccurred()) 69 | rw.WriteRequest(req, 1337, false, false) 70 | headerFrame, _ := decode(headerStream.dataWritten.Bytes()) 71 | Expect(headerFrame.StreamEnded()).To(BeFalse()) 72 | }) 73 | 74 | It("requests gzip compression, if requested", func() { 75 | req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil) 76 | Expect(err).ToNot(HaveOccurred()) 77 | rw.WriteRequest(req, 1337, true, true) 78 | _, headerFields := decode(headerStream.dataWritten.Bytes()) 79 | Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) 80 | }) 81 | 82 | It("writes a POST request", func() { 83 | form := url.Values{} 84 | form.Add("foo", "bar") 85 | req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode())) 86 | Expect(err).ToNot(HaveOccurred()) 87 | rw.WriteRequest(req, 5, true, false) 88 | _, headerFields := decode(headerStream.dataWritten.Bytes()) 89 | Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) 90 | Expect(headerFields).To(HaveKey("content-length")) 91 | contentLength, err := strconv.Atoi(headerFields["content-length"]) 92 | Expect(err).ToNot(HaveOccurred()) 93 | Expect(contentLength).To(BeNumerically(">", 0)) 94 | }) 95 | 96 | It("sends cookies", func() { 97 | req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil) 98 | Expect(err).ToNot(HaveOccurred()) 99 | cookie1 := &http.Cookie{ 100 | Name: "Cookie #1", 101 | Value: "Value #1", 102 | } 103 | cookie2 := &http.Cookie{ 104 | Name: "Cookie #2", 105 | Value: "Value #2", 106 | } 107 | req.AddCookie(cookie1) 108 | req.AddCookie(cookie2) 109 | rw.WriteRequest(req, 11, true, false) 110 | _, headerFields := decode(headerStream.dataWritten.Bytes()) 111 | // TODO(lclemente): Remove Or() once we drop support for Go 1.8. 112 | Expect(headerFields).To(Or( 113 | HaveKeyWithValue("cookie", "Cookie #1=Value #1; Cookie #2=Value #2"), 114 | HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`), 115 | )) 116 | }) 117 | }) 118 | -------------------------------------------------------------------------------- /h2quic/roundtrip.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strings" 10 | "sync" 11 | 12 | quic "github.com/lucas-clemente/quic-go" 13 | 14 | "golang.org/x/net/lex/httplex" 15 | ) 16 | 17 | type roundTripCloser interface { 18 | http.RoundTripper 19 | io.Closer 20 | } 21 | 22 | // RoundTripper implements the http.RoundTripper interface 23 | type RoundTripper struct { 24 | mutex sync.Mutex 25 | 26 | // DisableCompression, if true, prevents the Transport from 27 | // requesting compression with an "Accept-Encoding: gzip" 28 | // request header when the Request contains no existing 29 | // Accept-Encoding value. If the Transport requests gzip on 30 | // its own and gets a gzipped response, it's transparently 31 | // decoded in the Response.Body. However, if the user 32 | // explicitly requested gzip it is not automatically 33 | // uncompressed. 34 | DisableCompression bool 35 | 36 | // TLSClientConfig specifies the TLS configuration to use with 37 | // tls.Client. If nil, the default configuration is used. 38 | TLSClientConfig *tls.Config 39 | 40 | // QuicConfig is the quic.Config used for dialing new connections. 41 | // If nil, reasonable default values will be used. 42 | QuicConfig *quic.Config 43 | 44 | clients map[string]roundTripCloser 45 | } 46 | 47 | // RoundTripOpt are options for the Transport.RoundTripOpt method. 48 | type RoundTripOpt struct { 49 | // OnlyCachedConn controls whether the RoundTripper may 50 | // create a new QUIC connection. If set true and 51 | // no cached connection is available, RoundTrip 52 | // will return ErrNoCachedConn. 53 | OnlyCachedConn bool 54 | } 55 | 56 | var _ roundTripCloser = &RoundTripper{} 57 | 58 | // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set 59 | var ErrNoCachedConn = errors.New("h2quic: no cached connection was available") 60 | 61 | // RoundTripOpt is like RoundTrip, but takes options. 62 | func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 63 | if req.URL == nil { 64 | closeRequestBody(req) 65 | return nil, errors.New("quic: nil Request.URL") 66 | } 67 | if req.URL.Host == "" { 68 | closeRequestBody(req) 69 | return nil, errors.New("quic: no Host in request URL") 70 | } 71 | if req.Header == nil { 72 | closeRequestBody(req) 73 | return nil, errors.New("quic: nil Request.Header") 74 | } 75 | 76 | if req.URL.Scheme == "https" { 77 | for k, vv := range req.Header { 78 | if !httplex.ValidHeaderFieldName(k) { 79 | return nil, fmt.Errorf("quic: invalid http header field name %q", k) 80 | } 81 | for _, v := range vv { 82 | if !httplex.ValidHeaderFieldValue(v) { 83 | return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k) 84 | } 85 | } 86 | } 87 | } else { 88 | closeRequestBody(req) 89 | return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme) 90 | } 91 | 92 | if req.Method != "" && !validMethod(req.Method) { 93 | closeRequestBody(req) 94 | return nil, fmt.Errorf("quic: invalid method %q", req.Method) 95 | } 96 | 97 | hostname := authorityAddr("https", hostnameFromRequest(req)) 98 | cl, err := r.getClient(hostname, opt.OnlyCachedConn) 99 | if err != nil { 100 | return nil, err 101 | } 102 | return cl.RoundTrip(req) 103 | } 104 | 105 | // RoundTrip does a round trip. 106 | func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 107 | return r.RoundTripOpt(req, RoundTripOpt{}) 108 | } 109 | 110 | func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) { 111 | r.mutex.Lock() 112 | defer r.mutex.Unlock() 113 | 114 | if r.clients == nil { 115 | r.clients = make(map[string]roundTripCloser) 116 | } 117 | 118 | client, ok := r.clients[hostname] 119 | if !ok { 120 | if onlyCached { 121 | return nil, ErrNoCachedConn 122 | } 123 | client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig) 124 | r.clients[hostname] = client 125 | } 126 | return client, nil 127 | } 128 | 129 | // Close closes the QUIC connections that this RoundTripper has used 130 | func (r *RoundTripper) Close() error { 131 | r.mutex.Lock() 132 | defer r.mutex.Unlock() 133 | for _, client := range r.clients { 134 | if err := client.Close(); err != nil { 135 | return err 136 | } 137 | } 138 | r.clients = nil 139 | return nil 140 | } 141 | 142 | func closeRequestBody(req *http.Request) { 143 | if req.Body != nil { 144 | req.Body.Close() 145 | } 146 | } 147 | 148 | func validMethod(method string) bool { 149 | /* 150 | Method = "OPTIONS" ; Section 9.2 151 | | "GET" ; Section 9.3 152 | | "HEAD" ; Section 9.4 153 | | "POST" ; Section 9.5 154 | | "PUT" ; Section 9.6 155 | | "DELETE" ; Section 9.7 156 | | "TRACE" ; Section 9.8 157 | | "CONNECT" ; Section 9.9 158 | | extension-method 159 | extension-method = token 160 | token = 1* 161 | */ 162 | return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 163 | } 164 | 165 | // copied from net/http/http.go 166 | func isNotToken(r rune) bool { 167 | return !httplex.IsTokenRune(r) 168 | } 169 | -------------------------------------------------------------------------------- /h2quic/response_writer_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | "net/http" 8 | "sync" 9 | "time" 10 | 11 | "golang.org/x/net/http2" 12 | "golang.org/x/net/http2/hpack" 13 | 14 | quic "github.com/lucas-clemente/quic-go" 15 | . "github.com/onsi/ginkgo" 16 | . "github.com/onsi/gomega" 17 | ) 18 | 19 | type mockStream struct { 20 | id quic.StreamID 21 | dataToRead bytes.Buffer 22 | dataWritten bytes.Buffer 23 | reset bool 24 | closed bool 25 | remoteClosed bool 26 | 27 | unblockRead chan struct{} 28 | ctx context.Context 29 | ctxCancel context.CancelFunc 30 | } 31 | 32 | func newMockStream(id quic.StreamID) *mockStream { 33 | s := &mockStream{ 34 | id: id, 35 | unblockRead: make(chan struct{}), 36 | } 37 | s.ctx, s.ctxCancel = context.WithCancel(context.Background()) 38 | return s 39 | } 40 | 41 | func (s *mockStream) Close() error { s.closed = true; s.ctxCancel(); return nil } 42 | func (s *mockStream) Reset(error) { s.reset = true } 43 | func (s *mockStream) CloseRemote(offset uint64) { s.remoteClosed = true; s.ctxCancel() } 44 | func (s mockStream) StreamID() quic.StreamID { return s.id } 45 | func (s *mockStream) Context() context.Context { return s.ctx } 46 | func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") } 47 | func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") } 48 | func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") } 49 | 50 | func (s *mockStream) Read(p []byte) (int, error) { 51 | n, _ := s.dataToRead.Read(p) 52 | if n == 0 { // block if there's no data 53 | <-s.unblockRead 54 | return 0, io.EOF 55 | } 56 | return n, nil // never return an EOF 57 | } 58 | func (s *mockStream) Write(p []byte) (int, error) { return s.dataWritten.Write(p) } 59 | 60 | var _ = Describe("Response Writer", func() { 61 | var ( 62 | w *responseWriter 63 | headerStream *mockStream 64 | dataStream *mockStream 65 | ) 66 | 67 | BeforeEach(func() { 68 | headerStream = &mockStream{} 69 | dataStream = &mockStream{} 70 | w = newResponseWriter(headerStream, &sync.Mutex{}, dataStream, 5) 71 | }) 72 | 73 | decodeHeaderFields := func() map[string][]string { 74 | fields := make(map[string][]string) 75 | decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) 76 | h2framer := http2.NewFramer(nil, bytes.NewReader(headerStream.dataWritten.Bytes())) 77 | 78 | frame, err := h2framer.ReadFrame() 79 | Expect(err).ToNot(HaveOccurred()) 80 | Expect(frame).To(BeAssignableToTypeOf(&http2.HeadersFrame{})) 81 | hframe := frame.(*http2.HeadersFrame) 82 | mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} 83 | Expect(mhframe.StreamID).To(BeEquivalentTo(5)) 84 | mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) 85 | Expect(err).ToNot(HaveOccurred()) 86 | for _, p := range mhframe.Fields { 87 | fields[p.Name] = append(fields[p.Name], p.Value) 88 | } 89 | return fields 90 | } 91 | 92 | It("writes status", func() { 93 | w.WriteHeader(http.StatusTeapot) 94 | fields := decodeHeaderFields() 95 | Expect(fields).To(HaveLen(1)) 96 | Expect(fields).To(HaveKeyWithValue(":status", []string{"418"})) 97 | }) 98 | 99 | It("writes headers", func() { 100 | w.Header().Add("content-length", "42") 101 | w.WriteHeader(http.StatusTeapot) 102 | fields := decodeHeaderFields() 103 | Expect(fields).To(HaveKeyWithValue("content-length", []string{"42"})) 104 | }) 105 | 106 | It("writes multiple headers with the same name", func() { 107 | const cookie1 = "test1=1; Max-Age=7200; path=/" 108 | const cookie2 = "test2=2; Max-Age=7200; path=/" 109 | w.Header().Add("set-cookie", cookie1) 110 | w.Header().Add("set-cookie", cookie2) 111 | w.WriteHeader(http.StatusTeapot) 112 | fields := decodeHeaderFields() 113 | Expect(fields).To(HaveKey("set-cookie")) 114 | cookies := fields["set-cookie"] 115 | Expect(cookies).To(ContainElement(cookie1)) 116 | Expect(cookies).To(ContainElement(cookie2)) 117 | }) 118 | 119 | It("writes data", func() { 120 | n, err := w.Write([]byte("foobar")) 121 | Expect(n).To(Equal(6)) 122 | Expect(err).ToNot(HaveOccurred()) 123 | // Should have written 200 on the header stream 124 | fields := decodeHeaderFields() 125 | Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) 126 | // And foobar on the data stream 127 | Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar"))) 128 | }) 129 | 130 | It("writes data after WriteHeader is called", func() { 131 | w.WriteHeader(http.StatusTeapot) 132 | n, err := w.Write([]byte("foobar")) 133 | Expect(n).To(Equal(6)) 134 | Expect(err).ToNot(HaveOccurred()) 135 | // Should have written 418 on the header stream 136 | fields := decodeHeaderFields() 137 | Expect(fields).To(HaveKeyWithValue(":status", []string{"418"})) 138 | // And foobar on the data stream 139 | Expect(dataStream.dataWritten.Bytes()).To(Equal([]byte("foobar"))) 140 | }) 141 | 142 | It("does not WriteHeader() twice", func() { 143 | w.WriteHeader(200) 144 | w.WriteHeader(500) 145 | fields := decodeHeaderFields() 146 | Expect(fields).To(HaveLen(1)) 147 | Expect(fields).To(HaveKeyWithValue(":status", []string{"200"})) 148 | }) 149 | 150 | It("doesn't allow writes if the status code doesn't allow a body", func() { 151 | w.WriteHeader(304) 152 | n, err := w.Write([]byte("foobar")) 153 | Expect(n).To(BeZero()) 154 | Expect(err).To(MatchError(http.ErrBodyNotAllowed)) 155 | Expect(dataStream.dataWritten.Bytes()).To(HaveLen(0)) 156 | }) 157 | }) 158 | -------------------------------------------------------------------------------- /h2quic/request_writer.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "strconv" 8 | "strings" 9 | "sync" 10 | 11 | "golang.org/x/net/http2" 12 | "golang.org/x/net/http2/hpack" 13 | "golang.org/x/net/lex/httplex" 14 | 15 | quic "github.com/lucas-clemente/quic-go" 16 | ) 17 | 18 | type requestWriter struct { 19 | mutex sync.Mutex 20 | headerStream quic.Stream 21 | 22 | henc *hpack.Encoder 23 | hbuf bytes.Buffer // HPACK encoder writes into this 24 | } 25 | 26 | const defaultUserAgent = "quic-go" 27 | 28 | func newRequestWriter(headerStream quic.Stream) *requestWriter { 29 | rw := &requestWriter{ 30 | headerStream: headerStream, 31 | } 32 | rw.henc = hpack.NewEncoder(&rw.hbuf) 33 | return rw 34 | } 35 | 36 | func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID quic.StreamID, endStream, requestGzip bool) error { 37 | // TODO: add support for trailers 38 | // TODO: add support for gzip compression 39 | // TODO: write continuation frames, if the header frame is too long 40 | 41 | w.mutex.Lock() 42 | defer w.mutex.Unlock() 43 | 44 | w.encodeHeaders(req, requestGzip, "", actualContentLength(req)) 45 | h2framer := http2.NewFramer(w.headerStream, nil) 46 | return h2framer.WriteHeaders(http2.HeadersFrameParam{ 47 | StreamID: uint32(dataStreamID), 48 | EndHeaders: true, 49 | EndStream: endStream, 50 | BlockFragment: w.hbuf.Bytes(), 51 | Priority: http2.PriorityParam{Weight: 0xff}, 52 | }) 53 | } 54 | 55 | // the rest of this files is copied from http2.Transport 56 | func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { 57 | w.hbuf.Reset() 58 | 59 | host := req.Host 60 | if host == "" { 61 | host = req.URL.Host 62 | } 63 | host, err := httplex.PunycodeHostPort(host) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | var path string 69 | if req.Method != "CONNECT" { 70 | path = req.URL.RequestURI() 71 | if !validPseudoPath(path) { 72 | orig := path 73 | path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) 74 | if !validPseudoPath(path) { 75 | if req.URL.Opaque != "" { 76 | return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) 77 | } else { 78 | return nil, fmt.Errorf("invalid request :path %q", orig) 79 | } 80 | } 81 | } 82 | } 83 | 84 | // Check for any invalid headers and return an error before we 85 | // potentially pollute our hpack state. (We want to be able to 86 | // continue to reuse the hpack encoder for future requests) 87 | for k, vv := range req.Header { 88 | if !httplex.ValidHeaderFieldName(k) { 89 | return nil, fmt.Errorf("invalid HTTP header name %q", k) 90 | } 91 | for _, v := range vv { 92 | if !httplex.ValidHeaderFieldValue(v) { 93 | return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) 94 | } 95 | } 96 | } 97 | 98 | // 8.1.2.3 Request Pseudo-Header Fields 99 | // The :path pseudo-header field includes the path and query parts of the 100 | // target URI (the path-absolute production and optionally a '?' character 101 | // followed by the query production (see Sections 3.3 and 3.4 of 102 | // [RFC3986]). 103 | w.writeHeader(":authority", host) 104 | w.writeHeader(":method", req.Method) 105 | if req.Method != "CONNECT" { 106 | w.writeHeader(":path", path) 107 | w.writeHeader(":scheme", req.URL.Scheme) 108 | } 109 | if trailers != "" { 110 | w.writeHeader("trailer", trailers) 111 | } 112 | 113 | var didUA bool 114 | for k, vv := range req.Header { 115 | lowKey := strings.ToLower(k) 116 | switch lowKey { 117 | case "host", "content-length": 118 | // Host is :authority, already sent. 119 | // Content-Length is automatic, set below. 120 | continue 121 | case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": 122 | // Per 8.1.2.2 Connection-Specific Header 123 | // Fields, don't send connection-specific 124 | // fields. We have already checked if any 125 | // are error-worthy so just ignore the rest. 126 | continue 127 | case "user-agent": 128 | // Match Go's http1 behavior: at most one 129 | // User-Agent. If set to nil or empty string, 130 | // then omit it. Otherwise if not mentioned, 131 | // include the default (below). 132 | didUA = true 133 | if len(vv) < 1 { 134 | continue 135 | } 136 | vv = vv[:1] 137 | if vv[0] == "" { 138 | continue 139 | } 140 | } 141 | for _, v := range vv { 142 | w.writeHeader(lowKey, v) 143 | } 144 | } 145 | if shouldSendReqContentLength(req.Method, contentLength) { 146 | w.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) 147 | } 148 | if addGzipHeader { 149 | w.writeHeader("accept-encoding", "gzip") 150 | } 151 | if !didUA { 152 | w.writeHeader("user-agent", defaultUserAgent) 153 | } 154 | return w.hbuf.Bytes(), nil 155 | } 156 | 157 | func (w *requestWriter) writeHeader(name, value string) { 158 | w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) 159 | } 160 | 161 | // shouldSendReqContentLength reports whether the http2.Transport should send 162 | // a "content-length" request header. This logic is basically a copy of the net/http 163 | // transferWriter.shouldSendContentLength. 164 | // The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). 165 | // -1 means unknown. 166 | func shouldSendReqContentLength(method string, contentLength int64) bool { 167 | if contentLength > 0 { 168 | return true 169 | } 170 | if contentLength < 0 { 171 | return false 172 | } 173 | // For zero bodies, whether we send a content-length depends on the method. 174 | // It also kinda doesn't matter for http2 either way, with END_STREAM. 175 | switch method { 176 | case "POST", "PUT", "PATCH": 177 | return true 178 | default: 179 | return false 180 | } 181 | } 182 | 183 | func validPseudoPath(v string) bool { 184 | return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" 185 | } 186 | 187 | // actualContentLength returns a sanitized version of 188 | // req.ContentLength, where 0 actually means zero (not unknown) and -1 189 | // means unknown. 190 | func actualContentLength(req *http.Request) int64 { 191 | if req.Body == nil { 192 | return 0 193 | } 194 | if req.ContentLength != 0 { 195 | return req.ContentLength 196 | } 197 | return -1 198 | } 199 | -------------------------------------------------------------------------------- /internal/socks/socks.go: -------------------------------------------------------------------------------- 1 | package socks 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "io" 7 | "net" 8 | "strconv" 9 | ) 10 | 11 | // See https://www.ietf.org/rfc/rfc1928.txt 12 | 13 | const socksVersion = 5 14 | 15 | // Commands 16 | const ( 17 | CmdConnect = 1 18 | CmdBind = 2 19 | CmdAssociate = 3 20 | ) 21 | 22 | // Address types 23 | const ( 24 | AtypIPv4 = 1 25 | AtypDomain = 3 26 | AtypIPv6 = 4 27 | ) 28 | 29 | // Auth Methods 30 | const ( 31 | AuthNoAuthenticationRequired = 0x00 32 | AuthNoAcceptableMethod = 0xFF 33 | ) 34 | 35 | // Status 36 | const ( 37 | StatusSucceeded = 0 38 | StatusGeneralFailure = 1 39 | StatusConnectionNotAllowed = 2 40 | StatusNetworkUnreachable = 3 41 | StatusHostUnreachable = 4 42 | StatusConnectionRefused = 5 43 | StatusTtlExpired = 6 44 | StatusCmdNotSupported = 7 45 | StatusAddrNotSupported = 8 46 | ) 47 | 48 | // Errors 49 | var ( 50 | ErrNoAuth = errors.New("could not authenticate SOCKS connection") 51 | ErrAtypNotSupported = errors.New("address type is not supported") 52 | ) 53 | 54 | func Auth(rd *bufio.Reader, w io.Writer) error { 55 | // 1 version 56 | // 1 nmethods 57 | // 1 method[nmethods] (we only read 1 at a time) 58 | var header [3]byte 59 | if _, err := io.ReadFull(rd, header[:]); err != nil { 60 | return err 61 | } 62 | 63 | // check SOCKS version 64 | if clVersion := header[0]; clVersion != socksVersion { 65 | return errors.New("incompatible SOCKS version: " + 66 | strconv.FormatUint(uint64(clVersion), 10)) 67 | } 68 | 69 | // check auth 70 | // currently only NoAuthenticationRequired is supported 71 | acceptableAuth := false 72 | if nMethods := header[1]; nMethods > 0 { 73 | if method := header[2]; method == AuthNoAuthenticationRequired { 74 | acceptableAuth = true 75 | } 76 | for n := uint8(1); n < nMethods; n++ { 77 | // if we already have an acceptable auth method, we can skip all 78 | if acceptableAuth { 79 | if _, err := rd.Discard(int(nMethods - n)); err != nil { 80 | return err 81 | } 82 | break 83 | } 84 | 85 | // keep checking until we find an acceptable auth method 86 | method, err := rd.ReadByte() 87 | if err != nil { 88 | return err 89 | } 90 | if method == AuthNoAuthenticationRequired { 91 | acceptableAuth = true 92 | } 93 | } 94 | } 95 | 96 | // send auth method selection to client 97 | if !acceptableAuth { 98 | w.Write([]byte{socksVersion, AuthNoAcceptableMethod}) 99 | return ErrNoAuth 100 | } 101 | _, err := w.Write([]byte{socksVersion, AuthNoAuthenticationRequired}) 102 | return err 103 | } 104 | 105 | type Request []byte 106 | 107 | // PeekRequest peeks 108 | func PeekRequest(rd *bufio.Reader) (Request, error) { 109 | // 1 version 110 | // 1 command 111 | // 1 reserved 112 | // 1 atyp 113 | header, err := rd.Peek(4) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | // check SOCKS version 119 | if clVersion := header[0]; clVersion != socksVersion { 120 | return nil, errors.New("incompatible SOCKS version: " + 121 | strconv.FormatUint(uint64(clVersion), 10)) 122 | } 123 | 124 | // read address (IPv4, IPv6 or Domain) 125 | const addrStart = 4 126 | atyp := header[3] 127 | switch atyp { 128 | case AtypIPv4: 129 | // read IPv4 address + port 130 | buf, err := rd.Peek(addrStart + net.IPv4len + 2) 131 | return Request(buf), err 132 | case AtypDomain: 133 | header, err = rd.Peek(addrStart + 1) 134 | if err != nil { 135 | return nil, err 136 | } 137 | domainLen := int(header[4]) 138 | 139 | // read domain name + port 140 | buf, err := rd.Peek(addrStart + 1 + domainLen + 2) 141 | return Request(buf), err 142 | case AtypIPv6: 143 | // read IPv6 address + port 144 | buf, err := rd.Peek(addrStart + net.IPv6len + 2) 145 | return Request(buf), err 146 | default: 147 | return nil, ErrAtypNotSupported 148 | } 149 | } 150 | 151 | func (r Request) Cmd() byte { 152 | return r[1] 153 | } 154 | 155 | func (r Request) Dest() Addr { 156 | return Addr(r[3:]) 157 | } 158 | 159 | // Addr is a pair of IPv4, IPv6 or Domain and a port 160 | type Addr []byte 161 | 162 | // Type returns the address type 163 | func (a Addr) Type() byte { 164 | return a[0] 165 | } 166 | 167 | // Port returns the port of the address 168 | func (a Addr) Port() int { 169 | var i = len(a) - 2 170 | return (int(a[i]) << 8) | int(a[i+1]) 171 | } 172 | 173 | // String formats the address as a host:port string 174 | func (a Addr) String() string { 175 | var host string 176 | switch a.Type() { 177 | case AtypIPv4, AtypIPv6: 178 | host = (net.IP(a[1 : len(a)-2])).String() 179 | case AtypDomain: 180 | host = string(a[2 : len(a)-2]) 181 | default: 182 | return "" 183 | } 184 | return net.JoinHostPort(host, strconv.Itoa(a.Port())) 185 | } 186 | 187 | // TODO: allow to pass buffer or writer 188 | func NewIPAddr(ip net.IP, port int) Addr { 189 | port1 := byte(port >> 8) 190 | port2 := byte(port & 0xff) 191 | if ip4 := ip.To4(); ip4 != nil { 192 | return Addr{AtypIPv4, 193 | ip4[0], ip4[1], ip4[2], ip4[3], 194 | port1, port2, 195 | } 196 | } 197 | if ip16 := ip.To16(); ip16 != nil { 198 | return Addr{AtypIPv6, 199 | ip16[0], ip16[1], ip16[2], ip16[3], 200 | ip16[4], ip16[5], ip16[6], ip16[7], 201 | ip16[8], ip16[9], ip16[10], ip16[11], 202 | ip16[12], ip16[13], ip16[14], ip16[15], 203 | port1, port2, 204 | } 205 | } 206 | return nil 207 | } 208 | 209 | func SendReply(wr io.Writer, status byte, addr Addr) error { 210 | // buffer to avoid allocations in the common cases 211 | var buf [64]byte 212 | reply := buf[:] 213 | if len(addr)+3 > cap(buf) { 214 | reply = make([]byte, len(addr)+3) 215 | } 216 | 217 | // 1 ver 218 | reply[0] = socksVersion 219 | 220 | // 1 rep 221 | reply[1] = status 222 | 223 | // 1 reserved 224 | 225 | if addr == nil { 226 | reply = reply[:4+net.IPv4len+2] 227 | 228 | // reply[3] = AtypDomain 229 | // reply[4] = 0 230 | 231 | // 1 address type 232 | reply[3] = AtypIPv4 233 | 234 | // 4 IPv4 235 | reply[4] = 0 236 | reply[5] = 0 237 | reply[6] = 0 238 | reply[7] = 0 239 | 240 | // 2 port 241 | reply[8] = 0 242 | reply[9] = 0 243 | } else { 244 | reply = reply[:3+len(addr)] 245 | copy(reply[3:], addr) 246 | } 247 | 248 | // write reply 249 | _, err := wr.Write(reply) 250 | return err 251 | } 252 | 253 | func HandleRequest(req *Request) { 254 | 255 | } 256 | -------------------------------------------------------------------------------- /h2quic/roundtrip_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "crypto/tls" 6 | "errors" 7 | "io" 8 | "net/http" 9 | "time" 10 | 11 | quic "github.com/lucas-clemente/quic-go" 12 | . "github.com/onsi/ginkgo" 13 | . "github.com/onsi/gomega" 14 | ) 15 | 16 | type mockClient struct { 17 | closed bool 18 | } 19 | 20 | func (m *mockClient) RoundTrip(req *http.Request) (*http.Response, error) { 21 | return &http.Response{Request: req}, nil 22 | } 23 | func (m *mockClient) Close() error { 24 | m.closed = true 25 | return nil 26 | } 27 | 28 | var _ roundTripCloser = &mockClient{} 29 | 30 | type mockBody struct { 31 | reader bytes.Reader 32 | readErr error 33 | closeErr error 34 | closed bool 35 | } 36 | 37 | func (m *mockBody) Read(p []byte) (int, error) { 38 | if m.readErr != nil { 39 | return 0, m.readErr 40 | } 41 | return m.reader.Read(p) 42 | } 43 | 44 | func (m *mockBody) SetData(data []byte) { 45 | m.reader = *bytes.NewReader(data) 46 | } 47 | 48 | func (m *mockBody) Close() error { 49 | m.closed = true 50 | return m.closeErr 51 | } 52 | 53 | // make sure the mockBody can be used as a http.Request.Body 54 | var _ io.ReadCloser = &mockBody{} 55 | 56 | var _ = Describe("RoundTripper", func() { 57 | var ( 58 | rt *RoundTripper 59 | req1 *http.Request 60 | ) 61 | 62 | BeforeEach(func() { 63 | rt = &RoundTripper{} 64 | var err error 65 | req1, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) 66 | Expect(err).ToNot(HaveOccurred()) 67 | }) 68 | 69 | Context("dialing hosts", func() { 70 | origDialAddr := dialAddr 71 | streamOpenErr := errors.New("error opening stream") 72 | 73 | BeforeEach(func() { 74 | origDialAddr = dialAddr 75 | dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { 76 | // return an error when trying to open a stream 77 | // we don't want to test all the dial logic here, just that dialing happens at all 78 | return &mockSession{streamOpenErr: streamOpenErr}, nil 79 | } 80 | }) 81 | 82 | AfterEach(func() { 83 | dialAddr = origDialAddr 84 | }) 85 | 86 | It("creates new clients", func() { 87 | req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) 88 | Expect(err).ToNot(HaveOccurred()) 89 | _, err = rt.RoundTrip(req) 90 | Expect(err).To(MatchError(streamOpenErr)) 91 | Expect(rt.clients).To(HaveLen(1)) 92 | }) 93 | 94 | It("uses the quic.Config, if provided", func() { 95 | config := &quic.Config{HandshakeTimeout: time.Millisecond} 96 | var receivedConfig *quic.Config 97 | dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) { 98 | receivedConfig = config 99 | return nil, errors.New("err") 100 | } 101 | rt.QuicConfig = config 102 | rt.RoundTrip(req1) 103 | Expect(receivedConfig).To(Equal(config)) 104 | }) 105 | 106 | It("reuses existing clients", func() { 107 | req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) 108 | Expect(err).ToNot(HaveOccurred()) 109 | _, err = rt.RoundTrip(req) 110 | Expect(err).To(MatchError(streamOpenErr)) 111 | Expect(rt.clients).To(HaveLen(1)) 112 | req2, err := http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) 113 | Expect(err).ToNot(HaveOccurred()) 114 | _, err = rt.RoundTrip(req2) 115 | Expect(err).To(MatchError(streamOpenErr)) 116 | Expect(rt.clients).To(HaveLen(1)) 117 | }) 118 | 119 | It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { 120 | req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) 121 | Expect(err).ToNot(HaveOccurred()) 122 | _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) 123 | Expect(err).To(MatchError(ErrNoCachedConn)) 124 | }) 125 | }) 126 | 127 | Context("validating request", func() { 128 | It("rejects plain HTTP requests", func() { 129 | req, err := http.NewRequest("GET", "http://www.example.org/", nil) 130 | req.Body = &mockBody{} 131 | Expect(err).ToNot(HaveOccurred()) 132 | _, err = rt.RoundTrip(req) 133 | Expect(err).To(MatchError("quic: unsupported protocol scheme: http")) 134 | Expect(req.Body.(*mockBody).closed).To(BeTrue()) 135 | }) 136 | 137 | It("rejects requests without a URL", func() { 138 | req1.URL = nil 139 | req1.Body = &mockBody{} 140 | _, err := rt.RoundTrip(req1) 141 | Expect(err).To(MatchError("quic: nil Request.URL")) 142 | Expect(req1.Body.(*mockBody).closed).To(BeTrue()) 143 | }) 144 | 145 | It("rejects request without a URL Host", func() { 146 | req1.URL.Host = "" 147 | req1.Body = &mockBody{} 148 | _, err := rt.RoundTrip(req1) 149 | Expect(err).To(MatchError("quic: no Host in request URL")) 150 | Expect(req1.Body.(*mockBody).closed).To(BeTrue()) 151 | }) 152 | 153 | It("doesn't try to close the body if the request doesn't have one", func() { 154 | req1.URL = nil 155 | Expect(req1.Body).To(BeNil()) 156 | _, err := rt.RoundTrip(req1) 157 | Expect(err).To(MatchError("quic: nil Request.URL")) 158 | }) 159 | 160 | It("rejects requests without a header", func() { 161 | req1.Header = nil 162 | req1.Body = &mockBody{} 163 | _, err := rt.RoundTrip(req1) 164 | Expect(err).To(MatchError("quic: nil Request.Header")) 165 | Expect(req1.Body.(*mockBody).closed).To(BeTrue()) 166 | }) 167 | 168 | It("rejects requests with invalid header name fields", func() { 169 | req1.Header.Add("foobär", "value") 170 | _, err := rt.RoundTrip(req1) 171 | Expect(err).To(MatchError("quic: invalid http header field name \"foobär\"")) 172 | }) 173 | 174 | It("rejects requests with invalid header name values", func() { 175 | req1.Header.Add("foo", string([]byte{0x7})) 176 | _, err := rt.RoundTrip(req1) 177 | Expect(err.Error()).To(ContainSubstring("quic: invalid http header field value")) 178 | }) 179 | 180 | It("rejects requests with an invalid request method", func() { 181 | req1.Method = "foobär" 182 | req1.Body = &mockBody{} 183 | _, err := rt.RoundTrip(req1) 184 | Expect(err).To(MatchError("quic: invalid method \"foobär\"")) 185 | Expect(req1.Body.(*mockBody).closed).To(BeTrue()) 186 | }) 187 | }) 188 | 189 | Context("closing", func() { 190 | It("closes", func() { 191 | rt.clients = make(map[string]roundTripCloser) 192 | cl := &mockClient{} 193 | rt.clients["foo.bar"] = cl 194 | err := rt.Close() 195 | Expect(err).ToNot(HaveOccurred()) 196 | Expect(len(rt.clients)).To(BeZero()) 197 | Expect(cl.closed).To(BeTrue()) 198 | }) 199 | 200 | It("closes a RoundTripper that has never been used", func() { 201 | Expect(len(rt.clients)).To(BeZero()) 202 | err := rt.Close() 203 | Expect(err).ToNot(HaveOccurred()) 204 | Expect(len(rt.clients)).To(BeZero()) 205 | }) 206 | }) 207 | }) 208 | -------------------------------------------------------------------------------- /request_writer.go: -------------------------------------------------------------------------------- 1 | package quictun 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "fmt" 7 | "log" 8 | "net" 9 | "net/http" 10 | "strconv" 11 | "strings" 12 | 13 | "golang.org/x/net/http2" 14 | "golang.org/x/net/http2/hpack" 15 | "golang.org/x/net/idna" 16 | "golang.org/x/net/lex/httplex" 17 | 18 | quic "github.com/lucas-clemente/quic-go" 19 | ) 20 | 21 | // http://www.ietf.org/rfc/rfc2617.txt 22 | func basicAuth(username, password string) string { 23 | auth := username + ":" + password 24 | return base64.StdEncoding.EncodeToString([]byte(auth)) 25 | } 26 | 27 | // rest is mostly from http2.Transport 28 | 29 | // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) 30 | // and returns a host:port. The port 443 is added if needed. 31 | func authorityAddr(host, port string) (addr string) { 32 | if port == "" { 33 | port = "443" 34 | } 35 | if a, err := idna.ToASCII(host); err == nil { 36 | host = a 37 | } 38 | // IPv6 address literal, without a port: 39 | if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { 40 | return host + ":" + port 41 | } 42 | return net.JoinHostPort(host, port) 43 | } 44 | 45 | // shouldSendReqContentLength reports whether the http2.Transport should send 46 | // a "content-length" request header. This logic is basically a copy of the net/http 47 | // transferWriter.shouldSendContentLength. 48 | // The contentLength is the corrected contentLength (so 0 means actually 0, not unknown). 49 | // -1 means unknown. 50 | func shouldSendReqContentLength(method string, contentLength int64) bool { 51 | if contentLength > 0 { 52 | return true 53 | } 54 | if contentLength < 0 { 55 | return false 56 | } 57 | // For zero bodies, whether we send a content-length depends on the method. 58 | // It also kinda doesn't matter for http2 either way, with END_STREAM. 59 | switch method { 60 | case "POST", "PUT", "PATCH": 61 | return true 62 | default: 63 | return false 64 | } 65 | } 66 | 67 | func validPseudoPath(v string) bool { 68 | return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" 69 | } 70 | 71 | // actualContentLength returns a sanitized version of 72 | // req.ContentLength, where 0 actually means zero (not unknown) and -1 73 | // means unknown. 74 | func actualContentLength(req *http.Request) int64 { 75 | if req.Body == nil { 76 | return 0 77 | } 78 | if req.ContentLength != 0 { 79 | return req.ContentLength 80 | } 81 | return -1 82 | } 83 | 84 | type requestWriter struct { 85 | headerStream quic.Stream 86 | henc *hpack.Encoder 87 | hbuf bytes.Buffer // HPACK encoder writes into this 88 | } 89 | 90 | func newRequestWriter(headerStream quic.Stream) *requestWriter { 91 | rw := &requestWriter{ 92 | headerStream: headerStream, 93 | } 94 | rw.henc = hpack.NewEncoder(&rw.hbuf) 95 | return rw 96 | } 97 | 98 | func (rw *requestWriter) WriteRequest(req *http.Request, dataStreamID quic.StreamID, endStream bool) error { 99 | if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" { 100 | username := u.Username() 101 | password, _ := u.Password() 102 | req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) 103 | } 104 | 105 | buf, err := rw.encodeHeaders(req, actualContentLength(req)) 106 | if err != nil { 107 | log.Fatal("Failed to encode request headers: ", err) 108 | return err 109 | } 110 | h2framer := http2.NewFramer(rw.headerStream, nil) 111 | return h2framer.WriteHeaders(http2.HeadersFrameParam{ 112 | StreamID: uint32(dataStreamID), 113 | EndHeaders: true, 114 | EndStream: endStream, 115 | BlockFragment: buf, 116 | Priority: http2.PriorityParam{Weight: 0xff}, 117 | }) 118 | } 119 | 120 | func (w *requestWriter) encodeHeaders(req *http.Request, contentLength int64) ([]byte, error) { 121 | w.hbuf.Reset() 122 | 123 | host := req.Host 124 | if host == "" { 125 | host = req.URL.Host 126 | } 127 | host, err := httplex.PunycodeHostPort(host) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | path := req.URL.RequestURI() 133 | if !validPseudoPath(path) { 134 | orig := path 135 | path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) 136 | if !validPseudoPath(path) { 137 | if req.URL.Opaque != "" { 138 | return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) 139 | } else { 140 | return nil, fmt.Errorf("invalid request :path %q", orig) 141 | } 142 | } 143 | } 144 | 145 | // Check for any invalid headers and return an error before we 146 | // potentially pollute our hpack state. (We want to be able to 147 | // continue to reuse the hpack encoder for future requests) 148 | for k, vv := range req.Header { 149 | if !httplex.ValidHeaderFieldName(k) { 150 | return nil, fmt.Errorf("invalid HTTP header name %q", k) 151 | } 152 | for _, v := range vv { 153 | if !httplex.ValidHeaderFieldValue(v) { 154 | return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) 155 | } 156 | } 157 | } 158 | 159 | // 8.1.2.3 Request Pseudo-Header Fields 160 | // The :path pseudo-header field includes the path and query parts of the 161 | // target URI (the path-absolute production and optionally a '?' character 162 | // followed by the query production (see Sections 3.3 and 3.4 of 163 | // [RFC3986]). 164 | w.writeHeader(":authority", host) 165 | w.writeHeader(":method", req.Method) 166 | w.writeHeader(":path", path) 167 | w.writeHeader(":scheme", req.URL.Scheme) 168 | 169 | var didUA bool 170 | for k, vv := range req.Header { 171 | lowKey := strings.ToLower(k) 172 | switch lowKey { 173 | case "host", "content-length": 174 | // Host is :authority, already sent. 175 | // Content-Length is automatic, set below. 176 | continue 177 | case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": 178 | // Per 8.1.2.2 Connection-Specific Header 179 | // Fields, don't send connection-specific 180 | // fields. We have already checked if any 181 | // are error-worthy so just ignore the rest. 182 | continue 183 | case "user-agent": 184 | // Match Go's http1 behavior: at most one 185 | // User-Agent. If set to nil or empty string, 186 | // then omit it. Otherwise if not mentioned, 187 | // include the default (below). 188 | didUA = true 189 | if len(vv) < 1 { 190 | continue 191 | } 192 | vv = vv[:1] 193 | if vv[0] == "" { 194 | continue 195 | } 196 | } 197 | for _, v := range vv { 198 | w.writeHeader(lowKey, v) 199 | } 200 | } 201 | if shouldSendReqContentLength(req.Method, contentLength) { 202 | w.writeHeader("content-length", strconv.FormatInt(contentLength, 10)) 203 | } 204 | if !didUA { 205 | panic("user agent info is missing") 206 | } 207 | return w.hbuf.Bytes(), nil 208 | } 209 | 210 | func (w *requestWriter) writeHeader(name, value string) { 211 | //fmt.Printf("http2: Transport encoding header %q = %q\n", name, value) 212 | w.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) 213 | } 214 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package quictun 2 | 3 | import ( 4 | "bufio" 5 | "crypto/tls" 6 | "errors" 7 | "fmt" 8 | "log" 9 | "math/rand" 10 | "net" 11 | "net/http" 12 | "net/url" 13 | "time" 14 | 15 | "github.com/julienschmidt/quictun/internal/atomic" 16 | "github.com/julienschmidt/quictun/internal/socks" 17 | "golang.org/x/net/http2" 18 | "golang.org/x/net/http2/hpack" 19 | 20 | quic "github.com/lucas-clemente/quic-go" 21 | ) 22 | 23 | const protocolIdentifier = "QTP/0.1" 24 | 25 | var ( 26 | ErrInvalidResponse = errors.New("server returned an invalid response") 27 | ErrInvalidSequence = errors.New("client sequence number invalid") 28 | ErrNotAQuictunServer = errors.New("server does not seems to be a quictun server") 29 | ErrWrongCredentials = errors.New("authentication credentials seems to be wrong") 30 | ) 31 | 32 | // Client holds the configuration and state of a quictun client 33 | type Client struct { 34 | // config 35 | ListenAddr string 36 | TunnelAddr string 37 | UserAgent string 38 | TlsCfg *tls.Config 39 | QuicConfig *quic.Config 40 | DialTimeout time.Duration 41 | 42 | // state 43 | session quic.Session 44 | connected atomic.Bool 45 | 46 | // replay protection 47 | clientID uint64 48 | sequenceNumber uint32 49 | 50 | // header 51 | headerStream quic.Stream 52 | hDecoder *hpack.Decoder 53 | h2framer *http2.Framer 54 | } 55 | 56 | func (c *Client) generateClientID() { 57 | // generate clientID 58 | rand.Seed(time.Now().UnixNano()) 59 | c.clientID = rand.Uint64() 60 | } 61 | 62 | func (c *Client) connect() error { 63 | authURL := c.TunnelAddr 64 | 65 | // extract hostname from auth url 66 | uri, err := url.ParseRequestURI(authURL) 67 | if err != nil { 68 | log.Fatal("Invalid Auth URL: ", err) 69 | return err 70 | } 71 | hostname := authorityAddr(uri.Hostname(), uri.Port()) 72 | fmt.Println("Connecting to", hostname) 73 | 74 | c.session, err = quic.DialAddr(hostname, c.TlsCfg, c.QuicConfig) 75 | if err != nil { 76 | log.Fatal("Dial Err: ", err) 77 | return err 78 | } 79 | 80 | // once the version has been negotiated, open the header stream 81 | c.headerStream, err = c.session.OpenStream() 82 | if err != nil { 83 | log.Fatal("OpenStream Err: ", err) 84 | return err 85 | } 86 | //fmt.Println("Header StreamID:", c.headerStream.StreamID()) 87 | 88 | dataStream, err := c.session.OpenStreamSync() 89 | if err != nil { 90 | log.Fatal("OpenStreamSync Err: ", err) 91 | } 92 | //fmt.Println("Data StreamID:", dataStream.StreamID()) 93 | 94 | // build HTTP request 95 | // The authorization credentials are automatically encoded from the URL 96 | req, err := http.NewRequest("GET", authURL, nil) 97 | if err != nil { 98 | log.Fatal("NewRequest Err: ", err) 99 | return err 100 | } 101 | req.Header.Set("User-Agent", c.UserAgent) 102 | 103 | // request protocol upgrade 104 | req.Header.Set("Connection", "Upgrade") 105 | req.Header.Set("Upgrade", protocolIdentifier) 106 | 107 | // replay protection 108 | c.sequenceNumber++ 109 | req.Header.Set("QTP", fmt.Sprintf("%016X%08X", c.clientID, c.sequenceNumber)) 110 | 111 | rw := newRequestWriter(c.headerStream) 112 | endStream := true //endStream := !hasBody 113 | fmt.Println("requesting", authURL) 114 | err = rw.WriteRequest(req, dataStream.StreamID(), endStream) 115 | if err != nil { 116 | log.Fatal("WriteHeaders Err: ", err) 117 | } 118 | 119 | fmt.Println("Waiting...") 120 | // read frames from headerStream 121 | c.h2framer = http2.NewFramer(nil, c.headerStream) 122 | c.hDecoder = hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) 123 | 124 | frame, err := c.h2framer.ReadFrame() 125 | if err != nil { 126 | // c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") 127 | log.Fatal("cannot read frame: ", err) 128 | } 129 | hframe, ok := frame.(*http2.HeadersFrame) 130 | if !ok { 131 | // c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") 132 | log.Fatal("not a headers frame: ", err) 133 | } 134 | mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} 135 | mhframe.Fields, err = c.hDecoder.DecodeFull(hframe.HeaderBlockFragment()) 136 | if err != nil { 137 | // c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") 138 | log.Fatal("cannot read header fields: ", err) 139 | } 140 | 141 | //fmt.Println("Frame for StreamID:", hframe.StreamID) 142 | 143 | rsp, err := responseFromHeaders(mhframe) 144 | if err != nil { 145 | log.Fatal("responseFromHeaders: ", err) 146 | } 147 | switch rsp.StatusCode { 148 | case http.StatusSwitchingProtocols: 149 | header := rsp.Header 150 | if header.Get("Connection") != "Upgrade" { 151 | return ErrInvalidResponse 152 | } 153 | if header.Get("Upgrade") != protocolIdentifier { 154 | return ErrNotAQuictunServer 155 | } 156 | return nil 157 | case http.StatusUnauthorized, http.StatusForbidden: 158 | return ErrWrongCredentials 159 | case http.StatusBadRequest: 160 | c.generateClientID() 161 | return ErrInvalidSequence 162 | default: 163 | return ErrInvalidResponse 164 | } 165 | } 166 | 167 | func (c *Client) watchCancel() { 168 | session := c.session 169 | if session == nil { 170 | fmt.Println("session is nil") 171 | return 172 | } 173 | 174 | ctx := session.Context() 175 | if ctx == nil { 176 | fmt.Println("ctx is nil") 177 | return 178 | } 179 | 180 | // TODO: add graceful shutdown channel 181 | <-ctx.Done() 182 | fmt.Println("session closed", ctx.Err()) 183 | c.connected.Set(false) 184 | } 185 | 186 | func (c *Client) tunnelConn(local net.Conn) { 187 | local.(*net.TCPConn).SetKeepAlive(true) 188 | // TODO: SetReadTimeout(conn) 189 | 190 | localRd := bufio.NewReader(local) 191 | 192 | // initiate SOCKS connection 193 | if err := socks.Auth(localRd, local); err != nil { 194 | fmt.Println(err) 195 | local.Close() 196 | return 197 | } 198 | 199 | req, err := socks.PeekRequest(localRd) 200 | if err != nil { 201 | fmt.Println(err) 202 | socks.SendReply(local, socks.StatusConnectionRefused, nil) 203 | local.Close() 204 | return 205 | } 206 | 207 | fmt.Println("request", req.Dest()) 208 | 209 | switch req.Cmd() { 210 | case socks.CmdConnect: 211 | fmt.Println("[Connect]") 212 | if err = socks.SendReply(local, socks.StatusSucceeded, nil); err != nil { 213 | fmt.Println(err) 214 | local.Close() 215 | return 216 | } 217 | 218 | default: 219 | socks.SendReply(local, socks.StatusCmdNotSupported, nil) 220 | local.Close() 221 | return 222 | } 223 | 224 | // TODO: check connected status again and reconnect if necessary 225 | stream, err := c.session.OpenStreamSync() 226 | if err != nil { 227 | fmt.Println("open stream err", err) 228 | local.Close() 229 | return 230 | } 231 | 232 | fmt.Println("Start proxying...") 233 | go proxy(local, stream) // recv from stream and send to local 234 | proxy(stream, localRd) // recv from local and send to stream 235 | } 236 | 237 | // Close closes the client 238 | func (c *Client) close(err error) error { 239 | if c.session == nil { 240 | return nil 241 | } 242 | return c.session.Close(err) 243 | } 244 | 245 | // Run starts the client to accept incoming SOCKS connections, which are tunneled 246 | // to the configured quictun server. 247 | // The tunnel connection is opened only on-demand. 248 | func (c *Client) Run() error { 249 | c.generateClientID() 250 | 251 | listener, err := net.Listen("tcp", c.ListenAddr) 252 | if err != nil { 253 | return fmt.Errorf("Failed to listen on %s: %s", c.ListenAddr, err) 254 | } 255 | 256 | fmt.Println("Listening for incoming SOCKS connection...") 257 | // accept local connections and tunnel them 258 | for { 259 | conn, err := listener.Accept() 260 | if err != nil { 261 | log.Println("Accept Err:", err) 262 | continue 263 | } 264 | 265 | fmt.Println("new SOCKS conn", conn.RemoteAddr().String()) 266 | 267 | if !c.connected.IsSet() { 268 | err = c.connect() 269 | if err != nil { 270 | fmt.Println("Failed to connect to tunnel host:", err) 271 | conn.Close() 272 | continue 273 | } 274 | // start watcher which closes when canceled 275 | go c.watchCancel() 276 | 277 | c.connected.Set(true) 278 | } 279 | 280 | go c.tunnelConn(conn) 281 | } 282 | } 283 | -------------------------------------------------------------------------------- /h2quic/client.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "net/http" 10 | "strings" 11 | "sync" 12 | 13 | "golang.org/x/net/http2" 14 | "golang.org/x/net/http2/hpack" 15 | "golang.org/x/net/idna" 16 | 17 | quic "github.com/lucas-clemente/quic-go" 18 | "github.com/lucas-clemente/quic-go/qerr" 19 | ) 20 | 21 | type roundTripperOpts struct { 22 | DisableCompression bool 23 | } 24 | 25 | var dialAddr = quic.DialAddr 26 | 27 | // client is a HTTP2 client doing QUIC requests 28 | type client struct { 29 | mutex sync.RWMutex 30 | 31 | tlsConf *tls.Config 32 | config *quic.Config 33 | opts *roundTripperOpts 34 | 35 | hostname string 36 | handshakeErr error 37 | dialOnce sync.Once 38 | 39 | session quic.Session 40 | headerStream quic.Stream 41 | headerErr *qerr.QuicError 42 | headerErrored chan struct{} // this channel is closed if an error occurs on the header stream 43 | requestWriter *requestWriter 44 | 45 | responses map[quic.StreamID]chan *http.Response 46 | } 47 | 48 | var _ http.RoundTripper = &client{} 49 | 50 | var defaultQuicConfig = &quic.Config{ 51 | RequestConnectionIDOmission: true, 52 | KeepAlive: true, 53 | } 54 | 55 | // newClient creates a new client 56 | func newClient( 57 | hostname string, 58 | tlsConfig *tls.Config, 59 | opts *roundTripperOpts, 60 | quicConfig *quic.Config, 61 | ) *client { 62 | config := defaultQuicConfig 63 | if quicConfig != nil { 64 | config = quicConfig 65 | } 66 | return &client{ 67 | hostname: authorityAddr("https", hostname), 68 | responses: make(map[quic.StreamID]chan *http.Response), 69 | tlsConf: tlsConfig, 70 | config: config, 71 | opts: opts, 72 | headerErrored: make(chan struct{}), 73 | } 74 | } 75 | 76 | // dial dials the connection 77 | func (c *client) dial() error { 78 | var err error 79 | c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) 80 | if err != nil { 81 | return err 82 | } 83 | 84 | // once the version has been negotiated, open the header stream 85 | c.headerStream, err = c.session.OpenStream() 86 | if err != nil { 87 | return err 88 | } 89 | c.requestWriter = newRequestWriter(c.headerStream) 90 | go c.handleHeaderStream() 91 | return nil 92 | } 93 | 94 | func (c *client) handleHeaderStream() { 95 | decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) 96 | h2framer := http2.NewFramer(nil, c.headerStream) 97 | 98 | var lastStream quic.StreamID 99 | 100 | for { 101 | frame, err := h2framer.ReadFrame() 102 | if err != nil { 103 | c.headerErr = qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") 104 | break 105 | } 106 | lastStream = quic.StreamID(frame.Header().StreamID) 107 | hframe, ok := frame.(*http2.HeadersFrame) 108 | if !ok { 109 | c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame") 110 | break 111 | } 112 | mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} 113 | mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) 114 | if err != nil { 115 | c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields") 116 | break 117 | } 118 | 119 | c.mutex.RLock() 120 | responseChan, ok := c.responses[quic.StreamID(hframe.StreamID)] 121 | c.mutex.RUnlock() 122 | if !ok { 123 | c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) 124 | break 125 | } 126 | 127 | rsp, err := responseFromHeaders(mhframe) 128 | if err != nil { 129 | c.headerErr = qerr.Error(qerr.InternalError, err.Error()) 130 | } 131 | responseChan <- rsp 132 | } 133 | 134 | // stop all running request 135 | fmt.Printf("Error handling header stream %d: %s\n", lastStream, c.headerErr.Error()) 136 | close(c.headerErrored) 137 | } 138 | 139 | // Roundtrip executes a request and returns a response 140 | func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { 141 | // TODO: add port to address, if it doesn't have one 142 | if req.URL.Scheme != "https" { 143 | return nil, errors.New("quic http2: unsupported scheme") 144 | } 145 | if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { 146 | return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host) 147 | } 148 | 149 | c.dialOnce.Do(func() { 150 | c.handshakeErr = c.dial() 151 | }) 152 | 153 | if c.handshakeErr != nil { 154 | return nil, c.handshakeErr 155 | } 156 | 157 | hasBody := (req.Body != nil) 158 | 159 | responseChan := make(chan *http.Response) 160 | dataStream, err := c.session.OpenStreamSync() 161 | if err != nil { 162 | _ = c.CloseWithError(err) 163 | return nil, err 164 | } 165 | c.mutex.Lock() 166 | c.responses[dataStream.StreamID()] = responseChan 167 | c.mutex.Unlock() 168 | 169 | var requestedGzip bool 170 | if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { 171 | requestedGzip = true 172 | } 173 | // TODO: add support for trailers 174 | endStream := !hasBody 175 | err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) 176 | if err != nil { 177 | _ = c.CloseWithError(err) 178 | return nil, err 179 | } 180 | 181 | resc := make(chan error, 1) 182 | if hasBody { 183 | go func() { 184 | resc <- c.writeRequestBody(dataStream, req.Body) 185 | }() 186 | } 187 | 188 | var res *http.Response 189 | 190 | var receivedResponse bool 191 | var bodySent bool 192 | 193 | if !hasBody { 194 | bodySent = true 195 | } 196 | 197 | for !(bodySent && receivedResponse) { 198 | select { 199 | case res = <-responseChan: 200 | receivedResponse = true 201 | c.mutex.Lock() 202 | delete(c.responses, dataStream.StreamID()) 203 | c.mutex.Unlock() 204 | case err := <-resc: 205 | bodySent = true 206 | if err != nil { 207 | return nil, err 208 | } 209 | case <-c.headerErrored: 210 | // an error occured on the header stream 211 | _ = c.CloseWithError(c.headerErr) 212 | return nil, c.headerErr 213 | } 214 | } 215 | 216 | // TODO: correctly set this variable 217 | var streamEnded bool 218 | isHead := (req.Method == "HEAD") 219 | 220 | res = setLength(res, isHead, streamEnded) 221 | 222 | if streamEnded || isHead { 223 | res.Body = noBody 224 | } else { 225 | res.Body = dataStream 226 | if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { 227 | res.Header.Del("Content-Encoding") 228 | res.Header.Del("Content-Length") 229 | res.ContentLength = -1 230 | res.Body = &gzipReader{body: res.Body} 231 | res.Uncompressed = true 232 | } 233 | } 234 | 235 | res.Request = req 236 | return res, nil 237 | } 238 | 239 | func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) { 240 | defer func() { 241 | cerr := body.Close() 242 | if err == nil { 243 | // TODO: what to do with dataStream here? Maybe reset it? 244 | err = cerr 245 | } 246 | }() 247 | 248 | _, err = io.Copy(dataStream, body) 249 | if err != nil { 250 | // TODO: what to do with dataStream here? Maybe reset it? 251 | return err 252 | } 253 | return dataStream.Close() 254 | } 255 | 256 | // Close closes the client 257 | func (c *client) CloseWithError(e error) error { 258 | if c.session == nil { 259 | return nil 260 | } 261 | return c.session.Close(e) 262 | } 263 | 264 | func (c *client) Close() error { 265 | return c.CloseWithError(nil) 266 | } 267 | 268 | // copied from net/transport.go 269 | 270 | // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) 271 | // and returns a host:port. The port 443 is added if needed. 272 | func authorityAddr(scheme string, authority string) (addr string) { 273 | host, port, err := net.SplitHostPort(authority) 274 | if err != nil { // authority didn't have a port 275 | port = "443" 276 | if scheme == "http" { 277 | port = "80" 278 | } 279 | host = authority 280 | } 281 | if a, err := idna.ToASCII(host); err == nil { 282 | host = a 283 | } 284 | // IPv6 address literal, without a port: 285 | if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { 286 | return host + ":" + port 287 | } 288 | return net.JoinHostPort(host, port) 289 | } 290 | -------------------------------------------------------------------------------- /h2quic/server.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "runtime" 10 | "strings" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | 15 | quic "github.com/lucas-clemente/quic-go" 16 | "github.com/lucas-clemente/quic-go/qerr" 17 | "golang.org/x/net/http2" 18 | "golang.org/x/net/http2/hpack" 19 | ) 20 | 21 | type streamCreator interface { 22 | quic.Session 23 | GetOrOpenStream(quic.StreamID) (quic.Stream, error) 24 | } 25 | 26 | type remoteCloser interface { 27 | CloseRemote(uint64) 28 | } 29 | 30 | // allows mocking of quic.Listen and quic.ListenAddr 31 | var ( 32 | quicListen = quic.Listen 33 | quicListenAddr = quic.ListenAddr 34 | ) 35 | 36 | // Server is a HTTP2 server listening for QUIC connections. 37 | type Server struct { 38 | *http.Server 39 | 40 | // By providing a quic.Config, it is possible to set parameters of the QUIC connection. 41 | // If nil, it uses reasonable default values. 42 | QuicConfig *quic.Config 43 | 44 | // Private flag for demo, do not use 45 | CloseAfterFirstRequest bool 46 | 47 | port uint32 // used atomically 48 | 49 | listenerMutex sync.Mutex 50 | listener quic.Listener 51 | 52 | supportedVersionsAsString string 53 | } 54 | 55 | // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. 56 | func (s *Server) ListenAndServe() error { 57 | if s.Server == nil { 58 | return errors.New("use of h2quic.Server without http.Server") 59 | } 60 | return s.serveImpl(s.TLSConfig, nil) 61 | } 62 | 63 | // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections. 64 | func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { 65 | var err error 66 | certs := make([]tls.Certificate, 1) 67 | certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) 68 | if err != nil { 69 | return err 70 | } 71 | // We currently only use the cert-related stuff from tls.Config, 72 | // so we don't need to make a full copy. 73 | config := &tls.Config{ 74 | Certificates: certs, 75 | } 76 | return s.serveImpl(config, nil) 77 | } 78 | 79 | // Serve an existing UDP connection. 80 | func (s *Server) Serve(conn net.PacketConn) error { 81 | return s.serveImpl(s.TLSConfig, conn) 82 | } 83 | 84 | func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error { 85 | if s.Server == nil { 86 | return errors.New("use of h2quic.Server without http.Server") 87 | } 88 | s.listenerMutex.Lock() 89 | if s.listener != nil { 90 | s.listenerMutex.Unlock() 91 | return errors.New("ListenAndServe may only be called once") 92 | } 93 | 94 | var ln quic.Listener 95 | var err error 96 | if conn == nil { 97 | ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig) 98 | } else { 99 | ln, err = quicListen(conn, tlsConfig, s.QuicConfig) 100 | } 101 | if err != nil { 102 | s.listenerMutex.Unlock() 103 | return err 104 | } 105 | s.listener = ln 106 | s.listenerMutex.Unlock() 107 | 108 | for { 109 | sess, err := ln.Accept() 110 | if err != nil { 111 | return err 112 | } 113 | go s.handleHeaderStream(sess.(streamCreator)) 114 | } 115 | } 116 | 117 | func (s *Server) handleHeaderStream(session streamCreator) { 118 | stream, err := session.AcceptStream() 119 | if err != nil { 120 | session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error())) 121 | return 122 | } 123 | 124 | hpackDecoder := hpack.NewDecoder(4096, nil) 125 | h2framer := http2.NewFramer(nil, stream) 126 | 127 | var headerStreamMutex sync.Mutex // Protects concurrent calls to Write() 128 | for { 129 | if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil { 130 | // QuicErrors must originate from stream.Read() returning an error. 131 | // In this case, the session has already logged the error, so we don't 132 | // need to log it again. 133 | if _, ok := err.(*qerr.QuicError); !ok { 134 | fmt.Printf("error handling h2 request: %s\n", err.Error()) 135 | } 136 | session.Close(err) 137 | return 138 | } 139 | } 140 | } 141 | 142 | func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error { 143 | h2frame, err := h2framer.ReadFrame() 144 | if err != nil { 145 | return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") 146 | } 147 | h2headersFrame, ok := h2frame.(*http2.HeadersFrame) 148 | if !ok { 149 | return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame") 150 | } 151 | if !h2headersFrame.HeadersEnded() { 152 | return errors.New("http2 header continuation not implemented") 153 | } 154 | headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment()) 155 | if err != nil { 156 | fmt.Println("invalid http2 headers encoding:", err.Error()) 157 | return err 158 | } 159 | 160 | req, err := requestFromHeaders(headers) 161 | if err != nil { 162 | return err 163 | } 164 | 165 | fmt.Printf("%s %s%s, on data stream %d\n", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID) 166 | 167 | dataStream, err := session.GetOrOpenStream(quic.StreamID(h2headersFrame.StreamID)) 168 | if err != nil { 169 | return err 170 | } 171 | // this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request 172 | if dataStream == nil { 173 | return nil 174 | } 175 | 176 | // handleRequest should be as non-blocking as possible to minimize 177 | // head-of-line blocking. Potentially blocking code is run in a separate 178 | // goroutine, enabling handleRequest to return before the code is executed. 179 | go func() { 180 | streamEnded := h2headersFrame.StreamEnded() 181 | if streamEnded { 182 | dataStream.(remoteCloser).CloseRemote(0) 183 | streamEnded = true 184 | _, _ = dataStream.Read([]byte{0}) // read the eof 185 | } 186 | 187 | req = req.WithContext(dataStream.Context()) 188 | reqBody := newRequestBody(dataStream) 189 | req.Body = reqBody 190 | 191 | req.RemoteAddr = session.RemoteAddr().String() 192 | 193 | responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, quic.StreamID(h2headersFrame.StreamID)) 194 | 195 | handler := s.Handler 196 | if handler == nil { 197 | handler = http.DefaultServeMux 198 | } 199 | panicked := false 200 | func() { 201 | defer func() { 202 | if p := recover(); p != nil { 203 | // Copied from net/http/server.go 204 | const size = 64 << 10 205 | buf := make([]byte, size) 206 | buf = buf[:runtime.Stack(buf, false)] 207 | fmt.Printf("http: panic serving: %v\n%s\n", p, buf) 208 | panicked = true 209 | } 210 | }() 211 | handler.ServeHTTP(responseWriter, req) 212 | }() 213 | 214 | if panicked { 215 | responseWriter.WriteHeader(500) 216 | } else { 217 | switch responseWriter.status { 218 | case http.StatusSwitchingProtocols: 219 | if protocols, ok := responseWriter.Header()["Upgrade"]; ok { 220 | fmt.Println("Upgrade to:", protocols) 221 | for _, protocol := range protocols { 222 | fmt.Println(protocol) 223 | if handler, ok := upgradeHandlers[protocol]; ok { 224 | handler(session) 225 | break 226 | } 227 | } 228 | } 229 | case 0: 230 | responseWriter.WriteHeader(200) 231 | } 232 | } 233 | 234 | if responseWriter.dataStream != nil { 235 | if !streamEnded && !reqBody.requestRead { 236 | responseWriter.dataStream.Reset(nil) 237 | } 238 | responseWriter.dataStream.Close() 239 | } 240 | if s.CloseAfterFirstRequest { 241 | time.Sleep(100 * time.Millisecond) 242 | session.Close(nil) 243 | } 244 | }() 245 | 246 | return nil 247 | } 248 | 249 | // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. 250 | // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. 251 | func (s *Server) Close() error { 252 | s.listenerMutex.Lock() 253 | defer s.listenerMutex.Unlock() 254 | if s.listener != nil { 255 | err := s.listener.Close() 256 | s.listener = nil 257 | return err 258 | } 259 | return nil 260 | } 261 | 262 | // CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. 263 | // CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. 264 | func (s *Server) CloseGracefully(timeout time.Duration) error { 265 | // TODO: implement 266 | return nil 267 | } 268 | 269 | // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC. 270 | // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443): 271 | // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" 272 | func (s *Server) SetQuicHeaders(hdr http.Header) error { 273 | port := atomic.LoadUint32(&s.port) 274 | 275 | if port == 0 { 276 | // Extract port from s.Server.Addr 277 | _, portStr, err := net.SplitHostPort(s.Server.Addr) 278 | if err != nil { 279 | return err 280 | } 281 | portInt, err := net.LookupPort("tcp", portStr) 282 | if err != nil { 283 | return err 284 | } 285 | port = uint32(portInt) 286 | atomic.StoreUint32(&s.port, port) 287 | } 288 | 289 | if s.supportedVersionsAsString == "" { 290 | var versions []string 291 | for _, v := range quic.SupportedVersions { 292 | versions = append(versions, v.ToAltSvc()) 293 | } 294 | s.supportedVersionsAsString = strings.Join(versions, ",") 295 | } 296 | 297 | hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString)) 298 | 299 | return nil 300 | } 301 | 302 | // ListenAndServeQUIC listens on the UDP network address addr and calls the 303 | // handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is 304 | // used when handler is nil. 305 | func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error { 306 | server := &Server{ 307 | Server: &http.Server{ 308 | Addr: addr, 309 | Handler: handler, 310 | }, 311 | } 312 | return server.ListenAndServeTLS(certFile, keyFile) 313 | } 314 | 315 | // ListenAndServe listens on the given network address for both, TLS and QUIC 316 | // connetions in parallel. It returns if one of the two returns an error. 317 | // http.DefaultServeMux is used when handler is nil. 318 | // The correct Alt-Svc headers for QUIC are set. 319 | func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error { 320 | // Load certs 321 | var err error 322 | certs := make([]tls.Certificate, 1) 323 | certs[0], err = tls.LoadX509KeyPair(certFile, keyFile) 324 | if err != nil { 325 | return err 326 | } 327 | // We currently only use the cert-related stuff from tls.Config, 328 | // so we don't need to make a full copy. 329 | config := &tls.Config{ 330 | Certificates: certs, 331 | } 332 | 333 | // Open the listeners 334 | udpAddr, err := net.ResolveUDPAddr("udp", addr) 335 | if err != nil { 336 | return err 337 | } 338 | udpConn, err := net.ListenUDP("udp", udpAddr) 339 | if err != nil { 340 | return err 341 | } 342 | defer udpConn.Close() 343 | 344 | tcpAddr, err := net.ResolveTCPAddr("tcp", addr) 345 | if err != nil { 346 | return err 347 | } 348 | tcpConn, err := net.ListenTCP("tcp", tcpAddr) 349 | if err != nil { 350 | return err 351 | } 352 | defer tcpConn.Close() 353 | 354 | tlsConn := tls.NewListener(tcpConn, config) 355 | defer tlsConn.Close() 356 | 357 | // Start the servers 358 | httpServer := &http.Server{ 359 | Addr: addr, 360 | TLSConfig: config, 361 | } 362 | 363 | quicServer := &Server{ 364 | Server: httpServer, 365 | } 366 | 367 | if handler == nil { 368 | handler = http.DefaultServeMux 369 | } 370 | httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 371 | quicServer.SetQuicHeaders(w.Header()) 372 | handler.ServeHTTP(w, r) 373 | }) 374 | 375 | hErr := make(chan error) 376 | qErr := make(chan error) 377 | go func() { 378 | hErr <- httpServer.Serve(tlsConn) 379 | }() 380 | go func() { 381 | qErr <- quicServer.Serve(udpConn) 382 | }() 383 | 384 | select { 385 | case err := <-hErr: 386 | quicServer.Close() 387 | return err 388 | case err := <-qErr: 389 | // Cannot close the HTTP server or wait for requests to complete properly :/ 390 | return err 391 | } 392 | } 393 | -------------------------------------------------------------------------------- /h2quic/server_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/tls" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/http" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "golang.org/x/net/http2" 17 | "golang.org/x/net/http2/hpack" 18 | 19 | "github.com/julienschmidt/quictun/internal/testdata" 20 | quic "github.com/lucas-clemente/quic-go" 21 | "github.com/lucas-clemente/quic-go/qerr" 22 | 23 | . "github.com/onsi/ginkgo" 24 | . "github.com/onsi/gomega" 25 | ) 26 | 27 | type mockSession struct { 28 | closed bool 29 | closedWithError error 30 | dataStream quic.Stream 31 | streamToAccept quic.Stream 32 | streamsToOpen []quic.Stream 33 | blockOpenStreamSync bool 34 | streamOpenErr error 35 | ctx context.Context 36 | ctxCancel context.CancelFunc 37 | } 38 | 39 | func (s *mockSession) GetOrOpenStream(id quic.StreamID) (quic.Stream, error) { 40 | return s.dataStream, nil 41 | } 42 | func (s *mockSession) AcceptStream() (quic.Stream, error) { return s.streamToAccept, nil } 43 | func (s *mockSession) OpenStream() (quic.Stream, error) { 44 | if s.streamOpenErr != nil { 45 | return nil, s.streamOpenErr 46 | } 47 | str := s.streamsToOpen[0] 48 | s.streamsToOpen = s.streamsToOpen[1:] 49 | return str, nil 50 | } 51 | func (s *mockSession) OpenStreamSync() (quic.Stream, error) { 52 | if s.blockOpenStreamSync { 53 | time.Sleep(time.Hour) 54 | } 55 | return s.OpenStream() 56 | } 57 | func (s *mockSession) Close(e error) error { 58 | s.closed = true 59 | s.closedWithError = e 60 | s.ctxCancel() 61 | return nil 62 | } 63 | func (s *mockSession) LocalAddr() net.Addr { 64 | panic("not implemented") 65 | } 66 | func (s *mockSession) RemoteAddr() net.Addr { 67 | return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42} 68 | } 69 | func (s *mockSession) Context() context.Context { 70 | return s.ctx 71 | } 72 | 73 | var _ = Describe("H2 server", func() { 74 | var ( 75 | s *Server 76 | session *mockSession 77 | dataStream *mockStream 78 | origQuicListenAddr = quicListenAddr 79 | ) 80 | 81 | BeforeEach(func() { 82 | s = &Server{ 83 | Server: &http.Server{ 84 | TLSConfig: testdata.GetTLSConfig(), 85 | }, 86 | } 87 | dataStream = newMockStream(0) 88 | close(dataStream.unblockRead) 89 | session = &mockSession{dataStream: dataStream} 90 | session.ctx, session.ctxCancel = context.WithCancel(context.Background()) 91 | origQuicListenAddr = quicListenAddr 92 | }) 93 | 94 | AfterEach(func() { 95 | quicListenAddr = origQuicListenAddr 96 | }) 97 | 98 | Context("handling requests", func() { 99 | var ( 100 | h2framer *http2.Framer 101 | hpackDecoder *hpack.Decoder 102 | headerStream *mockStream 103 | ) 104 | 105 | BeforeEach(func() { 106 | headerStream = &mockStream{} 107 | hpackDecoder = hpack.NewDecoder(4096, nil) 108 | h2framer = http2.NewFramer(nil, headerStream) 109 | }) 110 | 111 | It("handles a sample GET request", func() { 112 | var handlerCalled bool 113 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 114 | defer GinkgoRecover() 115 | Expect(r.Host).To(Equal("www.example.com")) 116 | Expect(r.RemoteAddr).To(Equal("127.0.0.1:42")) 117 | handlerCalled = true 118 | }) 119 | headerStream.dataToRead.Write([]byte{ 120 | 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, 121 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 122 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 123 | }) 124 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 125 | Expect(err).NotTo(HaveOccurred()) 126 | Eventually(func() bool { return handlerCalled }).Should(BeTrue()) 127 | Expect(dataStream.remoteClosed).To(BeTrue()) 128 | Expect(dataStream.reset).To(BeFalse()) 129 | }) 130 | 131 | It("returns 200 with an empty handler", func() { 132 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 133 | headerStream.dataToRead.Write([]byte{ 134 | 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, 135 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 136 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 137 | }) 138 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 139 | Expect(err).NotTo(HaveOccurred()) 140 | Eventually(func() []byte { 141 | return headerStream.dataWritten.Bytes() 142 | }).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x88})) // 0x88 is 200 143 | }) 144 | 145 | It("correctly handles a panicking handler", func() { 146 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 147 | panic("foobar") 148 | }) 149 | headerStream.dataToRead.Write([]byte{ 150 | 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, 151 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 152 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 153 | }) 154 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 155 | Expect(err).NotTo(HaveOccurred()) 156 | Eventually(func() []byte { 157 | return headerStream.dataWritten.Bytes() 158 | }).Should(Equal([]byte{0x0, 0x0, 0x1, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 0x8e})) // 0x82 is 500 159 | }) 160 | 161 | It("resets the dataStream when client sends a body in GET request", func() { 162 | var handlerCalled bool 163 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 164 | Expect(r.Host).To(Equal("www.example.com")) 165 | handlerCalled = true 166 | }) 167 | headerStream.dataToRead.Write([]byte{ 168 | 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 169 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 170 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 171 | }) 172 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 173 | Expect(err).NotTo(HaveOccurred()) 174 | Eventually(func() bool { return handlerCalled }).Should(BeTrue()) 175 | Eventually(func() bool { return dataStream.reset }).Should(BeTrue()) 176 | Expect(dataStream.remoteClosed).To(BeFalse()) 177 | }) 178 | 179 | It("resets the dataStream when the body of POST request is not read", func() { 180 | var handlerCalled bool 181 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 182 | Expect(r.Host).To(Equal("www.example.com")) 183 | Expect(r.Method).To(Equal("POST")) 184 | handlerCalled = true 185 | }) 186 | headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7}) 187 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 188 | Expect(err).NotTo(HaveOccurred()) 189 | Eventually(func() bool { return dataStream.reset }).Should(BeTrue()) 190 | Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse()) 191 | Expect(handlerCalled).To(BeTrue()) 192 | }) 193 | 194 | It("handles a request for which the client immediately resets the data stream", func() { 195 | session.dataStream = nil 196 | var handlerCalled bool 197 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 198 | handlerCalled = true 199 | }) 200 | headerStream.dataToRead.Write([]byte{ 201 | 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, 202 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 203 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 204 | }) 205 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 206 | Expect(err).NotTo(HaveOccurred()) 207 | Consistently(func() bool { return handlerCalled }).Should(BeFalse()) 208 | }) 209 | 210 | It("resets the dataStream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 211 | var handlerCalled bool 212 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 213 | r.Body = struct { 214 | io.Reader 215 | io.Closer 216 | }{} 217 | handlerCalled = true 218 | }) 219 | headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7}) 220 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 221 | Expect(err).NotTo(HaveOccurred()) 222 | Eventually(func() bool { return dataStream.reset }).Should(BeTrue()) 223 | Consistently(func() bool { return dataStream.remoteClosed }).Should(BeFalse()) 224 | Expect(handlerCalled).To(BeTrue()) 225 | }) 226 | 227 | It("closes the dataStream if the body of POST request was read", func() { 228 | var handlerCalled bool 229 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 230 | Expect(r.Host).To(Equal("www.example.com")) 231 | Expect(r.Method).To(Equal("POST")) 232 | handlerCalled = true 233 | // read the request body 234 | b := make([]byte, 1000) 235 | n, _ := r.Body.Read(b) 236 | Expect(n).ToNot(BeZero()) 237 | }) 238 | headerStream.dataToRead.Write([]byte{0x0, 0x0, 0x20, 0x1, 0x24, 0x0, 0x0, 0x0, 0x5, 0x0, 0x0, 0x0, 0x0, 0xff, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 0x83, 0x84, 0x87, 0x5c, 0x1, 0x37, 0x7a, 0x85, 0xed, 0x69, 0x88, 0xb4, 0xc7}) 239 | dataStream.dataToRead.Write([]byte("foo=bar")) 240 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 241 | Expect(err).NotTo(HaveOccurred()) 242 | Eventually(func() bool { return handlerCalled }).Should(BeTrue()) 243 | Expect(dataStream.reset).To(BeFalse()) 244 | }) 245 | 246 | It("errors when non-header frames are received", func() { 247 | headerStream.dataToRead.Write([]byte{ 248 | 0x0, 0x0, 0x06, 0x0, 0x0, 0x0, 0x0, 0x0, 0x5, 249 | 'f', 'o', 'o', 'b', 'a', 'r', 250 | }) 251 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 252 | Expect(err).To(MatchError("InvalidHeadersStreamData: expected a header frame")) 253 | }) 254 | 255 | It("Cancels the request context when the datstream is closed", func() { 256 | var handlerCalled bool 257 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 258 | defer GinkgoRecover() 259 | err := r.Context().Err() 260 | Expect(err).To(HaveOccurred()) 261 | Expect(err.Error()).To(Equal("context canceled")) 262 | handlerCalled = true 263 | }) 264 | headerStream.dataToRead.Write([]byte{ 265 | 0x0, 0x0, 0x11, 0x1, 0x5, 0x0, 0x0, 0x0, 0x5, 266 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 267 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 268 | }) 269 | dataStream.Close() 270 | err := s.handleRequest(session, headerStream, &sync.Mutex{}, hpackDecoder, h2framer) 271 | Expect(err).NotTo(HaveOccurred()) 272 | Eventually(func() bool { return handlerCalled }).Should(BeTrue()) 273 | Expect(dataStream.remoteClosed).To(BeTrue()) 274 | Expect(dataStream.reset).To(BeFalse()) 275 | }) 276 | 277 | }) 278 | 279 | It("handles the header stream", func() { 280 | var handlerCalled bool 281 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 282 | Expect(r.Host).To(Equal("www.example.com")) 283 | handlerCalled = true 284 | }) 285 | headerStream := &mockStream{id: 3} 286 | headerStream.dataToRead.Write([]byte{ 287 | 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 288 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 289 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 290 | }) 291 | session.streamToAccept = headerStream 292 | go s.handleHeaderStream(session) 293 | Eventually(func() bool { return handlerCalled }).Should(BeTrue()) 294 | }) 295 | 296 | It("closes the connection if it encounters an error on the header stream", func() { 297 | var handlerCalled bool 298 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 299 | handlerCalled = true 300 | }) 301 | headerStream := &mockStream{id: 3} 302 | headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) 303 | session.streamToAccept = headerStream 304 | go s.handleHeaderStream(session) 305 | Consistently(func() bool { return handlerCalled }).Should(BeFalse()) 306 | Eventually(func() bool { return session.closed }).Should(BeTrue()) 307 | Expect(session.closedWithError).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame"))) 308 | }) 309 | 310 | It("supports closing after first request", func() { 311 | s.CloseAfterFirstRequest = true 312 | s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 313 | headerStream := &mockStream{id: 3} 314 | headerStream.dataToRead.Write([]byte{ 315 | 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 316 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 317 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 318 | }) 319 | session.streamToAccept = headerStream 320 | Expect(session.closed).To(BeFalse()) 321 | go s.handleHeaderStream(session) 322 | Eventually(func() bool { return session.closed }).Should(BeTrue()) 323 | }) 324 | 325 | It("uses the default handler as fallback", func() { 326 | var handlerCalled bool 327 | http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 328 | Expect(r.Host).To(Equal("www.example.com")) 329 | handlerCalled = true 330 | })) 331 | headerStream := &mockStream{id: 3} 332 | headerStream.dataToRead.Write([]byte{ 333 | 0x0, 0x0, 0x11, 0x1, 0x4, 0x0, 0x0, 0x0, 0x5, 334 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 335 | 0x82, 0x86, 0x84, 0x41, 0x8c, 0xf1, 0xe3, 0xc2, 0xe5, 0xf2, 0x3a, 0x6b, 0xa0, 0xab, 0x90, 0xf4, 0xff, 336 | }) 337 | session.streamToAccept = headerStream 338 | go s.handleHeaderStream(session) 339 | Eventually(func() bool { return handlerCalled }).Should(BeTrue()) 340 | }) 341 | 342 | Context("setting http headers", func() { 343 | var expected http.Header 344 | 345 | getExpectedHeader := func(versions []quic.VersionNumber) http.Header { 346 | var versionsAsString []string 347 | for _, v := range versions { 348 | versionsAsString = append(versionsAsString, v.ToAltSvc()) 349 | } 350 | return http.Header{ 351 | "Alt-Svc": {fmt.Sprintf(`quic=":443"; ma=2592000; v="%s"`, strings.Join(versionsAsString, ","))}, 352 | } 353 | } 354 | 355 | BeforeEach(func() { 356 | Expect(getExpectedHeader([]quic.VersionNumber{99, 90, 9})).To(Equal(http.Header{"Alt-Svc": {`quic=":443"; ma=2592000; v="99,90,9"`}})) 357 | expected = getExpectedHeader(quic.SupportedVersions) 358 | }) 359 | 360 | It("sets proper headers with numeric port", func() { 361 | s.Server.Addr = ":443" 362 | hdr := http.Header{} 363 | err := s.SetQuicHeaders(hdr) 364 | Expect(err).NotTo(HaveOccurred()) 365 | Expect(hdr).To(Equal(expected)) 366 | }) 367 | 368 | It("sets proper headers with full addr", func() { 369 | s.Server.Addr = "127.0.0.1:443" 370 | hdr := http.Header{} 371 | err := s.SetQuicHeaders(hdr) 372 | Expect(err).NotTo(HaveOccurred()) 373 | Expect(hdr).To(Equal(expected)) 374 | }) 375 | 376 | It("sets proper headers with string port", func() { 377 | s.Server.Addr = ":https" 378 | hdr := http.Header{} 379 | err := s.SetQuicHeaders(hdr) 380 | Expect(err).NotTo(HaveOccurred()) 381 | Expect(hdr).To(Equal(expected)) 382 | }) 383 | 384 | It("works multiple times", func() { 385 | s.Server.Addr = ":https" 386 | hdr := http.Header{} 387 | err := s.SetQuicHeaders(hdr) 388 | Expect(err).NotTo(HaveOccurred()) 389 | Expect(hdr).To(Equal(expected)) 390 | hdr = http.Header{} 391 | err = s.SetQuicHeaders(hdr) 392 | Expect(err).NotTo(HaveOccurred()) 393 | Expect(hdr).To(Equal(expected)) 394 | }) 395 | }) 396 | 397 | It("should error when ListenAndServe is called with s.Server nil", func() { 398 | err := (&Server{}).ListenAndServe() 399 | Expect(err).To(MatchError("use of h2quic.Server without http.Server")) 400 | }) 401 | 402 | It("should error when ListenAndServeTLS is called with s.Server nil", func() { 403 | err := (&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths()) 404 | Expect(err).To(MatchError("use of h2quic.Server without http.Server")) 405 | }) 406 | 407 | It("should nop-Close() when s.server is nil", func() { 408 | err := (&Server{}).Close() 409 | Expect(err).NotTo(HaveOccurred()) 410 | }) 411 | 412 | Context("ListenAndServe", func() { 413 | BeforeEach(func() { 414 | s.Server.Addr = "localhost:0" 415 | }) 416 | 417 | AfterEach(func() { 418 | Expect(s.Close()).To(Succeed()) 419 | }) 420 | 421 | It("may only be called once", func() { 422 | cErr := make(chan error) 423 | for i := 0; i < 2; i++ { 424 | go func() { 425 | defer GinkgoRecover() 426 | err := s.ListenAndServe() 427 | if err != nil { 428 | cErr <- err 429 | } 430 | }() 431 | } 432 | err := <-cErr 433 | Expect(err).To(MatchError("ListenAndServe may only be called once")) 434 | err = s.Close() 435 | Expect(err).NotTo(HaveOccurred()) 436 | }, 0.5) 437 | 438 | It("uses the quic.Config to start the quic server", func() { 439 | conf := &quic.Config{HandshakeTimeout: time.Nanosecond} 440 | var receivedConf *quic.Config 441 | quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) { 442 | receivedConf = config 443 | return nil, errors.New("listen err") 444 | } 445 | s.QuicConfig = conf 446 | go s.ListenAndServe() 447 | Eventually(func() *quic.Config { return receivedConf }).Should(Equal(conf)) 448 | }) 449 | }) 450 | 451 | Context("ListenAndServeTLS", func() { 452 | BeforeEach(func() { 453 | s.Server.Addr = "localhost:0" 454 | }) 455 | 456 | AfterEach(func() { 457 | err := s.Close() 458 | Expect(err).NotTo(HaveOccurred()) 459 | }) 460 | 461 | It("may only be called once", func() { 462 | cErr := make(chan error) 463 | for i := 0; i < 2; i++ { 464 | go func() { 465 | defer GinkgoRecover() 466 | err := s.ListenAndServeTLS(testdata.GetCertificatePaths()) 467 | if err != nil { 468 | cErr <- err 469 | } 470 | }() 471 | } 472 | err := <-cErr 473 | Expect(err).To(MatchError("ListenAndServe may only be called once")) 474 | err = s.Close() 475 | Expect(err).NotTo(HaveOccurred()) 476 | }, 0.5) 477 | }) 478 | 479 | It("closes gracefully", func() { 480 | err := s.CloseGracefully(0) 481 | Expect(err).NotTo(HaveOccurred()) 482 | }) 483 | 484 | It("errors when listening fails", func() { 485 | testErr := errors.New("listen error") 486 | quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) { 487 | return nil, testErr 488 | } 489 | fullpem, privkey := testdata.GetCertificatePaths() 490 | err := ListenAndServeQUIC("", fullpem, privkey, nil) 491 | Expect(err).To(MatchError(testErr)) 492 | }) 493 | }) 494 | -------------------------------------------------------------------------------- /h2quic/client_test.go: -------------------------------------------------------------------------------- 1 | package h2quic 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "context" 7 | "crypto/tls" 8 | "errors" 9 | "io" 10 | "net/http" 11 | 12 | "golang.org/x/net/http2" 13 | "golang.org/x/net/http2/hpack" 14 | 15 | quic "github.com/lucas-clemente/quic-go" 16 | "github.com/lucas-clemente/quic-go/qerr" 17 | 18 | "time" 19 | 20 | . "github.com/onsi/ginkgo" 21 | . "github.com/onsi/gomega" 22 | ) 23 | 24 | var _ = Describe("Client", func() { 25 | var ( 26 | client *client 27 | session *mockSession 28 | headerStream *mockStream 29 | req *http.Request 30 | origDialAddr = dialAddr 31 | ) 32 | 33 | BeforeEach(func() { 34 | origDialAddr = dialAddr 35 | hostname := "quic.clemente.io:1337" 36 | client = newClient(hostname, nil, &roundTripperOpts{}, nil) 37 | Expect(client.hostname).To(Equal(hostname)) 38 | session = &mockSession{} 39 | session.ctx, session.ctxCancel = context.WithCancel(context.Background()) 40 | client.session = session 41 | 42 | headerStream = newMockStream(3) 43 | client.headerStream = headerStream 44 | client.requestWriter = newRequestWriter(headerStream) 45 | var err error 46 | req, err = http.NewRequest("GET", "https://localhost:1337", nil) 47 | Expect(err).ToNot(HaveOccurred()) 48 | }) 49 | 50 | AfterEach(func() { 51 | dialAddr = origDialAddr 52 | }) 53 | 54 | It("saves the TLS config", func() { 55 | tlsConf := &tls.Config{InsecureSkipVerify: true} 56 | client = newClient("", tlsConf, &roundTripperOpts{}, nil) 57 | Expect(client.tlsConf).To(Equal(tlsConf)) 58 | }) 59 | 60 | It("saves the QUIC config", func() { 61 | quicConf := &quic.Config{HandshakeTimeout: time.Nanosecond} 62 | client = newClient("", &tls.Config{}, &roundTripperOpts{}, quicConf) 63 | Expect(client.config).To(Equal(quicConf)) 64 | }) 65 | 66 | It("uses the default QUIC config if none is give", func() { 67 | client = newClient("", &tls.Config{}, &roundTripperOpts{}, nil) 68 | Expect(client.config).ToNot(BeNil()) 69 | Expect(client.config).To(Equal(defaultQuicConfig)) 70 | }) 71 | 72 | It("adds the port to the hostname, if none is given", func() { 73 | client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) 74 | Expect(client.hostname).To(Equal("quic.clemente.io:443")) 75 | }) 76 | 77 | It("dials", func(done Done) { 78 | client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) 79 | session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} 80 | dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { 81 | return session, nil 82 | } 83 | close(headerStream.unblockRead) 84 | go client.RoundTrip(req) 85 | Eventually(func() quic.Session { return client.session }).Should(Equal(session)) 86 | close(done) 87 | }, 2) 88 | 89 | It("errors when dialing fails", func() { 90 | testErr := errors.New("handshake error") 91 | client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) 92 | dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { 93 | return nil, testErr 94 | } 95 | _, err := client.RoundTrip(req) 96 | Expect(err).To(MatchError(testErr)) 97 | }) 98 | 99 | It("errors if it can't open a stream", func() { 100 | testErr := errors.New("you shall not pass") 101 | client = newClient("localhost:1337", nil, &roundTripperOpts{}, nil) 102 | session.streamOpenErr = testErr 103 | dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { 104 | return session, nil 105 | } 106 | _, err := client.RoundTrip(req) 107 | Expect(err).To(MatchError(testErr)) 108 | }) 109 | 110 | It("returns a request when dial fails", func() { 111 | testErr := errors.New("dial error") 112 | dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { 113 | return nil, testErr 114 | } 115 | request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) 116 | Expect(err).ToNot(HaveOccurred()) 117 | 118 | var doErr error 119 | go func() { 120 | _, doErr = client.RoundTrip(request) 121 | }() 122 | _, err = client.RoundTrip(request) 123 | Expect(err).To(MatchError(testErr)) 124 | Eventually(func() error { return doErr }).Should(MatchError(testErr)) 125 | }) 126 | 127 | Context("Doing requests", func() { 128 | var request *http.Request 129 | var dataStream *mockStream 130 | 131 | getRequest := func(data []byte) *http2.MetaHeadersFrame { 132 | r := bytes.NewReader(data) 133 | decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) 134 | h2framer := http2.NewFramer(nil, r) 135 | frame, err := h2framer.ReadFrame() 136 | Expect(err).ToNot(HaveOccurred()) 137 | mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)} 138 | mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment()) 139 | Expect(err).ToNot(HaveOccurred()) 140 | return mhframe 141 | } 142 | 143 | getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string { 144 | fields := make(map[string]string) 145 | for _, hf := range f.Fields { 146 | fields[hf.Name] = hf.Value 147 | } 148 | return fields 149 | } 150 | 151 | BeforeEach(func() { 152 | var err error 153 | dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { 154 | return session, nil 155 | } 156 | dataStream = newMockStream(5) 157 | session.streamsToOpen = []quic.Stream{headerStream, dataStream} 158 | request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) 159 | Expect(err).ToNot(HaveOccurred()) 160 | }) 161 | 162 | It("does a request", func(done Done) { 163 | var doRsp *http.Response 164 | var doErr error 165 | var doReturned bool 166 | go func() { 167 | doRsp, doErr = client.RoundTrip(request) 168 | doReturned = true 169 | }() 170 | 171 | Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) 172 | Eventually(func() map[quic.StreamID]chan *http.Response { return client.responses }).Should(HaveKey(quic.StreamID(5))) 173 | rsp := &http.Response{ 174 | Status: "418 I'm a teapot", 175 | StatusCode: 418, 176 | } 177 | Expect(client.responses[5]).ToNot(BeClosed()) 178 | Expect(client.headerErrored).ToNot(BeClosed()) 179 | client.responses[5] <- rsp 180 | Eventually(func() bool { return doReturned }).Should(BeTrue()) 181 | Expect(doErr).ToNot(HaveOccurred()) 182 | Expect(doRsp).To(Equal(rsp)) 183 | Expect(doRsp.Body).To(Equal(dataStream)) 184 | Expect(doRsp.ContentLength).To(BeEquivalentTo(-1)) 185 | Expect(doRsp.Request).To(Equal(request)) 186 | 187 | close(done) 188 | }) 189 | 190 | It("closes the quic client when encountering an error on the header stream", func(done Done) { 191 | headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) 192 | var doReturned bool 193 | go func() { 194 | defer GinkgoRecover() 195 | var err error 196 | rsp, err := client.RoundTrip(request) 197 | Expect(err).To(MatchError(client.headerErr)) 198 | Expect(rsp).To(BeNil()) 199 | doReturned = true 200 | }() 201 | 202 | Eventually(func() bool { return doReturned }).Should(BeTrue()) 203 | Expect(client.headerErr).To(MatchError(qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame"))) 204 | Expect(client.session.(*mockSession).closedWithError).To(MatchError(client.headerErr)) 205 | close(done) 206 | }, 2) 207 | 208 | It("returns subsequent request if there was an error on the header stream before", func(done Done) { 209 | expectedErr := qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame") 210 | session.streamsToOpen = []quic.Stream{headerStream, dataStream, newMockStream(7)} 211 | headerStream.dataToRead.Write(bytes.Repeat([]byte{0}, 100)) 212 | var firstReqReturned bool 213 | go func() { 214 | defer GinkgoRecover() 215 | _, err := client.RoundTrip(request) 216 | Expect(err).To(MatchError(expectedErr)) 217 | firstReqReturned = true 218 | }() 219 | 220 | Eventually(func() bool { return firstReqReturned }).Should(BeTrue()) 221 | // now that the first request failed due to an error on the header stream, try another request 222 | _, err := client.RoundTrip(request) 223 | Expect(err).To(MatchError(expectedErr)) 224 | close(done) 225 | }) 226 | 227 | It("blocks if no stream is available", func() { 228 | session.streamsToOpen = []quic.Stream{headerStream} 229 | session.blockOpenStreamSync = true 230 | var doReturned bool 231 | go func() { 232 | defer GinkgoRecover() 233 | _, err := client.RoundTrip(request) 234 | Expect(err).ToNot(HaveOccurred()) 235 | doReturned = true 236 | }() 237 | go client.handleHeaderStream() 238 | 239 | Consistently(func() bool { return doReturned }).Should(BeFalse()) 240 | }) 241 | 242 | Context("validating the address", func() { 243 | It("refuses to do requests for the wrong host", func() { 244 | req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) 245 | Expect(err).ToNot(HaveOccurred()) 246 | _, err = client.RoundTrip(req) 247 | Expect(err).To(MatchError("h2quic Client BUG: RoundTrip called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) 248 | }) 249 | 250 | It("refuses to do plain HTTP requests", func() { 251 | req, err := http.NewRequest("https", "http://quic.clemente.io:1337/foobar.html", nil) 252 | Expect(err).ToNot(HaveOccurred()) 253 | _, err = client.RoundTrip(req) 254 | Expect(err).To(MatchError("quic http2: unsupported scheme")) 255 | }) 256 | 257 | It("adds the port for request URLs without one", func(done Done) { 258 | var err error 259 | client = newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil) 260 | req, err := http.NewRequest("https", "https://quic.clemente.io/foobar.html", nil) 261 | Expect(err).ToNot(HaveOccurred()) 262 | 263 | var doErr error 264 | var doReturned bool 265 | // the client.RoundTrip will block, because the encryption level is still set to Unencrypted 266 | go func() { 267 | _, doErr = client.RoundTrip(req) 268 | doReturned = true 269 | }() 270 | 271 | Consistently(doReturned).Should(BeFalse()) 272 | Expect(doErr).ToNot(HaveOccurred()) 273 | close(done) 274 | }) 275 | }) 276 | 277 | It("sets the EndStream header for requests without a body", func() { 278 | go func() { client.RoundTrip(request) }() 279 | Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil()) 280 | mhf := getRequest(headerStream.dataWritten.Bytes()) 281 | Expect(mhf.HeadersFrame.StreamEnded()).To(BeTrue()) 282 | }) 283 | 284 | It("sets the EndStream header to false for requests with a body", func() { 285 | request.Body = &mockBody{} 286 | go func() { client.RoundTrip(request) }() 287 | Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeNil()) 288 | mhf := getRequest(headerStream.dataWritten.Bytes()) 289 | Expect(mhf.HeadersFrame.StreamEnded()).To(BeFalse()) 290 | }) 291 | 292 | Context("requests containing a Body", func() { 293 | var requestBody []byte 294 | var response *http.Response 295 | 296 | BeforeEach(func() { 297 | requestBody = []byte("request body") 298 | body := &mockBody{} 299 | body.SetData(requestBody) 300 | request.Body = body 301 | response = &http.Response{ 302 | StatusCode: 200, 303 | Header: http.Header{"Content-Length": []string{"1000"}}, 304 | } 305 | // fake a handshake 306 | client.dialOnce.Do(func() {}) 307 | session.streamsToOpen = []quic.Stream{dataStream} 308 | }) 309 | 310 | It("sends a request", func() { 311 | var doRsp *http.Response 312 | var doErr error 313 | var doReturned bool 314 | go func() { 315 | defer GinkgoRecover() 316 | doRsp, doErr = client.RoundTrip(request) 317 | Expect(doErr).ToNot(HaveOccurred()) 318 | doReturned = true 319 | }() 320 | Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) 321 | client.responses[5] <- response 322 | Eventually(func() bool { return doReturned }).Should(BeTrue()) 323 | Expect(dataStream.dataWritten.Bytes()).To(Equal(requestBody)) 324 | Expect(dataStream.closed).To(BeTrue()) 325 | Expect(request.Body.(*mockBody).closed).To(BeTrue()) 326 | Expect(doRsp).To(Equal(response)) 327 | }) 328 | 329 | It("returns the error that occurred when reading the body", func() { 330 | testErr := errors.New("testErr") 331 | request.Body.(*mockBody).readErr = testErr 332 | 333 | var doRsp *http.Response 334 | var doErr error 335 | var doReturned bool 336 | go func() { 337 | doRsp, doErr = client.RoundTrip(request) 338 | doReturned = true 339 | }() 340 | Eventually(func() bool { return doReturned }).Should(BeTrue()) 341 | Expect(doErr).To(MatchError(testErr)) 342 | Expect(doRsp).To(BeNil()) 343 | Expect(request.Body.(*mockBody).closed).To(BeTrue()) 344 | }) 345 | 346 | It("returns the error that occurred when closing the body", func() { 347 | testErr := errors.New("testErr") 348 | request.Body.(*mockBody).closeErr = testErr 349 | 350 | var doRsp *http.Response 351 | var doErr error 352 | var doReturned bool 353 | go func() { 354 | doRsp, doErr = client.RoundTrip(request) 355 | doReturned = true 356 | }() 357 | Eventually(func() bool { return doReturned }).Should(BeTrue()) 358 | Expect(doErr).To(MatchError(testErr)) 359 | Expect(doRsp).To(BeNil()) 360 | Expect(request.Body.(*mockBody).closed).To(BeTrue()) 361 | }) 362 | }) 363 | 364 | Context("gzip compression", func() { 365 | var gzippedData []byte // a gzipped foobar 366 | var response *http.Response 367 | 368 | BeforeEach(func() { 369 | var b bytes.Buffer 370 | w := gzip.NewWriter(&b) 371 | w.Write([]byte("foobar")) 372 | w.Close() 373 | gzippedData = b.Bytes() 374 | response = &http.Response{ 375 | StatusCode: 200, 376 | Header: http.Header{"Content-Length": []string{"1000"}}, 377 | } 378 | }) 379 | 380 | It("adds the gzip header to requests", func(done Done) { 381 | var doRsp *http.Response 382 | var doErr error 383 | go func() { doRsp, doErr = client.RoundTrip(request) }() 384 | 385 | Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) 386 | dataStream.dataToRead.Write(gzippedData) 387 | response.Header.Add("Content-Encoding", "gzip") 388 | client.responses[5] <- response 389 | Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) 390 | Expect(doErr).ToNot(HaveOccurred()) 391 | headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) 392 | Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) 393 | Expect(doRsp.ContentLength).To(BeEquivalentTo(-1)) 394 | Expect(doRsp.Header.Get("Content-Encoding")).To(BeEmpty()) 395 | Expect(doRsp.Header.Get("Content-Length")).To(BeEmpty()) 396 | close(dataStream.unblockRead) 397 | data := make([]byte, 6) 398 | _, err := io.ReadFull(doRsp.Body, data) 399 | Expect(err).ToNot(HaveOccurred()) 400 | Expect(data).To(Equal([]byte("foobar"))) 401 | close(done) 402 | }, 2) 403 | 404 | It("doesn't add gzip if the header disable it", func() { 405 | client.opts.DisableCompression = true 406 | var doErr error 407 | go func() { _, doErr = client.RoundTrip(request) }() 408 | 409 | Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) 410 | Expect(doErr).ToNot(HaveOccurred()) 411 | Eventually(func() []byte { return headerStream.dataWritten.Bytes() }).ShouldNot(BeEmpty()) 412 | headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) 413 | Expect(headers).ToNot(HaveKey("accept-encoding")) 414 | }) 415 | 416 | It("only decompresses the response if the response contains the right content-encoding header", func() { 417 | var doRsp *http.Response 418 | var doErr error 419 | go func() { doRsp, doErr = client.RoundTrip(request) }() 420 | 421 | Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) 422 | dataStream.dataToRead.Write([]byte("not gzipped")) 423 | client.responses[5] <- response 424 | Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) 425 | Expect(doErr).ToNot(HaveOccurred()) 426 | headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) 427 | Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) 428 | data := make([]byte, 11) 429 | doRsp.Body.Read(data) 430 | Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1)) 431 | Expect(data).To(Equal([]byte("not gzipped"))) 432 | }) 433 | 434 | It("doesn't add the gzip header for requests that have the accept-enconding set", func() { 435 | request.Header.Add("accept-encoding", "gzip") 436 | var doRsp *http.Response 437 | var doErr error 438 | go func() { doRsp, doErr = client.RoundTrip(request) }() 439 | 440 | Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) 441 | dataStream.dataToRead.Write([]byte("gzipped data")) 442 | client.responses[5] <- response 443 | Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) 444 | Expect(doErr).ToNot(HaveOccurred()) 445 | headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) 446 | Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) 447 | data := make([]byte, 12) 448 | doRsp.Body.Read(data) 449 | Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1)) 450 | Expect(data).To(Equal([]byte("gzipped data"))) 451 | }) 452 | }) 453 | 454 | Context("handling the header stream", func() { 455 | var h2framer *http2.Framer 456 | 457 | BeforeEach(func() { 458 | h2framer = http2.NewFramer(&headerStream.dataToRead, nil) 459 | client.responses[23] = make(chan *http.Response) 460 | }) 461 | 462 | It("reads header values from a response", func() { 463 | // Taken from https://http2.github.io/http2-spec/compression.html#request.examples.with.huffman.coding 464 | data := []byte{0x48, 0x03, 0x33, 0x30, 0x32, 0x58, 0x07, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x61, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x31, 0x20, 0x4f, 0x63, 0x74, 0x20, 0x32, 0x30, 0x31, 0x33, 0x20, 0x32, 0x30, 0x3a, 0x31, 0x33, 0x3a, 0x32, 0x31, 0x20, 0x47, 0x4d, 0x54, 0x6e, 0x17, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d} 465 | headerStream.dataToRead.Write([]byte{0x0, 0x0, byte(len(data)), 0x1, 0x5, 0x0, 0x0, 0x0, 23}) 466 | headerStream.dataToRead.Write(data) 467 | go client.handleHeaderStream() 468 | var rsp *http.Response 469 | Eventually(client.responses[23]).Should(Receive(&rsp)) 470 | Expect(rsp).ToNot(BeNil()) 471 | Expect(rsp.Proto).To(Equal("HTTP/2.0")) 472 | Expect(rsp.ProtoMajor).To(BeEquivalentTo(2)) 473 | Expect(rsp.StatusCode).To(BeEquivalentTo(302)) 474 | Expect(rsp.Status).To(Equal("302 Found")) 475 | Expect(rsp.Header).To(HaveKeyWithValue("Location", []string{"https://www.example.com"})) 476 | Expect(rsp.Header).To(HaveKeyWithValue("Cache-Control", []string{"private"})) 477 | }) 478 | 479 | It("errors if the H2 frame is not a HeadersFrame", func() { 480 | h2framer.WritePing(true, [8]byte{0, 0, 0, 0, 0, 0, 0, 0}) 481 | 482 | var handlerReturned bool 483 | go func() { 484 | client.handleHeaderStream() 485 | handlerReturned = true 486 | }() 487 | 488 | Eventually(client.headerErrored).Should(BeClosed()) 489 | Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame"))) 490 | Eventually(func() bool { return handlerReturned }).Should(BeTrue()) 491 | }) 492 | 493 | It("errors if it can't read the HPACK encoded header fields", func() { 494 | h2framer.WriteHeaders(http2.HeadersFrameParam{ 495 | StreamID: 23, 496 | EndHeaders: true, 497 | BlockFragment: []byte("invalid HPACK data"), 498 | }) 499 | 500 | var handlerReturned bool 501 | go func() { 502 | client.handleHeaderStream() 503 | handlerReturned = true 504 | }() 505 | 506 | Eventually(client.headerErrored).Should(BeClosed()) 507 | Expect(client.headerErr).To(MatchError(qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields"))) 508 | Eventually(func() bool { return handlerReturned }).Should(BeTrue()) 509 | }) 510 | }) 511 | }) 512 | }) 513 | --------------------------------------------------------------------------------