├── .gitignore ├── LICENSE ├── README.md ├── addr.go ├── addr_test.go ├── circle.yml ├── cmd └── echo │ └── main.go ├── conn.go ├── conn_test.go ├── copy.go ├── copy_test.go ├── error.go ├── error_test.go ├── handler.go ├── handler_test.go ├── httpx ├── README.md ├── accept.go ├── accept_test.go ├── conn.go ├── conn_test.go ├── encoding.go ├── encoding_test.go ├── handler.go ├── handler_test.go ├── httplex.go ├── httplex_test.go ├── httpxtest │ ├── test_server.go │ └── test_transport.go ├── media.go ├── media_test.go ├── proto.go ├── proto_test.go ├── proxy.go ├── proxy_test.go ├── quote.go ├── quote_test.go ├── retry.go ├── retry_test.go ├── server.go ├── server_test.go ├── transport.go └── upgrade.go ├── ip.go ├── ip_test.go ├── listen.go ├── pair.go ├── pair_darwin.go ├── pair_linux.go ├── pair_test.go ├── proxy.go ├── proxy_darwin.go ├── proxy_linux.go ├── proxy_test.go ├── server.go ├── server_test.go ├── tunnel.go ├── tunnel_test.go ├── unix.go └── unix_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | # Emacs 27 | *~ 28 | 29 | # binary 30 | /echo -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Segment 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | netx [![CircleCI](https://circleci.com/gh/segmentio/netx.svg?style=shield)](https://circleci.com/gh/segmentio/netx) [![Go Report Card](https://goreportcard.com/badge/github.com/segmentio/netx)](https://goreportcard.com/report/github.com/segmentio/netx) [![GoDoc](https://godoc.org/github.com/segmentio/netx?status.svg)](https://godoc.org/github.com/segmentio/netx) 2 | ==== 3 | 4 | Go package augmenting the standard net package with more basic building blocks 5 | for writing network applications. 6 | 7 | Motivations 8 | ----------- 9 | 10 | The intent of this package is to provide reusable tools that fit well with the 11 | standard net package and extend it with features that aren't available, like a 12 | different server implementations that have support for graceful shutdowns, and 13 | defining interfaces that allow other packages to provide plugins that work with 14 | these tools. 15 | -------------------------------------------------------------------------------- /addr.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "net" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // NetAddr is a type satisifying the net.Addr interface. 10 | type NetAddr struct { 11 | Net string 12 | Addr string 13 | } 14 | 15 | // Network returns a.Net 16 | func (a *NetAddr) Network() string { return a.Net } 17 | 18 | // String returns a.Addr 19 | func (a *NetAddr) String() string { return a.Addr } 20 | 21 | // MultiAddr is used for compound listeners returned by MultiListener. 22 | type MultiAddr []net.Addr 23 | 24 | // Network returns a comma-separated list of the addresses networks. 25 | func (addr MultiAddr) Network() string { 26 | s := make([]string, len(addr)) 27 | for i, a := range addr { 28 | s[i] = a.Network() 29 | } 30 | return strings.Join(s, ",") 31 | } 32 | 33 | // String returns a comma-separated list of the addresses string 34 | // representations. 35 | func (addr MultiAddr) String() string { 36 | s := make([]string, len(addr)) 37 | for i, a := range addr { 38 | s[i] = a.String() 39 | } 40 | return strings.Join(s, ",") 41 | } 42 | 43 | // SplitNetAddr splits the network scheme and the address in s. 44 | func SplitNetAddr(s string) (net string, addr string) { 45 | if i := strings.Index(s, "://"); i >= 0 { 46 | net, addr = s[:i], s[i+3:] 47 | } else { 48 | addr = s 49 | } 50 | return 51 | } 52 | 53 | // SplitAddrPort splits the address and port from s. 54 | // 55 | // The function is a wrapper around the standard net.SplitHostPort which 56 | // expects the port part to be a number, setting the port value to -1 if it 57 | // could not parse it. 58 | func SplitAddrPort(s string) (addr string, port int) { 59 | h, p, err := net.SplitHostPort(s) 60 | 61 | if err != nil { 62 | addr = s 63 | port = -1 64 | return 65 | } 66 | 67 | if port, err = strconv.Atoi(p); err != nil { 68 | port = -1 69 | } 70 | 71 | addr = h 72 | return 73 | } 74 | -------------------------------------------------------------------------------- /addr_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import "testing" 4 | 5 | func TestNetAddr(t *testing.T) { 6 | a := &NetAddr{ 7 | Net: "N", 8 | Addr: "A", 9 | } 10 | 11 | if s := a.Network(); s != "N" { 12 | t.Error("bad network:", s) 13 | } 14 | 15 | if s := a.String(); s != "A" { 16 | t.Error("bad address:", s) 17 | } 18 | } 19 | 20 | func TestMultiAddr(t *testing.T) { 21 | m := MultiAddr{ 22 | &NetAddr{"N1", "A1"}, 23 | &NetAddr{"N2", "A2"}, 24 | } 25 | 26 | if s := m.Network(); s != "N1,N2" { 27 | t.Error("bad network:", s) 28 | } 29 | 30 | if s := m.String(); s != "A1,A2" { 31 | t.Error("bad address:", s) 32 | } 33 | } 34 | 35 | func TestSplitNetAddr(t *testing.T) { 36 | tests := []struct { 37 | s string 38 | n string 39 | a string 40 | }{ 41 | { 42 | s: "", 43 | n: "", 44 | a: "", 45 | }, 46 | { 47 | s: "tcp://", 48 | n: "tcp", 49 | a: "", 50 | }, 51 | { 52 | s: "127.0.0.1:4242", 53 | n: "", 54 | a: "127.0.0.1:4242", 55 | }, 56 | { 57 | s: "tcp://127.0.0.1:4242", 58 | n: "tcp", 59 | a: "127.0.0.1:4242", 60 | }, 61 | } 62 | 63 | for _, test := range tests { 64 | t.Run(test.s, func(t *testing.T) { 65 | n, a := SplitNetAddr(test.s) 66 | 67 | if n != test.n { 68 | t.Error("bad network:", n) 69 | } 70 | 71 | if a != test.a { 72 | t.Error("bad address:", a) 73 | } 74 | }) 75 | } 76 | } 77 | 78 | func TestSplitAddrPort(t *testing.T) { 79 | tests := []struct { 80 | s string 81 | a string 82 | p int 83 | }{ 84 | { 85 | s: "", 86 | a: "", 87 | p: -1, 88 | }, 89 | { 90 | s: "127.0.0.1", 91 | a: "127.0.0.1", 92 | p: -1, 93 | }, 94 | { 95 | s: "127.0.0.1:4242", 96 | a: "127.0.0.1", 97 | p: 4242, 98 | }, 99 | { 100 | s: "[::1]:4242", 101 | a: "::1", 102 | p: 4242, 103 | }, 104 | { 105 | s: ":1234", 106 | a: "", 107 | p: 1234, 108 | }, 109 | { 110 | s: "127.0.0.1:http", 111 | a: "127.0.0.1", 112 | p: -1, 113 | }, 114 | } 115 | 116 | for _, test := range tests { 117 | t.Run(test.s, func(t *testing.T) { 118 | a, p := SplitAddrPort(test.s) 119 | 120 | if a != test.a { 121 | t.Error("bad address:", a) 122 | } 123 | 124 | if p != test.p { 125 | t.Error("bad port:", p) 126 | } 127 | }) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /circle.yml: -------------------------------------------------------------------------------- 1 | machine: 2 | services: 3 | - docker 4 | 5 | dependencies: 6 | override: 7 | - docker pull segment/golang:latest 8 | 9 | test: 10 | override: 11 | - > 12 | docker run 13 | $(env | grep -E '^CIRCLE_|^DOCKER_|^CIRCLECI=|^CI=' | sed 's/^/--env /g' | tr "\\n" " ") 14 | --rm 15 | --tty 16 | --interactive 17 | --name go 18 | --net host 19 | --volume /var/run/docker.sock:/run/docker.sock 20 | --volume ${GOPATH%%:*}/src:/go/src 21 | --volume ${PWD}:/go/src/github.com/${CIRCLE_PROJECT_USERNAME}/${CIRCLE_PROJECT_REPONAME} 22 | --workdir /go/src/github.com/${CIRCLE_PROJECT_USERNAME}/${CIRCLE_PROJECT_REPONAME} 23 | segment/golang:latest 24 | -------------------------------------------------------------------------------- /cmd/echo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "log" 7 | "os" 8 | "os/signal" 9 | 10 | "github.com/segmentio/netx" 11 | ) 12 | 13 | func main() { 14 | var bind string 15 | var mode string 16 | 17 | flag.StringVar(&bind, "bind", ":4242", "The network address to listen for incoming connections.") 18 | flag.StringVar(&mode, "mode", "raw", "The echo mode, either 'line' or 'raw'") 19 | flag.Parse() 20 | 21 | var handler netx.Handler 22 | 23 | switch mode { 24 | case "line": 25 | handler = netx.EchoLine 26 | case "raw": 27 | handler = netx.Echo 28 | default: 29 | log.Fatal("bad echo mode:", mode) 30 | } 31 | 32 | log.Printf("setting echo mode to '%s'", mode) 33 | log.Printf("listening on %s", bind) 34 | 35 | lstn, err := netx.Listen(bind) 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | 40 | if u, ok := lstn.(*netx.RecvUnixListener); ok { 41 | c, err := netx.DupUnix(u.UnixConn()) 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | handler = netx.NewSendUnixHandler(c, handler) 46 | } 47 | 48 | sigchan := make(chan os.Signal) 49 | signal.Notify(sigchan, os.Interrupt) 50 | 51 | ctx, cancel := context.WithCancel(context.Background()) 52 | go func() { 53 | sig := <-sigchan 54 | log.Print("signal: ", sig) 55 | cancel() 56 | }() 57 | 58 | server := &netx.Server{ 59 | Handler: handler, 60 | Context: ctx, 61 | } 62 | 63 | if err := server.Serve(lstn); err != nil { 64 | log.Fatal(err) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "net" 5 | "os" 6 | "syscall" 7 | ) 8 | 9 | // DupUnix makes a duplicate of the given unix connection. 10 | func DupUnix(conn *net.UnixConn) (*net.UnixConn, error) { 11 | c, err := dup(conn) 12 | if err != nil { 13 | return nil, err 14 | } 15 | return c.(*net.UnixConn), nil 16 | } 17 | 18 | // DupTCP makes a duplicate of the given TCP connection. 19 | func DupTCP(conn *net.TCPConn) (*net.TCPConn, error) { 20 | c, err := dup(conn) 21 | if err != nil { 22 | return nil, err 23 | } 24 | return c.(*net.TCPConn), nil 25 | } 26 | 27 | func dup(conn fileConn) (net.Conn, error) { 28 | f, err := conn.File() 29 | if err != nil { 30 | return nil, err 31 | } 32 | syscall.SetNonblock(int(f.Fd()), true) 33 | defer f.Close() 34 | return net.FileConn(f) 35 | } 36 | 37 | // BaseConn returns the base connection object of conn. 38 | // 39 | // The function works by dynamically checking whether conn implements the 40 | // `BaseConn() net.Conn` method, recursing dynamically to find the root connection 41 | // object. 42 | func BaseConn(conn net.Conn) net.Conn { 43 | for ok := true; ok; { 44 | var b baseConn 45 | if b, ok = conn.(baseConn); ok { 46 | conn = b.BaseConn() 47 | } 48 | } 49 | return conn 50 | } 51 | 52 | // BasePacketConn returns the base connection object of conn. 53 | // 54 | // The function works by dynamically checking whether conn implements the 55 | // `BasePacketConn() net.PacketConn` method, recursing dynamically to find the root connection 56 | // object. 57 | func BasePacketConn(conn net.PacketConn) net.PacketConn { 58 | for ok := true; ok; { 59 | var b basePacketConn 60 | if b, ok = conn.(basePacketConn); ok { 61 | conn = b.BasePacketConn() 62 | } 63 | } 64 | return conn 65 | } 66 | 67 | // baseConn is an interface implemented by connection wrappers wanting to expose 68 | // the underlying net.Conn object they use. 69 | type baseConn interface { 70 | BaseConn() net.Conn 71 | } 72 | 73 | // basePacketConn is an interface implemented by connection wrappers wanting to 74 | // expose the underlying net.PacketConn object they use. 75 | type basePacketConn interface { 76 | BasePacketConn() net.PacketConn 77 | } 78 | 79 | // fileConn is used internally to figure out if a net.Conn value also exposes a 80 | // File method. 81 | type fileConn interface { 82 | File() (*os.File, error) 83 | } 84 | -------------------------------------------------------------------------------- /conn_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | ) 7 | 8 | type baseTestConn struct{ net.Conn } 9 | 10 | func (c *baseTestConn) BaseConn() net.Conn { return c.Conn } 11 | 12 | func TestBaseConn(t *testing.T) { 13 | c1 := &net.TCPConn{} 14 | c2 := &baseTestConn{c1} 15 | 16 | tests := []struct { 17 | name string 18 | base net.Conn 19 | conn net.Conn 20 | }{ 21 | { 22 | name: "base:false", 23 | base: c1, 24 | conn: c1, 25 | }, 26 | { 27 | name: "base:true", 28 | base: c1, 29 | conn: c2, 30 | }, 31 | } 32 | 33 | for _, test := range tests { 34 | t.Run(test.name, func(t *testing.T) { 35 | if base := BaseConn(test.conn); base != test.base { 36 | t.Errorf("bad base conn: %#v", base) 37 | } 38 | }) 39 | } 40 | } 41 | 42 | type baseTestPacketConn struct{ net.PacketConn } 43 | 44 | func (c *baseTestPacketConn) BasePacketConn() net.PacketConn { return c.PacketConn } 45 | 46 | func TestBasePacketConn(t *testing.T) { 47 | c1 := &net.UDPConn{} 48 | c2 := &baseTestPacketConn{c1} 49 | 50 | tests := []struct { 51 | name string 52 | base net.PacketConn 53 | conn net.PacketConn 54 | }{ 55 | { 56 | name: "base:false", 57 | base: c1, 58 | conn: c1, 59 | }, 60 | { 61 | name: "base:true", 62 | base: c1, 63 | conn: c2, 64 | }, 65 | } 66 | 67 | for _, test := range tests { 68 | t.Run(test.name, func(t *testing.T) { 69 | if base := BasePacketConn(test.conn); base != test.base { 70 | t.Errorf("bad base packet conn: %#v", base) 71 | } 72 | }) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /copy.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "io" 5 | "sync" 6 | ) 7 | 8 | // Copy behaves exactly like io.Copy but uses an internal buffer pool to release 9 | // pressure off of the garbage collector. 10 | func Copy(w io.Writer, r io.Reader) (n int64, err error) { 11 | // Check for io.WriterTo and io.ReaderFrom so we don't hold a buffer during 12 | // the copy if one of these interfaces is already implemented, io.CopyBuffer 13 | // will double-check on that and fail but that's OK, the cost is likely 14 | // going to be small compared to the rest of the time spent moving bytes 15 | // from the reader to the writer. 16 | if from, ok := r.(io.WriterTo); ok { 17 | return from.WriteTo(w) 18 | } 19 | if to, ok := w.(io.ReaderFrom); ok { 20 | return to.ReadFrom(r) 21 | } 22 | buf := bufferPool.Get().(*buffer) 23 | n, err = io.CopyBuffer(w, r, buf.b) 24 | bufferPool.Put(buf) 25 | return 26 | } 27 | 28 | // buffer is a simple wrapper around []byte, it prevents Go from making a memory 29 | // allocation when converting the byte slice to an interface{}. 30 | type buffer struct{ b []byte } 31 | 32 | var bufferPool = sync.Pool{ 33 | New: func() interface{} { return &buffer{make([]byte, 8192, 8192)} }, 34 | } 35 | -------------------------------------------------------------------------------- /copy_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "io/ioutil" 7 | "testing" 8 | ) 9 | 10 | func TestCopy(t *testing.T) { 11 | t.Run("WriterTo", func(t *testing.T) { 12 | w := bytes.NewBuffer(nil) 13 | r := bytes.NewBuffer([]byte("Hello World!")) 14 | 15 | if n, err := Copy(w, r); err != nil { 16 | t.Error(err) 17 | } else if n != 12 { 18 | t.Error("bad byte count:", n) 19 | } 20 | 21 | if s := w.String(); s != "Hello World!" { 22 | t.Error("bad output:", s) 23 | } 24 | }) 25 | 26 | t.Run("ReaderFrom", func(t *testing.T) { 27 | w := bytes.NewBuffer(nil) 28 | r := &testBuffer{[]byte("Hello World!")} 29 | 30 | if n, err := Copy(w, r); err != nil { 31 | t.Error(err) 32 | } else if n != 12 { 33 | t.Error("bad byte count:", n) 34 | } 35 | 36 | if s := w.String(); s != "Hello World!" { 37 | t.Error("bad output:", s) 38 | } 39 | }) 40 | 41 | t.Run("Basic", func(t *testing.T) { 42 | c1, c2, err := ConnPair("tcp") 43 | if err != nil { 44 | t.Error(err) 45 | return 46 | } 47 | defer c1.Close() 48 | defer c2.Close() 49 | 50 | go Copy(c1, c2) 51 | 52 | if _, err := c2.Write([]byte("Hello World!")); err != nil { 53 | t.Error(err) 54 | return 55 | } 56 | c2.Close() 57 | 58 | b, err := ioutil.ReadAll(c1) 59 | if err != nil { 60 | t.Error(err) 61 | return 62 | } 63 | 64 | if s := string(b); s != "Hello World!" { 65 | t.Error("bad output:", s) 66 | } 67 | }) 68 | } 69 | 70 | type testBuffer struct{ b []byte } 71 | 72 | func (buf *testBuffer) Read(b []byte) (n int, err error) { 73 | if len(b) == 0 { 74 | return 75 | } 76 | if len(buf.b) == 0 { 77 | err = io.EOF 78 | return 79 | } 80 | n = copy(b, buf.b) 81 | buf.b = buf.b[n:] 82 | return 83 | } 84 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | // Timeout returns a new network error representing a timeout. 9 | func Timeout(msg string) net.Error { return &timeout{msg} } 10 | 11 | type timeout struct{ msg string } 12 | 13 | func (t *timeout) Error() string { return t.msg } 14 | func (t *timeout) Timeout() bool { return true } 15 | func (t *timeout) Temporary() bool { return true } 16 | 17 | // IsTemporary checks whether err is a temporary error. 18 | func IsTemporary(err error) bool { 19 | e, ok := err.(interface { 20 | Temporary() bool 21 | }) 22 | return ok && e != nil && e.Temporary() 23 | } 24 | 25 | // IsTimeout checks whether err resulted from a timeout. 26 | func IsTimeout(err error) bool { 27 | e, ok := err.(interface { 28 | Timeout() bool 29 | }) 30 | return ok && e != nil && e.Timeout() 31 | } 32 | 33 | var ( 34 | // ErrLineTooLong should be used by line-based protocol readers that detect 35 | // a line longer than they were configured to handle. 36 | ErrLineTooLong = errors.New("the line is too long") 37 | 38 | // ErrNoPipeline should be used by handlers that detect an attempt to use 39 | // pipelining when they don't support it. 40 | ErrNoPipeline = errors.New("pipelining is not supported") 41 | ) 42 | -------------------------------------------------------------------------------- /error_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | type testError struct { 9 | temporary bool 10 | timeout bool 11 | } 12 | 13 | func (e testError) Error() string { return "test error" } 14 | func (e testError) Temporary() bool { return e.temporary } 15 | func (e testError) Timeout() bool { return e.timeout } 16 | 17 | func TestIsTemporary(t *testing.T) { 18 | tests := []struct { 19 | e error 20 | x bool 21 | }{ 22 | {testError{temporary: false}, false}, 23 | {testError{temporary: true}, true}, 24 | {errors.New(""), false}, 25 | } 26 | 27 | for _, test := range tests { 28 | t.Run("", func(t *testing.T) { 29 | if x := IsTemporary(test.e); x != test.x { 30 | t.Error(test.e) 31 | } 32 | }) 33 | } 34 | } 35 | 36 | func TestIsTimeout(t *testing.T) { 37 | tests := []struct { 38 | e error 39 | x bool 40 | }{ 41 | {testError{timeout: false}, false}, 42 | {testError{timeout: true}, true}, 43 | {errors.New(""), false}, 44 | } 45 | 46 | for _, test := range tests { 47 | t.Run("", func(t *testing.T) { 48 | if x := IsTimeout(test.e); x != test.x { 49 | t.Error(test.e) 50 | } 51 | }) 52 | } 53 | 54 | } 55 | 56 | func TestTimeout(t *testing.T) { 57 | err := Timeout("something went wrong") 58 | 59 | if !IsTimeout(err) { 60 | t.Error("not a timeout error") 61 | } 62 | 63 | if !IsTemporary(err) { 64 | t.Error("not a temporary error") 65 | } 66 | 67 | if s := err.Error(); s != "something went wrong" { 68 | t.Error("bad error message:", s) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "io" 7 | "net" 8 | "time" 9 | ) 10 | 11 | // A Handler manages a network connection. 12 | // 13 | // The ServeConn method is called by a Server when a new client connection is 14 | // established, the method receives the connection and a context object that 15 | // the server may use to indicate that it's shutting down. 16 | // 17 | // Servers recover from panics that escape the handlers and log the error and 18 | // stack trace. 19 | type Handler interface { 20 | ServeConn(ctx context.Context, conn net.Conn) 21 | } 22 | 23 | // The HandlerFunc type allows simple functions to be used as connection 24 | // handlers. 25 | type HandlerFunc func(context.Context, net.Conn) 26 | 27 | // ServeConn calls f. 28 | func (f HandlerFunc) ServeConn(ctx context.Context, conn net.Conn) { 29 | f(ctx, conn) 30 | } 31 | 32 | // CloseHandler wraps handler to ensure that the connections it receives are 33 | // always closed after it returns. 34 | func CloseHandler(handler Handler) Handler { 35 | return HandlerFunc(func(ctx context.Context, conn net.Conn) { 36 | defer conn.Close() 37 | handler.ServeConn(ctx, conn) 38 | }) 39 | } 40 | 41 | var ( 42 | // Echo is the implementation of a connection handler that simply sends what 43 | // it receives back to the client. 44 | Echo Handler = HandlerFunc(echo) 45 | 46 | // EchoLine is the implementation of a connection handler that reads lines 47 | // and echos them back to the client, expecting the client not to send more 48 | // than one line before getting it echoed back. 49 | // 50 | // The implementation supports cancellations and ensures that no partial 51 | // lines are read from the connection. 52 | // 53 | // The maximum line length is limited to 8192 bytes. 54 | EchoLine Handler = HandlerFunc(echoLine) 55 | 56 | // Pass is the implementation of a connection that does nothing. 57 | Pass Handler = HandlerFunc(pass) 58 | ) 59 | 60 | func echo(ctx context.Context, conn net.Conn) { 61 | ctx, cancel := context.WithCancel(ctx) 62 | 63 | go func() { 64 | defer cancel() 65 | Copy(conn, conn) 66 | }() 67 | 68 | <-ctx.Done() 69 | conn.Close() 70 | } 71 | 72 | func echoLine(ctx context.Context, conn net.Conn) { 73 | r := bufio.NewReaderSize(conn, 8192) 74 | 75 | for { 76 | line, err := readLine(ctx, conn, r) 77 | 78 | switch err { 79 | case nil: 80 | case io.EOF, context.Canceled: 81 | return 82 | default: 83 | fatal(conn, err) 84 | } 85 | 86 | if _, err := conn.Write(line); err != nil { 87 | fatal(conn, err) 88 | } 89 | } 90 | } 91 | 92 | func pass(ctx context.Context, conn net.Conn) { 93 | // do nothing 94 | } 95 | 96 | func fatal(conn net.Conn, err error) { 97 | conn.Close() 98 | panic(err) 99 | } 100 | 101 | func readLine(ctx context.Context, conn net.Conn, r *bufio.Reader) ([]byte, error) { 102 | for { 103 | select { 104 | default: 105 | case <-ctx.Done(): 106 | return nil, ctx.Err() 107 | } 108 | 109 | conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 110 | 111 | if _, err := r.Peek(1); err != nil { 112 | if IsTimeout(err) { 113 | continue 114 | } 115 | } 116 | 117 | line, prefix, err := r.ReadLine() 118 | 119 | switch { 120 | case prefix: 121 | line, err = nil, ErrLineTooLong 122 | case err != nil: 123 | line = nil 124 | case r.Buffered() != 0: 125 | line, err = nil, ErrNoPipeline 126 | default: 127 | if line = line[:len(line)+1]; line[len(line)-1] == '\r' { 128 | line = line[:len(line)+1] 129 | } 130 | } 131 | 132 | return line, err 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /handler_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestEcho(t *testing.T) { 11 | c1, c2, err := TCPConnPair("tcp") 12 | if err != nil { 13 | t.Error(err) 14 | return 15 | } 16 | defer c1.Close() 17 | defer c2.Close() 18 | 19 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 20 | defer cancel() 21 | 22 | go CloseHandler(Echo).ServeConn(ctx, c2) 23 | 24 | if _, err := c1.Write([]byte("Hello World!\n")); err != nil { 25 | t.Error(err) 26 | return 27 | } 28 | c1.CloseWrite() 29 | 30 | b, err := ioutil.ReadAll(c1) 31 | if err != nil { 32 | t.Error(err) 33 | } 34 | if s := string(b); s != "Hello World!\n" { 35 | t.Error("bad output:", s) 36 | } 37 | } 38 | 39 | func TestEchoLine(t *testing.T) { 40 | c1, c2, err := TCPConnPair("tcp") 41 | if err != nil { 42 | t.Error(err) 43 | return 44 | } 45 | defer c1.Close() 46 | defer c2.Close() 47 | 48 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 49 | defer cancel() 50 | 51 | go CloseHandler(EchoLine).ServeConn(ctx, c2) 52 | 53 | if _, err := c1.Write([]byte("Hello World!\r\n")); err != nil { 54 | t.Error(err) 55 | return 56 | } 57 | c1.CloseWrite() 58 | 59 | b, err := ioutil.ReadAll(c1) 60 | if err != nil { 61 | t.Error(err) 62 | } 63 | if s := string(b); s != "Hello World!\r\n" { 64 | t.Error("bad output:", s) 65 | } 66 | } 67 | 68 | func TestPass(t *testing.T) { 69 | c1, c2, err := TCPConnPair("tcp") 70 | if err != nil { 71 | t.Error(err) 72 | return 73 | } 74 | defer c1.Close() 75 | defer c2.Close() 76 | 77 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 78 | defer cancel() 79 | 80 | go CloseHandler(Pass).ServeConn(ctx, c2) 81 | 82 | b, err := ioutil.ReadAll(c1) 83 | if err != nil { 84 | t.Error(err) 85 | } 86 | if s := string(b); s != "" { 87 | t.Error("bad output:", s) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /httpx/README.md: -------------------------------------------------------------------------------- 1 | httpx [![GoDoc](https://godoc.org/github.com/segmentio/netx/httpx?status.svg)](https://godoc.org/github.com/segmentio/netx/httpx) 2 | ==== 3 | 4 | Go package augmenting the standard net/http package with more basic building 5 | blocks for writing http applications. 6 | 7 | Motivations 8 | ----------- 9 | 10 | The intent of this package is to provide reusable tools that fit well with the 11 | standard net package and extend it with features that aren't available. 12 | -------------------------------------------------------------------------------- /httpx/accept.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sort" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // Negotiate performs an Accept header negotiation where the server can expose 12 | // the content in the given list of types. 13 | // 14 | // If none types match the method returns the first element in the list of 15 | // types. 16 | // 17 | // Here's an example of a typical use of this function: 18 | // 19 | // accept := Negotiate(req.Header.Get("Accept"), "image/png", "image/jpg") 20 | // 21 | func Negotiate(accept string, types ...string) string { 22 | a, _ := ParseAccept(accept) 23 | return a.Negotiate(types...) 24 | } 25 | 26 | // NegotiateEncoding performs an Accept-Encoding header negotiation where the 27 | // server can expose the content in the given list of codings. 28 | // 29 | // If none types match the method returns an empty string to indicate that the 30 | // server should not apply any encoding to its response. 31 | // 32 | // Here's an exmaple of a typical use of this function: 33 | // 34 | // encoding := NegotiateEncoding(req.Get("Accept-Encoding"), "gzip", "deflate") 35 | // 36 | func NegotiateEncoding(accept string, codings ...string) string { 37 | a, _ := ParseAcceptEncoding(accept) 38 | return a.Negotiate(codings...) 39 | } 40 | 41 | // AcceptItem is the representation of an item in an Accept header. 42 | type AcceptItem struct { 43 | typ string 44 | sub string 45 | q float32 46 | params []MediaParam 47 | extens []MediaParam 48 | } 49 | 50 | // String satisfies the fmt.Stringer interface. 51 | func (item AcceptItem) String() string { 52 | return fmt.Sprint(item) 53 | } 54 | 55 | // Format satisfies the fmt.Formatter interface. 56 | func (item AcceptItem) Format(w fmt.State, _ rune) { 57 | fmt.Fprintf(w, "%s/%s", item.typ, item.sub) 58 | 59 | for _, p := range item.params { 60 | fmt.Fprintf(w, ";%v", p) 61 | } 62 | 63 | fmt.Fprintf(w, ";q=%.1f", item.q) 64 | 65 | for _, e := range item.extens { 66 | fmt.Fprintf(w, ";%v", e) 67 | } 68 | } 69 | 70 | // ParseAcceptItem parses a single item in an Accept header. 71 | func ParseAcceptItem(s string) (item AcceptItem, err error) { 72 | var r MediaRange 73 | 74 | if r, err = ParseMediaRange(s); err != nil { 75 | err = errorInvalidAccept(s) 76 | return 77 | } 78 | 79 | item = AcceptItem{ 80 | typ: r.typ, 81 | sub: r.sub, 82 | q: 1.0, 83 | params: r.params, 84 | } 85 | 86 | for i, p := range r.params { 87 | if p.name == "q" { 88 | item.q = q(p.value) 89 | if item.params = r.params[:i]; len(item.params) == 0 { 90 | item.params = nil 91 | } 92 | if item.extens = r.params[i+1:]; len(item.extens) == 0 { 93 | item.extens = nil 94 | } 95 | break 96 | } 97 | } 98 | 99 | return 100 | } 101 | 102 | // Accept is the representation of an Accept header. 103 | type Accept []AcceptItem 104 | 105 | // String satisfies the fmt.Stringer interface. 106 | func (accept Accept) String() string { 107 | return fmt.Sprint(accept) 108 | } 109 | 110 | // Format satisfies the fmt.Formatter interface. 111 | func (accept Accept) Format(w fmt.State, r rune) { 112 | for i, item := range accept { 113 | if i != 0 { 114 | fmt.Fprint(w, ", ") 115 | } 116 | item.Format(w, r) 117 | } 118 | } 119 | 120 | // Negotiate performs an Accept header negotiation where the server can expose 121 | // the content in the given list of types. 122 | // 123 | // If none types match the method returns the first element in the list of 124 | // types. 125 | func (accept Accept) Negotiate(types ...string) string { 126 | if len(types) == 0 { 127 | return "" 128 | } 129 | for _, acc := range accept { 130 | for _, typ := range types { 131 | t2, err := ParseMediaType(typ) 132 | if err != nil { 133 | continue 134 | } 135 | t1 := MediaType{ 136 | typ: acc.typ, 137 | sub: acc.sub, 138 | } 139 | if t1.Contains(t2) { 140 | return typ 141 | } 142 | } 143 | } 144 | return types[0] 145 | } 146 | 147 | // Less satisfies sort.Interface. 148 | func (accept Accept) Less(i int, j int) bool { 149 | ai, aj := &accept[i], &accept[j] 150 | 151 | if ai.q > aj.q { 152 | return true 153 | } 154 | 155 | if ai.q < aj.q { 156 | return false 157 | } 158 | 159 | if ai.typ == aj.typ && ai.sub == aj.sub { 160 | n1 := len(ai.params) + len(ai.extens) 161 | n2 := len(aj.params) + len(aj.extens) 162 | return n1 > n2 163 | } 164 | 165 | if ai.typ != aj.typ { 166 | return mediaTypeLess(ai.typ, aj.typ) 167 | } 168 | 169 | return mediaTypeLess(ai.sub, aj.sub) 170 | } 171 | 172 | // Swap satisfies sort.Interface. 173 | func (accept Accept) Swap(i int, j int) { 174 | accept[i], accept[j] = accept[j], accept[i] 175 | } 176 | 177 | // Len satisfies sort.Interface. 178 | func (accept Accept) Len() int { 179 | return len(accept) 180 | } 181 | 182 | // ParseAccept parses the value of an Accept header from s. 183 | func ParseAccept(s string) (accept Accept, err error) { 184 | var head string 185 | var tail = s 186 | 187 | for len(tail) != 0 { 188 | var item AcceptItem 189 | head, tail = splitTrimOWS(tail, ',') 190 | 191 | if item, err = ParseAcceptItem(head); err != nil { 192 | return 193 | } 194 | 195 | accept = append(accept, item) 196 | } 197 | 198 | sort.Sort(accept) 199 | return 200 | } 201 | 202 | // AcceptEncodingItem represents a single item in an Accept-Encoding header. 203 | type AcceptEncodingItem struct { 204 | coding string 205 | q float32 206 | } 207 | 208 | // String satisfies the fmt.Stringer interface. 209 | func (item AcceptEncodingItem) String() string { 210 | return fmt.Sprint(item) 211 | } 212 | 213 | // Format satisfies the fmt.Formatter interface. 214 | func (item AcceptEncodingItem) Format(w fmt.State, _ rune) { 215 | fmt.Fprintf(w, "%s;q=%.1f", item.coding, item.q) 216 | } 217 | 218 | // ParseAcceptEncodingItem parses a single item in an Accept-Encoding header. 219 | func ParseAcceptEncodingItem(s string) (item AcceptEncodingItem, err error) { 220 | if i := strings.IndexByte(s, ';'); i < 0 { 221 | item.coding = s 222 | item.q = 1.0 223 | } else { 224 | var p MediaParam 225 | 226 | if p, err = ParseMediaParam(trimOWS(s[i+1:])); err != nil { 227 | goto error 228 | } 229 | if p.name != "q" { 230 | goto error 231 | } 232 | 233 | item.coding = s[:i] 234 | item.q = q(p.value) 235 | } 236 | if !isToken(item.coding) { 237 | goto error 238 | } 239 | return 240 | error: 241 | err = errorInvalidAcceptEncoding(s) 242 | return 243 | } 244 | 245 | // AcceptEncoding respresents an Accept-Encoding header. 246 | type AcceptEncoding []AcceptEncodingItem 247 | 248 | // String satisfies the fmt.Stringer interface. 249 | func (accept AcceptEncoding) String() string { 250 | return fmt.Sprint(accept) 251 | } 252 | 253 | // Format satisfies the fmt.Formatter interface. 254 | func (accept AcceptEncoding) Format(w fmt.State, r rune) { 255 | for i, item := range accept { 256 | if i != 0 { 257 | fmt.Fprint(w, ", ") 258 | } 259 | item.Format(w, r) 260 | } 261 | } 262 | 263 | // Negotiate performs an Accept-Encoding header negotiation where the server can 264 | // expose the content in the given list of codings. 265 | // 266 | // If none types match the method returns an empty string to indicate that the 267 | // server should not apply any encoding to its response. 268 | func (accept AcceptEncoding) Negotiate(codings ...string) string { 269 | for _, acc := range accept { 270 | for _, coding := range codings { 271 | if coding == acc.coding { 272 | return coding 273 | } 274 | } 275 | } 276 | return "" 277 | } 278 | 279 | // Less satisfies sort.Interface. 280 | func (accept AcceptEncoding) Less(i int, j int) bool { 281 | ai, aj := &accept[i], &accept[j] 282 | return ai.q > aj.q || (ai.q == aj.q && mediaTypeLess(ai.coding, aj.coding)) 283 | } 284 | 285 | // Swap satisfies sort.Interface. 286 | func (accept AcceptEncoding) Swap(i int, j int) { 287 | accept[i], accept[j] = accept[j], accept[i] 288 | } 289 | 290 | // Len satisfies sort.Interface. 291 | func (accept AcceptEncoding) Len() int { 292 | return len(accept) 293 | } 294 | 295 | // ParseAcceptEncoding parses an Accept-Encoding header value from s. 296 | func ParseAcceptEncoding(s string) (accept AcceptEncoding, err error) { 297 | var head string 298 | var tail = s 299 | 300 | for len(tail) != 0 { 301 | var item AcceptEncodingItem 302 | head, tail = splitTrimOWS(tail, ',') 303 | 304 | if item, err = ParseAcceptEncodingItem(head); err != nil { 305 | return 306 | } 307 | 308 | accept = append(accept, item) 309 | } 310 | 311 | sort.Sort(accept) 312 | return 313 | } 314 | 315 | func errorInvalidAccept(s string) error { 316 | return errors.New("invalid Accept header value: " + s) 317 | } 318 | 319 | func errorInvalidAcceptEncoding(s string) error { 320 | return errors.New("invalid Accept-Encoding header value: " + s) 321 | } 322 | 323 | func q(s string) float32 { 324 | q, _ := strconv.ParseFloat(s, 32) 325 | return float32(q) 326 | } 327 | -------------------------------------------------------------------------------- /httpx/accept_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestParseAcceptItemSuccess(t *testing.T) { 10 | tests := []struct { 11 | s string 12 | a AcceptItem 13 | }{ 14 | { 15 | s: `text/html`, 16 | a: AcceptItem{ 17 | typ: "text", 18 | sub: "html", 19 | q: 1.0, 20 | }, 21 | }, 22 | { 23 | s: `text/*;q=0`, 24 | a: AcceptItem{ 25 | typ: "text", 26 | sub: "*", 27 | }, 28 | }, 29 | { 30 | s: `text/html; param="Hello World!"; q=1.0; ext=value`, 31 | a: AcceptItem{ 32 | typ: "text", 33 | sub: "html", 34 | q: 1.0, 35 | params: []MediaParam{{"param", "Hello World!"}}, 36 | extens: []MediaParam{{"ext", "value"}}, 37 | }, 38 | }, 39 | } 40 | 41 | for _, test := range tests { 42 | t.Run(test.a.String(), func(t *testing.T) { 43 | a, err := ParseAcceptItem(test.s) 44 | 45 | if err != nil { 46 | t.Error(err) 47 | } 48 | 49 | if !reflect.DeepEqual(a, test.a) { 50 | t.Error(a) 51 | } 52 | }) 53 | } 54 | 55 | } 56 | 57 | func TestParseAcceptItemFailure(t *testing.T) { 58 | tests := []struct { 59 | s string 60 | }{ 61 | {``}, // empty string 62 | {`garbage`}, // garbage 63 | } 64 | 65 | for _, test := range tests { 66 | t.Run(test.s, func(t *testing.T) { 67 | if a, err := ParseAcceptItem(test.s); err == nil { 68 | t.Error(a) 69 | } 70 | }) 71 | } 72 | } 73 | 74 | func TestParseAcceptSuccess(t *testing.T) { 75 | tests := []struct { 76 | s string 77 | a Accept 78 | }{ 79 | { 80 | s: `text/html`, 81 | a: Accept{ 82 | { 83 | typ: "text", 84 | sub: "html", 85 | q: 1.0, 86 | }, 87 | }, 88 | }, 89 | { 90 | s: `text/*; q=0, text/html; param="Hello World!"; q=1.0; ext=value`, 91 | a: Accept{ 92 | { 93 | typ: "text", 94 | sub: "html", 95 | q: 1.0, 96 | params: []MediaParam{{"param", "Hello World!"}}, 97 | extens: []MediaParam{{"ext", "value"}}, 98 | }, 99 | { 100 | typ: "text", 101 | sub: "*", 102 | }, 103 | }, 104 | }, 105 | } 106 | 107 | for _, test := range tests { 108 | t.Run(test.a.String(), func(t *testing.T) { 109 | a, err := ParseAccept(test.s) 110 | 111 | if err != nil { 112 | t.Error(err) 113 | } 114 | 115 | if !reflect.DeepEqual(a, test.a) { 116 | t.Error(a) 117 | } 118 | }) 119 | } 120 | } 121 | 122 | func TestParseAcceptFailure(t *testing.T) { 123 | tests := []struct { 124 | s string 125 | }{ 126 | {`garbage`}, // garbage 127 | } 128 | 129 | for _, test := range tests { 130 | t.Run(test.s, func(t *testing.T) { 131 | if a, err := ParseAccept(test.s); err == nil { 132 | t.Error(a) 133 | } 134 | }) 135 | } 136 | } 137 | 138 | func TestAcceptSort(t *testing.T) { 139 | a, err := ParseAccept(`text/*, image/*;q=0.5, text/plain;q=1.0, text/html, text/html;level=1, */*`) 140 | 141 | if err != nil { 142 | t.Error(err) 143 | return 144 | } 145 | 146 | if !reflect.DeepEqual(a, Accept{ 147 | {typ: "text", sub: "html", q: 1, params: []MediaParam{{"level", "1"}}}, 148 | {typ: "text", sub: "html", q: 1}, 149 | {typ: "text", sub: "plain", q: 1}, 150 | {typ: "text", sub: "*", q: 1}, 151 | {typ: "*", sub: "*", q: 1}, 152 | {typ: "image", sub: "*", q: 0.5}, 153 | }) { 154 | t.Error(a) 155 | } 156 | } 157 | 158 | func TestAcceptNegotiate(t *testing.T) { 159 | tests := []struct { 160 | t []string 161 | s string 162 | }{ 163 | { 164 | t: nil, 165 | s: "", 166 | }, 167 | { 168 | t: []string{"text/html"}, 169 | s: "text/html", 170 | }, 171 | { 172 | t: []string{"application/json"}, 173 | s: "application/json", 174 | }, 175 | { 176 | t: []string{"application/msgpack"}, 177 | s: "application/msgpack", 178 | }, 179 | { 180 | t: []string{"application/json", "application/msgpack"}, 181 | s: "application/msgpack", 182 | }, 183 | { 184 | t: []string{"application/msgpack", "application/json"}, 185 | s: "application/msgpack", 186 | }, 187 | { 188 | t: []string{"msgpack", "application/json"}, // first type is bad 189 | s: "application/json", 190 | }, 191 | } 192 | 193 | for _, test := range tests { 194 | t.Run(test.s, func(t *testing.T) { 195 | if s := Negotiate("application/msgpack;q=1.0, application/json;q=0.5", test.t...); s != test.s { 196 | t.Error(s) 197 | } 198 | }) 199 | } 200 | } 201 | 202 | func TestAcceptNegotiateEncoding(t *testing.T) { 203 | tests := []struct { 204 | c []string 205 | s string 206 | }{ 207 | { 208 | c: nil, 209 | s: "", 210 | }, 211 | { 212 | c: []string{"gzip"}, 213 | s: "gzip", 214 | }, 215 | { 216 | c: []string{"deflate"}, 217 | s: "deflate", 218 | }, 219 | { 220 | c: []string{"deflate", "gzip"}, 221 | s: "gzip", 222 | }, 223 | } 224 | 225 | for _, test := range tests { 226 | t.Run(strings.Join(test.c, ","), func(t *testing.T) { 227 | if s := NegotiateEncoding("gzip;q=1.0, deflate;q=0.5", test.c...); s != test.s { 228 | t.Error(s) 229 | } 230 | }) 231 | } 232 | } 233 | 234 | func TestParseAcceptEncodingItemSuccess(t *testing.T) { 235 | tests := []struct { 236 | s string 237 | a AcceptEncodingItem 238 | }{ 239 | { 240 | s: `gzip`, 241 | a: AcceptEncodingItem{ 242 | coding: "gzip", 243 | q: 1.0, 244 | }, 245 | }, 246 | { 247 | s: `gzip;q=0`, 248 | a: AcceptEncodingItem{ 249 | coding: "gzip", 250 | }, 251 | }, 252 | } 253 | 254 | for _, test := range tests { 255 | t.Run(test.a.String(), func(t *testing.T) { 256 | a, err := ParseAcceptEncodingItem(test.s) 257 | 258 | if err != nil { 259 | t.Error(err) 260 | } 261 | 262 | if !reflect.DeepEqual(a, test.a) { 263 | t.Error(a) 264 | } 265 | }) 266 | } 267 | 268 | } 269 | 270 | func TestParseAcceptEncodingItemFailure(t *testing.T) { 271 | tests := []struct { 272 | s string 273 | }{ 274 | {``}, // empty string 275 | {`q=`}, // missing value 276 | {`gzip;key=value`}, // not q=X 277 | } 278 | 279 | for _, test := range tests { 280 | t.Run(test.s, func(t *testing.T) { 281 | if a, err := ParseAcceptEncodingItem(test.s); err == nil { 282 | t.Error(a) 283 | } 284 | }) 285 | } 286 | } 287 | 288 | func TestParseAcceptEncodingSuccess(t *testing.T) { 289 | tests := []struct { 290 | s string 291 | a AcceptEncoding 292 | }{ 293 | { 294 | s: `gzip;q=1.0, *;q=0, identity; q=0.5`, 295 | a: AcceptEncoding{ 296 | { 297 | coding: "gzip", 298 | q: 1.0, 299 | }, 300 | { 301 | coding: "identity", 302 | q: 0.5, 303 | }, 304 | { 305 | coding: "*", 306 | q: 0.0, 307 | }, 308 | }, 309 | }, 310 | } 311 | 312 | for _, test := range tests { 313 | t.Run(test.a.String(), func(t *testing.T) { 314 | a, err := ParseAcceptEncoding(test.s) 315 | 316 | if err != nil { 317 | t.Error(err) 318 | } 319 | 320 | if !reflect.DeepEqual(a, test.a) { 321 | t.Error(a) 322 | } 323 | }) 324 | } 325 | } 326 | 327 | func TestParseAcceptEncodingFailure(t *testing.T) { 328 | tests := []struct { 329 | s string 330 | }{ 331 | {`gzip;`}, // missing q=X 332 | {`gzip;key=value`}, // not q=X 333 | } 334 | 335 | for _, test := range tests { 336 | t.Run(test.s, func(t *testing.T) { 337 | if a, err := ParseAcceptEncoding(test.s); err == nil { 338 | t.Error(a) 339 | } 340 | }) 341 | } 342 | } 343 | -------------------------------------------------------------------------------- /httpx/conn.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "io" 7 | "net" 8 | "net/http" 9 | "time" 10 | 11 | "github.com/segmentio/netx" 12 | ) 13 | 14 | // ConnTransport is a http.RoundTripper that works on a pre-established network 15 | // connection. 16 | type ConnTransport struct { 17 | // Conn is the connection to use to send requests and receive responses. 18 | Conn net.Conn 19 | 20 | // Buffer may be set to a bufio.ReadWriter which will be used to buffer all 21 | // I/O done on the connection. 22 | Buffer *bufio.ReadWriter 23 | 24 | // DialContext is used to open a connection when Conn is set to nil. 25 | // If the function is nil the transport uses a default dialer. 26 | DialContext func(context.Context, string, string) (net.Conn, error) 27 | 28 | // ResponseHeaderTimeout, if non-zero, specifies the amount of time to wait 29 | // for a server's response headers after fully writing the request (including 30 | // its body, if any). This time does not include the time to read the response 31 | // body. 32 | ResponseHeaderTimeout time.Duration 33 | 34 | // MaxResponseHeaderBytes specifies a limit on how many response bytes are 35 | // allowed in the server's response header. 36 | // 37 | // Zero means to use a default limit. 38 | MaxResponseHeaderBytes int 39 | } 40 | 41 | // the default dialer used by ConnTransport when neither Conn nor DialContext is 42 | // set. 43 | var dialer net.Dialer 44 | 45 | // RoundTrip satisfies the http.RoundTripper interface. 46 | func (t *ConnTransport) RoundTrip(req *http.Request) (res *http.Response, err error) { 47 | var ctx = req.Context() 48 | var conn net.Conn 49 | var dial func(context.Context, string, string) (net.Conn, error) 50 | 51 | if conn = t.Conn; conn == nil { 52 | if dial = t.DialContext; dial == nil { 53 | dial = dialer.DialContext 54 | } 55 | if conn, err = dial(ctx, "tcp", req.Host); err != nil { 56 | return 57 | } 58 | } 59 | 60 | var c = &connReader{Conn: conn, limit: -1} 61 | var b = t.Buffer 62 | var r *bufio.Reader 63 | var w *bufio.Writer 64 | 65 | if b != nil && b.Reader != nil { 66 | r = b.Reader 67 | r.Reset(c) 68 | } else { 69 | r = bufio.NewReader(c) 70 | } 71 | 72 | if b != nil && b.Writer != nil { 73 | w = b.Writer 74 | w.Reset(c) 75 | } else { 76 | w = bufio.NewWriter(c) 77 | } 78 | 79 | if err = req.Write(w); err != nil { 80 | return 81 | } 82 | if err = w.Flush(); err != nil { 83 | return 84 | } 85 | 86 | switch limit := t.MaxResponseHeaderBytes; { 87 | case limit == 0: 88 | c.limit = http.DefaultMaxHeaderBytes 89 | case limit > 0: 90 | c.limit = limit 91 | } 92 | 93 | if timeout := t.ResponseHeaderTimeout; timeout != 0 { 94 | c.SetReadDeadline(time.Now().Add(timeout)) 95 | } 96 | if res, err = http.ReadResponse(r, req); err != nil { 97 | return 98 | } 99 | 100 | if dial != nil { 101 | res.Body = struct { 102 | io.Reader 103 | io.Closer 104 | }{ 105 | Reader: res.Body, 106 | Closer: conn, 107 | } 108 | } 109 | 110 | c.limit = -1 111 | c.SetReadDeadline(time.Time{}) 112 | return 113 | } 114 | 115 | // connReader is a net.Conn wrappers used by the HTTP server to limit the size 116 | // of the request header. 117 | // 118 | // A cancel function can also be set on the reader, it is expected to be used to 119 | // cancel the associated request context to notify the handlers that a client is 120 | // gone and the request can be aborted. 121 | type connReader struct { 122 | net.Conn 123 | limit int 124 | cancel context.CancelFunc 125 | } 126 | 127 | // Read satsifies the io.Reader interface. 128 | func (c *connReader) Read(b []byte) (n int, err error) { 129 | if c.limit == 0 { 130 | err = io.EOF 131 | return 132 | } 133 | 134 | n1 := len(b) 135 | n2 := c.limit 136 | 137 | if n2 > 0 && n1 > n2 { 138 | n1 = n2 139 | } 140 | 141 | if n, err = c.Conn.Read(b[:n1]); n > 0 && n2 > 0 { 142 | c.limit -= n 143 | } 144 | 145 | if err != nil && !netx.IsTemporary(err) { 146 | c.cancel() 147 | } 148 | 149 | return 150 | } 151 | -------------------------------------------------------------------------------- /httpx/conn_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "net" 7 | "net/http" 8 | "testing" 9 | "time" 10 | 11 | "github.com/segmentio/netx/httpx/httpxtest" 12 | ) 13 | 14 | func TestConnTransportDefault(t *testing.T) { 15 | httpxtest.TestTransport(t, func() http.RoundTripper { 16 | return &ConnTransport{} 17 | }) 18 | } 19 | 20 | func TestConnTransportConfigured(t *testing.T) { 21 | httpxtest.TestTransport(t, func() http.RoundTripper { 22 | return &ConnTransport{ 23 | Buffer: &bufio.ReadWriter{ 24 | Reader: bufio.NewReader(nil), 25 | Writer: bufio.NewWriter(nil), 26 | }, 27 | ResponseHeaderTimeout: 100 * time.Millisecond, 28 | MaxResponseHeaderBytes: 65536, 29 | } 30 | }) 31 | } 32 | 33 | func TestConnReader(t *testing.T) { 34 | c1, c2 := net.Pipe() 35 | defer c1.Close() 36 | defer c2.Close() 37 | 38 | ok := false 39 | cr := &connReader{ 40 | Conn: c1, 41 | limit: 10, 42 | cancel: func() { ok = true }, 43 | } 44 | 45 | go c2.Write([]byte("Hello World!")) 46 | var b [16]byte 47 | 48 | n, err := cr.Read(b[:]) 49 | if err != nil { 50 | t.Error(err) 51 | return 52 | } 53 | 54 | if n > 10 { 55 | t.Error("too many bytes read from c1:", n) 56 | return 57 | } 58 | 59 | if _, err := cr.Read(b[:]); err != io.EOF { 60 | t.Error("expected io.EOF but got", err) 61 | return 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /httpx/encoding.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "compress/flate" 5 | "compress/gzip" 6 | "compress/zlib" 7 | "io" 8 | "net/http" 9 | "strings" 10 | ) 11 | 12 | // ContentEncoding is an interfae implemented by types that provide the 13 | // implementation of a content encoding for HTTP responses. 14 | type ContentEncoding interface { 15 | // Coding returns the format in which the content encoding's writers 16 | // can encode HTTP responses. 17 | Coding() string 18 | 19 | // NewReader wraps r in a reader that supports the content encoding's 20 | // format. 21 | NewReader(r io.Reader) (io.ReadCloser, error) 22 | 23 | // NewWriter wraps w in a writer that applies the content encoding's 24 | // format. 25 | NewWriter(w io.Writer) (io.WriteCloser, error) 26 | } 27 | 28 | // NewEncodingTransport wraps transport to support decoding the responses with 29 | // specified content encodings. 30 | // 31 | // If contentEncodings is nil (no arguments were passed) the returned transport 32 | // uses DefaultEncodings. 33 | func NewEncodingTransport(transport http.RoundTripper, contentEncodings ...ContentEncoding) http.RoundTripper { 34 | if contentEncodings == nil { 35 | contentEncodings = defaultEncodings() 36 | } 37 | 38 | encodings := make(map[string]ContentEncoding, len(contentEncodings)) 39 | codings := make([]string, 0, len(contentEncodings)) 40 | 41 | for _, encoding := range contentEncodings { 42 | coding := encoding.Coding() 43 | codings = append(codings, coding) 44 | encodings[coding] = encoding 45 | } 46 | 47 | acceptEncoding := strings.Join(codings, ", ") 48 | 49 | return RoundTripperFunc(func(req *http.Request) (res *http.Response, err error) { 50 | req.Header["Accept-Encoding"] = []string{acceptEncoding} 51 | 52 | if res, err = transport.RoundTrip(req); err != nil { 53 | return 54 | } 55 | 56 | coding := res.Header["Content-Encoding"] 57 | 58 | if len(coding) != 0 { 59 | encoding := encodings[coding[0]] 60 | 61 | if encoding != nil { 62 | var decoder io.ReadCloser 63 | 64 | if decoder, err = encoding.NewReader(res.Body); err != nil { 65 | res.Body.Close() 66 | return 67 | } 68 | 69 | req.ContentLength = -1 70 | res.Body = &contentEncodingReader{ 71 | decoder: decoder, 72 | body: res.Body, 73 | } 74 | 75 | delete(res.Header, "Content-Length") 76 | delete(res.Header, "Content-Encoding") 77 | } 78 | } 79 | 80 | return 81 | }) 82 | } 83 | 84 | type contentEncodingReader struct { 85 | decoder io.ReadCloser 86 | body io.ReadCloser 87 | } 88 | 89 | func (r *contentEncodingReader) Read(b []byte) (int, error) { 90 | return r.decoder.Read(b) 91 | } 92 | 93 | func (r *contentEncodingReader) Close() error { 94 | r.decoder.Close() 95 | r.body.Close() 96 | return nil 97 | } 98 | 99 | // NewEncodingHandler wraps handler to support encoding the responses by 100 | // negotiating the coding based on the given list of supported content encodings. 101 | // 102 | // If contentEncodings is nil (no arguments were passed) the returned handler 103 | // uses DefaultEncodings. 104 | func NewEncodingHandler(handler http.Handler, contentEncodings ...ContentEncoding) http.Handler { 105 | if contentEncodings == nil { 106 | contentEncodings = defaultEncodings() 107 | } 108 | 109 | encodings := make(map[string]ContentEncoding, len(contentEncodings)) 110 | codings := make([]string, 0, len(contentEncodings)) 111 | 112 | for _, encoding := range contentEncodings { 113 | coding := encoding.Coding() 114 | encodings[coding] = encoding 115 | codings = append(codings, coding) 116 | } 117 | 118 | return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 119 | accept := req.Header["Accept-Encoding"] 120 | 121 | if len(accept) != 0 { 122 | coding := NegotiateEncoding(accept[0], codings...) 123 | 124 | if len(coding) != 0 { 125 | if w, err := encodings[coding].NewWriter(res); err == nil { 126 | defer w.Close() 127 | 128 | h := res.Header() 129 | h["Content-Encoding"] = []string{coding} 130 | delete(h, "Content-Length") 131 | 132 | res = &contentEncodingWriter{res, w} 133 | delete(req.Header, "Accept-Encoding") 134 | } 135 | } 136 | } 137 | 138 | handler.ServeHTTP(res, req) 139 | }) 140 | } 141 | 142 | type contentEncodingWriter struct { 143 | http.ResponseWriter 144 | io.Writer 145 | } 146 | 147 | func (w *contentEncodingWriter) Write(b []byte) (int, error) { 148 | return w.Writer.Write(b) 149 | } 150 | 151 | // DeflateEncoding implements the ContentEncoding interface for the deflate 152 | // algorithm. 153 | type DeflateEncoding struct { 154 | Level int 155 | } 156 | 157 | // NewDeflateEncoding creates a new content encoding with the default compression 158 | // level. 159 | func NewDeflateEncoding() *DeflateEncoding { 160 | return NewDeflateEncodingLevel(flate.DefaultCompression) 161 | } 162 | 163 | // NewDeflateEncodingLevel creates a new content encoding with the given 164 | // compression level. 165 | func NewDeflateEncodingLevel(level int) *DeflateEncoding { 166 | return &DeflateEncoding{ 167 | Level: level, 168 | } 169 | } 170 | 171 | // Coding satsifies the ContentEncoding interface. 172 | func (e *DeflateEncoding) Coding() string { 173 | return "deflate" 174 | } 175 | 176 | // NewReader satisfies the ContentEncoding interface. 177 | func (e *DeflateEncoding) NewReader(r io.Reader) (io.ReadCloser, error) { 178 | return flate.NewReader(r), nil 179 | } 180 | 181 | // NewWriter satsifies the ContentEncoding interface. 182 | func (e *DeflateEncoding) NewWriter(w io.Writer) (io.WriteCloser, error) { 183 | return flate.NewWriter(w, e.Level) 184 | } 185 | 186 | // GzipEncoding implements the ContentEncoding interface for the gzip 187 | // algorithm. 188 | type GzipEncoding struct { 189 | Level int 190 | } 191 | 192 | // NewGzipEncoding creates a new content encoding with the default compression 193 | // level. 194 | func NewGzipEncoding() *GzipEncoding { 195 | return NewGzipEncodingLevel(gzip.DefaultCompression) 196 | } 197 | 198 | // NewGzipEncodingLevel creates a new content encoding with the given 199 | // compression level. 200 | func NewGzipEncodingLevel(level int) *GzipEncoding { 201 | return &GzipEncoding{ 202 | Level: level, 203 | } 204 | } 205 | 206 | // Coding satsifies the ContentEncoding interface. 207 | func (e *GzipEncoding) Coding() string { 208 | return "gzip" 209 | } 210 | 211 | // NewReader satisfies the ContentEncoding interface. 212 | func (e *GzipEncoding) NewReader(r io.Reader) (io.ReadCloser, error) { 213 | return gzip.NewReader(r) 214 | } 215 | 216 | // NewWriter satsifies the ContentEncoding interface. 217 | func (e *GzipEncoding) NewWriter(w io.Writer) (io.WriteCloser, error) { 218 | return gzip.NewWriterLevel(w, e.Level) 219 | } 220 | 221 | // ZlibEncoding implements the ContentEncoding interface for the zlib 222 | // algorithm. 223 | type ZlibEncoding struct { 224 | Level int 225 | } 226 | 227 | // NewZlibEncoding creates a new content encoding with the default compression 228 | // level. 229 | func NewZlibEncoding() *ZlibEncoding { 230 | return NewZlibEncodingLevel(zlib.DefaultCompression) 231 | } 232 | 233 | // NewZlibEncodingLevel creates a new content encoding with the given 234 | // compression level. 235 | func NewZlibEncodingLevel(level int) *ZlibEncoding { 236 | return &ZlibEncoding{ 237 | Level: level, 238 | } 239 | } 240 | 241 | // Coding satsifies the ContentEncoding interface. 242 | func (e *ZlibEncoding) Coding() string { 243 | return "zlib" 244 | } 245 | 246 | // NewReader satisfies the ContentEncoding interface. 247 | func (e *ZlibEncoding) NewReader(r io.Reader) (io.ReadCloser, error) { 248 | return zlib.NewReader(r) 249 | } 250 | 251 | // NewWriter satsifies the ContentEncoding interface. 252 | func (e *ZlibEncoding) NewWriter(w io.Writer) (io.WriteCloser, error) { 253 | return zlib.NewWriterLevel(w, e.Level) 254 | } 255 | 256 | func defaultEncodings() []ContentEncoding { 257 | return []ContentEncoding{ 258 | NewGzipEncoding(), 259 | NewZlibEncoding(), 260 | NewDeflateEncoding(), 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /httpx/encoding_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "compress/flate" 5 | "compress/gzip" 6 | "compress/zlib" 7 | "io" 8 | "io/ioutil" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | ) 13 | 14 | func TestEncodingHandler(t *testing.T) { 15 | tests := []struct { 16 | coding string 17 | newReader func(io.Reader) io.ReadCloser 18 | }{ 19 | { 20 | coding: "deflate", 21 | newReader: flate.NewReader, 22 | }, 23 | { 24 | coding: "gzip", 25 | newReader: func(r io.Reader) io.ReadCloser { 26 | z, _ := gzip.NewReader(r) 27 | return z 28 | }, 29 | }, 30 | { 31 | coding: "zlib", 32 | newReader: func(r io.Reader) io.ReadCloser { 33 | z, _ := zlib.NewReader(r) 34 | return z 35 | }, 36 | }, 37 | } 38 | 39 | h := NewEncodingHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 40 | res.Write([]byte("Hello World!")) 41 | })) 42 | 43 | for _, test := range tests { 44 | t.Run(test.coding, func(t *testing.T) { 45 | req := httptest.NewRequest("GET", "/", nil) 46 | req.Header.Set("Accept-Encoding", test.coding) 47 | 48 | res := httptest.NewRecorder() 49 | 50 | h.ServeHTTP(res, req) 51 | res.Flush() 52 | 53 | r := test.newReader(res.Body) 54 | b, _ := ioutil.ReadAll(r) 55 | 56 | if res.Code != http.StatusOK { 57 | t.Error("bad status:", res.Code) 58 | } 59 | if coding := res.HeaderMap.Get("Content-Encoding"); coding != test.coding { 60 | t.Error("bad content encoding:", coding) 61 | } 62 | if s := string(b); s != "Hello World!" { 63 | t.Error("bad content:", s) 64 | } 65 | }) 66 | } 67 | } 68 | 69 | func TestEncodingTransport(t *testing.T) { 70 | tests := []struct { 71 | encoding ContentEncoding 72 | }{ 73 | {NewDeflateEncoding()}, 74 | {NewGzipEncoding()}, 75 | {NewZlibEncoding()}, 76 | } 77 | 78 | for _, test := range tests { 79 | t.Run(test.encoding.Coding(), func(t *testing.T) { 80 | server := httptest.NewServer(NewEncodingHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 81 | res.Write([]byte("Hello World!")) 82 | }))) 83 | defer server.Close() 84 | 85 | client := http.Client{ 86 | Transport: NewEncodingTransport(http.DefaultTransport, test.encoding), 87 | } 88 | 89 | res, err := client.Get(server.URL + "/") 90 | if err != nil { 91 | t.Error(err) 92 | return 93 | } 94 | 95 | b, _ := ioutil.ReadAll(res.Body) 96 | res.Body.Close() 97 | 98 | if s := string(b); s != "Hello World!" { 99 | t.Error(s) 100 | } 101 | }) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /httpx/handler.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import "net/http" 4 | 5 | // StatusHandler returns a HTTP handler that always responds with status and an 6 | // empty body. 7 | func StatusHandler(status int) http.Handler { 8 | return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 9 | res.WriteHeader(status) 10 | }) 11 | } 12 | -------------------------------------------------------------------------------- /httpx/handler_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestStatusHandler(t *testing.T) { 11 | for _, status := range []int{http.StatusOK, http.StatusNotFound} { 12 | t.Run(fmt.Sprint(status), func(t *testing.T) { 13 | req := httptest.NewRequest("GET", "/", nil) 14 | res := httptest.NewRecorder() 15 | 16 | handler := StatusHandler(status) 17 | handler.ServeHTTP(res, req) 18 | 19 | if res.Code != status { 20 | t.Error(res.Code) 21 | } 22 | }) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /httpx/httplex.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | // ============================================================================= 4 | // This file was adapted from x/net/lex/httplex because we needed some 5 | // unexported functions of the package. 6 | // The tests for the functions we capied have also been ported in 7 | // httplex_test.go. 8 | // ============================================================================= 9 | 10 | // Copyright 2016 The Go Authors. All rights reserved. 11 | // Use of this source code is governed by a BSD-style 12 | // license that can be found in the LICENSE file. 13 | 14 | // Package httplex contains rules around lexical matters of various 15 | // HTTP-related specifications. 16 | // 17 | // This package is shared by the standard library (which vendors it) 18 | // and x/net/http2. It comes with no API stability promise. 19 | //package httplex 20 | 21 | import ( 22 | "strings" 23 | "unicode/utf8" 24 | ) 25 | 26 | var isTokenTable = [127]bool{ 27 | '!': true, 28 | '#': true, 29 | '$': true, 30 | '%': true, 31 | '&': true, 32 | '\'': true, 33 | '*': true, 34 | '+': true, 35 | '-': true, 36 | '.': true, 37 | '0': true, 38 | '1': true, 39 | '2': true, 40 | '3': true, 41 | '4': true, 42 | '5': true, 43 | '6': true, 44 | '7': true, 45 | '8': true, 46 | '9': true, 47 | 'A': true, 48 | 'B': true, 49 | 'C': true, 50 | 'D': true, 51 | 'E': true, 52 | 'F': true, 53 | 'G': true, 54 | 'H': true, 55 | 'I': true, 56 | 'J': true, 57 | 'K': true, 58 | 'L': true, 59 | 'M': true, 60 | 'N': true, 61 | 'O': true, 62 | 'P': true, 63 | 'Q': true, 64 | 'R': true, 65 | 'S': true, 66 | 'T': true, 67 | 'U': true, 68 | 'W': true, 69 | 'V': true, 70 | 'X': true, 71 | 'Y': true, 72 | 'Z': true, 73 | '^': true, 74 | '_': true, 75 | '`': true, 76 | 'a': true, 77 | 'b': true, 78 | 'c': true, 79 | 'd': true, 80 | 'e': true, 81 | 'f': true, 82 | 'g': true, 83 | 'h': true, 84 | 'i': true, 85 | 'j': true, 86 | 'k': true, 87 | 'l': true, 88 | 'm': true, 89 | 'n': true, 90 | 'o': true, 91 | 'p': true, 92 | 'q': true, 93 | 'r': true, 94 | 's': true, 95 | 't': true, 96 | 'u': true, 97 | 'v': true, 98 | 'w': true, 99 | 'x': true, 100 | 'y': true, 101 | 'z': true, 102 | '|': true, 103 | '~': true, 104 | } 105 | 106 | // isTokenByte returns true if b is a byte that can be found in a token. 107 | func isTokenByte(b byte) bool { 108 | i := int(b) 109 | return i < len(isTokenTable) && isTokenTable[i] 110 | } 111 | 112 | // isToken returns true if s is a valid HTTP token. 113 | func isToken(s string) bool { 114 | if len(s) == 0 { 115 | return false 116 | } 117 | for i := range s { 118 | if !isTokenByte(s[i]) { 119 | return false 120 | } 121 | } 122 | return true 123 | } 124 | 125 | // headerValuesContainsToken reports whether any string in values 126 | // contains the provided token, ASCII case-insensitively. 127 | func headerValuesContainsToken(values []string, token string) bool { 128 | for _, v := range values { 129 | if headerValueContainsToken(v, token) { 130 | return true 131 | } 132 | } 133 | return false 134 | } 135 | 136 | // isOWS reports whether b is an optional whitespace byte, as defined 137 | // by RFC 7230 section 3.2.3. 138 | func isOWS(b byte) bool { return b == ' ' || b == '\t' } 139 | 140 | // trimOWS returns x with all optional whitespace removes from the 141 | // beginning and end. 142 | func trimOWS(x string) string { 143 | // TODO: consider using strings.Trim(x, " \t") instead, 144 | // if and when it's fast enough. See issue 10292. 145 | // But this ASCII-only code will probably always beat UTF-8 146 | // aware code. 147 | for len(x) > 0 && isOWS(x[0]) { 148 | x = x[1:] 149 | } 150 | for len(x) > 0 && isOWS(x[len(x)-1]) { 151 | x = x[:len(x)-1] 152 | } 153 | return x 154 | } 155 | 156 | // headerValueContainsToken reports whether v (assumed to be a 157 | // 0#element, in the ABNF extension described in RFC 7230 section 7) 158 | // contains token amongst its comma-separated tokens, ASCII 159 | // case-insensitively. 160 | func headerValueContainsToken(v string, token string) bool { 161 | v = trimOWS(v) 162 | if comma := strings.IndexByte(v, ','); comma != -1 { 163 | return tokenEqual(trimOWS(v[:comma]), token) || headerValueContainsToken(v[comma+1:], token) 164 | } 165 | return tokenEqual(v, token) 166 | } 167 | 168 | // lowerASCII returns the ASCII lowercase version of b. 169 | func lowerASCII(b byte) byte { 170 | if 'A' <= b && b <= 'Z' { 171 | return b + ('a' - 'A') 172 | } 173 | return b 174 | } 175 | 176 | // tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively. 177 | func tokenEqual(t1, t2 string) bool { 178 | if len(t1) != len(t2) { 179 | return false 180 | } 181 | for i, b := range t1 { 182 | if b >= utf8.RuneSelf { 183 | // No UTF-8 or non-ASCII allowed in tokens. 184 | return false 185 | } 186 | if lowerASCII(byte(b)) != lowerASCII(t2[i]) { 187 | return false 188 | } 189 | } 190 | return true 191 | } 192 | -------------------------------------------------------------------------------- /httpx/httplex_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //package httplex 6 | package httpx 7 | 8 | import ( 9 | "fmt" 10 | "testing" 11 | ) 12 | 13 | func TestHeaderValuesContainsToken(t *testing.T) { 14 | tests := []struct { 15 | vals []string 16 | token string 17 | want bool 18 | }{ 19 | { 20 | vals: []string{"foo"}, 21 | token: "foo", 22 | want: true, 23 | }, 24 | { 25 | vals: []string{"bar", "foo"}, 26 | token: "foo", 27 | want: true, 28 | }, 29 | { 30 | vals: []string{"foo"}, 31 | token: "FOO", 32 | want: true, 33 | }, 34 | { 35 | vals: []string{"foo"}, 36 | token: "bar", 37 | want: false, 38 | }, 39 | { 40 | vals: []string{" foo "}, 41 | token: "FOO", 42 | want: true, 43 | }, 44 | { 45 | vals: []string{"foo,bar"}, 46 | token: "FOO", 47 | want: true, 48 | }, 49 | { 50 | vals: []string{"bar,foo,bar"}, 51 | token: "FOO", 52 | want: true, 53 | }, 54 | { 55 | vals: []string{"bar , foo"}, 56 | token: "FOO", 57 | want: true, 58 | }, 59 | { 60 | vals: []string{"foo ,bar "}, 61 | token: "FOO", 62 | want: true, 63 | }, 64 | { 65 | vals: []string{"bar, foo ,bar"}, 66 | token: "FOO", 67 | want: true, 68 | }, 69 | { 70 | vals: []string{"bar , foo"}, 71 | token: "FOO", 72 | want: true, 73 | }, 74 | } 75 | for _, tt := range tests { 76 | t.Run(fmt.Sprint(tt.vals), func(t *testing.T) { 77 | got := headerValuesContainsToken(tt.vals, tt.token) 78 | if got != tt.want { 79 | t.Errorf("headerValuesContainsToken(%q, %q) = %v; want %v", tt.vals, tt.token, got, tt.want) 80 | } 81 | }) 82 | } 83 | } 84 | 85 | func TestTokenEqual(t *testing.T) { 86 | tests := []struct { 87 | t1 string 88 | t2 string 89 | eq bool 90 | }{ 91 | { 92 | t1: "", 93 | t2: "", 94 | eq: true, 95 | }, 96 | { 97 | t1: "A", 98 | t2: "B", 99 | eq: false, 100 | }, 101 | { 102 | t1: "A", 103 | t2: "a", 104 | eq: true, 105 | }, 106 | { 107 | t1: "你好", 108 | t2: "你好", 109 | eq: false, 110 | }, 111 | { 112 | t1: "123", 113 | t2: "A", 114 | eq: false, 115 | }, 116 | } 117 | 118 | for _, test := range tests { 119 | t.Run(fmt.Sprintf("%q==%q:%v", test.t1, test.t2, test.eq), func(t *testing.T) { 120 | if eq := tokenEqual(test.t1, test.t2); eq != test.eq { 121 | t.Error(eq) 122 | } 123 | }) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /httpx/httpxtest/test_server.go: -------------------------------------------------------------------------------- 1 | package httpxtest 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "io" 7 | "io/ioutil" 8 | "net" 9 | "net/http" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | // ServerConfig is used to configure the HTTP server started by MakeServer. 15 | type ServerConfig struct { 16 | Handler http.Handler 17 | ReadTimeout time.Duration 18 | WriteTimeout time.Duration 19 | MaxHeaderBytes int 20 | } 21 | 22 | // MakeServer is a function called by the TestServer test suite to create a new 23 | // HTTP server that runs the given config. 24 | // The function must return the URL at which the server can be accessed and a 25 | // closer function to terminate the server. 26 | type MakeServer func(ServerConfig) (url string, close func()) 27 | 28 | // TestServer is a test suite for HTTP servers, inspired by 29 | // golang.org/x/net/nettest.TestConn. 30 | func TestServer(t *testing.T, f MakeServer) { 31 | run := func(name string, test func(*testing.T, MakeServer)) { 32 | t.Run(name, func(t *testing.T) { 33 | t.Parallel() 34 | test(t, f) 35 | }) 36 | } 37 | run("Basic", testServerBasic) 38 | run("Transfer-Encoding:chunked", testServerTransferEncodingChunked) 39 | run("ErrBodyNotAllowed", testServerErrBodyNotAllowed) 40 | run("ErrContentLength", testServerErrContentLength) 41 | run("ReadTimeout", testServerReadTimeout) 42 | run("WriteTimeout", testServerWriteTimeout) 43 | } 44 | 45 | // tests that basic features of the http server are working as expected, setting 46 | // a content length and a response status should work fine. 47 | func testServerBasic(t *testing.T, f MakeServer) { 48 | url, close := f(ServerConfig{ 49 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 50 | w.Header().Set("Content-Length", "12") 51 | w.WriteHeader(http.StatusAccepted) 52 | w.Write([]byte("Hello World!")) 53 | }), 54 | }) 55 | defer close() 56 | 57 | res, err := http.Get(url + "/") 58 | if err != nil { 59 | t.Error(err) 60 | return 61 | } 62 | 63 | buf := &bytes.Buffer{} 64 | buf.ReadFrom(res.Body) 65 | 66 | if err := res.Body.Close(); err != nil { 67 | t.Error("error closing the response body:", err) 68 | } 69 | if res.StatusCode != http.StatusAccepted { 70 | t.Error("bad response code:", res.StatusCode) 71 | } 72 | if s := buf.String(); s != "Hello World!" { 73 | t.Error("bad response body:", s) 74 | } 75 | } 76 | 77 | // test that a chunked transfer encoding on the connection works as expected, 78 | // this is done by sending a huge payload via multiple calls to Write. 79 | func testServerTransferEncodingChunked(t *testing.T, f MakeServer) { 80 | b := make([]byte, 128) 81 | 82 | url, close := f(ServerConfig{ 83 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 84 | // No Content-Length is set, the server should be using 85 | // "Transfer-Encoding: chunked" in the response. 86 | for i := 0; i != 100; i++ { 87 | if _, err := w.Write(b); err != nil { 88 | t.Error(err) 89 | return 90 | } 91 | } 92 | }), 93 | }) 94 | defer close() 95 | 96 | res, err := http.Get(url + "/") 97 | if err != nil { 98 | t.Error(err) 99 | return 100 | } 101 | 102 | r := &countReader{R: res.Body} 103 | io.Copy(ioutil.Discard, r) 104 | 105 | if err := res.Body.Close(); err != nil { 106 | t.Error("error closing the response body:", err) 107 | } 108 | if res.StatusCode != http.StatusOK { 109 | t.Error("bad response code:", res.StatusCode) 110 | } 111 | if r.N != (100 * len(b)) { 112 | t.Error("bad response body length:", r.N) 113 | } 114 | } 115 | 116 | // test that the server's response writer returns http.ErrBodyNotAllowed when 117 | // the program attempts to write a body on a response that doesn't allow one. 118 | func testServerErrBodyNotAllowed(t *testing.T, f MakeServer) { 119 | tests := []struct { 120 | reason string 121 | handler http.Handler 122 | }{ 123 | { 124 | reason: "101 Switching Protocols", 125 | handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 126 | w.WriteHeader(http.StatusSwitchingProtocols) 127 | }), 128 | }, 129 | { 130 | reason: "101 Switching Protocols", 131 | handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 132 | w.WriteHeader(http.StatusNoContent) 133 | }), 134 | }, 135 | { 136 | reason: "101 Switching Protocols", 137 | handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 138 | w.WriteHeader(http.StatusNotModified) 139 | }), 140 | }, 141 | } 142 | 143 | for _, test := range tests { 144 | t.Run(test.reason, func(t *testing.T) { 145 | url, close := f(ServerConfig{ 146 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 147 | test.handler.ServeHTTP(w, req) 148 | // No body is allowed on this response, the Write method 149 | // must return an error indicating that the program is 150 | // misbehaving. 151 | if _, err := w.Write([]byte("Hello World!")); err != http.ErrBodyNotAllowed { 152 | t.Errorf("expected http.ErrBodyNotAllowed but got %v", err) 153 | } 154 | }), 155 | }) 156 | defer close() 157 | 158 | res, err := http.Get(url + "/") 159 | if err != nil { 160 | t.Error(err) 161 | return 162 | } 163 | 164 | r := &countReader{R: res.Body} 165 | io.Copy(ioutil.Discard, r) 166 | res.Body.Close() 167 | 168 | if r.N != 0 { 169 | t.Errorf("expected no body in the response but received %d bytes", r.N) 170 | } 171 | }) 172 | } 173 | } 174 | 175 | // test that the server's response writer returns http.ErrContentLength when the 176 | // program attempts to write more data in the response than it previously set on 177 | // the Content-Length header. 178 | func testServerErrContentLength(t *testing.T, f MakeServer) { 179 | url, close := f(ServerConfig{ 180 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 181 | w.Header().Set("Content-Length", "1") 182 | w.WriteHeader(http.StatusOK) 183 | // The program writes too many bytes to the response, it must be 184 | // notified by getting an error on the Write call. 185 | if _, err := w.Write([]byte("Hello World!")); err != http.ErrContentLength { 186 | t.Errorf("expected http.ErrContentLength but got %v", err) 187 | } 188 | }), 189 | }) 190 | defer close() 191 | 192 | res, err := http.Get(url + "/") 193 | if err != nil { 194 | t.Error(err) 195 | return 196 | } 197 | 198 | r := &countReader{R: res.Body} 199 | io.Copy(ioutil.Discard, r) 200 | res.Body.Close() 201 | 202 | if r.N != 1 { 203 | t.Errorf("expected at 1 byte in the response but received %d bytes", r.N) 204 | } 205 | } 206 | 207 | // test that the server properly closes connections when reading a request takes 208 | // too much time. 209 | func testServerReadTimeout(t *testing.T, f MakeServer) { 210 | url, close := f(ServerConfig{ 211 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}), 212 | ReadTimeout: 100 * time.Millisecond, 213 | }) 214 | defer close() 215 | 216 | conn, err := net.Dial("tcp", url[7:]) // trim "http://" 217 | if err != nil { 218 | t.Error(err) 219 | return 220 | } 221 | defer conn.Close() 222 | 223 | // Write the beginning of a request but doesn't terminate it, the server 224 | // should timeout the connection after 100ms. 225 | if _, err := conn.Write([]byte("GET / HTTP/1.1")); err != nil { 226 | t.Error(err) 227 | return 228 | } 229 | 230 | var b [128]byte 231 | if n, err := conn.Read(b[:]); err != io.EOF { 232 | t.Errorf("expected io.EOF on the read operation but got %v (%d bytes)", err, n) 233 | } 234 | } 235 | 236 | // test that the server properly closes connections when the client doesn't read 237 | // the response. 238 | func testServerWriteTimeout(t *testing.T, f MakeServer) { 239 | b := make([]byte, 10*(1<<22)) // 40MB 240 | 241 | url, close := f(ServerConfig{ 242 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 243 | if _, err := w.Write(b); err == nil { 244 | t.Error(err) 245 | } 246 | }), 247 | WriteTimeout: 100 * time.Millisecond, 248 | }) 249 | defer close() 250 | 251 | conn, err := net.Dial("tcp", url[7:]) // trim "http://" 252 | if err != nil { 253 | t.Error(err) 254 | return 255 | } 256 | defer conn.Close() 257 | 258 | r := bufio.NewReader(conn) 259 | w := bufio.NewWriter(conn) 260 | 261 | req, _ := http.NewRequest("GET", "/", nil) 262 | req.Write(w) 263 | 264 | if err := w.Flush(); err != nil { 265 | t.Error(err) 266 | return 267 | } 268 | 269 | // Wait so the server can timeout the request. 270 | time.Sleep(200 * time.Millisecond) 271 | 272 | res, err := http.ReadResponse(r, req) 273 | if err != nil { 274 | return // OK, nothing was sent 275 | } 276 | 277 | body := &countReader{R: res.Body} 278 | io.Copy(ioutil.Discard, body) 279 | res.Body.Close() 280 | 281 | if body.N >= len(b) { 282 | t.Errorf("the server shouldn't have been able to send the entire response body of %d bytes", body.N) 283 | } 284 | } 285 | 286 | // countReader is an io.Reader which counts how many bytes were read. 287 | type countReader struct { 288 | R io.Reader 289 | N int 290 | } 291 | 292 | func (r *countReader) Read(b []byte) (n int, err error) { 293 | if n, err = r.R.Read(b); n > 0 { 294 | r.N += n 295 | } 296 | return 297 | } 298 | -------------------------------------------------------------------------------- /httpx/httpxtest/test_transport.go: -------------------------------------------------------------------------------- 1 | package httpxtest 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | // MakeTransport constructs a new HTTP transport used by a single sub-test of 12 | // TestTransport. 13 | type MakeTransport func() http.RoundTripper 14 | 15 | // TestTransport is a test suite for HTTP transports, inspired by 16 | // golang.org/x/net/nettest.TestConn. 17 | func TestTransport(t *testing.T, f MakeTransport) { 18 | run := func(name string, test func(*testing.T, MakeTransport)) { 19 | t.Run(name, func(t *testing.T) { 20 | t.Parallel() 21 | test(t, f) 22 | }) 23 | } 24 | run("Basic", testTransportHEAD) 25 | } 26 | 27 | func testTransportHEAD(t *testing.T, f MakeTransport) { 28 | tests := []struct { 29 | method string 30 | path string 31 | reqBody string 32 | resBody string 33 | }{ 34 | { 35 | method: "HEAD", 36 | path: "/", 37 | }, 38 | { 39 | method: "GET", 40 | path: "/", 41 | resBody: "Hello World!", 42 | }, 43 | { 44 | method: "POST", 45 | path: "/hello/world", 46 | reqBody: "answer", 47 | resBody: "42", 48 | }, 49 | { 50 | method: "PUT", 51 | path: "/hello/world", 52 | reqBody: "answer=42", 53 | resBody: "", 54 | }, 55 | { 56 | method: "DELETE", 57 | path: "/hello/world", 58 | }, 59 | } 60 | 61 | for _, test := range tests { 62 | test := test 63 | t.Run(test.method, func(t *testing.T) { 64 | t.Parallel() 65 | 66 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 67 | if test.method != req.Method { 68 | t.Errorf("bad method received by the server, expected %s but got %s", test.method, req.Method) 69 | } 70 | if test.path != req.URL.Path { 71 | t.Errorf("bad path received by the server, expected %s but got %s", test.path, req.URL.Path) 72 | } 73 | 74 | b, err := ioutil.ReadAll(req.Body) 75 | req.Body.Close() 76 | 77 | if err != nil { 78 | t.Errorf("the server got an error while reading the request body: %v", err) 79 | } 80 | if s := string(b); s != test.reqBody { 81 | t.Errorf("bad request body received by the server, expected %#v but got %#v", test.reqBody, s) 82 | } 83 | 84 | h := w.Header() 85 | h.Set("Content-Type", "text/plain") 86 | 87 | if _, err := w.Write([]byte(test.resBody)); err != nil { 88 | t.Errorf("the server got an error while writing the response body: %v", err) 89 | } 90 | })) 91 | defer server.Close() 92 | 93 | req, err := http.NewRequest(test.method, server.URL+test.path, strings.NewReader(test.reqBody)) 94 | if err != nil { 95 | t.Error(err) 96 | return 97 | } 98 | req.Header.Set("Content-Type", "text/plain") 99 | 100 | res, err := f().RoundTrip(req) 101 | if err != nil { 102 | t.Error(err) 103 | return 104 | } 105 | 106 | b, err := ioutil.ReadAll(res.Body) 107 | res.Body.Close() 108 | 109 | if err != nil { 110 | t.Error(err) 111 | return 112 | } 113 | if s := string(b); s != test.resBody { 114 | t.Errorf("bad body received by the client, expected %#v but got %#v", test.resBody, s) 115 | } 116 | }) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /httpx/media.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // MediaRange is a representation of a HTTP media range as described in 10 | // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html 11 | type MediaRange struct { 12 | typ string 13 | sub string 14 | params []MediaParam 15 | } 16 | 17 | // Param return the value of the parameter with name, which will be an empty 18 | // string if none was found. 19 | func (r MediaRange) Param(name string) string { 20 | for _, p := range r.params { 21 | if tokenEqual(p.name, name) { 22 | return p.value 23 | } 24 | } 25 | return "" 26 | } 27 | 28 | // String satisfies the fmt.Stringer interface. 29 | func (r MediaRange) String() string { 30 | return fmt.Sprint(r) 31 | } 32 | 33 | // Format satisfies the fmt.Formatter interface. 34 | func (r MediaRange) Format(w fmt.State, _ rune) { 35 | fmt.Fprintf(w, "%s/%s", r.typ, r.sub) 36 | 37 | for _, p := range r.params { 38 | fmt.Fprintf(w, ";%v", p) 39 | } 40 | } 41 | 42 | // ParseMediaRange parses a string representation of a HTTP media range from s. 43 | func ParseMediaRange(s string) (r MediaRange, err error) { 44 | var s1 string 45 | var s2 string 46 | var s3 string 47 | var i int 48 | var j int 49 | var mp []MediaParam 50 | 51 | if i = strings.IndexByte(s, '/'); i < 0 { 52 | goto error 53 | } 54 | 55 | s1 = s[:i] 56 | 57 | if j = strings.IndexByte(s[i+1:], ';'); j < 0 { 58 | s2 = s[i+1:] 59 | } else { 60 | s2 = s[i+1 : i+1+j] 61 | s3 = s[i+j+2:] 62 | } 63 | 64 | if !isToken(s1) { 65 | goto error 66 | } 67 | if !isToken(s2) { 68 | goto error 69 | } 70 | 71 | for len(s3) != 0 { 72 | var p MediaParam 73 | 74 | if i = strings.IndexByte(s3, ';'); i < 0 { 75 | i = len(s3) 76 | } 77 | 78 | if p, err = ParseMediaParam(trimOWS(s3[:i])); err != nil { 79 | goto error 80 | } 81 | 82 | mp = append(mp, p) 83 | s3 = s3[i:] 84 | 85 | if len(s3) != 0 { 86 | s3 = s3[1:] 87 | } 88 | } 89 | 90 | r = MediaRange{ 91 | typ: s1, 92 | sub: s2, 93 | params: mp, 94 | } 95 | return 96 | error: 97 | err = errorInvalidMediaRange(s) 98 | return 99 | } 100 | 101 | func errorInvalidMediaRange(s string) error { 102 | return errors.New("invalid media range: " + s) 103 | } 104 | 105 | // MediaParam is a representation of a HTTP media parameter as described in 106 | // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html 107 | type MediaParam struct { 108 | name string 109 | value string 110 | } 111 | 112 | // String satisfies the fmt.Stringer interface. 113 | func (p MediaParam) String() string { 114 | return fmt.Sprint(p) 115 | } 116 | 117 | // Format satisfies the fmt.Formatter interface. 118 | func (p MediaParam) Format(w fmt.State, _ rune) { 119 | fmt.Fprintf(w, "%v=%v", p.name, quoted(p.value)) 120 | } 121 | 122 | // ParseMediaParam parses a string representation of a HTTP media parameter 123 | // from s. 124 | func ParseMediaParam(s string) (p MediaParam, err error) { 125 | var s1 string 126 | var s2 string 127 | var q quoted 128 | var i = strings.IndexByte(s, '=') 129 | 130 | if i < 0 { 131 | goto error 132 | } 133 | 134 | s1 = s[:i] 135 | s2 = s[i+1:] 136 | 137 | if !isToken(s1) { 138 | goto error 139 | } 140 | if q, err = parseQuoted(s2); err != nil { 141 | goto error 142 | } 143 | 144 | p = MediaParam{ 145 | name: s1, 146 | value: string(q), 147 | } 148 | return 149 | error: 150 | err = errorInvalidMediaParam(s) 151 | return 152 | } 153 | 154 | // MediaType is a representation of a HTTP media type which is usually expressed 155 | // in the form of a main and sub type as in "main/sub", where both may be the 156 | // special wildcard token "*". 157 | type MediaType struct { 158 | typ string 159 | sub string 160 | } 161 | 162 | // Contains returns true if t is a superset or is equal to t2. 163 | func (t MediaType) Contains(t2 MediaType) bool { 164 | return t.typ == "*" || (t.typ == t2.typ && (t.sub == "*" || t.sub == t2.sub)) 165 | } 166 | 167 | // String satisfies the fmt.Stringer interface. 168 | func (t MediaType) String() string { 169 | return fmt.Sprint(t) 170 | } 171 | 172 | // Format satisfies the fmt.Formatter interface. 173 | func (t MediaType) Format(w fmt.State, _ rune) { 174 | fmt.Fprintf(w, "%s/%s", t.typ, t.sub) 175 | } 176 | 177 | // ParseMediaType parses the media type in s. 178 | func ParseMediaType(s string) (t MediaType, err error) { 179 | var s1 string 180 | var s2 string 181 | var i = strings.IndexByte(s, '/') 182 | 183 | if i < 0 { 184 | goto error 185 | } 186 | 187 | s1 = s[:i] 188 | s2 = s[i+1:] 189 | 190 | if !isToken(s1) || !isToken(s2) { 191 | goto error 192 | } 193 | 194 | t = MediaType{ 195 | typ: s1, 196 | sub: s2, 197 | } 198 | return 199 | error: 200 | err = errorInvalidMediaType(s) 201 | return 202 | } 203 | 204 | func mediaTypeLess(t1 string, t2 string) bool { 205 | if t1 == t2 { 206 | return false 207 | } 208 | if t1 == "*" { 209 | return false 210 | } 211 | if t2 == "*" { 212 | return true 213 | } 214 | return t1 < t2 215 | } 216 | 217 | func errorInvalidMediaParam(s string) error { 218 | return errors.New("invalid media parameter: " + s) 219 | } 220 | 221 | func errorInvalidMediaType(s string) error { 222 | return errors.New("invalid media type: " + s) 223 | } 224 | -------------------------------------------------------------------------------- /httpx/media_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestParseMediaTypeSuccess(t *testing.T) { 9 | tests := []struct { 10 | s string 11 | m MediaType 12 | }{ 13 | { 14 | s: `*/*`, 15 | m: MediaType{typ: "*", sub: "*"}, 16 | }, 17 | { 18 | s: `text/*`, 19 | m: MediaType{typ: "text", sub: "*"}, 20 | }, 21 | { 22 | s: `text/plain`, 23 | m: MediaType{typ: "text", sub: "plain"}, 24 | }, 25 | } 26 | 27 | for _, test := range tests { 28 | t.Run(test.m.String(), func(t *testing.T) { 29 | m, err := ParseMediaType(test.s) 30 | 31 | if err != nil { 32 | t.Error(err) 33 | } 34 | 35 | if m != test.m { 36 | t.Error(m) 37 | } 38 | }) 39 | } 40 | } 41 | 42 | func TestParseMediaTypeFailure(t *testing.T) { 43 | tests := []struct { 44 | s string 45 | }{ 46 | {``}, // empty string 47 | {`/`}, // missing type and subtype 48 | {`text`}, // missing separator 49 | {`,/plain`}, // bad type 50 | {`text/,`}, // bad subtype 51 | } 52 | 53 | for _, test := range tests { 54 | t.Run(test.s, func(t *testing.T) { 55 | if m, err := ParseMediaType(test.s); err == nil { 56 | t.Error(m) 57 | } 58 | }) 59 | } 60 | } 61 | 62 | func TestMediaTypeContainsTrue(t *testing.T) { 63 | tests := []struct { 64 | t1 MediaType 65 | t2 MediaType 66 | }{ 67 | { 68 | t1: MediaType{typ: "*", sub: "*"}, 69 | t2: MediaType{typ: "text", sub: "plain"}, 70 | }, 71 | { 72 | t1: MediaType{typ: "text", sub: "*"}, 73 | t2: MediaType{typ: "text", sub: "plain"}, 74 | }, 75 | { 76 | t1: MediaType{typ: "text", sub: "plain"}, 77 | t2: MediaType{typ: "text", sub: "plain"}, 78 | }, 79 | } 80 | 81 | for _, test := range tests { 82 | t.Run(test.t1.String()+":"+test.t2.String(), func(t *testing.T) { 83 | if !test.t1.Contains(test.t2) { 84 | t.Error("nope") 85 | } 86 | }) 87 | } 88 | } 89 | 90 | func TestMediaTypeContainsFalse(t *testing.T) { 91 | tests := []struct { 92 | t1 MediaType 93 | t2 MediaType 94 | }{ 95 | { 96 | t1: MediaType{typ: "text", sub: "*"}, 97 | t2: MediaType{typ: "image", sub: "png"}, 98 | }, 99 | { 100 | t1: MediaType{typ: "text", sub: "plain"}, 101 | t2: MediaType{typ: "text", sub: "html"}, 102 | }, 103 | } 104 | 105 | for _, test := range tests { 106 | t.Run(test.t1.String()+":"+test.t2.String(), func(t *testing.T) { 107 | if test.t1.Contains(test.t2) { 108 | t.Error("nope") 109 | } 110 | }) 111 | } 112 | } 113 | 114 | func TestParseMediaParamSuccess(t *testing.T) { 115 | tests := []struct { 116 | s string 117 | p MediaParam 118 | }{ 119 | { 120 | s: `key=value`, 121 | p: MediaParam{name: "key", value: "value"}, 122 | }, 123 | { 124 | s: `key="你好"`, 125 | p: MediaParam{name: "key", value: "你好"}, 126 | }, 127 | } 128 | 129 | for _, test := range tests { 130 | t.Run(test.p.String(), func(t *testing.T) { 131 | p, err := ParseMediaParam(test.s) 132 | 133 | if err != nil { 134 | t.Error(err) 135 | } 136 | 137 | if p != test.p { 138 | t.Error(p) 139 | } 140 | }) 141 | } 142 | } 143 | 144 | func TestParseMediaParamFailure(t *testing.T) { 145 | tests := []struct { 146 | s string 147 | }{ 148 | {``}, // empty string 149 | {`key`}, // missing = 150 | {`key=`}, // missing value 151 | {`=value`}, // missing key 152 | {`=`}, // missing key and value 153 | {`key=你好`}, // non-token and non-quoted value 154 | } 155 | 156 | for _, test := range tests { 157 | t.Run(test.s, func(t *testing.T) { 158 | if p, err := ParseMediaParam(test.s); err == nil { 159 | t.Error(p) 160 | } 161 | }) 162 | } 163 | } 164 | 165 | func TestParseMediaRangeSuccess(t *testing.T) { 166 | tests := []struct { 167 | s string 168 | r MediaRange 169 | }{ 170 | { 171 | s: `image/*`, 172 | r: MediaRange{ 173 | typ: "image", 174 | sub: "*", 175 | }, 176 | }, 177 | { 178 | s: `image/*;`, // trailing ';' 179 | r: MediaRange{ 180 | typ: "image", 181 | sub: "*", 182 | }, 183 | }, 184 | { 185 | s: `text/html;key1=hello;key2="你好"`, 186 | r: MediaRange{ 187 | typ: "text", 188 | sub: "html", 189 | params: []MediaParam{{"key1", "hello"}, {"key2", "你好"}}, 190 | }, 191 | }, 192 | } 193 | 194 | for _, test := range tests { 195 | t.Run(test.r.String(), func(t *testing.T) { 196 | r, err := ParseMediaRange(test.s) 197 | 198 | if err != nil { 199 | t.Error(err) 200 | } 201 | 202 | if !reflect.DeepEqual(r, test.r) { 203 | t.Error(r) 204 | } 205 | }) 206 | } 207 | } 208 | 209 | func TestParseMediaRangeFailure(t *testing.T) { 210 | tests := []struct { 211 | s string 212 | }{ 213 | {``}, // empty string 214 | {`image`}, // no Media type 215 | {`/`}, // bad Media type 216 | {`image/,`}, // bad sub-type 217 | {`image/*;bad`}, // bad parameters 218 | } 219 | 220 | for _, test := range tests { 221 | t.Run(test.s, func(t *testing.T) { 222 | if m, err := ParseMediaRange(test.s); err == nil { 223 | t.Error(m) 224 | } 225 | }) 226 | } 227 | } 228 | 229 | func TestMediaRangeParam(t *testing.T) { 230 | r := MediaRange{ 231 | typ: "image", 232 | sub: "*", 233 | params: []MediaParam{{"answer", "42"}}, 234 | } 235 | 236 | p1 := r.Param("answer") 237 | p2 := r.Param("other") 238 | 239 | if p1 != "42" { 240 | t.Error("found bad Media parameter:", p1) 241 | } 242 | 243 | if p2 != "" { 244 | t.Error("found non-existing Media parameter:", p2) 245 | } 246 | } 247 | 248 | func TestMediaTypeLess(t *testing.T) { 249 | tests := []struct { 250 | t1 string 251 | t2 string 252 | less bool 253 | }{ 254 | { 255 | t1: "", 256 | t2: "", 257 | less: false, 258 | }, 259 | { 260 | t1: "*", 261 | t2: "*", 262 | less: false, 263 | }, 264 | { 265 | t1: "*", 266 | t2: "text", 267 | less: false, 268 | }, 269 | { 270 | t1: "text", 271 | t2: "*", 272 | less: true, 273 | }, 274 | { 275 | t1: "text", 276 | t2: "text", 277 | less: false, 278 | }, 279 | { 280 | t1: "plain", 281 | t2: "html", 282 | less: false, 283 | }, 284 | { 285 | t1: "html", 286 | t2: "plain", 287 | less: true, 288 | }, 289 | } 290 | 291 | for _, test := range tests { 292 | t.Run(test.t1+"<"+test.t2, func(t *testing.T) { 293 | if less := mediaTypeLess(test.t1, test.t2); less != test.less { 294 | t.Error(less) 295 | } 296 | }) 297 | } 298 | } 299 | -------------------------------------------------------------------------------- /httpx/proto.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/segmentio/netx" 10 | ) 11 | 12 | // protoEqual checks if the protocol version used for req is equal to 13 | // HTTP/major.minor 14 | func protoEqual(req *http.Request, major int, minor int) bool { 15 | return req.ProtoMajor == major && req.ProtoMinor == minor 16 | } 17 | 18 | // protoVersion returns the version part of the protocol identifier of req. 19 | func protoVersion(req *http.Request) string { 20 | proto := req.Proto 21 | if strings.HasPrefix(proto, "HTTP/") { 22 | proto = proto[5:] 23 | } 24 | return proto 25 | } 26 | 27 | // copyHeader copies the HTTP header src into dst. 28 | func copyHeader(dst http.Header, src http.Header) { 29 | for name, values := range src { 30 | dst[name] = append(make([]string, 0, len(values)), values...) 31 | } 32 | } 33 | 34 | // deleteHopFields deletes the hop-by-hop fields from header. 35 | func deleteHopFields(h http.Header) { 36 | forEachHeaderValues(h["Connection"], func(v string) { 37 | if v != "close" { 38 | h.Del(v) 39 | } 40 | }) 41 | h.Del("Connection") 42 | h.Del("Keep-Alive") 43 | h.Del("Proxy-Authenticate") 44 | h.Del("Proxy-Authorization") 45 | h.Del("Proxy-Connection") 46 | h.Del("Te") 47 | h.Del("Trailer") 48 | h.Del("Transfer-Encoding") 49 | h.Del("Upgrade") 50 | } 51 | 52 | // translateXForwardedFor converts the X-Forwarded-* headers in their equivalent 53 | // Forwarded header representation. 54 | func translateXForwarded(h http.Header) { 55 | xFor := h.Get("X-Forwarded-For") 56 | xBy := h.Get("X-Forwarded-By") 57 | xPort := h.Get("X-Forwarded-Port") 58 | xProto := h.Get("X-Forwarded-Proto") 59 | forwarded := "" 60 | 61 | // If there's more than one entry in the X-Forwarded-For header it gets way 62 | // too complex to report all the different combinations of X-Forwarded-* 63 | // headers, and there's no standard saying which ones should or shouldn't be 64 | // included so we just translate the X-Forwarded-For list and pass on the 65 | // other ones. 66 | if n := strings.Count(xFor, ","); n != 0 { 67 | s := make([]string, 0, n+1) 68 | forEachHeaderValues([]string{xFor}, func(v string) { 69 | s = append(s, "for="+quoteForwarded(v)) 70 | }) 71 | forwarded = strings.Join(s, ", ") 72 | } else { 73 | if len(xPort) != 0 { 74 | xFor = net.JoinHostPort(trimOWS(xFor), trimOWS(xPort)) 75 | } 76 | forwarded = makeForwarded(trimOWS(xProto), trimOWS(xFor), trimOWS(xBy)) 77 | } 78 | 79 | if len(forwarded) != 0 { 80 | h.Set("Forwarded", forwarded) 81 | } 82 | 83 | h.Del("X-Forwarded-For") 84 | h.Del("X-Forwarded-By") 85 | h.Del("X-Forwarded-Port") 86 | h.Del("X-Forwarded-Proto") 87 | } 88 | 89 | // quoteForwarded returns addr, quoted if necessary in order to be used in the 90 | // Forwarded header. 91 | func quoteForwarded(addr string) string { 92 | if netx.IsIPv4(addr) { 93 | return addr 94 | } 95 | if netx.IsIPv6(addr) { 96 | return quote("[" + addr + "]") 97 | } 98 | return quote(addr) 99 | } 100 | 101 | // mameForwarded builds a Forwarded header value from proto, forAddr, and byAddr. 102 | func makeForwarded(proto string, forAddr string, byAddr string) string { 103 | s := make([]string, 0, 4) 104 | if len(proto) != 0 { 105 | s = append(s, "proto="+quoted(proto).String()) 106 | } 107 | if len(forAddr) != 0 { 108 | s = append(s, "for="+quoteForwarded(forAddr)) 109 | } 110 | if len(byAddr) != 0 { 111 | s = append(s, "by="+quoteForwarded(byAddr)) 112 | } 113 | return strings.Join(s, ";") 114 | } 115 | 116 | // addForwarded adds proto, forAddr, and byAddr to the Forwarded header. 117 | func addForwarded(header http.Header, proto string, forAddr string, byAddr string) { 118 | addHeaderValue(header, "Forwarded", makeForwarded(proto, forAddr, byAddr)) 119 | } 120 | 121 | // makeVia creates a Via header value from version and host. 122 | func makeVia(version string, host string) string { 123 | return version + " " + host 124 | } 125 | 126 | // addVia adds version and host to the Via header. 127 | func addVia(header http.Header, version string, host string) { 128 | addHeaderValue(header, "Via", makeVia(version, host)) 129 | } 130 | 131 | // addHeaderValue adds value to the name header. 132 | func addHeaderValue(header http.Header, name string, value string) { 133 | if prev := header.Get(name); len(prev) != 0 { 134 | value = prev + ", " + value 135 | } 136 | header.Set(name, value) 137 | } 138 | 139 | // maxForwards returns the value of the Max-Forward header. 140 | func maxForwards(header http.Header) (max int, err error) { 141 | if s := header.Get("Max-Forwards"); len(s) == 0 { 142 | max = -1 143 | } else { 144 | max, err = strconv.Atoi(s) 145 | } 146 | return 147 | } 148 | 149 | // connectionUpgrade returns the value of the Upgrade header if it is present in 150 | // the Connection header. 151 | func connectionUpgrade(header http.Header) string { 152 | if !headerValuesContainsToken(header["Connection"], "Upgrade") { 153 | return "" 154 | } 155 | return header.Get("Upgrade") 156 | } 157 | 158 | // headerValuesRemoveTokens removes tokens from values, returning a new list of values. 159 | func headerValuesRemoveTokens(values []string, tokens ...string) []string { 160 | result := make([]string, 0, len(values)) 161 | for _, v := range values { 162 | var item []string 163 | forEachValue: 164 | for len(v) != 0 { 165 | var s string 166 | s, v = readHeaderValue(v) 167 | for _, t := range tokens { 168 | if tokenEqual(t, s) { 169 | continue forEachValue 170 | } 171 | } 172 | item = append(item, s) 173 | } 174 | if len(item) != 0 { 175 | result = append(result, strings.Join(item, ", ")) 176 | } 177 | } 178 | return result 179 | } 180 | 181 | // forEachHeaderValues through each value of l, where each element of l is a 182 | // comma-separated list of values, calling f on each element. 183 | func forEachHeaderValues(l []string, f func(string)) { 184 | for _, a := range l { 185 | for len(a) != 0 { 186 | var s string 187 | s, a = readHeaderValue(a) 188 | f(s) 189 | } 190 | } 191 | } 192 | 193 | // readHeaderValue tries to read the next value in a comma-separated list. 194 | func readHeaderValue(s string) (value string, tail string) { 195 | if off := strings.IndexByte(s, ','); off >= 0 { 196 | value, tail = s[:off], s[off+1:] 197 | } else { 198 | value = s 199 | } 200 | value = trimOWS(value) 201 | return 202 | } 203 | 204 | // isIdempotent returns true if method is idempotent. 205 | func isIdempotent(method string) bool { 206 | switch method { 207 | case http.MethodHead, http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions: 208 | return true 209 | } 210 | return false 211 | } 212 | 213 | // isRetriable returns true if the status is a retriable error. 214 | func isRetriable(status int) bool { 215 | switch status { 216 | case http.StatusInternalServerError: 217 | case http.StatusBadGateway: 218 | case http.StatusServiceUnavailable: 219 | case http.StatusGatewayTimeout: 220 | default: 221 | return false 222 | } 223 | return true 224 | } 225 | -------------------------------------------------------------------------------- /httpx/proto_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestProtoEqual(t *testing.T) { 11 | tests := []struct { 12 | req *http.Request 13 | maj int 14 | min int 15 | res bool 16 | }{ 17 | { 18 | req: &http.Request{}, 19 | maj: 0, 20 | min: 0, 21 | res: true, 22 | }, 23 | { 24 | req: &http.Request{ProtoMajor: 1, ProtoMinor: 0}, 25 | maj: 1, 26 | min: 0, 27 | res: true, 28 | }, 29 | { 30 | req: &http.Request{ProtoMajor: 1, ProtoMinor: 1}, 31 | maj: 1, 32 | min: 1, 33 | res: true, 34 | }, 35 | { 36 | req: &http.Request{ProtoMajor: 0, ProtoMinor: 9}, 37 | maj: 1, 38 | min: 0, 39 | res: false, 40 | }, 41 | { 42 | req: &http.Request{ProtoMajor: 1, ProtoMinor: 0}, 43 | maj: 1, 44 | min: 1, 45 | res: false, 46 | }, 47 | } 48 | 49 | for _, test := range tests { 50 | t.Run(fmt.Sprintf("HTTP/%d.%d", test.maj, test.min), func(t *testing.T) { 51 | if res := protoEqual(test.req, test.maj, test.min); res != test.res { 52 | t.Error(res) 53 | } 54 | }) 55 | } 56 | } 57 | 58 | func TestProtoVersion(t *testing.T) { 59 | tests := []struct { 60 | req *http.Request 61 | out string 62 | }{ 63 | { 64 | req: &http.Request{}, 65 | out: "", 66 | }, 67 | { 68 | req: &http.Request{Proto: "bad"}, 69 | out: "bad", 70 | }, 71 | { 72 | req: &http.Request{Proto: "HTTP/1.0"}, 73 | out: "1.0", 74 | }, 75 | { 76 | req: &http.Request{Proto: "HTTP/1.1"}, 77 | out: "1.1", 78 | }, 79 | } 80 | 81 | for _, test := range tests { 82 | t.Run(test.out, func(t *testing.T) { 83 | if s := protoVersion(test.req); s != test.out { 84 | t.Error(s) 85 | } 86 | }) 87 | } 88 | } 89 | 90 | func TestCopyHeader(t *testing.T) { 91 | h1 := http.Header{"Content-Type": {"text/html"}} 92 | h2 := http.Header{"Content-Type": {"text/html"}, "Content-Length": {"42"}} 93 | 94 | copyHeader(h1, h2) 95 | 96 | if !reflect.DeepEqual(h1, h2) { 97 | t.Error(h1) 98 | } 99 | } 100 | 101 | func TestDleeteHopFields(t *testing.T) { 102 | h := http.Header{ 103 | "Connection": {"Upgrade", "Other"}, 104 | "Keep-Alive": {}, 105 | "Proxy-Authenticate": {}, 106 | "Proxy-Authorization": {}, 107 | "Proxy-Connection": {}, 108 | "Te": {}, 109 | "Trailer": {}, 110 | "Transfer-Encoding": {}, 111 | "Upgrade": {}, 112 | "Other": {}, 113 | "Content-Type": {"text/html"}, 114 | } 115 | 116 | deleteHopFields(h) 117 | 118 | if !reflect.DeepEqual(h, http.Header{ 119 | "Content-Type": {"text/html"}, 120 | }) { 121 | t.Error(h) 122 | } 123 | } 124 | 125 | func TestTranslateXForwarded(t *testing.T) { 126 | tests := []struct { 127 | in http.Header 128 | out http.Header 129 | }{ 130 | { 131 | in: http.Header{}, 132 | out: http.Header{}, 133 | }, 134 | 135 | { 136 | in: http.Header{ 137 | "X-Forwarded-For": {"127.0.0.1"}, 138 | }, 139 | out: http.Header{ 140 | "Forwarded": {"for=127.0.0.1"}, 141 | }, 142 | }, 143 | 144 | { 145 | in: http.Header{ 146 | "X-Forwarded-For": {"127.0.0.1"}, 147 | "X-Forwarded-Port": {"56789"}, 148 | }, 149 | out: http.Header{ 150 | "Forwarded": {`for="127.0.0.1:56789"`}, 151 | }, 152 | }, 153 | 154 | { 155 | in: http.Header{ 156 | "X-Forwarded-For": {"127.0.0.1"}, 157 | "X-Forwarded-Port": {"56789"}, 158 | "X-Forwarded-Proto": {"https"}, 159 | }, 160 | out: http.Header{ 161 | "Forwarded": {`proto=https;for="127.0.0.1:56789"`}, 162 | }, 163 | }, 164 | 165 | { 166 | in: http.Header{ 167 | "X-Forwarded-For": {"127.0.0.1"}, 168 | "X-Forwarded-Port": {"56789"}, 169 | "X-Forwarded-Proto": {"https"}, 170 | "X-Forwarded-By": {"localhost"}, 171 | }, 172 | out: http.Header{ 173 | "Forwarded": {`proto=https;for="127.0.0.1:56789";by="localhost"`}, 174 | }, 175 | }, 176 | 177 | { 178 | in: http.Header{ 179 | "X-Forwarded-For": {"212.53.1.6, 127.0.0.1"}, 180 | "X-Forwarded-Port": {"56789"}, 181 | "X-Forwarded-Proto": {"https"}, 182 | "X-Forwarded-By": {"localhost"}, 183 | }, 184 | out: http.Header{ 185 | "Forwarded": {`for=212.53.1.6, for=127.0.0.1`}, 186 | }, 187 | }, 188 | } 189 | 190 | for _, test := range tests { 191 | t.Run("", func(t *testing.T) { 192 | translateXForwarded(test.in) 193 | 194 | if !reflect.DeepEqual(test.in, test.out) { 195 | t.Error(test.in) 196 | } 197 | }) 198 | } 199 | } 200 | 201 | func TestQuoteForwarded(t *testing.T) { 202 | tests := []struct { 203 | in string 204 | out string 205 | }{ 206 | {"", `""`}, 207 | {"127.0.0.1", `127.0.0.1`}, 208 | {"2001:db8:cafe::17", `"[2001:db8:cafe::17]"`}, 209 | {"[2001:db8:cafe::17]", `"[2001:db8:cafe::17]"`}, 210 | {"_gazonk", `"_gazonk"`}, 211 | } 212 | 213 | for _, test := range tests { 214 | t.Run(test.in, func(t *testing.T) { 215 | if s := quoteForwarded(test.in); s != test.out { 216 | t.Error(s) 217 | } 218 | }) 219 | } 220 | } 221 | 222 | func TestMakeForwarded(t *testing.T) { 223 | tests := []struct { 224 | proto string 225 | forAddr string 226 | byAddr string 227 | forwarded string 228 | }{ 229 | { /* all zero-values */ }, 230 | 231 | { 232 | proto: "http", 233 | forwarded: "proto=http", 234 | }, 235 | 236 | { 237 | forAddr: "127.0.0.1", 238 | forwarded: "for=127.0.0.1", 239 | }, 240 | 241 | { 242 | byAddr: "127.0.0.1", 243 | forwarded: "by=127.0.0.1", 244 | }, 245 | 246 | { 247 | proto: "http", 248 | forAddr: "212.53.1.6", 249 | byAddr: "127.0.0.1", 250 | forwarded: "proto=http;for=212.53.1.6;by=127.0.0.1", 251 | }, 252 | } 253 | 254 | for _, test := range tests { 255 | if s := makeForwarded(test.proto, test.forAddr, test.byAddr); s != test.forwarded { 256 | t.Error(s) 257 | } 258 | } 259 | } 260 | 261 | func TestAddForwarded(t *testing.T) { 262 | tests := []struct { 263 | header http.Header 264 | proto string 265 | forAddr string 266 | byAddr string 267 | result http.Header 268 | }{ 269 | { 270 | header: http.Header{}, 271 | proto: "http", 272 | forAddr: "127.0.0.1:56789", 273 | byAddr: "127.0.0.1:80", 274 | result: http.Header{"Forwarded": { 275 | `proto=http;for="127.0.0.1:56789";by="127.0.0.1:80"`, 276 | }}, 277 | }, 278 | 279 | { 280 | header: http.Header{"Forwarded": { 281 | `proto=https;for="212.53.1.6:54387";by="127.0.0.1:443"`, 282 | }}, 283 | proto: "http", 284 | forAddr: "127.0.0.1:56789", 285 | byAddr: "127.0.0.1:80", 286 | result: http.Header{"Forwarded": { 287 | `proto=https;for="212.53.1.6:54387";by="127.0.0.1:443", proto=http;for="127.0.0.1:56789";by="127.0.0.1:80"`, 288 | }}, 289 | }, 290 | } 291 | 292 | for _, test := range tests { 293 | t.Run(test.result.Get("Forwarded"), func(t *testing.T) { 294 | addForwarded(test.header, test.proto, test.forAddr, test.byAddr) 295 | if !reflect.DeepEqual(test.header, test.result) { 296 | t.Error(test.header) 297 | } 298 | }) 299 | } 300 | } 301 | 302 | func TestMakeVia(t *testing.T) { 303 | tests := []struct { 304 | version string 305 | host string 306 | via string 307 | }{ 308 | { 309 | version: "1.1", 310 | host: "localhost", 311 | via: "1.1 localhost", 312 | }, 313 | } 314 | 315 | for _, test := range tests { 316 | t.Run(test.via, func(t *testing.T) { 317 | if via := makeVia(test.version, test.host); via != test.via { 318 | t.Error(via) 319 | } 320 | }) 321 | } 322 | } 323 | 324 | func TestAddVia(t *testing.T) { 325 | tests := []struct { 326 | version string 327 | host string 328 | in http.Header 329 | out http.Header 330 | }{ 331 | { 332 | version: "1.1", 333 | host: "localhost", 334 | in: http.Header{}, 335 | out: http.Header{ 336 | "Via": {"1.1 localhost"}, 337 | }, 338 | }, 339 | { 340 | version: "1.1", 341 | host: "localhost", 342 | in: http.Header{ 343 | "Via": {"1.1 laptop"}, 344 | }, 345 | out: http.Header{ 346 | "Via": {"1.1 laptop, 1.1 localhost"}, 347 | }, 348 | }, 349 | } 350 | 351 | for _, test := range tests { 352 | t.Run(test.out.Get("Via"), func(t *testing.T) { 353 | addVia(test.in, test.version, test.host) 354 | 355 | if !reflect.DeepEqual(test.in, test.out) { 356 | t.Error(test.in) 357 | } 358 | }) 359 | } 360 | } 361 | 362 | func TestMaxForwards(t *testing.T) { 363 | tests := []struct { 364 | in http.Header 365 | max int 366 | }{ 367 | { 368 | in: http.Header{}, 369 | max: -1, 370 | }, 371 | { 372 | in: http.Header{"Max-Forwards": {"42"}}, 373 | max: 42, 374 | }, 375 | } 376 | 377 | for _, test := range tests { 378 | t.Run("", func(t *testing.T) { 379 | max, err := maxForwards(test.in) 380 | 381 | if err != nil { 382 | t.Error(err) 383 | } 384 | 385 | if max != test.max { 386 | t.Error("max =", max) 387 | } 388 | }) 389 | } 390 | } 391 | 392 | func TestConnectionUpgrade(t *testing.T) { 393 | tests := []struct { 394 | in http.Header 395 | out string 396 | }{ 397 | { 398 | in: http.Header{}, 399 | out: "", 400 | }, 401 | { 402 | in: http.Header{ // missing "Connection" 403 | "Upgrade": {"websocket"}, 404 | }, 405 | out: "", 406 | }, 407 | { 408 | in: http.Header{ // missing "Upgrade" 409 | "Connection": {"Upgrade"}, 410 | }, 411 | out: "", 412 | }, 413 | { 414 | in: http.Header{ 415 | "Connection": {"Upgrade"}, 416 | "Upgrade": {"websocket"}, 417 | }, 418 | out: "websocket", 419 | }, 420 | } 421 | 422 | for _, test := range tests { 423 | t.Run(test.out, func(t *testing.T) { 424 | if s := connectionUpgrade(test.in); s != test.out { 425 | t.Error(s) 426 | } 427 | }) 428 | } 429 | } 430 | 431 | func TestHeaderValuesRemoveTokens(t *testing.T) { 432 | tests := []struct { 433 | values []string 434 | tokens []string 435 | result []string 436 | }{ 437 | { // empty values 438 | values: []string{}, 439 | tokens: []string{}, 440 | result: []string{}, 441 | }, 442 | { // remove nothing 443 | values: []string{"A", "B", "C"}, 444 | tokens: []string{"other"}, 445 | result: []string{"A", "B", "C"}, 446 | }, 447 | { // remove first 448 | values: []string{"A", "B", "C"}, 449 | tokens: []string{"A"}, 450 | result: []string{"B", "C"}, 451 | }, 452 | { // remove middle 453 | values: []string{"A", "B", "C"}, 454 | tokens: []string{"B"}, 455 | result: []string{"A", "C"}, 456 | }, 457 | { // remove last 458 | values: []string{"A", "B", "C"}, 459 | tokens: []string{"C"}, 460 | result: []string{"A", "B"}, 461 | }, 462 | { // remove all 463 | values: []string{"A", "B", "C"}, 464 | tokens: []string{"A", "B", "C"}, 465 | result: []string{}, 466 | }, 467 | { // remove inner (single) 468 | values: []string{"A, B", "C"}, 469 | tokens: []string{"A"}, 470 | result: []string{"B", "C"}, 471 | }, 472 | { // remove inner (multi) 473 | values: []string{"A, B, C"}, 474 | tokens: []string{"A", "B"}, 475 | result: []string{"C"}, 476 | }, 477 | } 478 | 479 | for _, test := range tests { 480 | t.Run("", func(t *testing.T) { 481 | result := headerValuesRemoveTokens(test.values, test.tokens...) 482 | 483 | if !reflect.DeepEqual(result, test.result) { 484 | t.Error(result) 485 | } 486 | }) 487 | } 488 | } 489 | 490 | func TestIsIdempotent(t *testing.T) { 491 | tests := []struct { 492 | method string 493 | is bool 494 | }{ 495 | {"HEAD", true}, 496 | {"GET", true}, 497 | {"PUT", true}, 498 | {"OPTIONS", true}, 499 | {"DELETE", true}, 500 | {"POST", false}, 501 | {"TRACE", false}, 502 | {"PATCH", false}, 503 | } 504 | 505 | for _, test := range tests { 506 | t.Run(test.method, func(t *testing.T) { 507 | if is := isIdempotent(test.method); is != test.is { 508 | t.Error(is) 509 | } 510 | }) 511 | } 512 | } 513 | 514 | func TestIsRetriable(t *testing.T) { 515 | tests := []struct { 516 | status int 517 | retry bool 518 | }{ 519 | {http.StatusOK, false}, 520 | {http.StatusInternalServerError, true}, 521 | {http.StatusNotImplemented, false}, 522 | {http.StatusBadGateway, true}, 523 | {http.StatusServiceUnavailable, true}, 524 | {http.StatusGatewayTimeout, true}, 525 | {http.StatusHTTPVersionNotSupported, false}, 526 | } 527 | 528 | for _, test := range tests { 529 | t.Run(fmt.Sprint(test.status), func(t *testing.T) { 530 | if retry := isRetriable(test.status); retry != test.retry { 531 | t.Error(retry) 532 | } 533 | }) 534 | } 535 | } 536 | -------------------------------------------------------------------------------- /httpx/proxy.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "crypto/tls" 7 | "io" 8 | "io/ioutil" 9 | "net" 10 | "net/http" 11 | "net/http/httputil" 12 | "strconv" 13 | "sync" 14 | "time" 15 | 16 | "github.com/segmentio/netx" 17 | ) 18 | 19 | // ReverseProxy is a HTTP handler which implements the logic of a reverse HTTP 20 | // proxy, forwarding incoming requests to backend servers. 21 | // 22 | // The implementation is similar to httputil.ReverseProxy but the implementation 23 | // has some differences. Instead of using a Director function to rewrite the 24 | // request to its destination the proxy expects the request it receives to be 25 | // already well constructed to be forwarded to a backend server. Any conforming 26 | // HTTP client aware of being behing a proxy would have included the full URL in 27 | // the request line which the proxy will use to extract the backend address. 28 | // 29 | // The proxy also converts the X-Forwarded headers to Forwarded as defined by 30 | // RFC 7239 (see https://tools.ietf.org/html/rfc7239). 31 | // 32 | // HTTP upgrades are also supported by this reverse HTTP proxy implementation, 33 | // the proxy forwards the HTTP handshake requesting an upgrade to the backend 34 | // server, then if it gets a successful 101 Switching Protocol response it will 35 | // start acting as a simple TCP tunnel between the client and backend server. 36 | // 37 | // Finally, the proxy also properly handles the Max-Forward header for TRACE and 38 | // OPTIONS methods, decrementing the value or directly responding to the client 39 | // if it reaches zero. 40 | type ReverseProxy struct { 41 | // Transport is used to forward HTTP requests to backend servers. If nil, 42 | // http.DefaultTransport is used instead. 43 | Transport http.RoundTripper 44 | 45 | // DialContext is used for dialing new TCP connections on HTTP upgrades or 46 | // CONNECT requests. 47 | DialContext func(context.Context, string, string) (net.Conn, error) 48 | 49 | // TLSClientConfig specifies the TLS configuration to use for HTTP upgrades 50 | // that happen over a secured link. 51 | // If nil, the default configuration is used. 52 | TLSClientConfig *tls.Config 53 | } 54 | 55 | // ServeHTTP satisfies the http.Handler interface. 56 | func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { 57 | remoteAddr := req.RemoteAddr 58 | localAddr := requestLocalAddr(req) 59 | 60 | // Forwarded requests always use the HTTP/1.1 protocol when talking to the 61 | // backend server. 62 | outurl := *req.URL 63 | outreq := *req 64 | outreq.URL = &outurl 65 | outreq.Header = make(http.Header, len(req.Header)) 66 | outreq.Proto = "HTTP/1.1" 67 | outreq.ProtoMajor = 1 68 | outreq.ProtoMinor = 1 69 | outreq.Close = false 70 | 71 | // No target host was set on the request URL, assuming the client intended 72 | // to reach req.Host then. 73 | if len(outreq.URL.Host) == 0 { 74 | outreq.URL.Host = req.Host 75 | } 76 | 77 | // No target protocol was set, attempting to guess it from the port that the 78 | // client is trying to connect to (fail later otherwise). 79 | if len(outreq.URL.Scheme) == 0 { 80 | outreq.URL.Scheme = guessScheme(localAddr, req.URL.Host) 81 | } 82 | 83 | // Remove hop-by-hop headers from the request so they aren't forwarded to 84 | // the backend servers. 85 | copyHeader(outreq.Header, req.Header) 86 | deleteHopFields(outreq.Header) 87 | 88 | // There must be host set on the URL otherwise the proxy cannot forward the 89 | // request to any backend server. 90 | if len(outreq.URL.Host) == 0 { 91 | w.WriteHeader(http.StatusBadRequest) 92 | return 93 | } 94 | 95 | // Add proxy headers, Forwarded, Via, and convert X-Forwarded-For. 96 | if _, hasFwd := outreq.Header["Forwarded"]; !hasFwd { 97 | translateXForwarded(outreq.Header) 98 | } 99 | addForwarded(outreq.Header, outreq.URL.Scheme, remoteAddr, localAddr) 100 | addVia(outreq.Header, protoVersion(req), localAddr) 101 | 102 | switch method := outreq.Method; method { 103 | case http.MethodConnect: 104 | p.serveCONNECT(w, &outreq) 105 | return 106 | case http.MethodTrace, http.MethodOptions: 107 | // Decrement the Max-Forward header for TRACE and OPTIONS requests. 108 | max, err := maxForwards(outreq.Header) 109 | if max--; max == 0 || err != nil { 110 | if method == http.MethodTrace { 111 | p.serveTRACE(w, &outreq) 112 | } else { 113 | p.serveOPTIONS(w, &outreq) 114 | } 115 | return 116 | } 117 | outreq.Header.Set("Max-Forward", strconv.Itoa(max)) 118 | } 119 | 120 | // The proxy has to forward a protocol upgrade, we open a new connection to 121 | // the target host that we can make exclusive use of, then the handshake is 122 | // performed and the proxy starts passing bytes back and forth. 123 | if upgrade := connectionUpgrade(req.Header); len(upgrade) != 0 { 124 | outreq.Header.Set("Connection", "Upgrade") 125 | outreq.Header.Set("Upgrade", upgrade) 126 | p.serveUpgrade(w, &outreq) 127 | return 128 | } 129 | 130 | transport := p.Transport 131 | if transport == nil { 132 | transport = http.DefaultTransport 133 | } 134 | 135 | res, err := transport.RoundTrip(&outreq) 136 | if err != nil { 137 | w.WriteHeader(http.StatusBadGateway) 138 | return 139 | } 140 | 141 | deleteHopFields(res.Header) 142 | copyHeader(w.Header(), res.Header) 143 | 144 | w.WriteHeader(res.StatusCode) 145 | netx.Copy(w, res.Body) 146 | res.Body.Close() 147 | 148 | deleteHopFields(res.Trailer) 149 | copyHeader(w.Header(), res.Trailer) 150 | } 151 | 152 | func (p *ReverseProxy) serveCONNECT(w http.ResponseWriter, req *http.Request) { 153 | dial := p.DialContext 154 | if dial == nil { 155 | dial = (&net.Dialer{Timeout: 10 * time.Second}).DialContext 156 | } 157 | 158 | join := &sync.WaitGroup{} 159 | defer join.Wait() 160 | 161 | ctx, cancel := context.WithCancel(req.Context()) 162 | defer cancel() 163 | 164 | backend, err := dial(ctx, "tcp", req.URL.Host) 165 | if err != nil { 166 | w.WriteHeader(http.StatusBadGateway) 167 | return 168 | } 169 | defer backend.Close() 170 | 171 | io.Copy(ioutil.Discard, req.Body) 172 | req.Body.Close() 173 | w.WriteHeader(http.StatusOK) 174 | 175 | frontend, rw, err := w.(http.Hijacker).Hijack() 176 | if err != nil { 177 | panic(err) 178 | } 179 | defer frontend.Close() 180 | 181 | join.Add(1) 182 | go func(r *bufio.Reader) { 183 | defer join.Done() 184 | defer cancel() 185 | 186 | if _, err := r.WriteTo(backend); err != nil { 187 | return 188 | } 189 | 190 | r = nil 191 | netx.Copy(backend, frontend) 192 | }(rw.Reader) 193 | 194 | join.Add(1) 195 | go func(w *bufio.Writer) { 196 | defer join.Done() 197 | defer cancel() 198 | 199 | if err := w.Flush(); err != nil { 200 | return 201 | } 202 | 203 | w = nil 204 | netx.Copy(frontend, backend) 205 | }(rw.Writer) 206 | 207 | rw = nil 208 | <-ctx.Done() 209 | } 210 | 211 | func (p *ReverseProxy) serveOPTIONS(w http.ResponseWriter, req *http.Request) { 212 | w.WriteHeader(http.StatusOK) 213 | } 214 | 215 | func (p *ReverseProxy) serveTRACE(w http.ResponseWriter, req *http.Request) { 216 | content, err := httputil.DumpRequest(req, true) 217 | if err != nil { 218 | panic(err) 219 | } 220 | w.Header().Set("Content-Type", "message/http") 221 | w.WriteHeader(http.StatusOK) 222 | w.Write(content) 223 | } 224 | 225 | func (p *ReverseProxy) serveUpgrade(w http.ResponseWriter, req *http.Request) { 226 | dial := p.DialContext 227 | if dial == nil { 228 | dial = (&net.Dialer{Timeout: 10 * time.Second}).DialContext 229 | } 230 | 231 | ctx := req.Context() 232 | 233 | backend, err := dial(ctx, "tcp", req.URL.Host) 234 | if err != nil { 235 | w.WriteHeader(http.StatusBadGateway) 236 | return 237 | } 238 | if req.URL.Scheme == "https" { 239 | backend = tls.Client(backend, p.TLSClientConfig) 240 | } 241 | defer backend.Close() 242 | 243 | res, err := (&ConnTransport{ 244 | Conn: backend, 245 | ResponseHeaderTimeout: 10 * time.Second, 246 | }).RoundTrip(req) 247 | if err != nil { 248 | w.WriteHeader(http.StatusBadGateway) 249 | return 250 | } 251 | 252 | // Forward the response to the protocol upgrade request, removing the 253 | // hop-by-hop headers, except the Upgrade header which is used by some 254 | // protocol upgrades. 255 | upgrade := res.Header["Upgrade"] 256 | deleteHopFields(res.Header) 257 | if len(upgrade) != 0 { 258 | res.Header["Upgrade"] = upgrade 259 | res.Header["Connection"] = []string{"Upgrade"} 260 | } 261 | copyHeader(w.Header(), res.Header) 262 | w.WriteHeader(res.StatusCode) 263 | netx.Copy(w, res.Body) 264 | res.Body.Close() 265 | 266 | // Switching to a different protocol failed apparently, stopping here and 267 | // the server will wait for the next request on that connection. 268 | if res.StatusCode != http.StatusSwitchingProtocols { 269 | return 270 | } 271 | 272 | // No need to keep references to these objects anymore, the GC may collect 273 | // them if possible. 274 | upgrade = nil 275 | req = nil 276 | res = nil 277 | 278 | frontend, rw, err := w.(http.Hijacker).Hijack() 279 | if err != nil { 280 | w.WriteHeader(http.StatusInternalServerError) 281 | return 282 | } 283 | defer frontend.Close() 284 | 285 | if err := rw.Writer.Flush(); err != nil { 286 | return // the client is gone 287 | } 288 | 289 | done := make(chan struct{}, 2) 290 | go forward(rw.Writer, backend, done) 291 | go forward(backend, rw.Reader, done) 292 | 293 | // Wait for either the connections to be closed or the context to be 294 | // canceled. 295 | select { 296 | case <-done: 297 | case <-ctx.Done(): 298 | } 299 | } 300 | 301 | // guessScheme attempts to guess the protocol that should be used for a proxied 302 | // request (either http or https). 303 | func guessScheme(localAddr string, remoteAddr string) string { 304 | if scheme, _ := netx.SplitNetAddr(localAddr); scheme == "tls" { 305 | return "https" 306 | } 307 | switch _, port, _ := net.SplitHostPort(remoteAddr); port { 308 | case "", "80": 309 | return "http" 310 | case "443": 311 | return "https" 312 | } 313 | return "http" 314 | } 315 | 316 | // forward copies bytes from r to w, sending a signal on the done channel when 317 | // the copy completes. 318 | func forward(w io.Writer, r io.Reader, done chan<- struct{}) { 319 | defer func() { done <- struct{}{} }() 320 | netx.Copy(w, r) 321 | } 322 | 323 | // requestLocalAddr looks for the request's local address in its context and 324 | // returns the string representation. 325 | func requestLocalAddr(req *http.Request) string { 326 | addr := contextLocalAddr(req.Context()) 327 | if addr == nil { 328 | return "" 329 | } 330 | return addr.String() 331 | } 332 | 333 | // contextLocalAddr looks for the request's local address in ctx and returns it. 334 | func contextLocalAddr(ctx context.Context) net.Addr { 335 | val := ctx.Value(http.LocalAddrContextKey) 336 | if val == nil { 337 | return nil 338 | } 339 | addr, _ := val.(net.Addr) 340 | return addr 341 | } 342 | -------------------------------------------------------------------------------- /httpx/proxy_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/segmentio/netx" 8 | "github.com/segmentio/netx/httpx/httpxtest" 9 | ) 10 | 11 | func TestProxy(t *testing.T) { 12 | httpxtest.TestServer(t, func(config httpxtest.ServerConfig) (string, func()) { 13 | origin, closeOrigin := listenAndServe(&Server{ 14 | ReadTimeout: config.ReadTimeout, 15 | WriteTimeout: config.WriteTimeout, 16 | MaxHeaderBytes: config.MaxHeaderBytes, 17 | Handler: config.Handler, 18 | }) 19 | 20 | proxy, closeProxy := listenAndServe(&Server{ 21 | ReadTimeout: config.ReadTimeout, 22 | WriteTimeout: config.WriteTimeout, 23 | MaxHeaderBytes: config.MaxHeaderBytes, 24 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 25 | _, req.URL.Host = netx.SplitNetAddr(origin) 26 | (&ReverseProxy{}).ServeHTTP(w, req) 27 | }), 28 | }) 29 | 30 | return proxy, func() { 31 | closeProxy() 32 | closeOrigin() 33 | } 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /httpx/quote.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | // quoted is a type alias to string that implements the fmt.Stringer and 10 | // fmt.Formatter interfaces to efficiently output HTTP tokens or quoted 11 | // strings. 12 | type quoted string 13 | 14 | // parseQuoted parses s, potentially removing the quotes if it is a quoted 15 | // string. 16 | func parseQuoted(s string) (q quoted, err error) { 17 | if isToken(s) { 18 | q = quoted(s) 19 | return 20 | } 21 | 22 | n := len(s) 23 | 24 | if n < 2 || s[0] != '"' || s[n-1] != '"' { 25 | err = errorInvalidQuotedString(s) 26 | return 27 | } 28 | 29 | e := false 30 | b := bytes.Buffer{} 31 | b.Grow(n) 32 | 33 | for _, c := range s[1 : n-1] { 34 | if e { 35 | e = false 36 | switch c { 37 | case 'n': 38 | c = '\n' 39 | case 'r': 40 | c = '\r' 41 | case 't': 42 | c = '\t' 43 | case 'v': 44 | c = '\v' 45 | case 'f': 46 | c = '\f' 47 | case 'b': 48 | c = '\b' 49 | case '0': 50 | c = '\x00' 51 | } 52 | } else if c == '\\' { 53 | e = true 54 | continue 55 | } else if c == '"' { 56 | err = errorInvalidQuotedString(s) 57 | return 58 | } 59 | b.WriteRune(c) 60 | } 61 | 62 | if e { 63 | err = errorInvalidQuotedString(s) 64 | return 65 | } 66 | 67 | q = quoted(b.String()) 68 | return 69 | } 70 | 71 | // String satisfies the fmt.Stringer interface. 72 | func (q quoted) String() string { 73 | return fmt.Sprint(q) 74 | } 75 | 76 | // Format satisfies the fmt.Formatter interface. 77 | func (q quoted) Format(w fmt.State, _ rune) { 78 | if s := string(q); isToken(s) { 79 | fmt.Fprint(w, s) 80 | } else { 81 | writeQuoted(w, s) 82 | } 83 | } 84 | 85 | // quote returns a quoted representation of s, not checking whether s is a valid 86 | // HTTP token. 87 | func quote(s string) string { 88 | b := &bytes.Buffer{} 89 | b.Grow(len(s) + 10) 90 | writeQuoted(b, s) 91 | return b.String() 92 | } 93 | 94 | // writeQuoted writes the quoted representation of s to w. 95 | func writeQuoted(w io.Writer, s string) { 96 | fmt.Fprint(w, `"`) 97 | i := 0 98 | j := 0 99 | n := len(s) 100 | 101 | for j < n { 102 | c := s[j] 103 | j++ 104 | 105 | switch c { 106 | case '\\', '"': 107 | case '\n': 108 | c = 'n' 109 | case '\r': 110 | c = 'r' 111 | case '\t': 112 | c = 't' 113 | case '\f': 114 | c = 'f' 115 | case '\v': 116 | c = 'v' 117 | case '\b': 118 | c = 'b' 119 | case '\x00': 120 | c = '0' 121 | default: 122 | continue 123 | } 124 | 125 | fmt.Fprintf(w, `%s\%c`, s[i:j-1], c) 126 | i = j 127 | } 128 | 129 | fmt.Fprintf(w, `%s"`, s[i:]) 130 | } 131 | 132 | // split s at b, properly handling inner quoted strings in s, which means b 133 | // cannot be a double-quote. 134 | func split(s string, b byte) (head string, tail string) { 135 | e := false 136 | q := false 137 | 138 | for i := range s { 139 | if q { 140 | if e { 141 | e = false 142 | } else if s[i] == '\\' { 143 | e = true 144 | } else if s[i] == '"' { 145 | q = false 146 | } 147 | continue 148 | } 149 | 150 | if s[i] == '"' { 151 | q = true 152 | continue 153 | } 154 | 155 | if !q && s[i] == b { 156 | head = s[:i] 157 | tail = s[i+1:] 158 | return 159 | } 160 | } 161 | 162 | head = s 163 | return 164 | } 165 | 166 | // split s at b, properly handling inner quoted strings in s, and trimming the 167 | // results for leading and trailing white spaces. 168 | func splitTrimOWS(s string, b byte) (head string, tail string) { 169 | head, tail = split(s, b) 170 | head = trimOWS(head) 171 | tail = trimOWS(tail) 172 | return 173 | } 174 | 175 | func errorInvalidQuotedString(s string) error { 176 | return fmt.Errorf("invalid quoted string: %#v", s) 177 | } 178 | -------------------------------------------------------------------------------- /httpx/quote_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import "testing" 4 | 5 | var quotedTests = []struct { 6 | s string 7 | q quoted 8 | }{ 9 | {`""`, ``}, 10 | {`hello`, `hello`}, 11 | {`"你好"`, `你好`}, 12 | {`"hello\\world"`, "hello\\world"}, 13 | {`"hello\"world"`, "hello\"world"}, 14 | {`"hello\nworld"`, "hello\nworld"}, 15 | {`"hello\rworld"`, "hello\rworld"}, 16 | {`"hello\tworld"`, "hello\tworld"}, 17 | {`"hello\vworld"`, "hello\vworld"}, 18 | {`"hello\fworld"`, "hello\fworld"}, 19 | {`"hello\bworld"`, "hello\bworld"}, 20 | {`"hello\0world"`, "hello\x00world"}, 21 | } 22 | 23 | func TestParseQuotedFailure(t *testing.T) { 24 | tests := []struct { 25 | s string 26 | }{ 27 | {`,`}, // non-token and non-quoted-string (single byte) 28 | {`,,,`}, // non-token and non-quoted-string (multi byte) 29 | {`"hello"world"`}, // unexpected double-quote 30 | {`"hello\"`}, // terminated by an escaped quote 31 | } 32 | 33 | for _, test := range tests { 34 | t.Run(test.s, func(t *testing.T) { 35 | if q, err := parseQuoted(test.s); err == nil { 36 | t.Error(q) 37 | } 38 | }) 39 | } 40 | } 41 | 42 | func TestParseQuotedSuccess(t *testing.T) { 43 | for _, test := range quotedTests { 44 | t.Run(test.s, func(t *testing.T) { 45 | q, err := parseQuoted(test.s) 46 | 47 | if err != nil { 48 | t.Error(err) 49 | } 50 | 51 | if q != test.q { 52 | t.Error(q) 53 | } 54 | }) 55 | } 56 | } 57 | 58 | func TestQuotedString(t *testing.T) { 59 | for _, test := range quotedTests { 60 | t.Run(string(test.q), func(t *testing.T) { 61 | if s := test.q.String(); s != test.s { 62 | t.Error(s) 63 | } 64 | }) 65 | } 66 | } 67 | 68 | func TestSplitTrimOWS(t *testing.T) { 69 | tests := []struct { 70 | s string 71 | b byte 72 | head string 73 | tail string 74 | }{ 75 | { 76 | s: ``, 77 | b: ',', 78 | head: "", 79 | tail: "", 80 | }, 81 | { 82 | s: `hello, world`, 83 | b: ',', 84 | head: "hello", 85 | tail: "world", 86 | }, 87 | { 88 | s: `key1="hello, world", key2=`, 89 | b: ',', 90 | head: `key1="hello, world"`, 91 | tail: `key2=`, 92 | }, 93 | { 94 | s: `key1="message: \"hello, world\"", key2=`, 95 | b: ',', 96 | head: `key1="message: \"hello, world\""`, 97 | tail: `key2=`, 98 | }, 99 | } 100 | 101 | for _, test := range tests { 102 | t.Run(test.s, func(t *testing.T) { 103 | head, tail := splitTrimOWS(test.s, test.b) 104 | 105 | if head != test.head { 106 | t.Errorf("bad head: %#v", head) 107 | } 108 | 109 | if tail != test.tail { 110 | t.Errorf("bad tail: %#v", tail) 111 | } 112 | }) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /httpx/retry.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "time" 10 | ) 11 | 12 | const ( 13 | // DefaultMaxAttempts is the default number of attempts used by RetryHandler 14 | // and RetryTransport. 15 | DefaultMaxAttempts = 10 16 | ) 17 | 18 | // A RetryHandler is a http.Handler which retries calls to its sub-handler if 19 | // they fail with a 5xx code. When a request is retried the handler will apply 20 | // an exponential backoff to maximize the chances of success (because it is 21 | // usually unlikely that a failed request will succeed right away). 22 | // 23 | // Note that only idempotent methods are retried, because the handler doesn't 24 | // have enough context about why it failed, it wouldn't be safe to retry other 25 | // HTTP methods. 26 | type RetryHandler struct { 27 | // Handler is the sub-handler that the RetryHandler delegates requests to. 28 | // 29 | // ServeHTTP will panic if Handler is nil. 30 | Handler http.Handler 31 | 32 | // MaxAttampts is the maximum number of attempts that the handler will make 33 | // at handling a single request. 34 | // Zero means to use a default value. 35 | MaxAttempts int 36 | } 37 | 38 | // ServeHTTP satisfies the http.Handler interface. 39 | func (h *RetryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 40 | body := &retryRequestBody{ReadCloser: req.Body} 41 | req.Body = body 42 | 43 | res := &retryResponseWriter{ResponseWriter: w} 44 | max := h.MaxAttempts 45 | if max == 0 { 46 | max = DefaultMaxAttempts 47 | } 48 | 49 | for attempt := 0; true; { 50 | res.status = 0 51 | res.header = make(http.Header, 10) 52 | res.buffer.Reset() 53 | 54 | h.Handler.ServeHTTP(res, req) 55 | 56 | if res.status < 500 { 57 | return // success 58 | } 59 | 60 | if body.n != 0 { 61 | break 62 | } 63 | 64 | if !isRetriable(res.status) { 65 | break 66 | } 67 | 68 | if !isIdempotent(req.Method) { 69 | break 70 | } 71 | 72 | if attempt++; attempt >= max { 73 | break 74 | } 75 | 76 | if sleep(req.Context(), backoff(attempt)) != nil { 77 | break 78 | } 79 | } 80 | 81 | if res.status == 0 { 82 | res.status = http.StatusServiceUnavailable 83 | } 84 | 85 | // 5xx error, write the buffered response to the original writer. 86 | copyHeader(w.Header(), res.header) 87 | w.WriteHeader(res.status) 88 | res.buffer.WriteTo(w) 89 | } 90 | 91 | // RetryTransport is a http.RoundTripper which retries calls to its sub-handler 92 | // if they failed with connection or server errors. When a request is retried 93 | // the handler will apply an exponential backoff to maximize the chances of 94 | // success (because it is usually unlikely that a failed request will succeed 95 | // right away). 96 | // 97 | // Note that only idempotent methods are retried, because the handler doesn't 98 | // have enough context about why it failed, it wouldn't be safe to retry other 99 | // HTTP methods. 100 | type RetryTransport struct { 101 | // Transport is the sub-transport that the RetryTransport delegates requests 102 | // to. 103 | // 104 | // http.DefaultTransport is used if Transport is nil. 105 | Transport http.RoundTripper 106 | 107 | // MaxAttampts is the maximum number of attempts that the handler will make 108 | // at handling a single request. 109 | // Zero means to use a default value. 110 | MaxAttempts int 111 | } 112 | 113 | // RoundTrip satisfies the http.RoundTripper interface. 114 | func (t *RetryTransport) RoundTrip(req *http.Request) (res *http.Response, err error) { 115 | transport := t.Transport 116 | if transport == nil { 117 | transport = http.DefaultTransport 118 | } 119 | 120 | body := &retryRequestBody{ReadCloser: req.Body} 121 | req.Body = body 122 | 123 | max := t.MaxAttempts 124 | if max == 0 { 125 | max = DefaultMaxAttempts 126 | } 127 | 128 | for attempt := 0; true; { 129 | if res, err = transport.RoundTrip(req); err == nil { 130 | if res.StatusCode < 500 || !isRetriable(res.StatusCode) { 131 | break // success 132 | } 133 | } 134 | 135 | if body.n != 0 { 136 | err = fmt.Errorf("%s %s: failed and cannot be retried because %d bytes of the body have already been sent", req.Method, req.URL.Path, body.n) 137 | break 138 | } 139 | 140 | if !isIdempotent(req.Method) { 141 | err = fmt.Errorf("%s %s: failed and cannot be retried because the method is not idempotent", req.Method, req.URL.Path) 142 | break 143 | } 144 | 145 | if attempt++; attempt >= max { 146 | err = fmt.Errorf("%s %s: failed %d times: %s", req.Method, req.URL.Path, attempt, err) 147 | break 148 | } 149 | 150 | if err = sleep(req.Context(), backoff(attempt)); err != nil { 151 | break 152 | } 153 | } 154 | 155 | return 156 | } 157 | 158 | // retryResponseWriter is a http.ResponseWriter which captures 5xx responses. 159 | type retryResponseWriter struct { 160 | http.ResponseWriter 161 | status int 162 | header http.Header 163 | buffer bytes.Buffer 164 | } 165 | 166 | // Header satisfies the http.ResponseWriter interface. 167 | func (w *retryResponseWriter) Header() http.Header { 168 | return w.header 169 | } 170 | 171 | // WriteHeader satisfies the http.ResponseWriter interface. 172 | func (w *retryResponseWriter) WriteHeader(status int) { 173 | if w.status == 0 { 174 | w.status = status 175 | if status < 500 { 176 | copyHeader(w.ResponseWriter.Header(), w.header) 177 | w.ResponseWriter.WriteHeader(status) 178 | } 179 | } 180 | } 181 | 182 | // Write satisfies the http.ResponseWriter interface. 183 | func (w *retryResponseWriter) Write(b []byte) (int, error) { 184 | w.WriteHeader(http.StatusOK) 185 | if w.status >= 500 { 186 | return w.buffer.Write(b) 187 | } 188 | return w.ResponseWriter.Write(b) 189 | } 190 | 191 | // retryRequestBody is a io.ReadCloser wrapper which counts how many bytes were 192 | // processed by a request body. 193 | type retryRequestBody struct { 194 | io.ReadCloser 195 | n int 196 | } 197 | 198 | // Read satisfies the io.Reader interface. 199 | func (r *retryRequestBody) Read(b []byte) (n int, err error) { 200 | if n, err = r.ReadCloser.Read(b); n > 0 { 201 | r.n += n 202 | } 203 | return 204 | } 205 | 206 | // backoff returns the amount of time a goroutine should wait before retrying 207 | // what it was doing considering that it already made n attempts. 208 | func backoff(n int) time.Duration { 209 | return time.Duration(n*n) * 10 * time.Millisecond 210 | } 211 | 212 | // sleep puts the goroutine to sleep until either ctx is canceled or d amount of 213 | // time elapses. 214 | func sleep(ctx context.Context, d time.Duration) (err error) { 215 | if d != 0 { 216 | timer := time.NewTimer(d) 217 | select { 218 | case <-timer.C: 219 | case <-ctx.Done(): 220 | err = ctx.Err() 221 | } 222 | timer.Stop() 223 | } 224 | return 225 | } 226 | -------------------------------------------------------------------------------- /httpx/retry_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "io/ioutil" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | "time" 12 | 13 | "github.com/segmentio/netx/httpx/httpxtest" 14 | ) 15 | 16 | func TestRetryTransportDefault(t *testing.T) { 17 | httpxtest.TestTransport(t, func() http.RoundTripper { 18 | return &RetryTransport{} 19 | }) 20 | } 21 | 22 | func TestRetryTransportConfigured(t *testing.T) { 23 | httpxtest.TestTransport(t, func() http.RoundTripper { 24 | return &RetryTransport{ 25 | MaxAttempts: 1, 26 | } 27 | }) 28 | } 29 | 30 | func TestRetryHandler(t *testing.T) { 31 | tests := []struct { 32 | method string 33 | body string 34 | status int 35 | maxAttempts int 36 | }{ 37 | { // HEAD + default max attempts 38 | method: "HEAD", 39 | status: http.StatusOK, 40 | }, 41 | { // GET + default max attempts 42 | method: "GET", 43 | status: http.StatusOK, 44 | }, 45 | { // PUT + default max attempts 46 | method: "PUT", 47 | status: http.StatusOK, 48 | }, 49 | { // DELETE + default max attempts 50 | method: "DELETE", 51 | status: http.StatusOK, 52 | }, 53 | { // POST (not idempotent) + default max attempts 54 | method: "POST", 55 | status: http.StatusInternalServerError, 56 | }, 57 | { // GET + low max attempts 58 | method: "GET", 59 | status: http.StatusInternalServerError, 60 | maxAttempts: 1, 61 | }, 62 | { // PUT + non-empty request boyd 63 | method: "PUT", 64 | body: "Hello World!", 65 | status: http.StatusInternalServerError, 66 | }, 67 | } 68 | 69 | for _, test := range tests { 70 | t.Run("", func(t *testing.T) { 71 | req := httptest.NewRequest(test.method, "/", strings.NewReader(test.body)) 72 | res := httptest.NewRecorder() 73 | 74 | attempt := 0 75 | handler := &RetryHandler{ 76 | Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 77 | io.Copy(ioutil.Discard, req.Body) 78 | req.Body.Close() 79 | if attempt == 0 { 80 | w.WriteHeader(http.StatusInternalServerError) 81 | } else { 82 | w.WriteHeader(http.StatusOK) 83 | } 84 | attempt++ 85 | }), 86 | MaxAttempts: test.maxAttempts, 87 | } 88 | handler.ServeHTTP(res, req) 89 | 90 | if res.Code != test.status { 91 | t.Errorf("bad status code: expected %d but got %d", test.status, res.Code) 92 | } 93 | }) 94 | } 95 | } 96 | 97 | func TestSleepTimeout(t *testing.T) { 98 | ctx := context.Background() 99 | t0 := time.Now() 100 | err := sleep(ctx, 10*time.Millisecond) 101 | t1 := time.Now() 102 | 103 | if err != nil { 104 | t.Error(err) 105 | } 106 | 107 | if t1.Sub(t0) < (10 * time.Millisecond) { 108 | t.Error("sleep returned too early") 109 | } 110 | } 111 | 112 | func TestSleepCanceled(t *testing.T) { 113 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 114 | defer cancel() 115 | 116 | if err := sleep(ctx, 1*time.Second); err != context.DeadlineExceeded { 117 | t.Error(err) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /httpx/server.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "io/ioutil" 10 | "log" 11 | "net" 12 | "net/http" 13 | "strconv" 14 | "time" 15 | 16 | "github.com/segmentio/netx" 17 | ) 18 | 19 | // A Server implements the netx.Handler interface, it provides the handling of 20 | // HTTP requests from a net.Conn, graceful shutdowns... 21 | type Server struct { 22 | // Handler is called by the server for each HTTP request it received. 23 | Handler http.Handler 24 | 25 | // Upgrader is called by the server when an HTTP upgrade is detected. 26 | Upgrader http.Handler 27 | 28 | // IdleTimeout is the maximum amount of time the server waits on an inactive 29 | // connection before closing it. 30 | // Zero means no timeout. 31 | IdleTimeout time.Duration 32 | 33 | // ReadTimeout is the maximum amount of time the server waits for a request 34 | // to be fully read. 35 | // Zero means no timeout. 36 | ReadTimeout time.Duration 37 | 38 | // WriteTimeout is the maximum amount of time the server gives for responses 39 | // to be sent. 40 | // Zero means no timeout. 41 | WriteTimeout time.Duration 42 | 43 | // MaxHeaderBytes controls the maximum number of bytes the will read parsing 44 | // the request header's keys and values, including the request line. It does 45 | // not limit the size of the request body. 46 | // If zero, http.DefaultMaxHeaderBytes is used. 47 | MaxHeaderBytes int 48 | 49 | // ErrorLog specifies an optional logger for errors that occur when 50 | // attempting to proxy the request. If nil, logging goes to os.Stderr via 51 | // the log package's standard logger. 52 | ErrorLog *log.Logger 53 | 54 | // ServerName is the name of the server, returned in the "Server" response 55 | // header field. 56 | ServerName string 57 | } 58 | 59 | // ServeConn satisfies the netx.Handler interface. 60 | func (s *Server) ServeConn(ctx context.Context, conn net.Conn) { 61 | maxHeaderBytes := s.MaxHeaderBytes 62 | if maxHeaderBytes == 0 { 63 | maxHeaderBytes = http.DefaultMaxHeaderBytes 64 | } 65 | 66 | baseHeader := http.Header{ 67 | "Content-Type": {"application/octet-stream"}, 68 | "Server": {s.ServerName}, 69 | } 70 | if idleTimeout := s.IdleTimeout; idleTimeout != 0 { 71 | baseHeader.Set("Connection", "Keep-Alive") 72 | baseHeader.Set("Keep-Alive", fmt.Sprintf("timeout=%d", int(idleTimeout/time.Second))) 73 | } 74 | 75 | // The request context is completely detached from the server's main context 76 | // to allow in-flight request to be completed before terminating the server. 77 | var reqctx context.Context 78 | var cancel context.CancelFunc 79 | reqctx = context.Background() 80 | reqctx = context.WithValue(reqctx, http.LocalAddrContextKey, conn.LocalAddr()) 81 | reqctx, cancel = context.WithCancel(reqctx) 82 | 83 | sc := newServerConn(conn, cancel) 84 | defer sc.Close() 85 | 86 | res := &responseWriter{ 87 | header: make(http.Header, 10), 88 | conn: sc, 89 | timeout: s.WriteTimeout, 90 | } 91 | copyHeader(res.header, baseHeader) 92 | 93 | for { 94 | var req *http.Request 95 | var err error 96 | var closed bool 97 | 98 | if err = sc.waitReadyRead(ctx, s.IdleTimeout); err != nil { 99 | return 100 | } 101 | if req, err = sc.readRequest(reqctx, maxHeaderBytes, s.ReadTimeout); err != nil { 102 | return 103 | } 104 | res.req = req 105 | 106 | if closed = req.Close; closed { 107 | if req.ProtoAtLeast(1, 1) { 108 | res.header.Add("Connection", "close") 109 | } 110 | } else { 111 | if protoEqual(req, 1, 0) { 112 | res.header.Add("Connection", "keep-alive") 113 | } 114 | } 115 | 116 | s.serveHTTP(res, req, conn) 117 | 118 | if res.err != nil { // hijacked, or lost the connection 119 | return 120 | } 121 | if closed || req.Close { 122 | return 123 | } 124 | 125 | netx.Copy(ioutil.Discard, req.Body) 126 | req.Body.Close() 127 | 128 | res.reset(baseHeader) 129 | } 130 | } 131 | 132 | // ServeProxy satisfies the netx.ProxyHandler interface, it is used to support 133 | // transparent HTTP proxies, it rewrites the request to take into account the 134 | // fact that it was received on an intercepted connection and that the client 135 | // wasn't aware that it was being proxied. 136 | func (s *Server) ServeProxy(ctx context.Context, conn net.Conn, target net.Addr) { 137 | handler := http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { 138 | scheme := req.URL.Scheme 139 | if len(scheme) == 0 { 140 | scheme, _ = netx.SplitNetAddr(req.Host) 141 | } 142 | 143 | // If the Host had no scheme we're propbably in a transparent proxy and 144 | // the client didn't know it had to place the full URL in the header. 145 | // We attempt to guess the protocol from the network connection itself. 146 | if len(scheme) == 0 { 147 | if conn.LocalAddr().Network() == "tls" { 148 | scheme = "https" 149 | } else { 150 | scheme = "http" 151 | } 152 | } 153 | 154 | // Rewrite the URL to which the request will be forwarded. 155 | req.URL.Scheme = scheme 156 | req.URL.Host = target.String() 157 | 158 | // Fallback to the orignal server's handler. 159 | s.serveHTTP(res, req, conn) 160 | }) 161 | server := *s 162 | server.Upgrader = handler 163 | server.Handler = handler 164 | server.ServeConn(ctx, conn) 165 | } 166 | 167 | func (s *Server) serveHTTP(w http.ResponseWriter, req *http.Request, conn net.Conn) { 168 | defer func() { 169 | res := w.(*responseWriter) 170 | err := recover() 171 | 172 | if err != nil { 173 | netx.Recover(err, conn, s.ErrorLog) 174 | 175 | // If the header wasn't written yet when the error occurred we can 176 | // attempt to keep using the connection, otherwise we abort to 177 | // notify the client that something went wrong. 178 | if res.status == 0 { 179 | res.WriteHeader(http.StatusInternalServerError) 180 | } else { 181 | req.Close = true 182 | return 183 | } 184 | } 185 | 186 | res.close() 187 | res.Flush() 188 | }() 189 | 190 | handler := s.Handler 191 | upgrade := connectionUpgrade(req.Header) 192 | 193 | switch { 194 | case len(req.Header["Expect"]) != 0: 195 | handler = StatusHandler(http.StatusExpectationFailed) 196 | 197 | case len(upgrade) != 0: 198 | if s.Upgrader == nil { 199 | handler = StatusHandler(http.StatusNotImplemented) 200 | } else { 201 | handler = s.Upgrader 202 | } 203 | } 204 | 205 | handler.ServeHTTP(w, req) 206 | } 207 | 208 | // serverConn is a net.Conn that embeds a I/O buffers and a connReader, this is 209 | // mainly used as an optimization to reduce the number of dynamic memory 210 | // allocations. 211 | type serverConn struct { 212 | c connReader 213 | bufio.Reader 214 | bufio.Writer 215 | } 216 | 217 | func newServerConn(conn net.Conn, cancel context.CancelFunc) *serverConn { 218 | c := &serverConn{c: connReader{Conn: conn, limit: -1, cancel: cancel}} 219 | c.Reader = *bufio.NewReader(&c.c) 220 | c.Writer = *bufio.NewWriter(&c.c) 221 | return c 222 | } 223 | 224 | func (conn *serverConn) Close() error { return conn.c.Close() } 225 | func (conn *serverConn) LocalAddr() net.Addr { return conn.c.LocalAddr() } 226 | func (conn *serverConn) RemoteAddr() net.Addr { return conn.c.RemoteAddr() } 227 | func (conn *serverConn) SetDeadline(t time.Time) error { return conn.c.SetDeadline(t) } 228 | func (conn *serverConn) SetReadDeadline(t time.Time) error { return conn.c.SetReadDeadline(t) } 229 | func (conn *serverConn) SetWriteDeadline(t time.Time) error { return conn.c.SetWriteDeadline(t) } 230 | 231 | func (conn *serverConn) waitReadyRead(ctx context.Context, timeout time.Duration) (err error) { 232 | deadline := time.Now().Add(timeout) 233 | for { 234 | conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 235 | 236 | if _, err = conn.Peek(1); err == nil { 237 | return 238 | } 239 | 240 | if !netx.IsTimeout(err) { 241 | return 242 | } 243 | 244 | if timeout != 0 && deadline.Before(time.Now()) { 245 | err = netx.Timeout("i/o timeout waiting for an HTTP request") 246 | return 247 | } 248 | 249 | select { 250 | case <-ctx.Done(): 251 | return 252 | default: 253 | } 254 | } 255 | } 256 | 257 | func (conn *serverConn) readRequest(ctx context.Context, maxHeaderBytes int, timeout time.Duration) (req *http.Request, err error) { 258 | // Limit the size of the request header, if readRequest attempts to read 259 | // more than maxHeaderBytes it will get io.EOF. 260 | conn.c.limit = maxHeaderBytes 261 | 262 | if timeout != 0 { 263 | conn.SetReadDeadline(time.Now().Add(timeout)) 264 | } else { 265 | conn.SetReadDeadline(time.Time{}) 266 | } 267 | 268 | if req, err = http.ReadRequest(&conn.Reader); err != nil { 269 | return 270 | } 271 | req = req.WithContext(ctx) 272 | 273 | // Drop the size limit on the connection reader to let the request body 274 | // go through. 275 | conn.c.limit = -1 276 | return 277 | } 278 | 279 | // responseWriter is an implementation of the http.ResponseWriter interface. 280 | // 281 | // Instances of responseWriter provide most of the features exposed in the 282 | // standard library, however it doesn't do automatic detection of the content 283 | // type. 284 | type responseWriter struct { 285 | status int // status code of the response 286 | header http.Header // header sent in the response 287 | conn *serverConn // connection that the server got a request from 288 | req *http.Request // request that the writer sends a response for 289 | timeout time.Duration // timeout for the full write operation 290 | err error // any error detected internally by the writer 291 | remain uint64 // the remaining number of bytes to write 292 | hasBody bool // true when the request method allows to send a response body 293 | chunked bool // true when the writer uses "Transfer-Encoding: chunked" 294 | cw chunkWriter // chunk writer used with "Transfer-Encoding: chunked" 295 | } 296 | 297 | // Hijack satisfies the http.Hijacker interface. 298 | func (res *responseWriter) Hijack() (conn net.Conn, rw *bufio.ReadWriter, err error) { 299 | if res.err != nil { 300 | err = res.err 301 | return 302 | } 303 | 304 | if res.chunked { 305 | if err = res.cw.Flush(); err != nil { 306 | res.err = err 307 | return 308 | } 309 | } 310 | 311 | conn, rw = res.conn.c.Conn, bufio.NewReadWriter(&res.conn.Reader, &res.conn.Writer) 312 | res.conn = nil 313 | res.status = http.StatusSwitchingProtocols 314 | res.err = http.ErrHijacked 315 | 316 | // Cancel all deadlines on the connection before returning it. 317 | conn.SetDeadline(time.Time{}) 318 | return 319 | } 320 | 321 | // Header satisfies the http.ResponseWriter interface. 322 | func (res *responseWriter) Header() http.Header { 323 | return res.header 324 | } 325 | 326 | // WriteHeader satisfies the http.ResponseWriter interface. 327 | func (res *responseWriter) WriteHeader(status int) { 328 | if res.status != 0 { 329 | return 330 | } 331 | if status == 0 { 332 | status = http.StatusOK 333 | } 334 | res.status = status 335 | res.hasBody = status >= 200 && 336 | status != http.StatusNoContent && 337 | status != http.StatusNotModified 338 | 339 | // The chunkWriter's buffer is unused for now, we'll use it to write the 340 | // status line and avoid a couple of memory allocations (because byte 341 | // slices sent to the bufio.Writer will be seen as escaping to the 342 | // underlying io.Writer). 343 | var b = res.cw.b[:0] 344 | var c = res.conn 345 | var h = res.header 346 | 347 | if timeout := res.timeout; timeout != 0 { 348 | c.SetWriteDeadline(time.Now().Add(timeout)) 349 | } 350 | 351 | if res.hasBody { 352 | if s, hasLen := h["Content-Length"]; !hasLen { 353 | h.Set("Transfer-Encoding", "chunked") 354 | res.chunked = true 355 | res.cw.w = res.conn 356 | res.cw.n = 0 357 | } else if res.remain, res.err = strconv.ParseUint(s[0], 10, 64); res.err != nil { 358 | // The program put an invalid value in Content-Length, that's a 359 | // programming error. 360 | res.err = errors.New("bad Content-Length: " + s[0]) 361 | return 362 | } 363 | } else { 364 | // In case the application mistakenly added these. 365 | h.Del("Transfer-Encoding") 366 | h.Del("Content-Length") 367 | } 368 | 369 | if _, hasDate := h["Date"]; !hasDate { 370 | h.Set("Date", now().Format(time.RFC1123)) 371 | } 372 | 373 | b = append(b, res.req.Proto...) 374 | b = append(b, ' ') 375 | b = strconv.AppendInt(b, int64(status), 10) 376 | b = append(b, ' ') 377 | b = append(b, http.StatusText(status)...) 378 | b = append(b, '\r', '\n') 379 | 380 | if _, err := c.Write(b); err != nil { 381 | res.err = err 382 | return 383 | } 384 | if err := h.Write(c); err != nil { 385 | res.err = err 386 | return 387 | } 388 | if _, err := c.WriteString("\r\n"); err != nil { 389 | res.err = err 390 | return 391 | } 392 | } 393 | 394 | // Write satisfies the io.Writer and http.ResponseWriter interfaces. 395 | func (res *responseWriter) Write(b []byte) (n int, err error) { 396 | if err = res.err; err == nil { 397 | res.WriteHeader(0) 398 | 399 | if !res.hasBody { 400 | err = http.ErrBodyNotAllowed 401 | return 402 | } 403 | 404 | if res.chunked { 405 | n, err = res.cw.Write(b) 406 | } else { 407 | for len(b) != 0 && err == nil { 408 | if res.remain == 0 { 409 | // The program sent more bytes that it declared in the 410 | // Content-Length header. 411 | err = http.ErrContentLength 412 | return 413 | } 414 | 415 | n1 := uint64(len(b)) 416 | n2 := res.remain 417 | 418 | if n1 > n2 { 419 | n1 = n2 420 | } 421 | 422 | if n, err = res.conn.Write(b[:int(n1)]); n > 0 { 423 | b = b[n:] 424 | res.remain -= uint64(n) 425 | } 426 | } 427 | } 428 | 429 | res.err = err 430 | } 431 | return 432 | } 433 | 434 | // Flush satsifies the http.Flusher interface. 435 | func (res *responseWriter) Flush() { 436 | if res.err == nil { 437 | res.WriteHeader(0) 438 | 439 | if res.err == nil { 440 | if res.chunked { 441 | if res.err = res.cw.Flush(); res.err != nil { 442 | return 443 | } 444 | } 445 | res.err = res.conn.Flush() 446 | } 447 | } 448 | } 449 | 450 | func (res *responseWriter) close() { 451 | if res.chunked { 452 | res.WriteHeader(0) 453 | 454 | if res.err == nil { 455 | res.err = res.cw.Close() 456 | } 457 | } 458 | } 459 | 460 | func (res *responseWriter) reset(baseHeader http.Header) { 461 | res.remain = 0 462 | res.hasBody = false 463 | res.chunked = false 464 | res.cw.w = nil 465 | res.cw.n = 0 466 | res.req = nil 467 | res.header = make(http.Header, 10) 468 | copyHeader(res.header, baseHeader) 469 | } 470 | 471 | // chunkWriter provides the implementation of an HTTP writer that outputs a 472 | // response body using the chunked transfer encoding. 473 | type chunkWriter struct { 474 | w io.Writer // writer to which data are flushed 475 | n int // offset in of the last byte in b 476 | a [8]byte // buffer used for writing the chunk size 477 | b [512]byte // buffer used to aggregate small chunks 478 | } 479 | 480 | func (res *chunkWriter) Write(b []byte) (n int, err error) { 481 | for len(b) != 0 { 482 | n1 := len(b) 483 | n2 := len(res.b) - res.n 484 | 485 | if n1 >= n2 { 486 | if res.n == 0 { 487 | // Nothing is buffered and we have a large chunk already, bypass 488 | // the chunkWriter's buffer and directly output to its writer. 489 | return res.writeChunk(b) 490 | } 491 | n1 = n2 492 | } 493 | 494 | copy(res.b[res.n:], b[:n1]) 495 | res.n += n1 496 | n += n1 497 | 498 | if b = b[n1:]; len(b) != 0 { 499 | if err = res.Flush(); err != nil { 500 | break 501 | } 502 | } 503 | } 504 | return 505 | } 506 | 507 | func (res *chunkWriter) Close() (err error) { 508 | if err = res.Flush(); err == nil { 509 | _, err = res.w.Write(append(res.a[:0], "0\r\n\r\n"...)) 510 | } 511 | return 512 | } 513 | 514 | func (res *chunkWriter) Flush() (err error) { 515 | var n int 516 | 517 | if n, err = res.writeChunk(res.b[:res.n]); n > 0 { 518 | if n == res.n { 519 | res.n = 0 520 | } else { 521 | // Not all buffered data could be flushed, moving the bytes to the 522 | // front of the chunkWriter's buffer. 523 | copy(res.b[:], res.b[n:res.n]) 524 | res.n -= n 525 | } 526 | } 527 | 528 | return 529 | } 530 | 531 | func (res *chunkWriter) writeChunk(b []byte) (n int, err error) { 532 | if len(b) == 0 { 533 | // Don't write empty chunks, they would be misinterpreted as the end of 534 | // the stream. 535 | return 536 | } 537 | 538 | a := append(strconv.AppendInt(res.a[:0], int64(len(b)), 16), '\r', '\n') 539 | 540 | if _, err = res.w.Write(a); err != nil { 541 | return 542 | } 543 | if n, err = res.w.Write(b); err != nil { 544 | return 545 | } 546 | _, err = res.w.Write(a[len(a)-2:]) // CRLF 547 | return 548 | } 549 | 550 | var ( 551 | timezone = time.FixedZone("GMT", 0) 552 | ) 553 | 554 | func now() time.Time { 555 | return time.Now().In(timezone) 556 | } 557 | -------------------------------------------------------------------------------- /httpx/server_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | "testing" 8 | 9 | "github.com/segmentio/netx" 10 | "github.com/segmentio/netx/httpx/httpxtest" 11 | ) 12 | 13 | func TestServer(t *testing.T) { 14 | httpxtest.TestServer(t, func(config httpxtest.ServerConfig) (string, func()) { 15 | return listenAndServe(&Server{ 16 | Handler: config.Handler, 17 | ReadTimeout: config.ReadTimeout, 18 | WriteTimeout: config.WriteTimeout, 19 | MaxHeaderBytes: config.MaxHeaderBytes, 20 | }) 21 | }) 22 | } 23 | 24 | func listenAndServe(h netx.Handler) (url string, close func()) { 25 | lstn, err := netx.Listen("127.0.0.1:0") 26 | if err != nil { 27 | panic(err) 28 | } 29 | 30 | ctx, cancel := context.WithCancel(context.Background()) 31 | go (&netx.Server{ 32 | Handler: h, 33 | Context: ctx, 34 | ErrorLog: log.New(os.Stderr, "listen: ", 0), 35 | }).Serve(lstn) 36 | 37 | url, close = "http://"+lstn.Addr().String(), cancel 38 | return 39 | } 40 | -------------------------------------------------------------------------------- /httpx/transport.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import "net/http" 4 | 5 | // RoundTripperFunc makes it possible to use regular functions as transports for 6 | // HTTP clients. 7 | type RoundTripperFunc func(*http.Request) (*http.Response, error) 8 | 9 | // RoundTrip calls f. 10 | func (f RoundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { 11 | return f(r) 12 | } 13 | -------------------------------------------------------------------------------- /httpx/upgrade.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | ) 7 | 8 | // UpgradeMap maps protocol names to HTTP handlers that should be used to 9 | // perform upgrades from HTTP to a different protocol. 10 | // 11 | // It is expected that the upgrade handler flushes and hijacks the connection 12 | // after sending the response, and doesn't return from its ServeHTTP method 13 | // until it's done serving the connection (or it would be closed prematuraly). 14 | // 15 | // A special-case is made for the name "*" which indicates that the handler is 16 | // set as a the fallback upgrade handler to handle unrecognized protocols. 17 | // 18 | // Keys in the UpgradeMap map should be formatted with http.CanonicalHeaderKey. 19 | type UpgradeMap map[string]http.Handler 20 | 21 | // ServeHTTP satisfies the http.Handler interface so an UpgradeMap can be used as 22 | // handler on an HTTP server. 23 | func (u UpgradeMap) ServeHTTP(w http.ResponseWriter, req *http.Request) { 24 | h := u[http.CanonicalHeaderKey(req.Header.Get("Upgrade"))] 25 | 26 | if h == nil { 27 | h = u["*"] 28 | } 29 | 30 | if h == nil { 31 | w.WriteHeader(http.StatusNotImplemented) 32 | return 33 | } 34 | 35 | h.ServeHTTP(w, req) 36 | } 37 | 38 | // UpgradeMux maps protocol names to HTTP handlers that should be used to 39 | // perform upgrades from HTTP to a different protocol. 40 | // 41 | // It is expected that the upgrade handler flushes and hijacks the connection 42 | // after sending the response, and doesn't return from its ServeHTTP method 43 | // until it's done serving the connection (or it would be closed prematuraly). 44 | // 45 | // UpgradeMux exposes the exact same API than http.ServeMux, therefore is safe 46 | // to use by multiple concurrent goroutines. 47 | type UpgradeMux struct { 48 | mutex sync.RWMutex 49 | upgrader UpgradeMap 50 | } 51 | 52 | // NewUpgradeMux allocates and returns a new UpgradeMux. 53 | func NewUpgradeMux() *UpgradeMux { 54 | return &UpgradeMux{} 55 | } 56 | 57 | // Handle registers a handler for the given protocol name. If a handler already 58 | // exists for name, Handle panics. 59 | // 60 | // A special-case is made for the name "*" which indicates that the handler is 61 | // set as a the fallback upgrade handler to handle unrecognized protocols. 62 | func (mux *UpgradeMux) Handle(name string, handler http.Handler) { 63 | var key string 64 | 65 | if name != "*" { 66 | key = http.CanonicalHeaderKey(name) 67 | } 68 | 69 | defer mux.mutex.Unlock() 70 | mux.mutex.Lock() 71 | 72 | if mux.upgrader[key] != nil { 73 | panic("an upgrade handler already exists for " + name) 74 | } 75 | 76 | if mux.upgrader == nil { 77 | mux.upgrader = make(UpgradeMap) 78 | } 79 | 80 | mux.upgrader[key] = handler 81 | } 82 | 83 | // HandleFunc registers a handler function for the given protocol name. If a 84 | // handler already exists for name, HandleFunc panics. 85 | func (mux *UpgradeMux) HandleFunc(name string, handler func(http.ResponseWriter, *http.Request)) { 86 | mux.Handle(name, http.HandlerFunc(handler)) 87 | } 88 | 89 | // Handler returns the appropriate http.Handler for serving req. 90 | func (mux *UpgradeMux) Handler(req *http.Request) http.Handler { 91 | key := http.CanonicalHeaderKey(req.Header.Get("Upgrade")) 92 | 93 | if len(key) == 0 { 94 | return nil 95 | } 96 | 97 | mux.mutex.RLock() 98 | h := mux.upgrader[key] 99 | mux.mutex.RUnlock() 100 | return h 101 | } 102 | 103 | // ServeHTTP satisfies the http.Handler interface so UpgradeMux can be used as 104 | // handler on an HTTP server. 105 | func (mux *UpgradeMux) ServeHTTP(w http.ResponseWriter, req *http.Request) { 106 | h := mux.Handler(req) 107 | 108 | if h == nil { 109 | w.WriteHeader(http.StatusNotImplemented) 110 | return 111 | } 112 | 113 | h.ServeHTTP(w, req) 114 | } 115 | -------------------------------------------------------------------------------- /ip.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import "net" 4 | 5 | // IsIP checks if s is a valid representation of an IPv4 or IPv6 address. 6 | func IsIP(s string) bool { 7 | return net.ParseIP(s) != nil 8 | } 9 | 10 | // IsIPv4 checks if s is a valid representation of an IPv4 address. 11 | func IsIPv4(s string) bool { 12 | ip := net.ParseIP(s) 13 | return ip != nil && ip.To4() != nil 14 | } 15 | 16 | // IsIPv6 checks if s is a valid representation of an IPv6 address. 17 | func IsIPv6(s string) bool { 18 | ip := net.ParseIP(s) 19 | return ip != nil && ip.To4() == nil 20 | } 21 | -------------------------------------------------------------------------------- /ip_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import "testing" 4 | 5 | func TestIsIP(t *testing.T) { 6 | tests := []struct { 7 | s string 8 | x bool 9 | }{ 10 | {"", false}, 11 | {"127", false}, 12 | {"127.0.0.1", true}, 13 | {"::", true}, 14 | {"::1", true}, 15 | } 16 | 17 | for _, test := range tests { 18 | t.Run(test.s, func(t *testing.T) { 19 | if x := IsIP(test.s); x != test.x { 20 | t.Error(test.s) 21 | } 22 | }) 23 | } 24 | } 25 | 26 | func TestIsIPv4(t *testing.T) { 27 | tests := []struct { 28 | s string 29 | x bool 30 | }{ 31 | {"", false}, 32 | {"127", false}, 33 | {"127.0.0.1", true}, 34 | {"::", false}, 35 | {"::1", false}, 36 | } 37 | 38 | for _, test := range tests { 39 | t.Run(test.s, func(t *testing.T) { 40 | if x := IsIPv4(test.s); x != test.x { 41 | t.Error(test.s) 42 | } 43 | }) 44 | } 45 | } 46 | 47 | func TestIsIPv6(t *testing.T) { 48 | tests := []struct { 49 | s string 50 | x bool 51 | }{ 52 | {"", false}, 53 | {"127", false}, 54 | {"127.0.0.1", false}, 55 | {"::", true}, 56 | {"::1", true}, 57 | } 58 | 59 | for _, test := range tests { 60 | t.Run(test.s, func(t *testing.T) { 61 | if x := IsIPv6(test.s); x != test.x { 62 | t.Error(test.s) 63 | } 64 | }) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /listen.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net" 7 | "os" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | ) 12 | 13 | // Listen is equivalent to net.Listen but guesses the network from the address. 14 | // 15 | // The function accepts addresses that may be prefixed by a URL scheme to set 16 | // the protocol that will be used, supported protocols are tcp, tcp4, tcp6, 17 | // unix, unixpacket, and fd. 18 | // 19 | // The address may contain a path to a file for unix sockets, a pair of an IP 20 | // address and port, a pair of a network interface name and port, or just port. 21 | // 22 | // If the port is omitted for network addresses the operating system will pick 23 | // one automatically. 24 | func Listen(address string) (lstn net.Listener, err error) { 25 | var network string 26 | var addrs []string 27 | 28 | if network, addrs, err = resolveListen(address, "tcp", "unix", []string{ 29 | "tcp", 30 | "tcp4", 31 | "tcp6", 32 | "unix", 33 | "unixpacket", 34 | "fd", 35 | }); err != nil { 36 | return 37 | } 38 | 39 | if len(addrs) == 1 { 40 | return listen(network, addrs[0]) 41 | } 42 | 43 | lstns := make([]net.Listener, 0, len(addrs)) 44 | 45 | for _, a := range addrs { 46 | l, e := listen(network, a) 47 | if e != nil { 48 | for _, l := range lstns { 49 | l.Close() 50 | } 51 | return 52 | } 53 | lstns = append(lstns, l) 54 | } 55 | 56 | lstn = MultiListener(lstns...) 57 | return 58 | } 59 | 60 | func listen(network string, address string) (lstn net.Listener, err error) { 61 | if network == "fd" { 62 | var fd int 63 | var f *os.File 64 | var c net.Conn 65 | 66 | if fd, err = strconv.Atoi(address); err != nil { 67 | err = errors.New("invalid file descriptor in fd://" + address) 68 | return 69 | } else if fd < 0 { 70 | err = errors.New("invalid negative file descriptor in fd://" + address) 71 | return 72 | } 73 | 74 | f = os.NewFile(uintptr(fd), network) 75 | defer f.Close() 76 | 77 | if c, err = net.FileConn(f); err != nil { 78 | return 79 | } 80 | return NewRecvUnixListener(c.(*net.UnixConn)), nil 81 | } 82 | return net.Listen(network, address) 83 | } 84 | 85 | // ListenPacket is similar to Listen but returns a PacketConn, and works with 86 | // udp, udp4, udp6, ip, ip4, ip6, unixdgram, or fd protocols. 87 | func ListenPacket(address string) (conn net.PacketConn, err error) { 88 | var network string 89 | var addrs []string 90 | 91 | if network, addrs, err = resolveListen(address, "udp", "unixdgram", []string{ 92 | "udp", 93 | "udp4", 94 | "udp6", 95 | "ip", 96 | "ip4", 97 | "ip6", 98 | "unixdgram", 99 | "fd", 100 | }); err != nil { 101 | return 102 | } 103 | 104 | if network == "fd" { 105 | var fd int 106 | var f *os.File 107 | var c net.Conn 108 | 109 | if fd, err = strconv.Atoi(addrs[0]); err != nil { 110 | err = errors.New("invalid file descriptor in fd://" + addrs[0]) 111 | return 112 | } else if fd < 0 { 113 | err = errors.New("invalid negative file descriptor in fd://" + addrs[0]) 114 | return 115 | } 116 | 117 | f = os.NewFile(uintptr(fd), network) 118 | defer f.Close() 119 | 120 | if c, err = net.FileConn(f); err != nil { 121 | return 122 | } 123 | conn = c.(net.PacketConn) 124 | return 125 | } 126 | 127 | // TODO: listen on all addresses? 128 | for _, a := range addrs { 129 | if conn, err = net.ListenPacket(network, a); err == nil { 130 | break 131 | } 132 | } 133 | 134 | return 135 | } 136 | 137 | func resolveListen(address string, defaultProtoNetwork string, defaultProtoUnix string, protocols []string) (network string, addrs []string, err error) { 138 | var host string 139 | var port string 140 | var ifi *net.Interface 141 | 142 | if off := strings.Index(address, "://"); off >= 0 { 143 | for _, proto := range protocols { 144 | if strings.HasPrefix(address, proto+"://") { 145 | network, address = proto, address[len(proto)+3:] 146 | break 147 | } 148 | } 149 | 150 | if len(network) == 0 { 151 | err = errors.New("unsupported protocol: " + address[:off]) 152 | return 153 | } 154 | } 155 | 156 | if network == "fd" { 157 | if _, err = strconv.Atoi(address); err != nil { 158 | err = errors.New("expected file descriptor number with fd:// protocol but found " + address) 159 | } 160 | addrs = []string{address} 161 | return 162 | } 163 | 164 | if strings.HasPrefix(address, ":") { // :port 165 | network = defaultProtoNetwork 166 | addrs = []string{address} 167 | return 168 | } 169 | 170 | if host, port, err = net.SplitHostPort(address); err != nil { 171 | err = nil 172 | 173 | if strings.HasPrefix(address, ":") { 174 | // the address doesn't mention which interface to listen on 175 | port = address[1:] 176 | } else { 177 | // the address doesn't mention which port to listen on 178 | host = address 179 | } 180 | } 181 | 182 | if IsIP(host) { 183 | // The function received a simple IP address to listen on. 184 | addrs = append(addrs, address) 185 | 186 | if len(network) == 0 { 187 | network = defaultProtoNetwork 188 | } 189 | 190 | } else if ifi, err = net.InterfaceByName(host); err == nil { 191 | // The function received the name of a network interface, we have to 192 | // lookup the list of all network addresses to listen on. 193 | var ifa []net.Addr 194 | 195 | if ifa, err = ifi.Addrs(); err != nil { 196 | return 197 | } 198 | 199 | for _, a := range ifa { 200 | s := a.String() 201 | if len(port) != 0 { 202 | s = net.JoinHostPort(s, port) 203 | } 204 | addrs = append(addrs, s) 205 | } 206 | 207 | if len(network) == 0 { 208 | network = defaultProtoNetwork 209 | } 210 | 211 | } else { 212 | err = nil 213 | // Neither an IP address nor a network interface name was passed, we 214 | // assume this address is probably the path to a unix domain socket. 215 | addrs = append(addrs, address) 216 | 217 | if len(network) == 0 { 218 | network = defaultProtoUnix 219 | } 220 | } 221 | 222 | return 223 | } 224 | 225 | // MultiListener returns a compound listener made of the given list of 226 | // listeners. 227 | func MultiListener(lstn ...net.Listener) net.Listener { 228 | c := make(chan net.Conn) 229 | e := make(chan error) 230 | d := make(chan struct{}) 231 | x := make(chan struct{}) 232 | m := &multiListener{ 233 | l: append(make([]net.Listener, 0, len(lstn)), lstn...), 234 | c: c, 235 | e: e, 236 | d: d, 237 | x: x, 238 | } 239 | 240 | for _, l := range m.l { 241 | go func(l net.Listener, c chan<- net.Conn, e chan<- error, d chan<- struct{}) { 242 | defer func() { d <- struct{}{} }() 243 | for { 244 | if conn, err := l.Accept(); err == nil { 245 | c <- conn 246 | } else { 247 | e <- err 248 | 249 | if !IsTemporary(err) { 250 | break 251 | } 252 | } 253 | } 254 | }(l, c, e, d) 255 | } 256 | 257 | return m 258 | } 259 | 260 | type multiListener struct { 261 | l []net.Listener // the list of listeners 262 | c <-chan net.Conn // connections from Accept are published on this channel 263 | e <-chan error // errors from Accept are published on this channel 264 | d <-chan struct{} // each goroutine publishes to this channel when they exit 265 | x chan struct{} // closed when the listener is closed 266 | 267 | // Used by Close to allow multiple goroutines to call the method as well as 268 | // allowing the method to be called multiple times. 269 | once sync.Once 270 | } 271 | 272 | func (m *multiListener) Accept() (conn net.Conn, err error) { 273 | select { 274 | case conn = <-m.c: 275 | case err = <-m.e: 276 | case <-m.x: 277 | err = io.ErrClosedPipe 278 | } 279 | return 280 | } 281 | 282 | func (m *multiListener) Close() (err error) { 283 | m.once.Do(func() { 284 | var errs []string 285 | 286 | for _, l := range m.l { 287 | if e := l.Close(); e != nil { 288 | errs = append(errs, e.Error()) 289 | } 290 | } 291 | 292 | for i, n := 0, len(m.l); i != n; { 293 | select { 294 | case conn := <-m.c: 295 | conn.Close() 296 | case <-m.e: 297 | case <-m.d: 298 | i++ 299 | } 300 | } 301 | 302 | if errs != nil { 303 | err = errors.New(strings.Join(errs, "; ")) 304 | } 305 | 306 | close(m.x) 307 | }) 308 | return 309 | } 310 | 311 | func (m *multiListener) Addr() net.Addr { 312 | a := make(MultiAddr, len(m.l)) 313 | 314 | for i, l := range m.l { 315 | a[i] = l.Addr() 316 | } 317 | 318 | return a 319 | } 320 | -------------------------------------------------------------------------------- /pair.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | // ConnPair returns a pair of connections, each of them being the end of a 11 | // bidirectional communcation channel. network should be one of "tcp", "tcp4", 12 | // "tcp6", or "unix". 13 | func ConnPair(network string) (net.Conn, net.Conn, error) { 14 | switch network { 15 | case "unix": 16 | return UnixConnPair() 17 | case "tcp", "tcp4", "tcp6": 18 | return TCPConnPair(network) 19 | default: 20 | return nil, nil, errors.New("unsupported network pair: " + network) 21 | } 22 | } 23 | 24 | // TCPConnPair returns a pair of TCP connections, each of them being the end of a 25 | // bidirectional communication channel. 26 | func TCPConnPair(network string) (nc1 *net.TCPConn, nc2 *net.TCPConn, err error) { 27 | var lstn *net.TCPListener 28 | var ch1 = make(chan error, 1) 29 | var ch2 = make(chan *net.TCPConn, 1) 30 | 31 | if lstn, err = net.ListenTCP(network, nil); err != nil { 32 | return 33 | } 34 | defer lstn.Close() 35 | 36 | go func() { 37 | var conn *net.TCPConn 38 | var err error 39 | 40 | if conn, err = net.DialTCP(network, nil, lstn.Addr().(*net.TCPAddr)); err != nil { 41 | ch1 <- err 42 | } else { 43 | ch2 <- conn 44 | } 45 | }() 46 | 47 | if nc1, err = lstn.AcceptTCP(); err != nil { 48 | return 49 | } 50 | 51 | select { 52 | case nc2 = <-ch2: 53 | case err = <-ch1: 54 | nc1.Close() 55 | nc1 = nil 56 | } 57 | return 58 | } 59 | 60 | // UnixConnPair returns a pair of unix connections, each of them being the end of a 61 | // bidirection communication channel. 62 | func UnixConnPair() (uc1 *net.UnixConn, uc2 *net.UnixConn, err error) { 63 | var fd1 int 64 | var fd2 int 65 | 66 | if fd1, fd2, err = socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0); err != nil { 67 | return 68 | } 69 | 70 | f1 := os.NewFile(uintptr(fd1), "") 71 | f2 := os.NewFile(uintptr(fd2), "") 72 | 73 | defer f1.Close() 74 | defer f2.Close() 75 | 76 | var c1 net.Conn 77 | var c2 net.Conn 78 | 79 | if c1, err = net.FileConn(f1); err != nil { 80 | return 81 | } 82 | 83 | if c2, err = net.FileConn(f2); err != nil { 84 | c1.Close() 85 | return 86 | } 87 | 88 | uc1 = c1.(*net.UnixConn) 89 | uc2 = c2.(*net.UnixConn) 90 | return 91 | } 92 | -------------------------------------------------------------------------------- /pair_darwin.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "os" 5 | "syscall" 6 | ) 7 | 8 | func socketpair(domain int, socktype int, protocol int) (fd1 int, fd2 int, err error) { 9 | var fds [2]int 10 | fd1 = -1 11 | fd2 = -1 12 | syscall.ForkLock.Lock() 13 | 14 | if fds, err = syscall.Socketpair(domain, socktype, protocol); err != nil { 15 | err = os.NewSyscallError("socketpair", err) 16 | } 17 | 18 | syscall.CloseOnExec(fds[0]) 19 | syscall.CloseOnExec(fds[1]) 20 | syscall.ForkLock.Unlock() 21 | 22 | if err = syscall.SetNonblock(fds[0], true); err != nil { 23 | syscall.Close(fds[0]) 24 | syscall.Close(fds[1]) 25 | err = os.NewSyscallError("setnonblock", err) 26 | return 27 | } 28 | 29 | if err = syscall.SetNonblock(fds[1], true); err != nil { 30 | syscall.Close(fds[0]) 31 | syscall.Close(fds[1]) 32 | err = os.NewSyscallError("setnonblock", err) 33 | return 34 | } 35 | 36 | fd1 = fds[0] 37 | fd2 = fds[1] 38 | return 39 | } 40 | -------------------------------------------------------------------------------- /pair_linux.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "os" 5 | "syscall" 6 | ) 7 | 8 | func socketpair(domain int, socktype int, protocol int) (fd1 int, fd2 int, err error) { 9 | var fds [2]int 10 | fd1 = -1 11 | fd2 = -1 12 | 13 | if fds, err = syscall.Socketpair(domain, socktype|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0); err != nil { 14 | err = os.NewSyscallError("socketpair", err) 15 | return 16 | } 17 | 18 | fd1 = fds[0] 19 | fd2 = fds[1] 20 | return 21 | } 22 | -------------------------------------------------------------------------------- /pair_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "testing" 7 | 8 | "golang.org/x/net/nettest" 9 | ) 10 | 11 | func TestConnPair(t *testing.T) { 12 | for _, network := range [...]string{ 13 | "unix", 14 | "tcp", 15 | "tcp4", 16 | "tcp6", 17 | } { 18 | network := network // capture in lambda 19 | t.Run(network, func(t *testing.T) { 20 | t.Parallel() 21 | nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { 22 | if c1, c2, err = ConnPair(network); err == nil { 23 | stop = func() { 24 | c1.Close() 25 | c2.Close() 26 | } 27 | } 28 | return 29 | }) 30 | }) 31 | } 32 | } 33 | 34 | func TestUnixConnPairCloseWrite(t *testing.T) { 35 | c1, c2, err := UnixConnPair() 36 | if err != nil { 37 | t.Error(err) 38 | return 39 | } 40 | defer c1.Close() 41 | defer c2.Close() 42 | 43 | b := make([]byte, 100) 44 | 45 | if err := c1.CloseWrite(); err != nil { 46 | t.Error(err) 47 | return 48 | } 49 | 50 | if _, err := c2.Read(b); err != io.EOF { 51 | t.Error("expected EOF but got", err) 52 | return 53 | } 54 | 55 | if _, err := c2.Write([]byte("Hello World!")); err != nil { 56 | t.Error(err) 57 | return 58 | } 59 | 60 | if n, err := c1.Read(b); err != nil { 61 | t.Error(err) 62 | return 63 | } else if n != 12 { 64 | t.Error("bad number of byts returned:", n) 65 | return 66 | } 67 | 68 | if s := string(b[:12]); s != "Hello World!" { 69 | t.Error("bad content read:", s) 70 | return 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/binary" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "strconv" 12 | ) 13 | 14 | // ProxyHandler is an interface that must be implemented by types that intend to 15 | // proxy connections. 16 | // 17 | // The ServeProxy method is called by a Proxy when it receives a new connection. 18 | // It is similar to the ServeConn method of the Handler interface but receives 19 | // an extra target argument representing the original address that the 20 | // intercepted connection intended to reach. 21 | type ProxyHandler interface { 22 | ServeProxy(ctx context.Context, conn net.Conn, target net.Addr) 23 | } 24 | 25 | // ProxyHandlerFunc makes it possible for simple function types to be used as 26 | // connection proxies. 27 | type ProxyHandlerFunc func(context.Context, net.Conn, net.Addr) 28 | 29 | // ServeProxy calls f. 30 | func (f ProxyHandlerFunc) ServeProxy(ctx context.Context, conn net.Conn, target net.Addr) { 31 | f(ctx, conn, target) 32 | } 33 | 34 | // A Proxy is a connection handler that forwards its connections to a proxy 35 | // handler. 36 | type Proxy struct { 37 | // Handler is the proxy handler to which connetions are forwarded to. 38 | Handler ProxyHandler 39 | 40 | // Addr is the address to which connections are proxied. 41 | Addr net.Addr 42 | } 43 | 44 | // ServeConn satsifies the Handler interface. 45 | func (p *Proxy) ServeConn(ctx context.Context, conn net.Conn) { 46 | p.Handler.ServeProxy(ctx, conn, p.Addr) 47 | } 48 | 49 | // A TransparentProxy is a connection handler for intercepted connections. 50 | // 51 | // A proper usage of this proxy requires some iptables rules to redirect TCP 52 | // connections to to the listener its attached to. 53 | type TransparentProxy struct { 54 | // Handler is called by the proxy when it receives a connection that can be 55 | // proxied. 56 | // 57 | // Calling ServeConn on the proxy will panic if this field is nil. 58 | Handler ProxyHandler 59 | } 60 | 61 | // ServeConn satisfies the Handler interface. 62 | // 63 | // The method panics to report errors. 64 | func (p *TransparentProxy) ServeConn(ctx context.Context, conn net.Conn) { 65 | target, err := OriginalTargetAddr(conn) 66 | if err != nil { 67 | panic(err) 68 | } 69 | p.Handler.ServeProxy(ctx, conn, target) 70 | } 71 | 72 | // OriginalTargetAddr returns the original address that an intercepted 73 | // connection intended to reach. 74 | // 75 | // Note that this feature is only available for TCP connections on linux, 76 | // the function always returns an error on other platforms. 77 | func OriginalTargetAddr(conn net.Conn) (net.Addr, error) { 78 | return originalTargetAddr(conn) 79 | } 80 | 81 | // ProxyProtocol is the implementation of a connection handler which speaks 82 | // the proxy protocol. 83 | // 84 | // When the handler receives a LOCAL connection it handles the connection itself 85 | // and simply closes the connection. 86 | // 87 | // Version 1 and 2 are supported. 88 | // 89 | // http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt 90 | type ProxyProtocol struct { 91 | Handler Handler 92 | } 93 | 94 | // ServeConn satisifies the Handler interface. 95 | func (p *ProxyProtocol) ServeConn(ctx context.Context, conn net.Conn) { 96 | src, _, buf, local, err := parseProxyProto(conn) 97 | 98 | if err != nil { 99 | panic(err) 100 | } 101 | 102 | if local { 103 | conn.Close() 104 | return 105 | } 106 | 107 | proxyConn := &proxyProtoConn{ 108 | Conn: conn, 109 | src: src, 110 | buf: buf, 111 | } 112 | p.Handler.ServeConn(ctx, proxyConn) 113 | } 114 | 115 | type proxyProtoConn struct { 116 | net.Conn 117 | src net.Addr 118 | buf []byte 119 | } 120 | 121 | func (c *proxyProtoConn) Base() net.Conn { 122 | return c.Conn 123 | } 124 | 125 | func (c *proxyProtoConn) RemoteAddr() net.Addr { 126 | return c.src 127 | } 128 | 129 | func (c *proxyProtoConn) Read(b []byte) (n int, err error) { 130 | if len(c.buf) != 0 { 131 | n = copy(b, c.buf) 132 | c.buf = c.buf[n:] 133 | return 134 | } 135 | return c.Conn.Read(b) 136 | } 137 | 138 | var ( 139 | proxy = [...]byte{'P', 'R', 'O', 'X', 'Y'} 140 | tcp4 = [...]byte{'T', 'C', 'P', '4'} 141 | tcp6 = [...]byte{'T', 'C', 'P', '6'} 142 | crlf = [...]byte{'\r', '\n'} 143 | signature = [...]byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} 144 | ) 145 | 146 | func appendProxyProtoV1(b []byte, src net.Addr, dst net.Addr) []byte { 147 | var srcPortBuf [8]byte 148 | var dstPortBuf [8]byte 149 | var family []byte 150 | 151 | srcTCP := src.(*net.TCPAddr) 152 | dstTCP := dst.(*net.TCPAddr) 153 | 154 | srcPort := strconv.AppendUint(srcPortBuf[:0], uint64(srcTCP.Port), 10) 155 | dstPort := strconv.AppendUint(dstPortBuf[:0], uint64(dstTCP.Port), 10) 156 | 157 | if srcTCP.IP.To4() != nil { 158 | family = tcp4[:] 159 | } else { 160 | family = tcp6[:] 161 | } 162 | 163 | b = append(b, proxy[:]...) 164 | b = append(b, ' ') 165 | b = append(b, family...) 166 | b = append(b, ' ') 167 | b = append(b, srcTCP.IP.String()...) 168 | b = append(b, ' ') 169 | b = append(b, dstTCP.IP.String()...) 170 | b = append(b, ' ') 171 | b = append(b, srcPort...) 172 | b = append(b, ' ') 173 | b = append(b, dstPort...) 174 | b = append(b, crlf[:]...) 175 | return b 176 | } 177 | 178 | func appendProxyProtoV2(b []byte, src net.Addr, dst net.Addr, local bool) []byte { 179 | const ( 180 | AF_UNSPEC = 0 181 | AF_INET = 1 182 | AF_INET6 = 2 183 | AF_UNIX = 3 184 | 185 | UNSPEC = 0 186 | STREAM = 1 187 | DGRAM = 2 188 | 189 | PROXY = 1 190 | VERSION = 2 191 | ) 192 | 193 | var ( 194 | vercmd byte = VERSION << 4 195 | family byte = AF_UNSPEC 196 | socktype byte = UNSPEC 197 | 198 | srcIP net.IP 199 | dstIP net.IP 200 | 201 | srcAddr []byte 202 | dstAddr []byte 203 | srcPort []byte 204 | dstPort []byte 205 | 206 | srcAddrBuf [108]byte 207 | dstAddrBuf [108]byte 208 | srcPortBuf [2]byte 209 | dstPortBuf [2]byte 210 | ) 211 | 212 | if !local { 213 | vercmd |= PROXY 214 | } 215 | 216 | switch a := src.(type) { 217 | case *net.TCPAddr: 218 | b := dst.(*net.TCPAddr) 219 | socktype, srcIP, dstIP = STREAM, a.IP, b.IP 220 | srcPort = srcPortBuf[:] 221 | dstPort = dstPortBuf[:] 222 | binary.BigEndian.PutUint16(srcPort, uint16(a.Port)) 223 | binary.BigEndian.PutUint16(dstPort, uint16(b.Port)) 224 | 225 | case *net.UDPAddr: 226 | b := dst.(*net.UDPAddr) 227 | socktype, srcIP, dstIP = DGRAM, a.IP, b.IP 228 | srcPort = srcPortBuf[:] 229 | dstPort = dstPortBuf[:] 230 | binary.BigEndian.PutUint16(srcPort, uint16(a.Port)) 231 | binary.BigEndian.PutUint16(dstPort, uint16(b.Port)) 232 | 233 | case *net.UnixAddr: 234 | b := dst.(*net.UnixAddr) 235 | family = AF_UNIX 236 | srcAddr = srcAddrBuf[:] 237 | dstAddr = dstAddrBuf[:] 238 | copy(srcAddr, a.Name) 239 | copy(dstAddr, b.Name) 240 | 241 | switch a.Net { 242 | case "unix": 243 | socktype = STREAM 244 | case "unixgram": 245 | socktype = DGRAM 246 | } 247 | } 248 | 249 | if srcIP != nil { 250 | if ip := srcIP.To4(); ip != nil { 251 | family = AF_INET 252 | srcAddr = ip 253 | dstAddr = dstIP.To4() 254 | } else { 255 | family = AF_INET6 256 | srcAddr = srcIP.To16() 257 | dstAddr = dstIP.To16() 258 | } 259 | } 260 | 261 | b = append(b, signature[:]...) 262 | b = append(b, vercmd) 263 | b = append(b, (family<<4)|socktype) 264 | b = append(b, srcAddr...) 265 | b = append(b, dstAddr...) 266 | b = append(b, srcPort...) 267 | b = append(b, dstPort...) 268 | return b 269 | } 270 | 271 | func parseProxyProto(r io.Reader) (src net.Addr, dst net.Addr, buf []byte, local bool, err error) { 272 | var a [256]byte 273 | var b []byte 274 | var n int 275 | 276 | if n, err = io.ReadAtLeast(r, a[:], 14); err != nil { 277 | return 278 | } 279 | b = a[:n] 280 | 281 | switch { 282 | case bytes.HasPrefix(b, proxy[:]): 283 | var i = bytes.Index(b[:107], crlf[:]) 284 | 285 | for i < 0 { 286 | if len(b) >= 107 { 287 | err = errors.New("no '\r\n' sequence found in the first 107 bytes of a proxy protocol connection") 288 | return 289 | } 290 | if n, err = r.Read(a[len(b):107]); n == 0 { 291 | if err == io.EOF { 292 | err = io.ErrUnexpectedEOF 293 | } 294 | return 295 | } 296 | b = a[:len(b)+n] 297 | i = bytes.Index(b, crlf[:]) 298 | } 299 | 300 | src, dst, err = parseProxyProtoV1(b[:i]) 301 | buf = b[i+2:] 302 | return 303 | 304 | case bytes.HasPrefix(b, signature[:]): 305 | b = b[len(signature):] 306 | 307 | if version := b[0] >> 4; version != 2 { 308 | err = fmt.Errorf("invalid proxy protocol version: %#x", version) 309 | return 310 | } 311 | 312 | switch cmd := b[0] & 0xF; cmd { 313 | case 0: 314 | local = true 315 | case 1: 316 | default: 317 | err = fmt.Errorf("invalid proxy protocol command: %#x", cmd) 318 | return 319 | } 320 | 321 | var makeStreamAddr = makeTCPAddr 322 | var makeDgramAddr = makeUDPAddr 323 | var makeAddr func(int, []byte, []byte) net.Addr 324 | var addrLen int 325 | var portLen int 326 | var socktype int 327 | 328 | switch family := b[1] >> 4; family { 329 | case 0: // AF_UNSPEC 330 | case 1: // AF_INET 331 | addrLen, portLen = 4, 2 332 | case 2: // AF_INET6 333 | addrLen, portLen = 16, 2 334 | case 3: // AF_UNIX 335 | addrLen, portLen = 108, 0 336 | makeStreamAddr, makeDgramAddr = makeUnixAddr, makeUnixAddr 337 | default: 338 | err = fmt.Errorf("invalid socket family found in proxy protocol header: %#x", family) 339 | return 340 | } 341 | 342 | switch socktype = int(b[1] & 0xF); socktype { 343 | case 0: // UNSPEC 344 | case 1: // STREAM 345 | makeAddr = makeStreamAddr 346 | case 2: // DGRAM 347 | makeAddr = makeDgramAddr 348 | default: 349 | err = fmt.Errorf("invalid socket type found in proxy protocol header: %#x", socktype) 350 | return 351 | } 352 | b = b[2:] 353 | 354 | n1 := 2*addrLen + 2*portLen 355 | n2 := len(b) 356 | 357 | if n1 > n2 { 358 | if _, err = io.ReadFull(r, b[n2:n1]); err != nil { 359 | return 360 | } 361 | b = b[:n1] 362 | } 363 | 364 | if makeAddr != nil { 365 | src = makeAddr(socktype, b[:addrLen], b[2*addrLen:2*addrLen+portLen]) 366 | dst = makeAddr(socktype, b[addrLen:2*addrLen], b[2*addrLen+portLen:]) 367 | } 368 | 369 | buf = b[n1:] 370 | return 371 | } 372 | 373 | err = errors.New("invalid signature found in proxy protocol connection") 374 | return 375 | } 376 | 377 | func parseProxyProtoV1(b []byte) (src net.Addr, dst net.Addr, err error) { 378 | var family, srcIP, srcPort, dstIP, dstPort []byte 379 | 380 | if !bytes.HasPrefix(b, proxy[:]) { 381 | err = errors.New("expected 'PROXY' at the beginning of the proxy protocol connection") 382 | return 383 | } 384 | 385 | if b = b[len(proxy):]; len(b) != 0 && b[0] == ' ' { 386 | b = b[1:] 387 | } 388 | 389 | family, b = parseProxyProtoWord(b) 390 | srcIP, b = parseProxyProtoWord(b) 391 | dstIP, b = parseProxyProtoWord(b) 392 | srcPort, b = parseProxyProtoWord(b) 393 | dstPort, b = parseProxyProtoWord(b) 394 | 395 | switch { 396 | case bytes.Equal(family, tcp4[:]): 397 | case bytes.Equal(family, tcp6[:]): 398 | default: 399 | err = fmt.Errorf("invalid socket family found in proxy protocol header: %s", string(family)) 400 | return 401 | } 402 | 403 | var srcTCP net.TCPAddr 404 | var dstTCP net.TCPAddr 405 | 406 | if srcTCP.IP = net.ParseIP(string(srcIP)); srcTCP.IP == nil { 407 | err = fmt.Errorf("invalid source address found in proxy protocol header: %s", string(srcIP)) 408 | return 409 | } 410 | 411 | if dstTCP.IP = net.ParseIP(string(dstIP)); dstTCP.IP == nil { 412 | err = fmt.Errorf("invalid destination address found in proxy protocol header: %s", string(dstIP)) 413 | return 414 | } 415 | 416 | if srcTCP.Port, err = strconv.Atoi(string(srcPort)); err != nil { 417 | err = fmt.Errorf("invalid source port found in proxy protocol header: %s", string(srcPort)) 418 | return 419 | } 420 | 421 | if dstTCP.Port, err = strconv.Atoi(string(dstPort)); err != nil { 422 | err = fmt.Errorf("invalid source port found in proxy protocol header: %s", string(dstPort)) 423 | return 424 | } 425 | 426 | if len(b) != 0 { 427 | err = errors.New("invalid extra bytes found at the end of a proxy protocol header") 428 | return 429 | } 430 | 431 | src, dst = &srcTCP, &dstTCP 432 | return 433 | } 434 | 435 | func parseProxyProtoWord(b []byte) (word []byte, tail []byte) { 436 | for i, n := 0, len(b); i != n; i++ { 437 | if b[i] == ' ' { 438 | word, tail = b[:i], b[i+1:] 439 | return 440 | } 441 | } 442 | word = b 443 | return 444 | } 445 | 446 | func makeTCPAddr(_ int, ip []byte, port []byte) net.Addr { 447 | return &net.TCPAddr{ 448 | IP: makeIP(ip), 449 | Port: int(binary.BigEndian.Uint16(port)), 450 | } 451 | } 452 | 453 | func makeUDPAddr(_ int, ip []byte, port []byte) net.Addr { 454 | return &net.UDPAddr{ 455 | IP: makeIP(ip), 456 | Port: int(binary.BigEndian.Uint16(port)), 457 | } 458 | } 459 | 460 | func makeUnixAddr(socktype int, name []byte, _ []byte) net.Addr { 461 | off := bytes.IndexByte(name, 0) 462 | if off < 0 { 463 | off = len(name) 464 | } 465 | 466 | addr := &net.UnixAddr{ 467 | Name: string(name[:off]), 468 | } 469 | 470 | switch { 471 | case socktype == 1: // STREAM 472 | addr.Net = "unix" 473 | case socktype == 2: // DGRAM 474 | addr.Net = "unixgram" 475 | } 476 | 477 | return addr 478 | } 479 | 480 | func makeIP(b []byte) net.IP { 481 | if len(b) == 4 { 482 | return net.IPv4(b[0], b[1], b[2], b[3]) 483 | } 484 | ip := make(net.IP, len(b)) 485 | copy(ip, b) 486 | return ip 487 | } 488 | -------------------------------------------------------------------------------- /proxy_darwin.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | func originalTargetAddr(conn net.Conn) (net.Addr, error) { 9 | return nil, errors.New("netx.OriginalTargetAddr is not implemented on darwin") 10 | } 11 | -------------------------------------------------------------------------------- /proxy_linux.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "net" 5 | "os" 6 | "syscall" 7 | "unsafe" 8 | ) 9 | 10 | func originalTargetAddr(conn net.Conn) (n *net.TCPAddr, err error) { 11 | const ( 12 | SO_ORIGINAL_DST = 80 // missing from the syscall package 13 | ) 14 | 15 | // Calling conn.File will put the socket in blocking mode, we make sure to 16 | // set it back to non-blocking before returning to prevent the runtime from 17 | // creating tons of threads because it deals with blocking syscalls all over 18 | // the place. 19 | var f *os.File 20 | if f, err = BaseConn(conn).(fileConn).File(); err != nil { 21 | return 22 | } 23 | defer f.Close() 24 | defer syscall.SetNonblock(int(f.Fd()), true) 25 | 26 | sock := f.Fd() 27 | addr := syscall.RawSockaddrAny{} 28 | size := uint32(unsafe.Sizeof(addr)) 29 | 30 | _, _, e := syscall.RawSyscall6( 31 | uintptr(syscall.SYS_GETSOCKOPT), 32 | uintptr(sock), 33 | uintptr(syscall.SOL_IP), 34 | uintptr(SO_ORIGINAL_DST), 35 | uintptr(unsafe.Pointer(&addr)), 36 | uintptr(unsafe.Pointer(&size)), 37 | uintptr(0), 38 | ) 39 | 40 | if e != 0 { 41 | err = e 42 | return 43 | } 44 | 45 | switch addr.Addr.Family { 46 | case syscall.AF_INET: 47 | a := (*syscall.RawSockaddrInet4)(unsafe.Pointer(&addr)) 48 | n = &net.TCPAddr{ 49 | IP: net.IP(a.Addr[:]), 50 | Port: int((a.Port >> 8) | (a.Port << 8)), 51 | } 52 | 53 | case syscall.AF_INET6: 54 | a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(&addr)) 55 | n = &net.TCPAddr{ 56 | IP: net.IP(a.Addr[:]), 57 | Port: int((a.Port >> 8) | (a.Port << 8)), 58 | } 59 | 60 | default: 61 | err = syscall.EAFNOSUPPORT 62 | } 63 | 64 | return 65 | } 66 | -------------------------------------------------------------------------------- /proxy_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | type readOneByOne struct { 12 | b []byte 13 | } 14 | 15 | func (r *readOneByOne) Read(b []byte) (n int, err error) { 16 | if len(r.b) == 0 { 17 | err = io.EOF 18 | } else if len(b) != 0 { 19 | n, b[0], r.b = 1, r.b[0], r.b[1:] 20 | } 21 | return 22 | } 23 | 24 | func TestProxyProtoV1(t *testing.T) { 25 | tests := []struct { 26 | src net.Addr 27 | dst net.Addr 28 | str string 29 | }{ 30 | { 31 | src: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 56789}, 32 | dst: &net.TCPAddr{IP: net.ParseIP("192.1.0.123"), Port: 4242}, 33 | str: "PROXY TCP4 127.0.0.1 192.1.0.123 56789 4242\r\n", 34 | }, 35 | { 36 | src: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 56789}, 37 | dst: &net.TCPAddr{IP: net.ParseIP("fe80::f65c:89ff:feac:29cb"), Port: 4242}, 38 | str: "PROXY TCP6 ::1 fe80::f65c:89ff:feac:29cb 56789 4242\r\n", 39 | }, 40 | } 41 | 42 | for _, test := range tests { 43 | t.Run(fmt.Sprintf("%s->%s", test.src, test.dst), func(t *testing.T) { 44 | b := appendProxyProtoV1(nil, test.src, test.dst) 45 | 46 | if s := string(b); s != test.str { 47 | t.Error("bad proxy proto header:", s) 48 | return 49 | } 50 | 51 | r := &readOneByOne{b} 52 | a1, a2, buf, local, err := parseProxyProto(r) 53 | 54 | if err != nil { 55 | t.Error(err) 56 | } 57 | 58 | if len(r.b) != 0 || len(buf) != 0 { 59 | t.Error("unexpected trailing bytes") 60 | } 61 | 62 | if !reflect.DeepEqual(test.src, a1) { 63 | t.Errorf("bad source: %#v", a1) 64 | } 65 | 66 | if !reflect.DeepEqual(test.dst, a2) { 67 | t.Errorf("bad destination: %#v", a2) 68 | } 69 | 70 | if local { 71 | t.Errorf("bad local: %t", local) 72 | } 73 | }) 74 | } 75 | } 76 | 77 | func TestProxyProtoV2(t *testing.T) { 78 | tests := []struct { 79 | src net.Addr 80 | dst net.Addr 81 | str string 82 | }{ 83 | { 84 | src: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 56789}, 85 | dst: &net.TCPAddr{IP: net.ParseIP("192.1.0.123"), Port: 4242}, 86 | }, 87 | { 88 | src: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 56789}, 89 | dst: &net.TCPAddr{IP: net.ParseIP("fe80::f65c:89ff:feac:29cb"), Port: 4242}, 90 | }, 91 | { 92 | src: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 56789}, 93 | dst: &net.UDPAddr{IP: net.ParseIP("192.1.0.123"), Port: 4242}, 94 | }, 95 | { 96 | src: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 56789}, 97 | dst: &net.UDPAddr{IP: net.ParseIP("fe80::f65c:89ff:feac:29cb"), Port: 4242}, 98 | }, 99 | { 100 | src: &net.UnixAddr{Net: "unix", Name: "/var/lib/src.sock"}, 101 | dst: &net.UnixAddr{Net: "unix", Name: "/var/lib/dst.sock"}, 102 | }, 103 | { 104 | src: &net.UnixAddr{Net: "unixgram", Name: "/var/lib/src.sock"}, 105 | dst: &net.UnixAddr{Net: "unixgram", Name: "/var/lib/dst.sock"}, 106 | }, 107 | } 108 | 109 | for _, test := range tests { 110 | t.Run(fmt.Sprintf("%s://%s->%s", test.src.Network(), test.src, test.dst), func(t *testing.T) { 111 | b := appendProxyProtoV2(nil, test.src, test.dst, false) 112 | r := &readOneByOne{b} 113 | a1, a2, buf, local, err := parseProxyProto(r) 114 | 115 | if err != nil { 116 | t.Error(err) 117 | } 118 | 119 | if len(r.b) != 0 || len(buf) != 0 { 120 | t.Errorf("unexpected trailing bytes: %#v %#v", r.b, buf) 121 | } 122 | 123 | if !reflect.DeepEqual(test.src, a1) { 124 | t.Errorf("bad source: %#v", a1) 125 | } 126 | 127 | if !reflect.DeepEqual(test.dst, a2) { 128 | t.Errorf("bad destination: %#v", a2) 129 | } 130 | 131 | if local { 132 | t.Errorf("bad local state: %t", local) 133 | } 134 | }) 135 | } 136 | } 137 | 138 | func TestProxyProtoV2Local(t *testing.T) { 139 | b := appendProxyProtoV2(nil, &NetAddr{}, &NetAddr{}, true) 140 | r := &readOneByOne{b} 141 | src, dst, buf, local, err := parseProxyProto(r) 142 | 143 | if err != nil { 144 | t.Error(err) 145 | } 146 | 147 | if len(r.b) != 0 || len(buf) != 0 { 148 | t.Errorf("unexpected trailing bytes: %#v %#v", r.b, buf) 149 | } 150 | 151 | if src != nil { 152 | t.Errorf("bad source: %#v", src) 153 | } 154 | 155 | if dst != nil { 156 | t.Errorf("bad destination: %#v", dst) 157 | } 158 | 159 | if !local { 160 | t.Errorf("bad local state: %t", local) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "log" 7 | "net" 8 | "runtime" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // ListenAndServe listens on the address addr and then call Serve to handle 14 | // the incoming connections. 15 | func ListenAndServe(addr string, handler Handler) error { 16 | return (&Server{ 17 | Addr: addr, 18 | Handler: handler, 19 | }).ListenAndServe() 20 | } 21 | 22 | // Serve accepts incoming connections on the Listener lstn, creating a new 23 | // service goroutine for each. The service goroutines simply invoke the 24 | // handler's ServeConn method. 25 | func Serve(lstn net.Listener, handler Handler) error { 26 | return (&Server{ 27 | Handler: handler, 28 | }).Serve(lstn) 29 | } 30 | 31 | // A Server defines parameters for running servers that accept connections over 32 | // TCP or unix domains. 33 | type Server struct { 34 | Addr string // address to listen on 35 | Handler Handler // handler to invoke on new connections 36 | ErrorLog *log.Logger // the logger used to output internal errors 37 | Context context.Context // the base context used by the server 38 | } 39 | 40 | // ListenAndServe listens on the server address and then call Serve to handle 41 | // the incoming connections. 42 | func (s *Server) ListenAndServe() (err error) { 43 | var lstn net.Listener 44 | 45 | if lstn, err = Listen(s.Addr); err == nil { 46 | err = s.Serve(lstn) 47 | } 48 | 49 | return 50 | } 51 | 52 | // Serve accepts incoming connections on the Listener lstn, creating a new 53 | // service goroutine for each. The service goroutines simply invoke the 54 | // handler's ServeConn method. 55 | // 56 | // The server becomes the owner of the listener which will be closed by the time 57 | // the Serve method returns. 58 | func (s *Server) Serve(lstn net.Listener) error { 59 | defer lstn.Close() 60 | 61 | join := &sync.WaitGroup{} 62 | defer join.Wait() 63 | 64 | ctx := s.Context 65 | if ctx == nil { 66 | ctx = context.Background() 67 | } 68 | 69 | ctx, cancel := context.WithCancel(ctx) 70 | defer cancel() 71 | 72 | done := ctx.Done() 73 | errs := make(chan error) 74 | conns := make(chan net.Conn) 75 | 76 | join.Add(1) 77 | go s.accept(ctx, lstn, conns, errs, join) 78 | 79 | for conns != nil || errs != nil { 80 | select { 81 | case <-done: 82 | lstn.Close() 83 | done = nil 84 | 85 | case err, ok := <-errs: 86 | if !ok { 87 | errs = nil 88 | continue 89 | } 90 | return err 91 | 92 | case conn, ok := <-conns: 93 | if !ok { 94 | conns = nil 95 | continue 96 | } 97 | join.Add(1) 98 | go s.serve(ctx, conn, join) 99 | } 100 | } 101 | 102 | return nil 103 | } 104 | 105 | func (s *Server) accept(ctx context.Context, lstn net.Listener, conns chan<- net.Conn, errs chan<- error, join *sync.WaitGroup) { 106 | defer join.Done() 107 | defer close(errs) 108 | defer close(conns) 109 | 110 | const maxBackoff = 1 * time.Second 111 | for { 112 | var conn net.Conn 113 | var err error 114 | 115 | for attempt := 0; true; attempt++ { 116 | if conn, err = lstn.Accept(); err == nil { 117 | break 118 | } 119 | if !IsTemporary(err) { 120 | break 121 | } 122 | 123 | // Backoff strategy for handling temporary errors, this prevents from 124 | // retrying too fast when errors like running out of file descriptors 125 | // occur. 126 | backoff := time.Duration(attempt*attempt) * 10 * time.Millisecond 127 | if backoff > maxBackoff { 128 | backoff = maxBackoff 129 | } 130 | s.logf("Accept error: %v; retrying in %v", err, backoff) 131 | select { 132 | case <-time.After(backoff): 133 | case <-ctx.Done(): 134 | return 135 | } 136 | } 137 | 138 | if err != nil { 139 | switch e := err.(type) { 140 | case *net.OpError: 141 | // Don't report EOF, this is a normal termination of the listener. 142 | if e.Err == io.EOF { 143 | err = nil 144 | } 145 | } 146 | if err != nil { 147 | select { 148 | case <-ctx.Done(): 149 | // Don't report errors when the server stopped because its 150 | // context was canceled. 151 | default: 152 | errs <- err 153 | } 154 | } 155 | return 156 | } 157 | 158 | conns <- conn 159 | } 160 | } 161 | 162 | func (s *Server) serve(ctx context.Context, conn net.Conn, join *sync.WaitGroup) { 163 | defer func() { Recover(recover(), conn, s.ErrorLog) }() 164 | 165 | defer join.Done() 166 | defer conn.Close() 167 | 168 | ctx, cancel := context.WithCancel(ctx) 169 | defer cancel() 170 | 171 | s.Handler.ServeConn(ctx, conn) 172 | } 173 | 174 | func (s *Server) logf(format string, args ...interface{}) { 175 | logf(s.ErrorLog)(format, args...) 176 | } 177 | 178 | // Recover is intended to be used by servers that gracefully handle panics from 179 | // their handlers. 180 | func Recover(err interface{}, conn net.Conn, logger *log.Logger) { 181 | if err == nil { 182 | return 183 | } 184 | 185 | logf := logf(logger) 186 | laddr := conn.LocalAddr() 187 | raddr := conn.RemoteAddr() 188 | 189 | buf := make([]byte, 262144) 190 | buf = buf[:runtime.Stack(buf, false)] 191 | logf("panic serving %s->%s: %v\n%s", laddr, raddr, err, string(buf)) 192 | } 193 | 194 | func logf(logger *log.Logger) func(string, ...interface{}) { 195 | if logger == nil { 196 | return log.Printf 197 | } 198 | return logger.Printf 199 | } 200 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "log" 7 | "net" 8 | "os" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | "golang.org/x/net/nettest" 14 | ) 15 | 16 | func TestServerConn(t *testing.T) { 17 | nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { 18 | cnch := make(chan net.Conn) 19 | done := make(chan struct{}) 20 | 21 | a, f := listenAndServe(HandlerFunc(func(ctx context.Context, conn net.Conn) { 22 | cnch <- conn 23 | <-done 24 | })) 25 | 26 | if c1, err = net.Dial(a.Network(), a.String()); err != nil { 27 | return 28 | } 29 | 30 | c2 = <-cnch 31 | 32 | stop = func() { 33 | close(done) 34 | c2.Close() 35 | c1.Close() 36 | f() 37 | } 38 | return 39 | }) 40 | } 41 | 42 | func TestEchoServer(t *testing.T) { 43 | for _, test := range []struct { 44 | network string 45 | address string 46 | }{ 47 | {"unix", "/tmp/echo-server.sock"}, 48 | {"tcp", "127.0.0.1:56789"}, 49 | } { 50 | test := test 51 | t.Run(test.address, func(t *testing.T) { 52 | t.Parallel() 53 | 54 | ctx, cancel := context.WithCancel(context.Background()) 55 | defer cancel() 56 | 57 | server := &Server{ 58 | Addr: test.address, 59 | Context: ctx, 60 | Handler: Echo, 61 | } 62 | 63 | done := &sync.WaitGroup{} 64 | done.Add(1) 65 | go func() { 66 | defer done.Done() 67 | server.ListenAndServe() 68 | }() 69 | 70 | // Give a bit of time to the server to bind the socket. 71 | time.Sleep(50 * time.Millisecond) 72 | join := &sync.WaitGroup{} 73 | 74 | for i := 0; i != 10; i++ { 75 | join.Add(1) 76 | go func() { 77 | defer join.Done() 78 | 79 | conn, err := net.Dial(test.network, test.address) 80 | if err != nil { 81 | t.Error(err) 82 | return 83 | } 84 | defer conn.Close() 85 | 86 | b := [12]byte{} 87 | 88 | if _, err := conn.Write([]byte("Hello World!")); err != nil { 89 | t.Error(err) 90 | return 91 | } 92 | 93 | if _, err := io.ReadFull(conn, b[:]); err != nil { 94 | t.Error(err) 95 | return 96 | } 97 | 98 | if s := string(b[:]); s != "Hello World!" { 99 | t.Error(s) 100 | return 101 | } 102 | }() 103 | } 104 | 105 | join.Wait() 106 | cancel() 107 | done.Wait() 108 | }) 109 | } 110 | } 111 | 112 | func listenAndServe(h Handler) (addr net.Addr, close func()) { 113 | lstn, err := Listen("127.0.0.1:0") 114 | if err != nil { 115 | panic(err) 116 | } 117 | 118 | ctx, cancel := context.WithCancel(context.Background()) 119 | 120 | join := &sync.WaitGroup{} 121 | join.Add(1) 122 | 123 | go func() { 124 | defer join.Done() 125 | (&Server{ 126 | Handler: h, 127 | Context: ctx, 128 | ErrorLog: log.New(os.Stderr, "listen: ", 0), 129 | }).Serve(lstn) 130 | }() 131 | 132 | addr, close = lstn.Addr(), func() { 133 | cancel() 134 | join.Wait() 135 | } 136 | return 137 | } 138 | -------------------------------------------------------------------------------- /tunnel.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "io" 7 | "net" 8 | "time" 9 | ) 10 | 11 | // TunnelHandler is an interface that must be implemented by types that intend 12 | // to provide logic for tunnelling connections. 13 | // 14 | // The ServeTunnel method is called by a Tunnel after establishing a connection 15 | // to a remote target address. 16 | type TunnelHandler interface { 17 | ServeTunnel(ctx context.Context, from net.Conn, to net.Conn) 18 | } 19 | 20 | // TunnelHandlerFunc makes it possible for simple function types to be used as 21 | // connection proxies. 22 | type TunnelHandlerFunc func(context.Context, net.Conn, net.Conn) 23 | 24 | // ServeTunnel calls f. 25 | func (f TunnelHandlerFunc) ServeTunnel(ctx context.Context, from net.Conn, to net.Conn) { 26 | f(ctx, from, to) 27 | } 28 | 29 | // A Tunnel is a proxy handler that establishes a second connection to a 30 | // target address for every incoming connection it receives. 31 | type Tunnel struct { 32 | // Handler is called by the tunnel when it successfully established a 33 | // connection to its target. 34 | // 35 | // Calling one of the tunnel's method will panic if this field is nil. 36 | Handler TunnelHandler 37 | 38 | // DialContext can be set to a dialing function to configure how the tunnel 39 | // establishes new connections. 40 | DialContext func(context.Context, string, string) (net.Conn, error) 41 | } 42 | 43 | // ServeProxy satisfies the ProxyHandler interface. 44 | // 45 | // When called the tunnel establishes a connection to target, then delegate to 46 | // its handler. 47 | // 48 | // The method panics to report errors. 49 | func (t *Tunnel) ServeProxy(ctx context.Context, from net.Conn, target net.Addr) { 50 | dial := t.DialContext 51 | 52 | if dial == nil { 53 | dial = (&net.Dialer{Timeout: 10 * time.Second /* safeguard */}).DialContext 54 | } 55 | 56 | to, err := dial(ctx, target.Network(), target.String()) 57 | if err != nil { 58 | panic(err) 59 | } 60 | 61 | defer to.Close() 62 | t.Handler.ServeTunnel(ctx, from, to) 63 | } 64 | 65 | var ( 66 | // TunnelRaw is the implementation of a tunnel handler which passes bytes 67 | // back and forth between the two ends of a tunnel. 68 | // 69 | // The implementation supports cancelltations and closes the connections 70 | // before it returns (because it doesn't know anything about the underlying 71 | // protocol being spoken and could leave the connections in an unreusable 72 | // state). 73 | TunnelRaw TunnelHandler = TunnelHandlerFunc(tunnelRaw) 74 | 75 | // TunnelLine is the implementation of a tunnel handler which speaks a line 76 | // based protocol like TELNET, expecting the client not to send more than 77 | // one line before getting a response. 78 | // 79 | // The implementation supports cancellations and ensures that no partial 80 | // lines are read from the connection. 81 | // 82 | // The maximum line length is limited to 8192 bytes. 83 | TunnelLine TunnelHandler = TunnelHandlerFunc(tunnelLine) 84 | ) 85 | 86 | func tunnelRaw(ctx context.Context, from net.Conn, to net.Conn) { 87 | ctx, cancel := context.WithCancel(ctx) 88 | 89 | copy := func(w io.Writer, r io.Reader) { 90 | defer cancel() 91 | Copy(w, r) 92 | } 93 | 94 | go copy(to, from) 95 | go copy(from, to) 96 | 97 | <-ctx.Done() 98 | from.Close() 99 | } 100 | 101 | func tunnelLine(ctx context.Context, from net.Conn, to net.Conn) { 102 | r1 := bufio.NewReaderSize(from, 8192) 103 | r2 := bufio.NewReaderSize(to, 8192) 104 | 105 | for { 106 | line, err := readLine(ctx, from, r1) 107 | 108 | switch err { 109 | case nil: 110 | case io.EOF, context.Canceled: 111 | return 112 | default: 113 | fatal(from, err) 114 | } 115 | 116 | if _, err := to.Write(line); err != nil { 117 | fatal(from, err) 118 | } 119 | 120 | if line, err = readLine(context.Background(), to, r2); err != nil { 121 | fatal(from, err) 122 | } 123 | 124 | if _, err := from.Write(line); err != nil { 125 | fatal(from, err) 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /tunnel_test.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "testing" 7 | ) 8 | 9 | func TestTunnel(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | tunnel TunnelHandler 13 | }{ 14 | { 15 | name: "TunnelRaw", 16 | tunnel: TunnelRaw, 17 | }, 18 | { 19 | name: "TunnelLine", 20 | tunnel: TunnelLine, 21 | }, 22 | } 23 | 24 | for _, test := range tests { 25 | t.Run(test.name, func(t *testing.T) { 26 | addr1, close1 := listenAndServe(Echo) 27 | defer close1() 28 | 29 | addr2, close2 := listenAndServe(&Proxy{ 30 | Addr: addr1, 31 | Handler: &Tunnel{ 32 | Handler: test.tunnel, 33 | }, 34 | }) 35 | defer close2() 36 | 37 | conn, err := net.Dial(addr2.Network(), addr2.String()) 38 | if err != nil { 39 | t.Error(err) 40 | return 41 | } 42 | defer conn.Close() 43 | 44 | if _, err := io.WriteString(conn, "Hello World!\r\n"); err != nil { 45 | t.Error(err) 46 | return 47 | } 48 | 49 | b := [14]byte{} 50 | 51 | if _, err := io.ReadFull(conn, b[:]); err != nil { 52 | t.Error(err) 53 | return 54 | } 55 | 56 | if s := string(b[:]); s != "Hello World!\r\n" { 57 | t.Error(s) 58 | return 59 | } 60 | }) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /unix.go: -------------------------------------------------------------------------------- 1 | package netx 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net" 8 | "os" 9 | "sync/atomic" 10 | "syscall" 11 | ) 12 | 13 | // SendUnixConn sends a file descriptor embedded in conn over the unix domain 14 | // socket. 15 | // On success conn is closed because the owner is now the process that received 16 | // the file descriptor. 17 | // 18 | // conn must be a *net.TCPConn or similar (providing a File method) or the 19 | // function will panic. 20 | func SendUnixConn(socket *net.UnixConn, conn net.Conn) (err error) { 21 | return sendUnixFileConn(socket, BaseConn(conn).(fileConn), conn) 22 | } 23 | 24 | // SendUnixPacketConn sends a file descriptor embedded in conn over the unix 25 | // domain socket. 26 | // On success conn is closed because the owner is now the process that received 27 | // the file descriptor. 28 | // 29 | // conn must be a *net.UDPConn or similar (providing a File method) or the 30 | // function will panic. 31 | func SendUnixPacketConn(socket *net.UnixConn, conn net.PacketConn) (err error) { 32 | return sendUnixFileConn(socket, BasePacketConn(conn).(fileConn), conn) 33 | } 34 | 35 | func sendUnixFileConn(socket *net.UnixConn, conn fileConn, close io.Closer) (err error) { 36 | var f *os.File 37 | 38 | if f, err = conn.File(); err != nil { 39 | return 40 | } 41 | defer f.Close() 42 | 43 | if err = SendUnixFile(socket, f); err != nil { 44 | return 45 | } 46 | 47 | close.Close() 48 | return 49 | } 50 | 51 | // SendUnixFile sends a file descriptor embedded in file over the unix domain 52 | // socket. 53 | // On success the file is closed because the owner is now the process that 54 | // received the file descriptor. 55 | func SendUnixFile(socket *net.UnixConn, file *os.File) (err error) { 56 | var fds = [1]int{int(file.Fd())} 57 | var oob = syscall.UnixRights(fds[:]...) 58 | 59 | if _, _, err = socket.WriteMsgUnix(nil, oob, nil); err != nil { 60 | return 61 | } 62 | 63 | file.Close() 64 | return 65 | } 66 | 67 | // RecvUnixConn receives a network connection from a unix domain socket. 68 | func RecvUnixConn(socket *net.UnixConn) (conn net.Conn, err error) { 69 | var f *os.File 70 | if f, err = RecvUnixFile(socket); err != nil { 71 | return 72 | } 73 | defer f.Close() 74 | return net.FileConn(f) 75 | } 76 | 77 | // RecvUnixPacketConn receives a packet oriented network connection from a unix 78 | // domain socket. 79 | func RecvUnixPacketConn(socket *net.UnixConn) (conn net.PacketConn, err error) { 80 | var f *os.File 81 | if f, err = RecvUnixFile(socket); err != nil { 82 | return 83 | } 84 | defer f.Close() 85 | return net.FilePacketConn(f) 86 | } 87 | 88 | // RecvUnixFile receives a file descriptor from a unix domain socket. 89 | func RecvUnixFile(socket *net.UnixConn) (file *os.File, err error) { 90 | var oob = make([]byte, syscall.CmsgSpace(4)) 91 | var oobn int 92 | var msg []syscall.SocketControlMessage 93 | var fds []int 94 | 95 | if _, oobn, _, _, err = socket.ReadMsgUnix(nil, oob); err != nil { 96 | return 97 | } else if oobn == 0 { 98 | err = io.EOF 99 | return 100 | } 101 | 102 | if msg, err = syscall.ParseSocketControlMessage(oob); err != nil { 103 | err = os.NewSyscallError("ParseSocketControlMessage", err) 104 | return 105 | } 106 | 107 | if len(msg) != 1 { 108 | err = fmt.Errorf("invalid number of socket control messages, expected 1 but found %d", len(msg)) 109 | return 110 | } 111 | 112 | if fds, err = syscall.ParseUnixRights(&msg[0]); err != nil { 113 | err = os.NewSyscallError("ParseUnixRights", err) 114 | return 115 | } 116 | 117 | if len(fds) != 1 { 118 | for _, fd := range fds { 119 | syscall.Close(fd) 120 | } 121 | err = fmt.Errorf("too many file descriptors found in a single control message, %d were closed", len(fds)) 122 | return 123 | } 124 | 125 | file = os.NewFile(uintptr(fds[0]), "") 126 | return 127 | } 128 | 129 | // NewRecvUnixListener returns a new listener which accepts connection by 130 | // reading file descriptors from a unix domain socket. 131 | // 132 | // The function doesn't make a copy of socket, so the returned listener should 133 | // be considered the new owner of that object, which means closing the listener 134 | // will actually close the original socket (and vice versa). 135 | func NewRecvUnixListener(socket *net.UnixConn) *RecvUnixListener { 136 | return &RecvUnixListener{*socket} 137 | } 138 | 139 | // RecvUnixListener is a listener which acceptes connections by reading file 140 | // descriptors from a unix domain socket. 141 | type RecvUnixListener struct { 142 | socket net.UnixConn 143 | } 144 | 145 | // Accept receives a file descriptor from the listener's unix domain socket. 146 | func (l *RecvUnixListener) Accept() (net.Conn, error) { 147 | return RecvUnixConn(&l.socket) 148 | } 149 | 150 | // Addr returns the address of the listener's unix domain socket. 151 | func (l *RecvUnixListener) Addr() net.Addr { 152 | return l.socket.LocalAddr() 153 | } 154 | 155 | // Close closes the underlying unix domain socket. 156 | func (l *RecvUnixListener) Close() error { 157 | return l.socket.Close() 158 | } 159 | 160 | // UnixConn returns a pointer to the underlying unix domain socket. 161 | func (l *RecvUnixListener) UnixConn() *net.UnixConn { 162 | return &l.socket 163 | } 164 | 165 | // NewSendUnixHandler wraps handler so the connetions it receives will be sent 166 | // back to socket when handler returns without closing them. 167 | func NewSendUnixHandler(socket *net.UnixConn, handler Handler) *SendUnixHandler { 168 | return &SendUnixHandler{ 169 | handler: handler, 170 | socket: *socket, 171 | } 172 | } 173 | 174 | // SendUnixHandler is a connection handler which sends the connections it 175 | // handles back through a unix domain socket. 176 | type SendUnixHandler struct { 177 | handler Handler 178 | socket net.UnixConn 179 | } 180 | 181 | // ServeConn satisfies the Handler interface. 182 | func (h *SendUnixHandler) ServeConn(ctx context.Context, conn net.Conn) { 183 | c := &sendUnixConn{Conn: conn} 184 | h.handler.ServeConn(ctx, c) 185 | 186 | if atomic.LoadUint32(&c.closed) == 0 { 187 | if err := SendUnixConn(&h.socket, conn); err != nil { 188 | panic(fmt.Errorf("sending connection back over unix domain socket: %s", err)) 189 | } 190 | } 191 | } 192 | 193 | // UnixConn returns a pointer to the underlying unix domain socket. 194 | func (h *SendUnixHandler) UnixConn() *net.UnixConn { 195 | return &h.socket 196 | } 197 | 198 | type sendUnixConn struct { 199 | net.Conn 200 | closed uint32 201 | } 202 | 203 | func (c *sendUnixConn) Base() net.Conn { 204 | return c.Conn 205 | } 206 | 207 | func (c *sendUnixConn) Close() (err error) { 208 | atomic.StoreUint32(&c.closed, 1) 209 | return c.Conn.Close() 210 | } 211 | 212 | func (c *sendUnixConn) Read(b []byte) (n int, err error) { 213 | if n, err = c.Conn.Read(b); err != nil && !IsTemporary(err) { 214 | atomic.StoreUint32(&c.closed, 1) 215 | } 216 | return 217 | } 218 | 219 | func (c *sendUnixConn) Write(b []byte) (n int, err error) { 220 | if n, err = c.Conn.Write(b); err != nil && !IsTemporary(err) { 221 | atomic.StoreUint32(&c.closed, 1) 222 | } 223 | return 224 | } 225 | -------------------------------------------------------------------------------- /unix_test.go: -------------------------------------------------------------------------------- 1 | // +build darwin dragonfly freebsd linux netbsd openbsd solaris 2 | 3 | package netx 4 | 5 | import ( 6 | "net" 7 | "testing" 8 | 9 | "golang.org/x/net/nettest" 10 | ) 11 | 12 | func TestSendRecvUnixConn(t *testing.T) { 13 | nettest.TestConn(t, func() (c3 net.Conn, c4 net.Conn, stop func(), err error) { 14 | var c1 net.Conn 15 | var c2 net.Conn 16 | var u1 *net.UnixConn 17 | var u2 *net.UnixConn 18 | 19 | if u1, u2, err = UnixConnPair(); err != nil { 20 | return 21 | } 22 | defer u1.Close() 23 | defer u2.Close() 24 | 25 | if c1, c2, err = ConnPair("tcp"); err != nil { 26 | return 27 | } 28 | 29 | if err = SendUnixConn(u1, c1); err != nil { 30 | return 31 | } 32 | 33 | if err = SendUnixConn(u2, c2); err != nil { 34 | return 35 | } 36 | 37 | if c3, err = RecvUnixConn(u2); err != nil { 38 | return 39 | } 40 | 41 | if c4, err = RecvUnixConn(u1); err != nil { 42 | return 43 | } 44 | 45 | stop = func() { 46 | c3.Close() 47 | c4.Close() 48 | } 49 | return 50 | }) 51 | } 52 | --------------------------------------------------------------------------------