├── .gitignore ├── LICENSE ├── README.md ├── go.mod ├── protocol.go └── protocol_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.test 2 | *~ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Armon Dadgar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # proxyproto 2 | 3 | This library provides the `proxyproto` package which can be used for servers 4 | listening behind HAProxy of Amazon ELB load balancers. Those load balancers 5 | support the use of a proxy protocol (http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt), 6 | which provides a simple mechansim for the server to get the address of the client 7 | instead of the load balancer. 8 | 9 | This library provides both a net.Listener and net.Conn implementation that 10 | can be used to handle situation in which you may be using the proxy protocol. 11 | Only proxy protocol version 1, the human-readable form, is understood. 12 | 13 | The only caveat is that we check for the "PROXY " prefix to determine if the protocol 14 | is being used. If that string may occur as part of your input, then it is ambiguous 15 | if the protocol is being used and you may have problems. 16 | 17 | # Documentation 18 | 19 | Full documentation can be found [here](http://godoc.org/github.com/armon/go-proxyproto). 20 | 21 | # Examples 22 | 23 | Using the library is very simple: 24 | 25 | ``` 26 | 27 | // Create a listener 28 | list, err := net.Listen("tcp", "...") 29 | 30 | // Wrap listener in a proxyproto listener 31 | proxyList := &proxyproto.Listener{Listener: list} 32 | conn, err :=proxyList.Accept() 33 | 34 | ... 35 | ``` 36 | 37 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/armon/go-proxyproto 2 | 3 | go 1.20 4 | -------------------------------------------------------------------------------- /protocol.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | "time" 15 | ) 16 | 17 | var ( 18 | // prefix is the string we look for at the start of a connection 19 | // to check if this connection is using the proxy protocol 20 | prefix = []byte("PROXY ") 21 | prefixLen = len(prefix) 22 | 23 | ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information") 24 | ) 25 | 26 | // SourceChecker can be used to decide whether to trust the PROXY info or pass 27 | // the original connection address through. If set, the connecting address is 28 | // passed in as an argument. If the function returns an error due to the source 29 | // being disallowed, it should return ErrInvalidUpstream. 30 | // 31 | // If error is not nil, the call to Accept() will fail. If the reason for 32 | // triggering this failure is due to a disallowed source, it should return 33 | // ErrInvalidUpstream. 34 | // 35 | // If bool is true, the PROXY-set address is used. 36 | // 37 | // If bool is false, the connection's remote address is used, rather than the 38 | // address claimed in the PROXY info. 39 | type SourceChecker func(net.Addr) (bool, error) 40 | 41 | // Listener is used to wrap an underlying listener, 42 | // whose connections may be using the HAProxy Proxy Protocol (version 1). 43 | // If the connection is using the protocol, the RemoteAddr() will return 44 | // the correct client address. 45 | // 46 | // Optionally define ProxyHeaderTimeout to set a maximum time to 47 | // receive the Proxy Protocol Header. Zero means no timeout. 48 | type Listener struct { 49 | Listener net.Listener 50 | ProxyHeaderTimeout time.Duration 51 | SourceCheck SourceChecker 52 | UnknownOK bool // allow PROXY UNKNOWN 53 | } 54 | 55 | // Conn is used to wrap and underlying connection which 56 | // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will 57 | // return the address of the client instead of the proxy address. 58 | type Conn struct { 59 | bufReader *bufio.Reader 60 | conn net.Conn 61 | dstAddr *net.TCPAddr 62 | srcAddr *net.TCPAddr 63 | useConnAddr bool 64 | once sync.Once 65 | proxyHeaderTimeout time.Duration 66 | unknownOK bool 67 | } 68 | 69 | // Accept waits for and returns the next connection to the listener. 70 | func (p *Listener) Accept() (net.Conn, error) { 71 | // Get the underlying connection 72 | for { 73 | conn, err := p.Listener.Accept() 74 | if err != nil { 75 | return nil, err 76 | } 77 | var useConnAddr bool 78 | if p.SourceCheck != nil { 79 | allowed, err := p.SourceCheck(conn.RemoteAddr()) 80 | if err != nil { 81 | if err == ErrInvalidUpstream { 82 | conn.Close() 83 | continue 84 | } 85 | return nil, err 86 | } 87 | if !allowed { 88 | useConnAddr = true 89 | } 90 | } 91 | newConn := NewConn(conn, p.ProxyHeaderTimeout) 92 | newConn.useConnAddr = useConnAddr 93 | newConn.unknownOK = p.UnknownOK 94 | return newConn, nil 95 | } 96 | } 97 | 98 | // Close closes the underlying listener. 99 | func (p *Listener) Close() error { 100 | return p.Listener.Close() 101 | } 102 | 103 | // Addr returns the underlying listener's network address. 104 | func (p *Listener) Addr() net.Addr { 105 | return p.Listener.Addr() 106 | } 107 | 108 | // NewConn is used to wrap a net.Conn that may be speaking 109 | // the proxy protocol into a proxyproto.Conn 110 | func NewConn(conn net.Conn, timeout time.Duration) *Conn { 111 | pConn := &Conn{ 112 | bufReader: bufio.NewReader(conn), 113 | conn: conn, 114 | proxyHeaderTimeout: timeout, 115 | } 116 | return pConn 117 | } 118 | 119 | // Read is check for the proxy protocol header when doing 120 | // the initial scan. If there is an error parsing the header, 121 | // it is returned and the socket is closed. 122 | func (p *Conn) Read(b []byte) (int, error) { 123 | var err error 124 | p.once.Do(func() { err = p.checkPrefix() }) 125 | if err != nil { 126 | return 0, err 127 | } 128 | return p.bufReader.Read(b) 129 | } 130 | 131 | func (p *Conn) ReadFrom(r io.Reader) (int64, error) { 132 | if rf, ok := p.conn.(io.ReaderFrom); ok { 133 | return rf.ReadFrom(r) 134 | } 135 | return io.Copy(p.conn, r) 136 | } 137 | 138 | func (p *Conn) WriteTo(w io.Writer) (int64, error) { 139 | var err error 140 | p.once.Do(func() { err = p.checkPrefix() }) 141 | if err != nil { 142 | return 0, err 143 | } 144 | return p.bufReader.WriteTo(w) 145 | } 146 | 147 | func (p *Conn) Write(b []byte) (int, error) { 148 | return p.conn.Write(b) 149 | } 150 | 151 | func (p *Conn) Close() error { 152 | return p.conn.Close() 153 | } 154 | 155 | func (p *Conn) LocalAddr() net.Addr { 156 | p.checkPrefixOnce() 157 | if p.dstAddr != nil && !p.useConnAddr { 158 | return p.dstAddr 159 | } 160 | return p.conn.LocalAddr() 161 | } 162 | 163 | // RemoteAddr returns the address of the client if the proxy 164 | // protocol is being used, otherwise just returns the address of 165 | // the socket peer. If there is an error parsing the header, the 166 | // address of the client is not returned, and the socket is closed. 167 | // Once implication of this is that the call could block if the 168 | // client is slow. Using a Deadline is recommended if this is called 169 | // before Read() 170 | func (p *Conn) RemoteAddr() net.Addr { 171 | p.checkPrefixOnce() 172 | if p.srcAddr != nil && !p.useConnAddr { 173 | return p.srcAddr 174 | } 175 | return p.conn.RemoteAddr() 176 | } 177 | 178 | func (p *Conn) SetDeadline(t time.Time) error { 179 | return p.conn.SetDeadline(t) 180 | } 181 | 182 | func (p *Conn) SetReadDeadline(t time.Time) error { 183 | return p.conn.SetReadDeadline(t) 184 | } 185 | 186 | func (p *Conn) SetWriteDeadline(t time.Time) error { 187 | return p.conn.SetWriteDeadline(t) 188 | } 189 | 190 | func (p *Conn) checkPrefixOnce() { 191 | p.once.Do(func() { 192 | if err := p.checkPrefix(); err != nil && err != io.EOF { 193 | log.Printf("[ERR] Failed to read proxy prefix: %v", err) 194 | p.Close() 195 | p.bufReader = bufio.NewReader(p.conn) 196 | } 197 | }) 198 | } 199 | 200 | func (p *Conn) checkPrefix() error { 201 | if p.proxyHeaderTimeout != 0 { 202 | readDeadLine := time.Now().Add(p.proxyHeaderTimeout) 203 | p.conn.SetReadDeadline(readDeadLine) 204 | defer p.conn.SetReadDeadline(time.Time{}) 205 | } 206 | 207 | // Incrementally check each byte of the prefix 208 | for i := 1; i <= prefixLen; i++ { 209 | inp, err := p.bufReader.Peek(i) 210 | 211 | if err != nil { 212 | if neterr, ok := err.(net.Error); ok && neterr.Timeout() { 213 | return nil 214 | } else { 215 | return err 216 | } 217 | } 218 | 219 | // Check for a prefix mis-match, quit early 220 | if !bytes.Equal(inp, prefix[:i]) { 221 | return nil 222 | } 223 | } 224 | 225 | // Read the header line 226 | header, err := p.bufReader.ReadString('\n') 227 | if err != nil { 228 | p.conn.Close() 229 | return err 230 | } 231 | 232 | // Strip the carriage return and new line 233 | header = header[:len(header)-2] 234 | 235 | // Split on spaces, should be (PROXY ) 236 | parts := strings.Split(header, " ") 237 | if len(parts) < 2 { 238 | p.conn.Close() 239 | return fmt.Errorf("Invalid header line: %s", header) 240 | } 241 | 242 | // Verify the type is known 243 | switch parts[1] { 244 | case "UNKNOWN": 245 | if !p.unknownOK || len(parts) != 2 { 246 | p.conn.Close() 247 | return fmt.Errorf("Invalid UNKNOWN header line: %s", header) 248 | } 249 | p.useConnAddr = true 250 | return nil 251 | case "TCP4": 252 | case "TCP6": 253 | default: 254 | p.conn.Close() 255 | return fmt.Errorf("Unhandled address type: %s", parts[1]) 256 | } 257 | 258 | if len(parts) != 6 { 259 | p.conn.Close() 260 | return fmt.Errorf("Invalid header line: %s", header) 261 | } 262 | 263 | // Parse out the source address 264 | ip := net.ParseIP(parts[2]) 265 | if ip == nil { 266 | p.conn.Close() 267 | return fmt.Errorf("Invalid source ip: %s", parts[2]) 268 | } 269 | port, err := strconv.Atoi(parts[4]) 270 | if err != nil { 271 | p.conn.Close() 272 | return fmt.Errorf("Invalid source port: %s", parts[4]) 273 | } 274 | p.srcAddr = &net.TCPAddr{IP: ip, Port: port} 275 | 276 | // Parse out the destination address 277 | ip = net.ParseIP(parts[3]) 278 | if ip == nil { 279 | p.conn.Close() 280 | return fmt.Errorf("Invalid destination ip: %s", parts[3]) 281 | } 282 | port, err = strconv.Atoi(parts[5]) 283 | if err != nil { 284 | p.conn.Close() 285 | return fmt.Errorf("Invalid destination port: %s", parts[5]) 286 | } 287 | p.dstAddr = &net.TCPAddr{IP: ip, Port: port} 288 | 289 | return nil 290 | } 291 | -------------------------------------------------------------------------------- /protocol_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | const ( 12 | goodAddr = "127.0.0.1" 13 | badAddr = "127.0.0.2" 14 | errAddr = "9999.0.0.2" 15 | ) 16 | 17 | var ( 18 | checkAddr string 19 | ) 20 | 21 | func TestPassthrough(t *testing.T) { 22 | l, err := net.Listen("tcp", "127.0.0.1:0") 23 | if err != nil { 24 | t.Fatalf("err: %v", err) 25 | } 26 | 27 | pl := &Listener{Listener: l} 28 | 29 | go func() { 30 | conn, err := net.Dial("tcp", pl.Addr().String()) 31 | if err != nil { 32 | t.Fatalf("err: %v", err) 33 | } 34 | defer conn.Close() 35 | 36 | conn.Write([]byte("ping")) 37 | recv := make([]byte, 4) 38 | _, err = conn.Read(recv) 39 | if err != nil { 40 | t.Fatalf("err: %v", err) 41 | } 42 | if !bytes.Equal(recv, []byte("pong")) { 43 | t.Fatalf("bad: %v", recv) 44 | } 45 | }() 46 | 47 | conn, err := pl.Accept() 48 | if err != nil { 49 | t.Fatalf("err: %v", err) 50 | } 51 | defer conn.Close() 52 | 53 | recv := make([]byte, 4) 54 | _, err = conn.Read(recv) 55 | if err != nil { 56 | t.Fatalf("err: %v", err) 57 | } 58 | if !bytes.Equal(recv, []byte("ping")) { 59 | t.Fatalf("bad: %v", recv) 60 | } 61 | 62 | if _, err := conn.Write([]byte("pong")); err != nil { 63 | t.Fatalf("err: %v", err) 64 | } 65 | } 66 | 67 | func TestTimeout(t *testing.T) { 68 | l, err := net.Listen("tcp", "127.0.0.1:0") 69 | if err != nil { 70 | t.Fatalf("err: %v", err) 71 | } 72 | 73 | clientWriteDelay := 200 * time.Millisecond 74 | proxyHeaderTimeout := 50 * time.Millisecond 75 | pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout} 76 | 77 | go func() { 78 | conn, err := net.Dial("tcp", pl.Addr().String()) 79 | if err != nil { 80 | t.Fatalf("err: %v", err) 81 | } 82 | defer conn.Close() 83 | 84 | // Do not send data for a while 85 | time.Sleep(clientWriteDelay) 86 | 87 | conn.Write([]byte("ping")) 88 | recv := make([]byte, 4) 89 | _, err = conn.Read(recv) 90 | if err != nil { 91 | t.Fatalf("err: %v", err) 92 | } 93 | if !bytes.Equal(recv, []byte("pong")) { 94 | t.Fatalf("bad: %v", recv) 95 | } 96 | }() 97 | 98 | conn, err := pl.Accept() 99 | if err != nil { 100 | t.Fatalf("err: %v", err) 101 | } 102 | defer conn.Close() 103 | 104 | // Check the remote addr is the original 127.0.0.1 105 | remoteAddrStartTime := time.Now() 106 | addr := conn.RemoteAddr().(*net.TCPAddr) 107 | if addr.IP.String() != "127.0.0.1" { 108 | t.Fatalf("bad: %v", addr) 109 | } 110 | remoteAddrDuration := time.Since(remoteAddrStartTime) 111 | 112 | // Check RemoteAddr() call did timeout 113 | if remoteAddrDuration >= clientWriteDelay { 114 | t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration) 115 | } 116 | 117 | recv := make([]byte, 4) 118 | _, err = conn.Read(recv) 119 | if err != nil { 120 | t.Fatalf("err: %v", err) 121 | } 122 | if !bytes.Equal(recv, []byte("ping")) { 123 | t.Fatalf("bad: %v", recv) 124 | } 125 | 126 | if _, err := conn.Write([]byte("pong")); err != nil { 127 | t.Fatalf("err: %v", err) 128 | } 129 | } 130 | 131 | func TestParse_ipv4(t *testing.T) { 132 | l, err := net.Listen("tcp", "127.0.0.1:0") 133 | if err != nil { 134 | t.Fatalf("err: %v", err) 135 | } 136 | 137 | pl := &Listener{Listener: l} 138 | 139 | go func() { 140 | conn, err := net.Dial("tcp", pl.Addr().String()) 141 | if err != nil { 142 | t.Fatalf("err: %v", err) 143 | } 144 | defer conn.Close() 145 | 146 | // Write out the header! 147 | header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" 148 | conn.Write([]byte(header)) 149 | 150 | conn.Write([]byte("ping")) 151 | recv := make([]byte, 4) 152 | _, err = conn.Read(recv) 153 | if err != nil { 154 | t.Fatalf("err: %v", err) 155 | } 156 | if !bytes.Equal(recv, []byte("pong")) { 157 | t.Fatalf("bad: %v", recv) 158 | } 159 | }() 160 | 161 | conn, err := pl.Accept() 162 | if err != nil { 163 | t.Fatalf("err: %v", err) 164 | } 165 | defer conn.Close() 166 | 167 | recv := make([]byte, 4) 168 | _, err = conn.Read(recv) 169 | if err != nil { 170 | t.Fatalf("err: %v", err) 171 | } 172 | if !bytes.Equal(recv, []byte("ping")) { 173 | t.Fatalf("bad: %v", recv) 174 | } 175 | 176 | if _, err := conn.Write([]byte("pong")); err != nil { 177 | t.Fatalf("err: %v", err) 178 | } 179 | 180 | // Check the remote addr 181 | addr := conn.RemoteAddr().(*net.TCPAddr) 182 | if addr.IP.String() != "10.1.1.1" { 183 | t.Fatalf("bad: %v", addr) 184 | } 185 | if addr.Port != 1000 { 186 | t.Fatalf("bad: %v", addr) 187 | } 188 | } 189 | 190 | func TestParse_ipv6(t *testing.T) { 191 | l, err := net.Listen("tcp", "127.0.0.1:0") 192 | if err != nil { 193 | t.Fatalf("err: %v", err) 194 | } 195 | 196 | pl := &Listener{Listener: l} 197 | 198 | go func() { 199 | conn, err := net.Dial("tcp", pl.Addr().String()) 200 | if err != nil { 201 | t.Fatalf("err: %v", err) 202 | } 203 | defer conn.Close() 204 | 205 | // Write out the header! 206 | header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n" 207 | conn.Write([]byte(header)) 208 | 209 | conn.Write([]byte("ping")) 210 | recv := make([]byte, 4) 211 | _, err = conn.Read(recv) 212 | if err != nil { 213 | t.Fatalf("err: %v", err) 214 | } 215 | if !bytes.Equal(recv, []byte("pong")) { 216 | t.Fatalf("bad: %v", recv) 217 | } 218 | }() 219 | 220 | conn, err := pl.Accept() 221 | if err != nil { 222 | t.Fatalf("err: %v", err) 223 | } 224 | defer conn.Close() 225 | 226 | recv := make([]byte, 4) 227 | _, err = conn.Read(recv) 228 | if err != nil { 229 | t.Fatalf("err: %v", err) 230 | } 231 | if !bytes.Equal(recv, []byte("ping")) { 232 | t.Fatalf("bad: %v", recv) 233 | } 234 | 235 | if _, err := conn.Write([]byte("pong")); err != nil { 236 | t.Fatalf("err: %v", err) 237 | } 238 | 239 | // Check the remote addr 240 | addr := conn.RemoteAddr().(*net.TCPAddr) 241 | if addr.IP.String() != "ffff::ffff" { 242 | t.Fatalf("bad: %v", addr) 243 | } 244 | if addr.Port != 1000 { 245 | t.Fatalf("bad: %v", addr) 246 | } 247 | } 248 | 249 | func TestParse_Unknown(t *testing.T) { 250 | l, err := net.Listen("tcp", "127.0.0.1:0") 251 | if err != nil { 252 | t.Fatalf("err: %v", err) 253 | } 254 | 255 | pl := &Listener{Listener: l, UnknownOK: true} 256 | 257 | go func() { 258 | conn, err := net.Dial("tcp", pl.Addr().String()) 259 | if err != nil { 260 | t.Fatalf("err: %v", err) 261 | } 262 | defer conn.Close() 263 | 264 | // Write out the header! 265 | header := "PROXY UNKNOWN\r\n" 266 | conn.Write([]byte(header)) 267 | 268 | conn.Write([]byte("ping")) 269 | recv := make([]byte, 4) 270 | _, err = conn.Read(recv) 271 | if err != nil { 272 | t.Fatalf("err: %v", err) 273 | } 274 | if !bytes.Equal(recv, []byte("pong")) { 275 | t.Fatalf("bad: %v", recv) 276 | } 277 | }() 278 | 279 | conn, err := pl.Accept() 280 | if err != nil { 281 | t.Fatalf("err: %v", err) 282 | } 283 | defer conn.Close() 284 | 285 | recv := make([]byte, 4) 286 | _, err = conn.Read(recv) 287 | if err != nil { 288 | t.Fatalf("err: %v", err) 289 | } 290 | if !bytes.Equal(recv, []byte("ping")) { 291 | t.Fatalf("bad: %v", recv) 292 | } 293 | 294 | if _, err := conn.Write([]byte("pong")); err != nil { 295 | t.Fatalf("err: %v", err) 296 | } 297 | 298 | } 299 | 300 | func TestParse_BadHeader(t *testing.T) { 301 | l, err := net.Listen("tcp", "127.0.0.1:0") 302 | if err != nil { 303 | t.Fatalf("err: %v", err) 304 | } 305 | 306 | pl := &Listener{Listener: l} 307 | 308 | go func() { 309 | conn, err := net.Dial("tcp", pl.Addr().String()) 310 | if err != nil { 311 | t.Fatalf("err: %v", err) 312 | } 313 | defer conn.Close() 314 | 315 | // Write out the header! 316 | header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n" 317 | conn.Write([]byte(header)) 318 | 319 | conn.Write([]byte("ping")) 320 | 321 | recv := make([]byte, 4) 322 | _, err = conn.Read(recv) 323 | if err == nil { 324 | t.Fatalf("err: %v", err) 325 | } 326 | }() 327 | 328 | conn, err := pl.Accept() 329 | if err != nil { 330 | t.Fatalf("err: %v", err) 331 | } 332 | defer conn.Close() 333 | 334 | // Check the remote addr, should be the local addr 335 | addr := conn.RemoteAddr().(*net.TCPAddr) 336 | if addr.IP.String() != "127.0.0.1" { 337 | t.Fatalf("bad: %v", addr) 338 | } 339 | 340 | // Read should fail 341 | recv := make([]byte, 4) 342 | _, err = conn.Read(recv) 343 | if err == nil { 344 | t.Fatalf("err: %v", err) 345 | } 346 | } 347 | 348 | func TestParse_ipv4_checkfunc(t *testing.T) { 349 | checkAddr = goodAddr 350 | testParse_ipv4_checkfunc(t) 351 | checkAddr = badAddr 352 | testParse_ipv4_checkfunc(t) 353 | checkAddr = errAddr 354 | testParse_ipv4_checkfunc(t) 355 | } 356 | 357 | func testParse_ipv4_checkfunc(t *testing.T) { 358 | l, err := net.Listen("tcp", "127.0.0.1:0") 359 | if err != nil { 360 | t.Fatalf("err: %v", err) 361 | } 362 | 363 | checkFunc := func(addr net.Addr) (bool, error) { 364 | tcpAddr := addr.(*net.TCPAddr) 365 | if tcpAddr.IP.String() == checkAddr { 366 | return true, nil 367 | } 368 | return false, nil 369 | } 370 | 371 | pl := &Listener{Listener: l, SourceCheck: checkFunc} 372 | 373 | go func() { 374 | conn, err := net.Dial("tcp", pl.Addr().String()) 375 | if err != nil { 376 | t.Fatalf("err: %v", err) 377 | } 378 | defer conn.Close() 379 | 380 | // Write out the header! 381 | header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" 382 | conn.Write([]byte(header)) 383 | 384 | conn.Write([]byte("ping")) 385 | recv := make([]byte, 4) 386 | _, err = conn.Read(recv) 387 | if err != nil { 388 | t.Fatalf("err: %v", err) 389 | } 390 | if !bytes.Equal(recv, []byte("pong")) { 391 | t.Fatalf("bad: %v", recv) 392 | } 393 | }() 394 | 395 | conn, err := pl.Accept() 396 | if err != nil { 397 | if checkAddr == badAddr { 398 | return 399 | } 400 | t.Fatalf("err: %v", err) 401 | } 402 | defer conn.Close() 403 | 404 | recv := make([]byte, 4) 405 | _, err = conn.Read(recv) 406 | if err != nil { 407 | t.Fatalf("err: %v", err) 408 | } 409 | if !bytes.Equal(recv, []byte("ping")) { 410 | t.Fatalf("bad: %v", recv) 411 | } 412 | 413 | if _, err := conn.Write([]byte("pong")); err != nil { 414 | t.Fatalf("err: %v", err) 415 | } 416 | 417 | // Check the remote addr 418 | addr := conn.RemoteAddr().(*net.TCPAddr) 419 | switch checkAddr { 420 | case goodAddr: 421 | if addr.IP.String() != "10.1.1.1" { 422 | t.Fatalf("bad: %v", addr) 423 | } 424 | if addr.Port != 1000 { 425 | t.Fatalf("bad: %v", addr) 426 | } 427 | case badAddr: 428 | if addr.IP.String() != "127.0.0.1" { 429 | t.Fatalf("bad: %v", addr) 430 | } 431 | if addr.Port == 1000 { 432 | t.Fatalf("bad: %v", addr) 433 | } 434 | } 435 | } 436 | 437 | type testConn struct { 438 | readFromCalledWith io.Reader 439 | net.Conn // nil; crash on any unexpected use 440 | } 441 | 442 | func (c *testConn) ReadFrom(r io.Reader) (int64, error) { 443 | c.readFromCalledWith = r 444 | return 0, nil 445 | } 446 | func (c *testConn) Write(p []byte) (int, error) { 447 | return len(p), nil 448 | } 449 | func (c *testConn) Read(p []byte) (int, error) { 450 | return 1, nil 451 | } 452 | 453 | func TestCopyToWrappedConnection(t *testing.T) { 454 | innerConn := &testConn{} 455 | wrappedConn := NewConn(innerConn, 0) 456 | dummySrc := &testConn{} 457 | 458 | io.Copy(wrappedConn, dummySrc) 459 | if innerConn.readFromCalledWith != dummySrc { 460 | t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection") 461 | } 462 | } 463 | 464 | func TestCopyFromWrappedConnection(t *testing.T) { 465 | wrappedConn := NewConn(&testConn{}, 0) 466 | dummyDst := &testConn{} 467 | 468 | io.Copy(dummyDst, wrappedConn) 469 | if dummyDst.readFromCalledWith != wrappedConn.conn { 470 | t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination") 471 | } 472 | } 473 | 474 | func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { 475 | innerConn1 := &testConn{} 476 | wrappedConn1 := NewConn(innerConn1, 0) 477 | innerConn2 := &testConn{} 478 | wrappedConn2 := NewConn(innerConn2, 0) 479 | 480 | io.Copy(wrappedConn1, wrappedConn2) 481 | if innerConn1.readFromCalledWith != innerConn2 { 482 | t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection") 483 | } 484 | 485 | } 486 | --------------------------------------------------------------------------------