├── .github ├── FUNDING.yml └── workflows │ ├── golangci-lint.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── addr_proto.go ├── addr_proto_test.go ├── examples ├── client │ └── client.go ├── httpserver │ └── httpserver.go └── server │ └── server.go ├── go.mod ├── go.sum ├── header.go ├── header_test.go ├── helper └── http2 │ ├── http2.go │ ├── http2_test.go │ └── listener.go ├── policy.go ├── policy_test.go ├── protocol.go ├── protocol_test.go ├── tlv.go ├── tlv_test.go ├── tlvparse ├── aws.go ├── aws_test.go ├── azure.go ├── azure_test.go ├── gcp.go ├── gcp_test.go ├── ssl.go ├── ssl_test.go └── test.go ├── v1.go ├── v1_test.go ├── v2.go ├── v2_test.go ├── version_cmd.go └── version_cmd_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: pires 2 | -------------------------------------------------------------------------------- /.github/workflows/golangci-lint.yml: -------------------------------------------------------------------------------- 1 | name: golangci-lint 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | branches: 8 | - main 9 | pull_request: 10 | 11 | permissions: 12 | # Required: allow read access to the content for analysis. 13 | contents: read 14 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 15 | pull-requests: read 16 | # Optional: allow write access to checks to allow the action to annotate code in the PR. 17 | checks: write 18 | 19 | jobs: 20 | golangci: 21 | name: lint 22 | runs-on: ubuntu-latest 23 | env: 24 | GOTOOLCHAIN: local 25 | strategy: 26 | matrix: 27 | go: ['1.23', '1.24'] 28 | steps: 29 | - uses: actions/checkout@v4 30 | - uses: actions/setup-go@v5 31 | with: 32 | go-version: ${{ matrix.go }} 33 | 34 | - name: Tidy 35 | run: go mod tidy 36 | 37 | - name: Format 38 | run: go fmt 39 | 40 | - name: Vet 41 | run: go vet 42 | 43 | - name: lint 44 | uses: golangci/golangci-lint-action@v6 45 | #with: 46 | # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version 47 | #version: v1.29 48 | 49 | # Optional: working directory, useful for monorepos 50 | # working-directory: somedir 51 | 52 | # Optional: golangci-lint command line arguments. 53 | # args: --issues-exit-code=0 54 | 55 | # Optional: show only new issues if it's a pull request. The default value is `false`. 56 | # only-new-issues: true 57 | 58 | # Optional: if set to true then the all caching functionality will be complete disabled, 59 | # takes precedence over all other caching options. 60 | # skip-cache: true 61 | 62 | # Optional: if set to true then the action don't cache or restore ~/go/pkg. 63 | # skip-pkg-cache: true 64 | 65 | # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. 66 | # skip-build-cache: true 67 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Release 14 | uses: softprops/action-gh-release@v2 15 | with: 16 | generate_release_notes: true 17 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | pull_request: 5 | push: 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | env: 11 | GOTOOLCHAIN: local 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | go: ['1.23', '1.24'] 16 | steps: 17 | - uses: actions/setup-go@v4 18 | with: 19 | go-version: ${{ matrix.go }} 20 | - uses: actions/checkout@v4 21 | 22 | - name: Get dependencies 23 | run: | 24 | go get golang.org/x/tools/cmd/cover 25 | go get github.com/mattn/goveralls 26 | 27 | - name: Test 28 | run: go test -race -v -covermode=atomic -coverprofile=coverage.out 29 | 30 | - name: Send coverage 31 | uses: shogo82148/actions-goveralls@v1 32 | with: 33 | github-token: ${{ secrets.GITHUB_TOKEN }} 34 | path-to-profile: coverage.out 35 | flag-name: Go-${{ matrix.go }} 36 | parallel: true 37 | 38 | # notifies that all test jobs are finished. 39 | finish: 40 | needs: test 41 | runs-on: ubuntu-latest 42 | steps: 43 | - uses: shogo82148/actions-goveralls@v1 44 | with: 45 | parallel-finished: true 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | .idea 8 | bin 9 | pkg 10 | 11 | *.out 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2016 Paulo Pires 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-proxyproto 2 | 3 | [![Actions Status](https://github.com/pires/go-proxyproto/workflows/test/badge.svg)](https://github.com/pires/go-proxyproto/actions) 4 | [![Coverage Status](https://coveralls.io/repos/github/pires/go-proxyproto/badge.svg?branch=master)](https://coveralls.io/github/pires/go-proxyproto?branch=master) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/pires/go-proxyproto)](https://goreportcard.com/report/github.com/pires/go-proxyproto) 6 | [![](https://godoc.org/github.com/pires/go-proxyproto?status.svg)](https://pkg.go.dev/github.com/pires/go-proxyproto?tab=doc) 7 | 8 | 9 | A Go library implementation of the [PROXY protocol, versions 1 and 2](https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt), 10 | which provides, as per specification: 11 | > (...) a convenient way to safely transport connection 12 | > information such as a client's address across multiple layers of NAT or TCP 13 | > proxies. It is designed to require little changes to existing components and 14 | > to limit the performance impact caused by the processing of the transported 15 | > information. 16 | 17 | This library is to be used in one of or both proxy clients and proxy servers that need to support said protocol. 18 | Both protocol versions, 1 (text-based) and 2 (binary-based) are supported. 19 | 20 | ## Installation 21 | 22 | ```shell 23 | $ go get -u github.com/pires/go-proxyproto 24 | ``` 25 | 26 | ## Usage 27 | 28 | ### Client 29 | 30 | ```go 31 | package main 32 | 33 | import ( 34 | "io" 35 | "log" 36 | "net" 37 | 38 | proxyproto "github.com/pires/go-proxyproto" 39 | ) 40 | 41 | func chkErr(err error) { 42 | if err != nil { 43 | log.Fatalf("Error: %s", err.Error()) 44 | } 45 | } 46 | 47 | func main() { 48 | // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto 49 | target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:2319") 50 | chkErr(err) 51 | 52 | conn, err := net.DialTCP("tcp", nil, target) 53 | chkErr(err) 54 | 55 | defer conn.Close() 56 | 57 | // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you 58 | // have two conn's 59 | header := &proxyproto.Header{ 60 | Version: 1, 61 | Command: proxyproto.PROXY, 62 | TransportProtocol: proxyproto.TCPv4, 63 | SourceAddr: &net.TCPAddr{ 64 | IP: net.ParseIP("10.1.1.1"), 65 | Port: 1000, 66 | }, 67 | DestinationAddr: &net.TCPAddr{ 68 | IP: net.ParseIP("20.2.2.2"), 69 | Port: 2000, 70 | }, 71 | } 72 | // After the connection was created write the proxy headers first 73 | _, err = header.WriteTo(conn) 74 | chkErr(err) 75 | // Then your data... e.g.: 76 | _, err = io.WriteString(conn, "HELO") 77 | chkErr(err) 78 | } 79 | ``` 80 | 81 | ### Server 82 | 83 | ```go 84 | package main 85 | 86 | import ( 87 | "log" 88 | "net" 89 | 90 | proxyproto "github.com/pires/go-proxyproto" 91 | ) 92 | 93 | func main() { 94 | // Create a listener 95 | addr := "localhost:9876" 96 | list, err := net.Listen("tcp", addr) 97 | if err != nil { 98 | log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) 99 | } 100 | 101 | // Wrap listener in a proxyproto listener 102 | proxyListener := &proxyproto.Listener{Listener: list} 103 | defer proxyListener.Close() 104 | 105 | // Wait for a connection and accept it 106 | conn, err := proxyListener.Accept() 107 | defer conn.Close() 108 | 109 | // Print connection details 110 | if conn.LocalAddr() == nil { 111 | log.Fatal("couldn't retrieve local address") 112 | } 113 | log.Printf("local address: %q", conn.LocalAddr().String()) 114 | 115 | if conn.RemoteAddr() == nil { 116 | log.Fatal("couldn't retrieve remote address") 117 | } 118 | log.Printf("remote address: %q", conn.RemoteAddr().String()) 119 | } 120 | ``` 121 | 122 | ### HTTP Server 123 | ```go 124 | package main 125 | 126 | import ( 127 | "net" 128 | "net/http" 129 | "time" 130 | 131 | "github.com/pires/go-proxyproto" 132 | ) 133 | 134 | func main() { 135 | server := http.Server{ 136 | Addr: ":8080", 137 | } 138 | 139 | ln, err := net.Listen("tcp", server.Addr) 140 | if err != nil { 141 | panic(err) 142 | } 143 | 144 | proxyListener := &proxyproto.Listener{ 145 | Listener: ln, 146 | ReadHeaderTimeout: 10 * time.Second, 147 | } 148 | defer proxyListener.Close() 149 | 150 | server.Serve(proxyListener) 151 | } 152 | ``` 153 | 154 | ## Special notes 155 | 156 | ### AWS 157 | 158 | AWS Network Load Balancer (NLB) does not push the PPV2 header until the client starts sending the data. This is a problem if your server speaks first. e.g. SMTP, FTP, SSH etc. 159 | 160 | By default, NLB target group attribute `proxy_protocol_v2.client_to_server.header_placement` has the value `on_first_ack_with_payload`. You need to contact AWS support to change it to `on_first_ack`, instead. 161 | 162 | Just to be clear, you need this fix only if your server is designed to speak first. 163 | -------------------------------------------------------------------------------- /addr_proto.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | // AddressFamilyAndProtocol represents address family and transport protocol. 4 | type AddressFamilyAndProtocol byte 5 | 6 | const ( 7 | UNSPEC AddressFamilyAndProtocol = '\x00' 8 | TCPv4 AddressFamilyAndProtocol = '\x11' 9 | UDPv4 AddressFamilyAndProtocol = '\x12' 10 | TCPv6 AddressFamilyAndProtocol = '\x21' 11 | UDPv6 AddressFamilyAndProtocol = '\x22' 12 | UnixStream AddressFamilyAndProtocol = '\x31' 13 | UnixDatagram AddressFamilyAndProtocol = '\x32' 14 | ) 15 | 16 | // IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. 17 | func (ap AddressFamilyAndProtocol) IsIPv4() bool { 18 | return ap&0xF0 == 0x10 19 | } 20 | 21 | // IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. 22 | func (ap AddressFamilyAndProtocol) IsIPv6() bool { 23 | return ap&0xF0 == 0x20 24 | } 25 | 26 | // IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. 27 | func (ap AddressFamilyAndProtocol) IsUnix() bool { 28 | return ap&0xF0 == 0x30 29 | } 30 | 31 | // IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. 32 | func (ap AddressFamilyAndProtocol) IsStream() bool { 33 | return ap&0x0F == 0x01 34 | } 35 | 36 | // IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. 37 | func (ap AddressFamilyAndProtocol) IsDatagram() bool { 38 | return ap&0x0F == 0x02 39 | } 40 | 41 | // IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. 42 | func (ap AddressFamilyAndProtocol) IsUnspec() bool { 43 | return (ap&0xF0 == 0x00) || (ap&0x0F == 0x00) 44 | } 45 | 46 | func (ap AddressFamilyAndProtocol) toByte() byte { 47 | if ap.IsIPv4() && ap.IsStream() { 48 | return byte(TCPv4) 49 | } else if ap.IsIPv4() && ap.IsDatagram() { 50 | return byte(UDPv4) 51 | } else if ap.IsIPv6() && ap.IsStream() { 52 | return byte(TCPv6) 53 | } else if ap.IsIPv6() && ap.IsDatagram() { 54 | return byte(UDPv6) 55 | } else if ap.IsUnix() && ap.IsStream() { 56 | return byte(UnixStream) 57 | } else if ap.IsUnix() && ap.IsDatagram() { 58 | return byte(UnixDatagram) 59 | } 60 | 61 | return byte(UNSPEC) 62 | } 63 | -------------------------------------------------------------------------------- /addr_proto_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestTCPoverIPv4(t *testing.T) { 8 | b := byte(TCPv4) 9 | if !AddressFamilyAndProtocol(b).IsIPv4() { 10 | t.Fail() 11 | } 12 | if !AddressFamilyAndProtocol(b).IsStream() { 13 | t.Fail() 14 | } 15 | if AddressFamilyAndProtocol(b).toByte() != b { 16 | t.Fail() 17 | } 18 | } 19 | 20 | func TestTCPoverIPv6(t *testing.T) { 21 | b := byte(TCPv6) 22 | if !AddressFamilyAndProtocol(b).IsIPv6() { 23 | t.Fail() 24 | } 25 | if !AddressFamilyAndProtocol(b).IsStream() { 26 | t.Fail() 27 | } 28 | if AddressFamilyAndProtocol(b).toByte() != b { 29 | t.Fail() 30 | } 31 | } 32 | 33 | func TestUDPoverIPv4(t *testing.T) { 34 | b := byte(UDPv4) 35 | if !AddressFamilyAndProtocol(b).IsIPv4() { 36 | t.Fail() 37 | } 38 | if !AddressFamilyAndProtocol(b).IsDatagram() { 39 | t.Fail() 40 | } 41 | if AddressFamilyAndProtocol(b).toByte() != b { 42 | t.Fail() 43 | } 44 | } 45 | 46 | func TestUDPoverIPv6(t *testing.T) { 47 | b := byte(UDPv6) 48 | if !AddressFamilyAndProtocol(b).IsIPv6() { 49 | t.Fail() 50 | } 51 | if !AddressFamilyAndProtocol(b).IsDatagram() { 52 | t.Fail() 53 | } 54 | if AddressFamilyAndProtocol(b).toByte() != b { 55 | t.Fail() 56 | } 57 | } 58 | 59 | func TestUnixStream(t *testing.T) { 60 | b := byte(UnixStream) 61 | if !AddressFamilyAndProtocol(b).IsUnix() { 62 | t.Fail() 63 | } 64 | if !AddressFamilyAndProtocol(b).IsStream() { 65 | t.Fail() 66 | } 67 | if AddressFamilyAndProtocol(b).toByte() != b { 68 | t.Fail() 69 | } 70 | } 71 | 72 | func TestUnixDatagram(t *testing.T) { 73 | b := byte(UnixDatagram) 74 | if !AddressFamilyAndProtocol(b).IsUnix() { 75 | t.Fail() 76 | } 77 | if !AddressFamilyAndProtocol(b).IsDatagram() { 78 | t.Fail() 79 | } 80 | if AddressFamilyAndProtocol(b).toByte() != b { 81 | t.Fail() 82 | } 83 | } 84 | 85 | func TestInvalidAddressFamilyAndProtocol(t *testing.T) { 86 | b := byte(UNSPEC) 87 | if !AddressFamilyAndProtocol(b).IsUnspec() { 88 | t.Fail() 89 | } 90 | if AddressFamilyAndProtocol(b).toByte() != b { 91 | t.Fail() 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /examples/client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net" 7 | 8 | proxyproto "github.com/pires/go-proxyproto" 9 | ) 10 | 11 | func chkErr(err error) { 12 | if err != nil { 13 | log.Fatalf("Error: %s", err.Error()) 14 | } 15 | } 16 | 17 | func main() { 18 | // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto 19 | target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9876") 20 | chkErr(err) 21 | 22 | conn, err := net.DialTCP("tcp", nil, target) 23 | chkErr(err) 24 | 25 | defer conn.Close() 26 | 27 | // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you 28 | // have two conn's 29 | header := &proxyproto.Header{ 30 | Version: 1, 31 | Command: proxyproto.PROXY, 32 | TransportProtocol: proxyproto.TCPv4, 33 | SourceAddr: &net.TCPAddr{ 34 | IP: net.ParseIP("10.1.1.1"), 35 | Port: 1000, 36 | }, 37 | DestinationAddr: &net.TCPAddr{ 38 | IP: net.ParseIP("20.2.2.2"), 39 | Port: 2000, 40 | }, 41 | } 42 | // After the connection was created write the proxy headers first 43 | _, err = header.WriteTo(conn) 44 | chkErr(err) 45 | // Then your data... e.g.: 46 | _, err = io.WriteString(conn, "HELO") 47 | chkErr(err) 48 | } 49 | -------------------------------------------------------------------------------- /examples/httpserver/httpserver.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/pires/go-proxyproto" 10 | h2proxy "github.com/pires/go-proxyproto/helper/http2" 11 | ) 12 | 13 | // TODO: add httpclient example 14 | 15 | func main() { 16 | server := http.Server{ 17 | Addr: ":8080", 18 | ConnState: func(c net.Conn, s http.ConnState) { 19 | if s == http.StateNew { 20 | log.Printf("[ConnState] %s -> %s", c.LocalAddr().String(), c.RemoteAddr().String()) 21 | } 22 | }, 23 | Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 24 | log.Printf("[Handler] remote ip %q", r.RemoteAddr) 25 | }), 26 | } 27 | 28 | ln, err := net.Listen("tcp", server.Addr) 29 | if err != nil { 30 | panic(err) 31 | } 32 | 33 | proxyListener := &proxyproto.Listener{ 34 | Listener: ln, 35 | ReadHeaderTimeout: 10 * time.Second, 36 | } 37 | defer proxyListener.Close() 38 | 39 | // Create an HTTP server which can handle proxied incoming connections for 40 | // both HTTP/1 and HTTP/2. HTTP/2 support relies on TLS ALPN, the reverse 41 | // proxy needs to be configured to accept "h2". 42 | h2proxy.NewServer(&server, nil).Serve(proxyListener) 43 | } 44 | -------------------------------------------------------------------------------- /examples/server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net" 6 | 7 | proxyproto "github.com/pires/go-proxyproto" 8 | ) 9 | 10 | func main() { 11 | // Create a listener 12 | addr := "localhost:9876" 13 | list, err := net.Listen("tcp", addr) 14 | if err != nil { 15 | log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) 16 | } 17 | 18 | // Wrap listener in a proxyproto listener 19 | proxyListener := &proxyproto.Listener{Listener: list} 20 | defer proxyListener.Close() 21 | 22 | // Wait for a connection and accept it 23 | conn, err := proxyListener.Accept() 24 | defer conn.Close() 25 | 26 | // Print connection details 27 | if conn.LocalAddr() == nil { 28 | log.Fatal("couldn't retrieve local address") 29 | } 30 | log.Printf("local address: %q", conn.LocalAddr().String()) 31 | 32 | if conn.RemoteAddr() == nil { 33 | log.Fatal("couldn't retrieve remote address") 34 | } 35 | log.Printf("remote address: %q", conn.RemoteAddr().String()) 36 | } 37 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pires/go-proxyproto 2 | 3 | go 1.23 4 | 5 | require golang.org/x/net v0.39.0 6 | 7 | require golang.org/x/text v0.24.0 // indirect 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= 2 | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= 3 | golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= 4 | golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= 5 | -------------------------------------------------------------------------------- /header.go: -------------------------------------------------------------------------------- 1 | // Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification: 2 | // https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt 3 | package proxyproto 4 | 5 | import ( 6 | "bufio" 7 | "bytes" 8 | "errors" 9 | "io" 10 | "net" 11 | "time" 12 | ) 13 | 14 | var ( 15 | // Protocol 16 | SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'} 17 | SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} 18 | 19 | ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header") 20 | ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less") 21 | ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n") 22 | ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command") 23 | ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol") 24 | ErrCantReadLength = errors.New("proxyproto: can't read length") 25 | ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address") 26 | ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address") 27 | ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present") 28 | ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version") 29 | ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command") 30 | ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol") 31 | ErrInvalidLength = errors.New("proxyproto: invalid length") 32 | ErrInvalidAddress = errors.New("proxyproto: invalid address") 33 | ErrInvalidPortNumber = errors.New("proxyproto: invalid port number") 34 | ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one") 35 | ) 36 | 37 | // Header is the placeholder for proxy protocol header. 38 | type Header struct { 39 | Version byte 40 | Command ProtocolVersionAndCommand 41 | TransportProtocol AddressFamilyAndProtocol 42 | SourceAddr net.Addr 43 | DestinationAddr net.Addr 44 | rawTLVs []byte 45 | } 46 | 47 | // HeaderProxyFromAddrs creates a new PROXY header from a source and a 48 | // destination address. If version is zero, the latest protocol version is 49 | // used. 50 | // 51 | // The header is filled on a best-effort basis: if hints cannot be inferred 52 | // from the provided addresses, the header will be left unspecified. 53 | func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { 54 | if version < 1 || version > 2 { 55 | version = 2 56 | } 57 | h := &Header{ 58 | Version: version, 59 | Command: LOCAL, 60 | TransportProtocol: UNSPEC, 61 | } 62 | switch sourceAddr := sourceAddr.(type) { 63 | case *net.TCPAddr: 64 | if _, ok := destAddr.(*net.TCPAddr); !ok { 65 | break 66 | } 67 | if len(sourceAddr.IP.To4()) == net.IPv4len { 68 | h.TransportProtocol = TCPv4 69 | } else if len(sourceAddr.IP) == net.IPv6len { 70 | h.TransportProtocol = TCPv6 71 | } 72 | case *net.UDPAddr: 73 | if _, ok := destAddr.(*net.UDPAddr); !ok { 74 | break 75 | } 76 | if len(sourceAddr.IP.To4()) == net.IPv4len { 77 | h.TransportProtocol = UDPv4 78 | } else if len(sourceAddr.IP) == net.IPv6len { 79 | h.TransportProtocol = UDPv6 80 | } 81 | case *net.UnixAddr: 82 | if _, ok := destAddr.(*net.UnixAddr); !ok { 83 | break 84 | } 85 | switch sourceAddr.Net { 86 | case "unix": 87 | h.TransportProtocol = UnixStream 88 | case "unixgram": 89 | h.TransportProtocol = UnixDatagram 90 | } 91 | } 92 | if h.TransportProtocol != UNSPEC { 93 | h.Command = PROXY 94 | h.SourceAddr = sourceAddr 95 | h.DestinationAddr = destAddr 96 | } 97 | return h 98 | } 99 | 100 | func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) { 101 | if !header.TransportProtocol.IsStream() { 102 | return nil, nil, false 103 | } 104 | sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) 105 | destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) 106 | return sourceAddr, destAddr, sourceOK && destOK 107 | } 108 | 109 | func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) { 110 | if !header.TransportProtocol.IsDatagram() { 111 | return nil, nil, false 112 | } 113 | sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr) 114 | destAddr, destOK := header.DestinationAddr.(*net.UDPAddr) 115 | return sourceAddr, destAddr, sourceOK && destOK 116 | } 117 | 118 | func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) { 119 | if !header.TransportProtocol.IsUnix() { 120 | return nil, nil, false 121 | } 122 | sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr) 123 | destAddr, destOK := header.DestinationAddr.(*net.UnixAddr) 124 | return sourceAddr, destAddr, sourceOK && destOK 125 | } 126 | 127 | func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) { 128 | if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { 129 | return sourceAddr.IP, destAddr.IP, true 130 | } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { 131 | return sourceAddr.IP, destAddr.IP, true 132 | } else { 133 | return nil, nil, false 134 | } 135 | } 136 | 137 | func (header *Header) Ports() (sourcePort, destPort int, ok bool) { 138 | if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { 139 | return sourceAddr.Port, destAddr.Port, true 140 | } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { 141 | return sourceAddr.Port, destAddr.Port, true 142 | } else { 143 | return 0, 0, false 144 | } 145 | } 146 | 147 | // EqualTo returns true if headers are equivalent, false otherwise. 148 | // Deprecated: use EqualsTo instead. This method will eventually be removed. 149 | func (header *Header) EqualTo(otherHeader *Header) bool { 150 | return header.EqualsTo(otherHeader) 151 | } 152 | 153 | // EqualsTo returns true if headers are equivalent, false otherwise. 154 | func (header *Header) EqualsTo(otherHeader *Header) bool { 155 | if otherHeader == nil { 156 | return false 157 | } 158 | if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { 159 | return false 160 | } 161 | // TLVs only exist for version 2 162 | if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { 163 | return false 164 | } 165 | // Return early for header with LOCAL command, which contains no address information 166 | if header.Command == LOCAL { 167 | return true 168 | } 169 | return header.SourceAddr.String() == otherHeader.SourceAddr.String() && 170 | header.DestinationAddr.String() == otherHeader.DestinationAddr.String() 171 | } 172 | 173 | // WriteTo renders a proxy protocol header in a format and writes it to an io.Writer. 174 | func (header *Header) WriteTo(w io.Writer) (int64, error) { 175 | buf, err := header.Format() 176 | if err != nil { 177 | return 0, err 178 | } 179 | 180 | return bytes.NewBuffer(buf).WriteTo(w) 181 | } 182 | 183 | // Format renders a proxy protocol header in a format to write over the wire. 184 | func (header *Header) Format() ([]byte, error) { 185 | switch header.Version { 186 | case 1: 187 | return header.formatVersion1() 188 | case 2: 189 | return header.formatVersion2() 190 | default: 191 | return nil, ErrUnknownProxyProtocolVersion 192 | } 193 | } 194 | 195 | // TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol. 196 | func (header *Header) TLVs() ([]TLV, error) { 197 | return SplitTLVs(header.rawTLVs) 198 | } 199 | 200 | // SetTLVs sets the TLVs stored in this header. This method replaces any 201 | // previous TLV. 202 | func (header *Header) SetTLVs(tlvs []TLV) error { 203 | raw, err := JoinTLVs(tlvs) 204 | if err != nil { 205 | return err 206 | } 207 | header.rawTLVs = raw 208 | return nil 209 | } 210 | 211 | // Read identifies the proxy protocol version and reads the remaining of 212 | // the header, accordingly. 213 | // 214 | // If proxy protocol header signature is not present, the reader buffer remains untouched 215 | // and is safe for reading outside of this code. 216 | // 217 | // If proxy protocol header signature is present but an error is raised while processing 218 | // the remaining header, assume the reader buffer to be in a corrupt state. 219 | // Also, this operation will block until enough bytes are available for peeking. 220 | func Read(reader *bufio.Reader) (*Header, error) { 221 | // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone. 222 | b1, err := reader.Peek(1) 223 | if err != nil { 224 | if err == io.EOF { 225 | return nil, ErrNoProxyProtocol 226 | } 227 | return nil, err 228 | } 229 | 230 | if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) { 231 | signature, err := reader.Peek(5) 232 | if err != nil { 233 | if err == io.EOF { 234 | return nil, ErrNoProxyProtocol 235 | } 236 | return nil, err 237 | } 238 | if bytes.Equal(signature[:5], SIGV1) { 239 | return parseVersion1(reader) 240 | } 241 | 242 | signature, err = reader.Peek(12) 243 | if err != nil { 244 | if err == io.EOF { 245 | return nil, ErrNoProxyProtocol 246 | } 247 | return nil, err 248 | } 249 | if bytes.Equal(signature[:12], SIGV2) { 250 | return parseVersion2(reader) 251 | } 252 | } 253 | 254 | return nil, ErrNoProxyProtocol 255 | } 256 | 257 | // ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed 258 | // there's no proxy protocol header. 259 | func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) { 260 | type header struct { 261 | h *Header 262 | e error 263 | } 264 | read := make(chan *header, 1) 265 | 266 | go func() { 267 | h := &header{} 268 | h.h, h.e = Read(reader) 269 | read <- h 270 | }() 271 | 272 | timer := time.NewTimer(timeout) 273 | select { 274 | case result := <-read: 275 | timer.Stop() 276 | return result.h, result.e 277 | case <-timer.C: 278 | return nil, ErrNoProxyProtocol 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /header_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "net" 8 | "reflect" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | // Stuff to be used in both versions tests. 14 | 15 | const ( 16 | NO_PROTOCOL = "There is no spoon" 17 | IP4_ADDR = "127.0.0.1" 18 | IP4IN6_ADDR = "::ffff:127.0.0.1" 19 | IP6_ADDR = "::1" 20 | IP6_LONG_ADDR = "1234:5678:9abc:def0:cafe:babe:dead:2bad" 21 | PORT = 65533 22 | INVALID_PORT = 99999 23 | ) 24 | 25 | var ( 26 | v4ip = net.ParseIP(IP4_ADDR).To4() 27 | v6ip = net.ParseIP(IP6_ADDR).To16() 28 | 29 | v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT} 30 | v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT} 31 | 32 | v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: PORT} 33 | v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: PORT} 34 | 35 | unixStreamAddr net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"} 36 | unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"} 37 | 38 | errReadIntentionallyBroken = errors.New("read is intentionally broken") 39 | ) 40 | 41 | type timeoutReader []byte 42 | 43 | func (t *timeoutReader) Read([]byte) (int, error) { 44 | time.Sleep(500 * time.Millisecond) 45 | return 0, nil 46 | } 47 | 48 | type errorReader []byte 49 | 50 | func (e *errorReader) Read([]byte) (int, error) { 51 | return 0, errReadIntentionallyBroken 52 | } 53 | 54 | func TestReadTimeoutV1Invalid(t *testing.T) { 55 | var b timeoutReader 56 | reader := bufio.NewReader(&b) 57 | _, err := ReadTimeout(reader, 50*time.Millisecond) 58 | if err == nil { 59 | t.Fatalf("expected error %s", ErrNoProxyProtocol) 60 | } else if err != ErrNoProxyProtocol { 61 | t.Fatalf("expected %s, actual %s", ErrNoProxyProtocol, err) 62 | } 63 | } 64 | 65 | func TestReadTimeoutPropagatesReadError(t *testing.T) { 66 | var e errorReader 67 | reader := bufio.NewReader(&e) 68 | _, err := ReadTimeout(reader, 50*time.Millisecond) 69 | 70 | if err == nil { 71 | t.Fatalf("expected error %s", errReadIntentionallyBroken) 72 | } else if err != errReadIntentionallyBroken { 73 | t.Fatalf("expected error %s, actual %s", errReadIntentionallyBroken, err) 74 | } 75 | } 76 | 77 | func TestEqualsTo(t *testing.T) { 78 | var headersEqual = []struct { 79 | this, that *Header 80 | expected bool 81 | }{ 82 | { 83 | &Header{ 84 | Version: 1, 85 | Command: PROXY, 86 | TransportProtocol: TCPv4, 87 | SourceAddr: &net.TCPAddr{ 88 | IP: net.ParseIP("10.1.1.1"), 89 | Port: 1000, 90 | }, 91 | DestinationAddr: &net.TCPAddr{ 92 | IP: net.ParseIP("20.2.2.2"), 93 | Port: 2000, 94 | }, 95 | }, 96 | nil, 97 | false, 98 | }, 99 | { 100 | &Header{ 101 | Version: 1, 102 | Command: PROXY, 103 | TransportProtocol: TCPv4, 104 | SourceAddr: &net.TCPAddr{ 105 | IP: net.ParseIP("10.1.1.1"), 106 | Port: 1000, 107 | }, 108 | DestinationAddr: &net.TCPAddr{ 109 | IP: net.ParseIP("20.2.2.2"), 110 | Port: 2000, 111 | }, 112 | }, 113 | &Header{ 114 | Version: 2, 115 | Command: PROXY, 116 | TransportProtocol: TCPv4, 117 | SourceAddr: &net.TCPAddr{ 118 | IP: net.ParseIP("10.1.1.1"), 119 | Port: 1000, 120 | }, 121 | DestinationAddr: &net.TCPAddr{ 122 | IP: net.ParseIP("20.2.2.2"), 123 | Port: 2000, 124 | }, 125 | }, 126 | false, 127 | }, 128 | { 129 | &Header{ 130 | Version: 1, 131 | Command: PROXY, 132 | TransportProtocol: TCPv4, 133 | SourceAddr: &net.TCPAddr{ 134 | IP: net.ParseIP("10.1.1.1"), 135 | Port: 1000, 136 | }, 137 | DestinationAddr: &net.TCPAddr{ 138 | IP: net.ParseIP("20.2.2.2"), 139 | Port: 2000, 140 | }, 141 | }, 142 | &Header{ 143 | Version: 1, 144 | Command: PROXY, 145 | TransportProtocol: TCPv4, 146 | SourceAddr: &net.TCPAddr{ 147 | IP: net.ParseIP("10.1.1.1"), 148 | Port: 1000, 149 | }, 150 | DestinationAddr: &net.TCPAddr{ 151 | IP: net.ParseIP("20.2.2.2"), 152 | Port: 2000, 153 | }, 154 | }, 155 | true, 156 | }, 157 | } 158 | 159 | for _, tt := range headersEqual { 160 | if actual := tt.this.EqualsTo(tt.that); actual != tt.expected { 161 | t.Fatalf("expected %t, actual %t", tt.expected, actual) 162 | } 163 | } 164 | } 165 | 166 | // This is here just because of coveralls 167 | func TestEqualTo(t *testing.T) { 168 | TestEqualsTo(t) 169 | } 170 | 171 | func TestGetters(t *testing.T) { 172 | var tests = []struct { 173 | name string 174 | header *Header 175 | tcpSourceAddr, tcpDestAddr *net.TCPAddr 176 | udpSourceAddr, udpDestAddr *net.UDPAddr 177 | unixSourceAddr, unixDestAddr *net.UnixAddr 178 | ipSource, ipDest net.IP 179 | portSource, portDest int 180 | }{ 181 | { 182 | name: "TCPv4", 183 | header: &Header{ 184 | Version: 1, 185 | Command: PROXY, 186 | TransportProtocol: TCPv4, 187 | SourceAddr: &net.TCPAddr{ 188 | IP: net.ParseIP("10.1.1.1"), 189 | Port: 1000, 190 | }, 191 | DestinationAddr: &net.TCPAddr{ 192 | IP: net.ParseIP("20.2.2.2"), 193 | Port: 2000, 194 | }, 195 | }, 196 | tcpSourceAddr: &net.TCPAddr{ 197 | IP: net.ParseIP("10.1.1.1"), 198 | Port: 1000, 199 | }, 200 | tcpDestAddr: &net.TCPAddr{ 201 | IP: net.ParseIP("20.2.2.2"), 202 | Port: 2000, 203 | }, 204 | ipSource: net.ParseIP("10.1.1.1"), 205 | ipDest: net.ParseIP("20.2.2.2"), 206 | portSource: 1000, 207 | portDest: 2000, 208 | }, 209 | { 210 | name: "UDPv4", 211 | header: &Header{ 212 | Version: 2, 213 | Command: PROXY, 214 | TransportProtocol: UDPv6, 215 | SourceAddr: &net.UDPAddr{ 216 | IP: net.ParseIP("10.1.1.1"), 217 | Port: 1000, 218 | }, 219 | DestinationAddr: &net.UDPAddr{ 220 | IP: net.ParseIP("20.2.2.2"), 221 | Port: 2000, 222 | }, 223 | }, 224 | udpSourceAddr: &net.UDPAddr{ 225 | IP: net.ParseIP("10.1.1.1"), 226 | Port: 1000, 227 | }, 228 | udpDestAddr: &net.UDPAddr{ 229 | IP: net.ParseIP("20.2.2.2"), 230 | Port: 2000, 231 | }, 232 | ipSource: net.ParseIP("10.1.1.1"), 233 | ipDest: net.ParseIP("20.2.2.2"), 234 | portSource: 1000, 235 | portDest: 2000, 236 | }, 237 | { 238 | name: "UnixStream", 239 | header: &Header{ 240 | Version: 2, 241 | Command: PROXY, 242 | TransportProtocol: UnixStream, 243 | SourceAddr: &net.UnixAddr{ 244 | Net: "unix", 245 | Name: "src", 246 | }, 247 | DestinationAddr: &net.UnixAddr{ 248 | Net: "unix", 249 | Name: "dst", 250 | }, 251 | }, 252 | unixSourceAddr: &net.UnixAddr{ 253 | Net: "unix", 254 | Name: "src", 255 | }, 256 | unixDestAddr: &net.UnixAddr{ 257 | Net: "unix", 258 | Name: "dst", 259 | }, 260 | }, 261 | { 262 | name: "UnixDatagram", 263 | header: &Header{ 264 | Version: 2, 265 | Command: PROXY, 266 | TransportProtocol: UnixDatagram, 267 | SourceAddr: &net.UnixAddr{ 268 | Net: "unix", 269 | Name: "src", 270 | }, 271 | DestinationAddr: &net.UnixAddr{ 272 | Net: "unix", 273 | Name: "dst", 274 | }, 275 | }, 276 | unixSourceAddr: &net.UnixAddr{ 277 | Net: "unix", 278 | Name: "src", 279 | }, 280 | unixDestAddr: &net.UnixAddr{ 281 | Net: "unix", 282 | Name: "dst", 283 | }, 284 | }, 285 | { 286 | name: "Unspec", 287 | header: &Header{ 288 | Version: 1, 289 | Command: PROXY, 290 | TransportProtocol: UNSPEC, 291 | }, 292 | }, 293 | } 294 | 295 | for _, test := range tests { 296 | t.Run(test.name, func(t *testing.T) { 297 | tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs() 298 | if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) { 299 | t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr) 300 | } 301 | if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) { 302 | t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr) 303 | } 304 | 305 | udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs() 306 | if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) { 307 | t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr) 308 | } 309 | if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) { 310 | t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr) 311 | } 312 | 313 | unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs() 314 | if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) { 315 | t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr) 316 | } 317 | if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) { 318 | t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr) 319 | } 320 | 321 | ipSource, ipDest, _ := test.header.IPs() 322 | if test.ipSource != nil && !ipSource.Equal(test.ipSource) { 323 | t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource) 324 | } 325 | if test.ipDest != nil && !ipDest.Equal(test.ipDest) { 326 | t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest) 327 | } 328 | 329 | portSource, portDest, _ := test.header.Ports() 330 | if test.portSource != 0 && portSource != test.portSource { 331 | t.Errorf("Ports() source = %v, want %v", portSource, test.portSource) 332 | } 333 | if test.portDest != 0 && portDest != test.portDest { 334 | t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest) 335 | } 336 | }) 337 | } 338 | } 339 | 340 | func TestSetTLVs(t *testing.T) { 341 | tests := []struct { 342 | header *Header 343 | name string 344 | tlvs []TLV 345 | expectErr bool 346 | }{ 347 | { 348 | name: "add authority TLV", 349 | header: &Header{ 350 | Version: 1, 351 | Command: PROXY, 352 | TransportProtocol: TCPv4, 353 | SourceAddr: &net.TCPAddr{ 354 | IP: net.ParseIP("10.1.1.1"), 355 | Port: 1000, 356 | }, 357 | DestinationAddr: &net.TCPAddr{ 358 | IP: net.ParseIP("20.2.2.2"), 359 | Port: 2000, 360 | }, 361 | }, 362 | tlvs: []TLV{{ 363 | Type: PP2_TYPE_AUTHORITY, 364 | Value: []byte("example.org"), 365 | }}, 366 | }, 367 | { 368 | name: "add too long TLV", 369 | header: &Header{ 370 | Version: 1, 371 | Command: PROXY, 372 | TransportProtocol: TCPv4, 373 | SourceAddr: &net.TCPAddr{ 374 | IP: net.ParseIP("10.1.1.1"), 375 | Port: 1000, 376 | }, 377 | DestinationAddr: &net.TCPAddr{ 378 | IP: net.ParseIP("20.2.2.2"), 379 | Port: 2000, 380 | }, 381 | }, 382 | tlvs: []TLV{{ 383 | Type: PP2_TYPE_AUTHORITY, 384 | Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...), 385 | }}, 386 | expectErr: true, 387 | }, 388 | } 389 | for _, tt := range tests { 390 | err := tt.header.SetTLVs(tt.tlvs) 391 | if err != nil && !tt.expectErr { 392 | t.Fatalf("shouldn't have thrown error %q", err.Error()) 393 | } 394 | } 395 | } 396 | 397 | func TestWriteTo(t *testing.T) { 398 | var buf bytes.Buffer 399 | 400 | validHeader := &Header{ 401 | Version: 1, 402 | Command: PROXY, 403 | TransportProtocol: TCPv4, 404 | SourceAddr: &net.TCPAddr{ 405 | IP: net.ParseIP("10.1.1.1"), 406 | Port: 1000, 407 | }, 408 | DestinationAddr: &net.TCPAddr{ 409 | IP: net.ParseIP("20.2.2.2"), 410 | Port: 2000, 411 | }, 412 | } 413 | 414 | if _, err := validHeader.WriteTo(&buf); err != nil { 415 | t.Fatalf("shouldn't have thrown error %q", err.Error()) 416 | } 417 | 418 | invalidHeader := &Header{ 419 | SourceAddr: &net.TCPAddr{ 420 | IP: net.ParseIP("10.1.1.1"), 421 | Port: 1000, 422 | }, 423 | DestinationAddr: &net.TCPAddr{ 424 | IP: net.ParseIP("20.2.2.2"), 425 | Port: 2000, 426 | }, 427 | } 428 | 429 | if _, err := invalidHeader.WriteTo(&buf); err == nil { 430 | t.Fatalf("should have thrown error %q", err.Error()) 431 | } 432 | } 433 | 434 | func TestFormat(t *testing.T) { 435 | validHeader := &Header{ 436 | Version: 1, 437 | Command: PROXY, 438 | TransportProtocol: TCPv4, 439 | SourceAddr: &net.TCPAddr{ 440 | IP: net.ParseIP("10.1.1.1"), 441 | Port: 1000, 442 | }, 443 | DestinationAddr: &net.TCPAddr{ 444 | IP: net.ParseIP("20.2.2.2"), 445 | Port: 2000, 446 | }, 447 | } 448 | 449 | if _, err := validHeader.Format(); err != nil { 450 | t.Fatalf("shouldn't have thrown error %q", err.Error()) 451 | } 452 | } 453 | 454 | func TestFormatInvalid(t *testing.T) { 455 | tests := []struct { 456 | name string 457 | header *Header 458 | err error 459 | }{ 460 | { 461 | name: "invalidVersion", 462 | header: &Header{ 463 | Version: 3, 464 | Command: PROXY, 465 | TransportProtocol: TCPv4, 466 | SourceAddr: v4addr, 467 | DestinationAddr: v4addr, 468 | }, 469 | err: ErrUnknownProxyProtocolVersion, 470 | }, 471 | { 472 | name: "v2MismatchTCPv4_UDPv4", 473 | header: &Header{ 474 | Version: 2, 475 | Command: PROXY, 476 | TransportProtocol: TCPv4, 477 | SourceAddr: v4UDPAddr, 478 | DestinationAddr: v4addr, 479 | }, 480 | err: ErrInvalidAddress, 481 | }, 482 | { 483 | name: "v2MismatchTCPv4_TCPv6", 484 | header: &Header{ 485 | Version: 2, 486 | Command: PROXY, 487 | TransportProtocol: TCPv4, 488 | SourceAddr: v4addr, 489 | DestinationAddr: v6addr, 490 | }, 491 | err: ErrInvalidAddress, 492 | }, 493 | { 494 | name: "v2MismatchUnixStream_TCPv4", 495 | header: &Header{ 496 | Version: 2, 497 | Command: PROXY, 498 | TransportProtocol: UnixStream, 499 | SourceAddr: v4addr, 500 | DestinationAddr: unixStreamAddr, 501 | }, 502 | err: ErrInvalidAddress, 503 | }, 504 | { 505 | name: "v1MismatchTCPv4_TCPv6", 506 | header: &Header{ 507 | Version: 1, 508 | Command: PROXY, 509 | TransportProtocol: TCPv4, 510 | SourceAddr: v6addr, 511 | DestinationAddr: v4addr, 512 | }, 513 | err: ErrInvalidAddress, 514 | }, 515 | { 516 | name: "v1MismatchTCPv4_UDPv4", 517 | header: &Header{ 518 | Version: 1, 519 | Command: PROXY, 520 | TransportProtocol: TCPv4, 521 | SourceAddr: v4UDPAddr, 522 | DestinationAddr: v4addr, 523 | }, 524 | err: ErrInvalidAddress, 525 | }, 526 | } 527 | 528 | for _, test := range tests { 529 | t.Run(test.name, func(t *testing.T) { 530 | if _, err := test.header.Format(); err == nil { 531 | t.Errorf("Header.Format() succeeded, want an error") 532 | } else if err != test.err { 533 | t.Errorf("Header.Format() = %q, want %q", err, test.err) 534 | } 535 | }) 536 | } 537 | } 538 | 539 | func TestHeaderProxyFromAddrs(t *testing.T) { 540 | unspec := &Header{ 541 | Version: 2, 542 | Command: LOCAL, 543 | TransportProtocol: UNSPEC, 544 | } 545 | 546 | tests := []struct { 547 | name string 548 | version byte 549 | sourceAddr, destAddr net.Addr 550 | expected *Header 551 | }{ 552 | { 553 | name: "TCPv4", 554 | sourceAddr: &net.TCPAddr{ 555 | IP: net.ParseIP("10.1.1.1"), 556 | Port: 1000, 557 | }, 558 | destAddr: &net.TCPAddr{ 559 | IP: net.ParseIP("20.2.2.2"), 560 | Port: 2000, 561 | }, 562 | expected: &Header{ 563 | Version: 2, 564 | Command: PROXY, 565 | TransportProtocol: TCPv4, 566 | SourceAddr: &net.TCPAddr{ 567 | IP: net.ParseIP("10.1.1.1"), 568 | Port: 1000, 569 | }, 570 | DestinationAddr: &net.TCPAddr{ 571 | IP: net.ParseIP("20.2.2.2"), 572 | Port: 2000, 573 | }, 574 | }, 575 | }, 576 | { 577 | name: "TCPv6", 578 | sourceAddr: &net.TCPAddr{ 579 | IP: net.ParseIP("fde7::372"), 580 | Port: 1000, 581 | }, 582 | destAddr: &net.TCPAddr{ 583 | IP: net.ParseIP("fde7::1"), 584 | Port: 2000, 585 | }, 586 | expected: &Header{ 587 | Version: 2, 588 | Command: PROXY, 589 | TransportProtocol: TCPv6, 590 | SourceAddr: &net.TCPAddr{ 591 | IP: net.ParseIP("fde7::372"), 592 | Port: 1000, 593 | }, 594 | DestinationAddr: &net.TCPAddr{ 595 | IP: net.ParseIP("fde7::1"), 596 | Port: 2000, 597 | }, 598 | }, 599 | }, 600 | { 601 | name: "UDPv4", 602 | sourceAddr: &net.UDPAddr{ 603 | IP: net.ParseIP("10.1.1.1"), 604 | Port: 1000, 605 | }, 606 | destAddr: &net.UDPAddr{ 607 | IP: net.ParseIP("20.2.2.2"), 608 | Port: 2000, 609 | }, 610 | expected: &Header{ 611 | Version: 2, 612 | Command: PROXY, 613 | TransportProtocol: UDPv4, 614 | SourceAddr: &net.TCPAddr{ 615 | IP: net.ParseIP("10.1.1.1"), 616 | Port: 1000, 617 | }, 618 | DestinationAddr: &net.TCPAddr{ 619 | IP: net.ParseIP("20.2.2.2"), 620 | Port: 2000, 621 | }, 622 | }, 623 | }, 624 | { 625 | name: "UDPv6", 626 | sourceAddr: &net.UDPAddr{ 627 | IP: net.ParseIP("fde7::372"), 628 | Port: 1000, 629 | }, 630 | destAddr: &net.UDPAddr{ 631 | IP: net.ParseIP("fde7::1"), 632 | Port: 2000, 633 | }, 634 | expected: &Header{ 635 | Version: 2, 636 | Command: PROXY, 637 | TransportProtocol: UDPv6, 638 | SourceAddr: &net.TCPAddr{ 639 | IP: net.ParseIP("fde7::372"), 640 | Port: 1000, 641 | }, 642 | DestinationAddr: &net.TCPAddr{ 643 | IP: net.ParseIP("fde7::1"), 644 | Port: 2000, 645 | }, 646 | }, 647 | }, 648 | { 649 | name: "UnixStream", 650 | sourceAddr: &net.UnixAddr{ 651 | Net: "unix", 652 | Name: "src", 653 | }, 654 | destAddr: &net.UnixAddr{ 655 | Net: "unix", 656 | Name: "dst", 657 | }, 658 | expected: &Header{ 659 | Version: 2, 660 | Command: PROXY, 661 | TransportProtocol: UnixStream, 662 | SourceAddr: &net.UnixAddr{ 663 | Net: "unix", 664 | Name: "src", 665 | }, 666 | DestinationAddr: &net.UnixAddr{ 667 | Net: "unix", 668 | Name: "dst", 669 | }, 670 | }, 671 | }, 672 | { 673 | name: "UnixDatagram", 674 | sourceAddr: &net.UnixAddr{ 675 | Net: "unixgram", 676 | Name: "src", 677 | }, 678 | destAddr: &net.UnixAddr{ 679 | Net: "unixgram", 680 | Name: "dst", 681 | }, 682 | expected: &Header{ 683 | Version: 2, 684 | Command: PROXY, 685 | TransportProtocol: UnixDatagram, 686 | SourceAddr: &net.UnixAddr{ 687 | Net: "unixgram", 688 | Name: "src", 689 | }, 690 | DestinationAddr: &net.UnixAddr{ 691 | Net: "unixgram", 692 | Name: "dst", 693 | }, 694 | }, 695 | }, 696 | { 697 | name: "Version1", 698 | version: 1, 699 | sourceAddr: &net.TCPAddr{ 700 | IP: net.ParseIP("10.1.1.1"), 701 | Port: 1000, 702 | }, 703 | destAddr: &net.TCPAddr{ 704 | IP: net.ParseIP("20.2.2.2"), 705 | Port: 2000, 706 | }, 707 | expected: &Header{ 708 | Version: 1, 709 | Command: PROXY, 710 | TransportProtocol: TCPv4, 711 | SourceAddr: &net.TCPAddr{ 712 | IP: net.ParseIP("10.1.1.1"), 713 | Port: 1000, 714 | }, 715 | DestinationAddr: &net.TCPAddr{ 716 | IP: net.ParseIP("20.2.2.2"), 717 | Port: 2000, 718 | }, 719 | }, 720 | }, 721 | { 722 | name: "TCPInvalidIP", 723 | sourceAddr: &net.TCPAddr{ 724 | IP: nil, 725 | Port: 1000, 726 | }, 727 | destAddr: &net.TCPAddr{ 728 | IP: nil, 729 | Port: 2000, 730 | }, 731 | expected: unspec, 732 | }, 733 | { 734 | name: "UDPInvalidIP", 735 | sourceAddr: &net.UDPAddr{ 736 | IP: nil, 737 | Port: 1000, 738 | }, 739 | destAddr: &net.UDPAddr{ 740 | IP: nil, 741 | Port: 2000, 742 | }, 743 | expected: unspec, 744 | }, 745 | { 746 | name: "TCPAddrTypeMismatch", 747 | sourceAddr: &net.TCPAddr{ 748 | IP: net.ParseIP("10.1.1.1"), 749 | Port: 1000, 750 | }, 751 | destAddr: &net.UDPAddr{ 752 | IP: net.ParseIP("20.2.2.2"), 753 | Port: 2000, 754 | }, 755 | expected: unspec, 756 | }, 757 | { 758 | name: "UDPAddrTypeMismatch", 759 | sourceAddr: &net.UDPAddr{ 760 | IP: net.ParseIP("10.1.1.1"), 761 | Port: 1000, 762 | }, 763 | destAddr: &net.TCPAddr{ 764 | IP: net.ParseIP("20.2.2.2"), 765 | Port: 2000, 766 | }, 767 | expected: unspec, 768 | }, 769 | { 770 | name: "UnixAddrTypeMismatch", 771 | sourceAddr: &net.UnixAddr{ 772 | Net: "unix", 773 | }, 774 | destAddr: &net.TCPAddr{ 775 | IP: net.ParseIP("20.2.2.2"), 776 | Port: 2000, 777 | }, 778 | expected: unspec, 779 | }, 780 | } 781 | 782 | for _, tt := range tests { 783 | t.Run(tt.name, func(t *testing.T) { 784 | h := HeaderProxyFromAddrs(tt.version, tt.sourceAddr, tt.destAddr) 785 | 786 | if !h.EqualsTo(tt.expected) { 787 | t.Errorf("expected %+v, actual %+v for source %+v and destination %+v", tt.expected, h, tt.sourceAddr, tt.destAddr) 788 | } 789 | }) 790 | } 791 | } 792 | -------------------------------------------------------------------------------- /helper/http2/http2.go: -------------------------------------------------------------------------------- 1 | // Package http2 provides helpers for HTTP/2. 2 | package http2 3 | 4 | import ( 5 | "crypto/tls" 6 | "fmt" 7 | "log" 8 | "net" 9 | "net/http" 10 | "sync" 11 | "time" 12 | 13 | "github.com/pires/go-proxyproto" 14 | "golang.org/x/net/http2" 15 | ) 16 | 17 | const listenerRetryBaseDelay = 5 * time.Millisecond 18 | 19 | // Server is an HTTP server accepting both regular and proxied, both HTTP/1 and 20 | // HTTP/2 connections. 21 | // 22 | // HTTP/2 is negotiated using TLS ALPN, either directly via a tls.Conn, either 23 | // indirectly via the PROXY protocol. When the PROXY protocol is used, the 24 | // TLS-terminating proxy in front of the server must be configured to accept 25 | // the "h2" TLS ALPN protocol. 26 | // 27 | // The server is closed when the http.Server is. 28 | type Server struct { 29 | h1 *http.Server // regular HTTP/1 server 30 | h2 *http2.Server // HTTP/2 server 31 | h2Err error // HTTP/2 server setup error, if any 32 | h1Listener h1Listener // pipe listener for the HTTP/1 server 33 | 34 | // The following fields are protected by the mutex 35 | mu sync.Mutex 36 | closed bool 37 | listeners map[net.Listener]struct{} 38 | } 39 | 40 | // NewServer creates a new HTTP server. 41 | // 42 | // A nil h2 is equivalent to a zero http2.Server. 43 | func NewServer(h1 *http.Server, h2 *http2.Server) *Server { 44 | if h2 == nil { 45 | h2 = new(http2.Server) 46 | } 47 | srv := &Server{ 48 | h1: h1, 49 | h2: h2, 50 | h2Err: http2.ConfigureServer(h1, h2), 51 | listeners: make(map[net.Listener]struct{}), 52 | } 53 | srv.h1Listener = h1Listener{newPipeListener(), srv} 54 | go func() { 55 | // proxyListener.Accept never fails 56 | _ = h1.Serve(srv.h1Listener) 57 | }() 58 | return srv 59 | } 60 | 61 | func (srv *Server) errorLog() *log.Logger { 62 | if srv.h1.ErrorLog != nil { 63 | return srv.h1.ErrorLog 64 | } 65 | return log.Default() 66 | } 67 | 68 | // Serve accepts incoming connections on the listener ln. 69 | func (srv *Server) Serve(ln net.Listener) error { 70 | if srv.h2Err != nil { 71 | return srv.h2Err 72 | } 73 | 74 | srv.mu.Lock() 75 | ok := !srv.closed 76 | if ok { 77 | srv.listeners[ln] = struct{}{} 78 | } 79 | srv.mu.Unlock() 80 | if !ok { 81 | return http.ErrServerClosed 82 | } 83 | 84 | defer func() { 85 | srv.mu.Lock() 86 | delete(srv.listeners, ln) 87 | srv.mu.Unlock() 88 | }() 89 | 90 | // net.Listener.Accept can fail for temporary failures, e.g. too many open 91 | // files or other timeout conditions. In that case, wait and retry later. 92 | // This mirrors what the net/http package does. 93 | var delay time.Duration 94 | for { 95 | conn, err := ln.Accept() 96 | if ne, ok := err.(net.Error); ok && ne.Timeout() { 97 | if delay == 0 { 98 | delay = listenerRetryBaseDelay 99 | } else { 100 | delay *= 2 101 | } 102 | if max := 1 * time.Second; delay > max { 103 | delay = max 104 | } 105 | srv.errorLog().Printf("listener %q: accept error (retrying in %v): %v", ln.Addr(), delay, err) 106 | time.Sleep(delay) 107 | } else if err != nil { 108 | return fmt.Errorf("failed to accept connection: %w", err) 109 | } 110 | 111 | delay = 0 112 | 113 | go func() { 114 | if err := srv.serveConn(conn); err != nil { 115 | srv.errorLog().Printf("listener %q: %v", ln.Addr(), err) 116 | } 117 | }() 118 | } 119 | } 120 | 121 | func (srv *Server) serveConn(conn net.Conn) error { 122 | var proto string 123 | switch conn := conn.(type) { 124 | case *tls.Conn: 125 | proto = conn.ConnectionState().NegotiatedProtocol 126 | case *proxyproto.Conn: 127 | if proxyHeader := conn.ProxyHeader(); proxyHeader != nil { 128 | tlvs, err := proxyHeader.TLVs() 129 | if err != nil { 130 | conn.Close() 131 | return err 132 | } 133 | for _, tlv := range tlvs { 134 | if tlv.Type == proxyproto.PP2_TYPE_ALPN { 135 | proto = string(tlv.Value) 136 | break 137 | } 138 | } 139 | } 140 | } 141 | 142 | // See https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids 143 | switch proto { 144 | case http2.NextProtoTLS, "h2c": 145 | defer conn.Close() 146 | opts := http2.ServeConnOpts{Handler: srv.h1.Handler} 147 | srv.h2.ServeConn(conn, &opts) 148 | return nil 149 | case "", "http/1.0", "http/1.1": 150 | return srv.h1Listener.ServeConn(conn) 151 | default: 152 | conn.Close() 153 | return fmt.Errorf("unsupported protocol %q", proto) 154 | } 155 | } 156 | 157 | func (srv *Server) closeListeners() error { 158 | srv.mu.Lock() 159 | defer srv.mu.Unlock() 160 | 161 | srv.closed = true 162 | 163 | var err error 164 | for ln := range srv.listeners { 165 | if cerr := ln.Close(); cerr != nil { 166 | err = cerr 167 | } 168 | } 169 | return err 170 | } 171 | 172 | // h1Listener is used to signal back http.Server's Close and Shutdown to the 173 | // HTTP/2 server. 174 | type h1Listener struct { 175 | *pipeListener 176 | srv *Server 177 | } 178 | 179 | func (ln h1Listener) Close() error { 180 | // pipeListener.Close never fails 181 | _ = ln.pipeListener.Close() 182 | return ln.srv.closeListeners() 183 | } 184 | -------------------------------------------------------------------------------- /helper/http2/http2_test.go: -------------------------------------------------------------------------------- 1 | package http2_test 2 | 3 | import ( 4 | "errors" 5 | "log" 6 | "net" 7 | "net/http" 8 | "testing" 9 | 10 | "github.com/pires/go-proxyproto" 11 | h2proxy "github.com/pires/go-proxyproto/helper/http2" 12 | "golang.org/x/net/http2" 13 | ) 14 | 15 | func ExampleServer() { 16 | ln, err := net.Listen("tcp", "localhost:80") 17 | if err != nil { 18 | log.Fatalf("failed to listen: %v", err) 19 | } 20 | 21 | proxyLn := &proxyproto.Listener{ 22 | Listener: ln, 23 | } 24 | 25 | server := h2proxy.NewServer(&http.Server{ 26 | Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 | _, _ = w.Write([]byte("Hello world!\n")) 28 | }), 29 | }, nil) 30 | if err := server.Serve(proxyLn); err != nil { 31 | log.Fatalf("failed to serve: %v", err) 32 | } 33 | } 34 | 35 | func TestServer_h1(t *testing.T) { 36 | addr, server := newTestServer(t) 37 | defer server.Close() 38 | 39 | resp, err := http.Get("http://" + addr) 40 | if err != nil { 41 | t.Fatalf("failed to perform HTTP request: %v", err) 42 | } 43 | resp.Body.Close() 44 | } 45 | 46 | func TestServer_h2(t *testing.T) { 47 | addr, server := newTestServer(t) 48 | defer server.Close() 49 | 50 | conn, err := net.Dial("tcp", addr) 51 | if err != nil { 52 | t.Fatalf("failed to dial: %v", err) 53 | } 54 | defer conn.Close() 55 | 56 | proxyHeader := proxyproto.Header{ 57 | Version: 2, 58 | Command: proxyproto.LOCAL, 59 | TransportProtocol: proxyproto.UNSPEC, 60 | } 61 | tlvs := []proxyproto.TLV{{ 62 | Type: proxyproto.PP2_TYPE_ALPN, 63 | Value: []byte("h2"), 64 | }} 65 | if err := proxyHeader.SetTLVs(tlvs); err != nil { 66 | t.Fatalf("failed to set TLVs: %v", err) 67 | } 68 | if _, err := proxyHeader.WriteTo(conn); err != nil { 69 | t.Fatalf("failed to write PROXY header: %v", err) 70 | } 71 | 72 | h2Conn, err := new(http2.Transport).NewClientConn(conn) 73 | if err != nil { 74 | t.Fatalf("failed to create HTTP connection: %v", err) 75 | } 76 | 77 | req, err := http.NewRequest(http.MethodGet, "http://"+addr, nil) 78 | if err != nil { 79 | t.Fatalf("failed to create HTTP request: %v", err) 80 | } 81 | 82 | resp, err := h2Conn.RoundTrip(req) 83 | if err != nil { 84 | t.Fatalf("failed to perform HTTP request: %v", err) 85 | } 86 | resp.Body.Close() 87 | } 88 | 89 | func newTestServer(t *testing.T) (addr string, server *http.Server) { 90 | ln, err := net.Listen("tcp", "localhost:0") 91 | if err != nil { 92 | t.Fatalf("failed to listen: %v", err) 93 | } 94 | 95 | server = &http.Server{ 96 | Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 97 | }), 98 | } 99 | 100 | h2Server := h2proxy.NewServer(server, nil) 101 | done := make(chan error, 1) 102 | go func() { 103 | done <- h2Server.Serve(&proxyproto.Listener{Listener: ln}) 104 | }() 105 | 106 | t.Cleanup(func() { 107 | err := <-done 108 | if err != nil && !errors.Is(err, net.ErrClosed) { 109 | t.Fatalf("failed to serve: %v", err) 110 | } 111 | }) 112 | 113 | return ln.Addr().String(), server 114 | } 115 | -------------------------------------------------------------------------------- /helper/http2/listener.go: -------------------------------------------------------------------------------- 1 | package http2 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | ) 7 | 8 | // pipeListener is a hack to workaround the lack of http.Server.ServeConn. 9 | // See: https://github.com/golang/go/issues/36673 10 | type pipeListener struct { 11 | ch chan net.Conn 12 | closed bool 13 | mu sync.Mutex 14 | } 15 | 16 | func newPipeListener() *pipeListener { 17 | return &pipeListener{ 18 | ch: make(chan net.Conn, 64), 19 | } 20 | } 21 | 22 | func (ln *pipeListener) Accept() (net.Conn, error) { 23 | conn, ok := <-ln.ch 24 | if !ok { 25 | return nil, net.ErrClosed 26 | } 27 | return conn, nil 28 | } 29 | 30 | func (ln *pipeListener) Close() error { 31 | ln.mu.Lock() 32 | defer ln.mu.Unlock() 33 | 34 | if ln.closed { 35 | return net.ErrClosed 36 | } 37 | ln.closed = true 38 | close(ln.ch) 39 | return nil 40 | } 41 | 42 | // ServeConn enqueues a new connection. The connection will be returned in the 43 | // next Accept call. 44 | func (ln *pipeListener) ServeConn(conn net.Conn) error { 45 | ln.mu.Lock() 46 | defer ln.mu.Unlock() 47 | 48 | if ln.closed { 49 | return net.ErrClosed 50 | } 51 | ln.ch <- conn 52 | return nil 53 | } 54 | 55 | func (ln *pipeListener) Addr() net.Addr { 56 | return pipeAddr{} 57 | } 58 | 59 | type pipeAddr struct{} 60 | 61 | func (pipeAddr) Network() string { 62 | return "pipe" 63 | } 64 | 65 | func (pipeAddr) String() string { 66 | return "pipe" 67 | } 68 | -------------------------------------------------------------------------------- /policy.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "strings" 7 | ) 8 | 9 | // PolicyFunc can be used to decide whether to trust the PROXY info from 10 | // upstream. If set, the connecting address is passed in as an argument. 11 | // 12 | // See below for the different policies. 13 | // 14 | // In case an error is returned the connection is denied. 15 | type PolicyFunc func(upstream net.Addr) (Policy, error) 16 | 17 | // ConnPolicyFunc can be used to decide whether to trust the PROXY info 18 | // based on connection policy options. If set, the connecting addresses 19 | // (remote and local) are passed in as argument. 20 | // 21 | // See below for the different policies. 22 | // 23 | // In case an error is returned the connection is denied. 24 | type ConnPolicyFunc func(connPolicyOptions ConnPolicyOptions) (Policy, error) 25 | 26 | // ConnPolicyOptions contains the remote and local addresses of a connection. 27 | type ConnPolicyOptions struct { 28 | Upstream net.Addr 29 | Downstream net.Addr 30 | } 31 | 32 | // Policy defines how a connection with a PROXY header address is treated. 33 | type Policy int 34 | 35 | const ( 36 | // USE address from PROXY header 37 | USE Policy = iota 38 | // IGNORE address from PROXY header, but accept connection 39 | IGNORE 40 | // REJECT connection when PROXY header is sent 41 | // Note: even though the first read on the connection returns an error if 42 | // a PROXY header is present, subsequent reads do not. It is the task of 43 | // the code using the connection to handle that case properly. 44 | REJECT 45 | // REQUIRE connection to send PROXY header, reject if not present 46 | // Note: even though the first read on the connection returns an error if 47 | // a PROXY header is not present, subsequent reads do not. It is the task 48 | // of the code using the connection to handle that case properly. 49 | REQUIRE 50 | // SKIP accepts a connection without requiring the PROXY header 51 | // Note: an example usage can be found in the SkipProxyHeaderForCIDR 52 | // function. 53 | SKIP 54 | ) 55 | 56 | // SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a 57 | // connection from a skipHeaderCIDR without requiring a PROXY header, e.g. 58 | // Kubernetes pods local traffic. The def is a policy to use when an upstream 59 | // address doesn't match the skipHeaderCIDR. 60 | func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { 61 | return func(upstream net.Addr) (Policy, error) { 62 | ip, err := ipFromAddr(upstream) 63 | if err != nil { 64 | return def, err 65 | } 66 | 67 | if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) { 68 | return SKIP, nil 69 | } 70 | 71 | return def, nil 72 | } 73 | } 74 | 75 | // WithPolicy adds given policy to a connection when passed as option to NewConn() 76 | func WithPolicy(p Policy) func(*Conn) { 77 | return func(c *Conn) { 78 | c.ProxyHeaderPolicy = p 79 | } 80 | } 81 | 82 | // LaxWhiteListPolicy returns a PolicyFunc which decides whether the 83 | // upstream ip is allowed to send a proxy header based on a list of allowed 84 | // IP addresses and IP ranges. In case upstream IP is not in list the proxy 85 | // header will be ignored. If one of the provided IP addresses or IP ranges 86 | // is invalid it will return an error instead of a PolicyFunc. 87 | func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) { 88 | allowFrom, err := parse(allowed) 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | return whitelistPolicy(allowFrom, IGNORE), nil 94 | } 95 | 96 | // MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one 97 | // of the provided IP addresses or IP ranges is invalid. 98 | func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { 99 | pfunc, err := LaxWhiteListPolicy(allowed) 100 | if err != nil { 101 | panic(err) 102 | } 103 | 104 | return pfunc 105 | } 106 | 107 | // StrictWhiteListPolicy returns a PolicyFunc which decides whether the 108 | // upstream ip is allowed to send a proxy header based on a list of allowed 109 | // IP addresses and IP ranges. In case upstream IP is not in list reading on 110 | // the connection will be refused on the first read. Please note: subsequent 111 | // reads do not error. It is the task of the code using the connection to 112 | // handle that case properly. If one of the provided IP addresses or IP 113 | // ranges is invalid it will return an error instead of a PolicyFunc. 114 | func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { 115 | allowFrom, err := parse(allowed) 116 | if err != nil { 117 | return nil, err 118 | } 119 | 120 | return whitelistPolicy(allowFrom, REJECT), nil 121 | } 122 | 123 | // MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic 124 | // if one of the provided IP addresses or IP ranges is invalid. 125 | func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { 126 | pfunc, err := StrictWhiteListPolicy(allowed) 127 | if err != nil { 128 | panic(err) 129 | } 130 | 131 | return pfunc 132 | } 133 | 134 | func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc { 135 | return func(upstream net.Addr) (Policy, error) { 136 | upstreamIP, err := ipFromAddr(upstream) 137 | if err != nil { 138 | // something is wrong with the source IP, better reject the connection 139 | return REJECT, err 140 | } 141 | 142 | for _, allowFrom := range allowed { 143 | if allowFrom(upstreamIP) { 144 | return USE, nil 145 | } 146 | } 147 | 148 | return def, nil 149 | } 150 | } 151 | 152 | func parse(allowed []string) ([]func(net.IP) bool, error) { 153 | a := make([]func(net.IP) bool, len(allowed)) 154 | for i, allowFrom := range allowed { 155 | if strings.LastIndex(allowFrom, "/") > 0 { 156 | _, ipRange, err := net.ParseCIDR(allowFrom) 157 | if err != nil { 158 | return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP range: %v", allowFrom, err) 159 | } 160 | 161 | a[i] = ipRange.Contains 162 | } else { 163 | allowed := net.ParseIP(allowFrom) 164 | if allowed == nil { 165 | return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP address", allowFrom) 166 | } 167 | 168 | a[i] = allowed.Equal 169 | } 170 | } 171 | 172 | return a, nil 173 | } 174 | 175 | func ipFromAddr(upstream net.Addr) (net.IP, error) { 176 | upstreamString, _, err := net.SplitHostPort(upstream.String()) 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | upstreamIP := net.ParseIP(upstreamString) 182 | if nil == upstreamIP { 183 | return nil, fmt.Errorf("proxyproto: invalid IP address") 184 | } 185 | 186 | return upstreamIP, nil 187 | } 188 | 189 | // IgnoreProxyHeaderNotOnInterface retuns a ConnPolicyFunc which can be used to 190 | // decide whether to use or ignore PROXY headers depending on the connection 191 | // being made on a specific interface. This policy can be used when the server 192 | // is bound to multiple interfaces but wants to allow on only one interface. 193 | func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) ConnPolicyFunc { 194 | return func(connOpts ConnPolicyOptions) (Policy, error) { 195 | ip, err := ipFromAddr(connOpts.Downstream) 196 | if err != nil { 197 | return REJECT, err 198 | } 199 | 200 | if allowedIP.Equal(ip) { 201 | return USE, nil 202 | } 203 | 204 | return IGNORE, nil 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /policy_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | ) 7 | 8 | type failingAddr struct{} 9 | 10 | func (f failingAddr) Network() string { return "failing" } 11 | func (f failingAddr) String() string { return "failing" } 12 | 13 | func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) { 14 | var cases = []struct { 15 | name string 16 | policy PolicyFunc 17 | }{ 18 | {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, 19 | {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, 20 | } 21 | 22 | for _, tc := range cases { 23 | t.Run(tc.name, func(t *testing.T) { 24 | _, err := tc.policy(failingAddr{}) 25 | if err == nil { 26 | t.Fatal("Expected error, got none") 27 | } 28 | }) 29 | } 30 | } 31 | 32 | func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { 33 | p := MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) 34 | 35 | upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") 36 | if err != nil { 37 | t.Fatalf("err: %v", err) 38 | } 39 | 40 | policy, err := p(upstream) 41 | if err != nil { 42 | t.Fatalf("err: %v", err) 43 | } 44 | 45 | if policy != REJECT { 46 | t.Fatalf("Expected policy REJECT, got %v", policy) 47 | } 48 | } 49 | 50 | func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { 51 | p := MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) 52 | 53 | upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") 54 | if err != nil { 55 | t.Fatalf("err: %v", err) 56 | } 57 | 58 | policy, err := p(upstream) 59 | if err != nil { 60 | t.Fatalf("err: %v", err) 61 | } 62 | 63 | if policy != IGNORE { 64 | t.Fatalf("Expected policy IGNORE, got %v", policy) 65 | } 66 | } 67 | 68 | func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelist(t *testing.T) { 69 | var cases = []struct { 70 | name string 71 | policy PolicyFunc 72 | }{ 73 | {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, 74 | {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, 75 | } 76 | 77 | upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") 78 | if err != nil { 79 | t.Fatalf("err: %v", err) 80 | } 81 | 82 | for _, tc := range cases { 83 | t.Run(tc.name, func(t *testing.T) { 84 | policy, err := tc.policy(upstream) 85 | if err != nil { 86 | t.Fatalf("err: %v", err) 87 | } 88 | 89 | if policy != USE { 90 | t.Fatalf("Expected policy USE, got %v", policy) 91 | } 92 | }) 93 | } 94 | } 95 | 96 | func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing.T) { 97 | var cases = []struct { 98 | name string 99 | policy PolicyFunc 100 | }{ 101 | {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.0/29"})}, 102 | {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.0/29"})}, 103 | } 104 | 105 | upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") 106 | if err != nil { 107 | t.Fatalf("err: %v", err) 108 | } 109 | 110 | for _, tc := range cases { 111 | t.Run(tc.name, func(t *testing.T) { 112 | policy, err := tc.policy(upstream) 113 | if err != nil { 114 | t.Fatalf("err: %v", err) 115 | } 116 | 117 | if policy != USE { 118 | t.Fatalf("Expected policy USE, got %v", policy) 119 | } 120 | }) 121 | } 122 | } 123 | 124 | func Test_CreateWhitelistPolicyWithInvalidCidrReturnsError(t *testing.T) { 125 | _, err := StrictWhiteListPolicy([]string{"20/80"}) 126 | if err == nil { 127 | t.Error("Expected error, got none") 128 | } 129 | } 130 | 131 | func Test_CreateWhitelistPolicyWithInvalidIpAddressReturnsError(t *testing.T) { 132 | _, err := StrictWhiteListPolicy([]string{"855.222.233.11"}) 133 | if err == nil { 134 | t.Error("Expected error, got none") 135 | } 136 | } 137 | 138 | func Test_CreateLaxPolicyWithInvalidCidrReturnsError(t *testing.T) { 139 | _, err := LaxWhiteListPolicy([]string{"20/80"}) 140 | if err == nil { 141 | t.Error("Expected error, got none") 142 | } 143 | } 144 | 145 | func Test_CreateLaxPolicyWithInvalidIpAddresseturnsError(t *testing.T) { 146 | _, err := LaxWhiteListPolicy([]string{"855.222.233.11"}) 147 | if err == nil { 148 | t.Error("Expected error, got none") 149 | } 150 | } 151 | 152 | func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { 153 | defer func() { 154 | if r := recover(); r == nil { 155 | t.Error("Expected a panic, but got none") 156 | } 157 | }() 158 | 159 | MustLaxWhiteListPolicy([]string{"855.222.233.11"}) 160 | } 161 | 162 | func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { 163 | defer func() { 164 | if r := recover(); r == nil { 165 | t.Error("Expected a panic, but got none") 166 | } 167 | }() 168 | 169 | MustLaxWhiteListPolicy([]string{"20/80"}) 170 | } 171 | 172 | func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { 173 | defer func() { 174 | if r := recover(); r == nil { 175 | t.Error("Expected a panic, but got none") 176 | } 177 | }() 178 | 179 | MustStrictWhiteListPolicy([]string{"855.222.233.11"}) 180 | } 181 | 182 | func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { 183 | defer func() { 184 | if r := recover(); r == nil { 185 | t.Error("Expected a panic, but got none") 186 | } 187 | }() 188 | 189 | MustStrictWhiteListPolicy([]string{"20/80"}) 190 | } 191 | 192 | func TestSkipProxyHeaderForCIDR(t *testing.T) { 193 | _, cidr, _ := net.ParseCIDR("192.0.2.1/24") 194 | f := SkipProxyHeaderForCIDR(cidr, REJECT) 195 | 196 | upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345") 197 | policy, err := f(upstream) 198 | if err != nil { 199 | t.Fatalf("err: %v", err) 200 | } 201 | if policy != SKIP { 202 | t.Errorf("Expected a SKIP policy for the %s address", upstream) 203 | } 204 | 205 | upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345") 206 | policy, err = f(upstream) 207 | if err != nil { 208 | t.Fatalf("err: %v", err) 209 | } 210 | if policy != REJECT { 211 | t.Errorf("Expected a REJECT policy for the %s address", upstream) 212 | } 213 | } 214 | 215 | func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) { 216 | downstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") 217 | if err != nil { 218 | t.Fatalf("err: %v", err) 219 | } 220 | 221 | var cases = []struct { 222 | name string 223 | policy ConnPolicyFunc 224 | downstreamAddress net.Addr 225 | expectedPolicy Policy 226 | expectError bool 227 | }{ 228 | {"ignore header for requests non on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("192.0.2.1")), downstream, IGNORE, false}, 229 | {"use headers for requests on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), downstream, USE, false}, 230 | {"invalid address should return error", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true}, 231 | } 232 | 233 | for _, tc := range cases { 234 | t.Run(tc.name, func(t *testing.T) { 235 | policy, err := tc.policy(ConnPolicyOptions{ 236 | Downstream: tc.downstreamAddress, 237 | }) 238 | if !tc.expectError && err != nil { 239 | t.Fatalf("err: %v", err) 240 | } 241 | if tc.expectError && err == nil { 242 | t.Fatal("Expected error, got none") 243 | } 244 | 245 | if policy != tc.expectedPolicy { 246 | t.Fatalf("Expected policy %v, got %v", tc.expectedPolicy, policy) 247 | } 248 | }) 249 | } 250 | 251 | } 252 | -------------------------------------------------------------------------------- /protocol.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net" 9 | "sync" 10 | "sync/atomic" 11 | "time" 12 | ) 13 | 14 | var ( 15 | // DefaultReadHeaderTimeout is how long header processing waits for header to 16 | // be read from the wire, if Listener.ReaderHeaderTimeout is not set. 17 | // It's kept as a global variable so to make it easier to find and override, 18 | // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" 19 | DefaultReadHeaderTimeout = 10 * time.Second 20 | 21 | // ErrInvalidUpstream should be returned when an upstream connection address 22 | // is not trusted, and therefore is invalid. 23 | ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information") 24 | ) 25 | 26 | // Listener is used to wrap an underlying listener, 27 | // whose connections may be using the HAProxy Proxy Protocol. 28 | // If the connection is using the protocol, the RemoteAddr() will return 29 | // the correct client address. ReadHeaderTimeout will be applied to all 30 | // connections in order to prevent blocking operations. If no ReadHeaderTimeout 31 | // is set, a default of 10s will be used. This can be disabled by setting the 32 | // timeout to < 0. 33 | // 34 | // Only one of Policy or ConnPolicy should be provided. If both are provided then 35 | // a panic would occur during accept. 36 | type Listener struct { 37 | Listener net.Listener 38 | // Deprecated: use ConnPolicyFunc instead. This will be removed in future release. 39 | Policy PolicyFunc 40 | ConnPolicy ConnPolicyFunc 41 | ValidateHeader Validator 42 | ReadHeaderTimeout time.Duration 43 | } 44 | 45 | // Conn is used to wrap and underlying connection which 46 | // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will 47 | // return the address of the client instead of the proxy address. Each connection 48 | // will have its own readHeaderTimeout and readDeadline set by the Accept() call. 49 | type Conn struct { 50 | readDeadline atomic.Value // time.Time 51 | once sync.Once 52 | readErr error 53 | conn net.Conn 54 | bufReader *bufio.Reader 55 | reader io.Reader 56 | header *Header 57 | ProxyHeaderPolicy Policy 58 | Validate Validator 59 | readHeaderTimeout time.Duration 60 | } 61 | 62 | // Validator receives a header and decides whether it is a valid one 63 | // In case the header is not deemed valid it should return an error. 64 | type Validator func(*Header) error 65 | 66 | // ValidateHeader adds given validator for proxy headers to a connection when passed as option to NewConn() 67 | func ValidateHeader(v Validator) func(*Conn) { 68 | return func(c *Conn) { 69 | if v != nil { 70 | c.Validate = v 71 | } 72 | } 73 | } 74 | 75 | // SetReadHeaderTimeout sets the readHeaderTimeout for a connection when passed as option to NewConn() 76 | func SetReadHeaderTimeout(t time.Duration) func(*Conn) { 77 | return func(c *Conn) { 78 | if t >= 0 { 79 | c.readHeaderTimeout = t 80 | } 81 | } 82 | } 83 | 84 | // Accept waits for and returns the next valid connection to the listener. 85 | func (p *Listener) Accept() (net.Conn, error) { 86 | for { 87 | // Get the underlying connection 88 | conn, err := p.Listener.Accept() 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | proxyHeaderPolicy := USE 94 | if p.Policy != nil && p.ConnPolicy != nil { 95 | panic("only one of policy or connpolicy must be provided.") 96 | } 97 | if p.Policy != nil || p.ConnPolicy != nil { 98 | if p.Policy != nil { 99 | proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) 100 | } else { 101 | proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{ 102 | Upstream: conn.RemoteAddr(), 103 | Downstream: conn.LocalAddr(), 104 | }) 105 | } 106 | if err != nil { 107 | // can't decide the policy, we can't accept the connection 108 | conn.Close() 109 | 110 | if errors.Is(err, ErrInvalidUpstream) { 111 | // keep listening for other connections 112 | continue 113 | } 114 | 115 | return nil, err 116 | } 117 | // Handle a connection as a regular one 118 | if proxyHeaderPolicy == SKIP { 119 | return conn, nil 120 | } 121 | } 122 | 123 | newConn := NewConn( 124 | conn, 125 | WithPolicy(proxyHeaderPolicy), 126 | ValidateHeader(p.ValidateHeader), 127 | ) 128 | 129 | // If the ReadHeaderTimeout for the listener is unset, use the default timeout. 130 | if p.ReadHeaderTimeout == 0 { 131 | p.ReadHeaderTimeout = DefaultReadHeaderTimeout 132 | } 133 | 134 | // Set the readHeaderTimeout of the new conn to the value of the listener 135 | newConn.readHeaderTimeout = p.ReadHeaderTimeout 136 | 137 | return newConn, nil 138 | } 139 | } 140 | 141 | // Close closes the underlying listener. 142 | func (p *Listener) Close() error { 143 | return p.Listener.Close() 144 | } 145 | 146 | // Addr returns the underlying listener's network address. 147 | func (p *Listener) Addr() net.Addr { 148 | return p.Listener.Addr() 149 | } 150 | 151 | // NewConn is used to wrap a net.Conn that may be speaking 152 | // the proxy protocol into a proxyproto.Conn 153 | func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { 154 | // For v1 the header length is at most 108 bytes. 155 | // For v2 the header length is at most 52 bytes plus the length of the TLVs. 156 | // We use 256 bytes to be safe. 157 | const bufSize = 256 158 | br := bufio.NewReaderSize(conn, bufSize) 159 | 160 | pConn := &Conn{ 161 | bufReader: br, 162 | reader: io.MultiReader(br, conn), 163 | conn: conn, 164 | } 165 | 166 | for _, opt := range opts { 167 | opt(pConn) 168 | } 169 | 170 | return pConn 171 | } 172 | 173 | // Read is check for the proxy protocol header when doing 174 | // the initial scan. If there is an error parsing the header, 175 | // it is returned and the socket is closed. 176 | func (p *Conn) Read(b []byte) (int, error) { 177 | p.once.Do(func() { 178 | p.readErr = p.readHeader() 179 | }) 180 | if p.readErr != nil { 181 | return 0, p.readErr 182 | } 183 | 184 | return p.reader.Read(b) 185 | } 186 | 187 | // Write wraps original conn.Write 188 | func (p *Conn) Write(b []byte) (int, error) { 189 | return p.conn.Write(b) 190 | } 191 | 192 | // Close wraps original conn.Close 193 | func (p *Conn) Close() error { 194 | return p.conn.Close() 195 | } 196 | 197 | // ProxyHeader returns the proxy protocol header, if any. If an error occurs 198 | // while reading the proxy header, nil is returned. 199 | func (p *Conn) ProxyHeader() *Header { 200 | p.once.Do(func() { p.readErr = p.readHeader() }) 201 | return p.header 202 | } 203 | 204 | // LocalAddr returns the address of the server if the proxy 205 | // protocol is being used, otherwise just returns the address of 206 | // the socket server. In case an error happens on reading the 207 | // proxy header the original LocalAddr is returned, not the one 208 | // from the proxy header even if the proxy header itself is 209 | // syntactically correct. 210 | func (p *Conn) LocalAddr() net.Addr { 211 | p.once.Do(func() { p.readErr = p.readHeader() }) 212 | if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { 213 | return p.conn.LocalAddr() 214 | } 215 | 216 | return p.header.DestinationAddr 217 | } 218 | 219 | // RemoteAddr returns the address of the client if the proxy 220 | // protocol is being used, otherwise just returns the address of 221 | // the socket peer. In case an error happens on reading the 222 | // proxy header the original RemoteAddr is returned, not the one 223 | // from the proxy header even if the proxy header itself is 224 | // syntactically correct. 225 | func (p *Conn) RemoteAddr() net.Addr { 226 | p.once.Do(func() { p.readErr = p.readHeader() }) 227 | if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { 228 | return p.conn.RemoteAddr() 229 | } 230 | 231 | return p.header.SourceAddr 232 | } 233 | 234 | // Raw returns the underlying connection which can be casted to 235 | // a concrete type, allowing access to specialized functions. 236 | // 237 | // Use this ONLY if you know exactly what you are doing. 238 | func (p *Conn) Raw() net.Conn { 239 | return p.conn 240 | } 241 | 242 | // TCPConn returns the underlying TCP connection, 243 | // allowing access to specialized functions. 244 | // 245 | // Use this ONLY if you know exactly what you are doing. 246 | func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) { 247 | conn, ok = p.conn.(*net.TCPConn) 248 | return 249 | } 250 | 251 | // UnixConn returns the underlying Unix socket connection, 252 | // allowing access to specialized functions. 253 | // 254 | // Use this ONLY if you know exactly what you are doing. 255 | func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) { 256 | conn, ok = p.conn.(*net.UnixConn) 257 | return 258 | } 259 | 260 | // UDPConn returns the underlying UDP connection, 261 | // allowing access to specialized functions. 262 | // 263 | // Use this ONLY if you know exactly what you are doing. 264 | func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) { 265 | conn, ok = p.conn.(*net.UDPConn) 266 | return 267 | } 268 | 269 | // SetDeadline wraps original conn.SetDeadline 270 | func (p *Conn) SetDeadline(t time.Time) error { 271 | p.readDeadline.Store(t) 272 | return p.conn.SetDeadline(t) 273 | } 274 | 275 | // SetReadDeadline wraps original conn.SetReadDeadline 276 | func (p *Conn) SetReadDeadline(t time.Time) error { 277 | // Set a local var that tells us the desired deadline. This is 278 | // needed in order to reset the read deadline to the one that is 279 | // desired by the user, rather than an empty deadline. 280 | p.readDeadline.Store(t) 281 | return p.conn.SetReadDeadline(t) 282 | } 283 | 284 | // SetWriteDeadline wraps original conn.SetWriteDeadline 285 | func (p *Conn) SetWriteDeadline(t time.Time) error { 286 | return p.conn.SetWriteDeadline(t) 287 | } 288 | 289 | func (p *Conn) readHeader() error { 290 | // If the connection's readHeaderTimeout is more than 0, 291 | // push our deadline back to now plus the timeout. This should only 292 | // run on the connection, as we don't want to override the previous 293 | // read deadline the user may have used. 294 | if p.readHeaderTimeout > 0 { 295 | if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil { 296 | return err 297 | } 298 | } 299 | 300 | header, err := Read(p.bufReader) 301 | 302 | // If the connection's readHeaderTimeout is more than 0, undo the change to the 303 | // deadline that we made above. Because we retain the readDeadline as part of our 304 | // SetReadDeadline override, we know the user's desired deadline so we use that. 305 | // Therefore, we check whether the error is a net.Timeout and if it is, we decide 306 | // the proxy proto does not exist and set the error accordingly. 307 | if p.readHeaderTimeout > 0 { 308 | t := p.readDeadline.Load() 309 | if t == nil { 310 | t = time.Time{} 311 | } 312 | if err := p.conn.SetReadDeadline(t.(time.Time)); err != nil { 313 | return err 314 | } 315 | if netErr, ok := err.(net.Error); ok && netErr.Timeout() { 316 | err = ErrNoProxyProtocol 317 | } 318 | } 319 | 320 | // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto 321 | // let's act as if there was no error when PROXY protocol is not present. 322 | if err == ErrNoProxyProtocol { 323 | // but not if it is required that the connection has one 324 | if p.ProxyHeaderPolicy == REQUIRE { 325 | return err 326 | } 327 | 328 | return nil 329 | } 330 | 331 | // proxy protocol header was found 332 | if err == nil && header != nil { 333 | switch p.ProxyHeaderPolicy { 334 | case REJECT: 335 | // this connection is not allowed to send one 336 | return ErrSuperfluousProxyHeader 337 | case USE, REQUIRE: 338 | if p.Validate != nil { 339 | err = p.Validate(header) 340 | if err != nil { 341 | return err 342 | } 343 | } 344 | 345 | p.header = header 346 | } 347 | } 348 | 349 | return err 350 | } 351 | 352 | // ReadFrom implements the io.ReaderFrom ReadFrom method 353 | func (p *Conn) ReadFrom(r io.Reader) (int64, error) { 354 | if rf, ok := p.conn.(io.ReaderFrom); ok { 355 | return rf.ReadFrom(r) 356 | } 357 | return io.Copy(p.conn, r) 358 | } 359 | 360 | // WriteTo implements io.WriterTo 361 | func (p *Conn) WriteTo(w io.Writer) (int64, error) { 362 | p.once.Do(func() { p.readErr = p.readHeader() }) 363 | if p.readErr != nil { 364 | return 0, p.readErr 365 | } 366 | 367 | b := make([]byte, p.bufReader.Buffered()) 368 | if _, err := p.bufReader.Read(b); err != nil { 369 | return 0, err // this should never as we read buffered data 370 | } 371 | 372 | var n int64 373 | { 374 | nn, err := w.Write(b) 375 | n += int64(nn) 376 | if err != nil { 377 | return n, err 378 | } 379 | } 380 | { 381 | nn, err := io.Copy(w, p.conn) 382 | n += nn 383 | if err != nil { 384 | return n, err 385 | } 386 | } 387 | 388 | return n, nil 389 | } 390 | -------------------------------------------------------------------------------- /tlv.go: -------------------------------------------------------------------------------- 1 | // Type-Length-Value splitting and parsing for proxy protocol V2 2 | // See spec https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt sections 2.2 to 2.7 and 3 | 4 | package proxyproto 5 | 6 | import ( 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "math" 11 | ) 12 | 13 | const ( 14 | // Section 2.2 15 | PP2_TYPE_ALPN PP2Type = 0x01 16 | PP2_TYPE_AUTHORITY PP2Type = 0x02 17 | PP2_TYPE_CRC32C PP2Type = 0x03 18 | PP2_TYPE_NOOP PP2Type = 0x04 19 | PP2_TYPE_UNIQUE_ID PP2Type = 0x05 20 | PP2_TYPE_SSL PP2Type = 0x20 21 | PP2_SUBTYPE_SSL_VERSION PP2Type = 0x21 22 | PP2_SUBTYPE_SSL_CN PP2Type = 0x22 23 | PP2_SUBTYPE_SSL_CIPHER PP2Type = 0x23 24 | PP2_SUBTYPE_SSL_SIG_ALG PP2Type = 0x24 25 | PP2_SUBTYPE_SSL_KEY_ALG PP2Type = 0x25 26 | PP2_TYPE_NETNS PP2Type = 0x30 27 | 28 | // Section 2.2.7, reserved types 29 | PP2_TYPE_MIN_CUSTOM PP2Type = 0xE0 30 | PP2_TYPE_MAX_CUSTOM PP2Type = 0xEF 31 | PP2_TYPE_MIN_EXPERIMENT PP2Type = 0xF0 32 | PP2_TYPE_MAX_EXPERIMENT PP2Type = 0xF7 33 | PP2_TYPE_MIN_FUTURE PP2Type = 0xF8 34 | PP2_TYPE_MAX_FUTURE PP2Type = 0xFF 35 | ) 36 | 37 | var ( 38 | ErrTruncatedTLV = errors.New("proxyproto: truncated TLV") 39 | ErrMalformedTLV = errors.New("proxyproto: malformed TLV Value") 40 | ErrIncompatibleTLV = errors.New("proxyproto: incompatible TLV type") 41 | ) 42 | 43 | // PP2Type is the proxy protocol v2 type 44 | type PP2Type byte 45 | 46 | // TLV is a uninterpreted Type-Length-Value for V2 protocol, see section 2.2 47 | type TLV struct { 48 | Type PP2Type 49 | Value []byte 50 | } 51 | 52 | // SplitTLVs splits the Type-Length-Value vector, returns the vector or an error. 53 | func SplitTLVs(raw []byte) ([]TLV, error) { 54 | var tlvs []TLV 55 | for i := 0; i < len(raw); { 56 | tlv := TLV{ 57 | Type: PP2Type(raw[i]), 58 | } 59 | if len(raw)-i <= 2 { 60 | return nil, ErrTruncatedTLV 61 | } 62 | tlvLen := int(binary.BigEndian.Uint16(raw[i+1 : i+3])) // Max length = 65K 63 | i += 3 64 | if i+tlvLen > len(raw) { 65 | return nil, ErrTruncatedTLV 66 | } 67 | // Ignore no-op padding 68 | if tlv.Type != PP2_TYPE_NOOP { 69 | tlv.Value = make([]byte, tlvLen) 70 | copy(tlv.Value, raw[i:i+tlvLen]) 71 | } 72 | i += tlvLen 73 | tlvs = append(tlvs, tlv) 74 | } 75 | return tlvs, nil 76 | } 77 | 78 | // JoinTLVs joins multiple Type-Length-Value records. 79 | func JoinTLVs(tlvs []TLV) ([]byte, error) { 80 | var raw []byte 81 | for _, tlv := range tlvs { 82 | if len(tlv.Value) > math.MaxUint16 { 83 | return nil, fmt.Errorf("proxyproto: cannot format TLV %v with length %d", tlv.Type, len(tlv.Value)) 84 | } 85 | var length [2]byte 86 | binary.BigEndian.PutUint16(length[:], uint16(len(tlv.Value))) 87 | raw = append(raw, byte(tlv.Type)) 88 | raw = append(raw, length[:]...) 89 | raw = append(raw, tlv.Value...) 90 | } 91 | return raw, nil 92 | } 93 | 94 | // Registered is true if the type is registered in the spec, see section 2.2 95 | func (p PP2Type) Registered() bool { 96 | switch p { 97 | case PP2_TYPE_ALPN, 98 | PP2_TYPE_AUTHORITY, 99 | PP2_TYPE_CRC32C, 100 | PP2_TYPE_NOOP, 101 | PP2_TYPE_UNIQUE_ID, 102 | PP2_TYPE_SSL, 103 | PP2_SUBTYPE_SSL_VERSION, 104 | PP2_SUBTYPE_SSL_CN, 105 | PP2_SUBTYPE_SSL_CIPHER, 106 | PP2_SUBTYPE_SSL_SIG_ALG, 107 | PP2_SUBTYPE_SSL_KEY_ALG, 108 | PP2_TYPE_NETNS: 109 | return true 110 | } 111 | return false 112 | } 113 | 114 | // App is true if the type is reserved for application specific data, see section 2.2.7 115 | func (p PP2Type) App() bool { 116 | return p >= PP2_TYPE_MIN_CUSTOM && p <= PP2_TYPE_MAX_CUSTOM 117 | } 118 | 119 | // Experiment is true if the type is reserved for temporary experimental use by application developers, see section 2.2.7 120 | func (p PP2Type) Experiment() bool { 121 | return p >= PP2_TYPE_MIN_EXPERIMENT && p <= PP2_TYPE_MAX_EXPERIMENT 122 | } 123 | 124 | // Future is true is the type is reserved for future use, see section 2.2.7 125 | func (p PP2Type) Future() bool { 126 | return p >= PP2_TYPE_MIN_FUTURE 127 | } 128 | 129 | // Spec is true if the type is covered by the spec, see section 2.2 and 2.2.7 130 | func (p PP2Type) Spec() bool { 131 | return p.Registered() || p.App() || p.Experiment() || p.Future() 132 | } 133 | -------------------------------------------------------------------------------- /tlv_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "testing" 7 | ) 8 | 9 | var ( 10 | fixtureOneByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 1} 11 | fixtureTwoByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 2, 0x00} 12 | fixtureEmptyLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x01} 13 | fixturePartialLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x02, 0x00} 14 | ) 15 | 16 | var invalidTLVTests = []struct { 17 | name string 18 | reader *bufio.Reader 19 | expectedError error 20 | }{ 21 | { 22 | name: "One byte TLV", 23 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, 24 | fixtureOneByteTLV)...)), 25 | expectedError: ErrTruncatedTLV, 26 | }, 27 | { 28 | name: "Two byte TLV", 29 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, 30 | fixtureTwoByteTLV)...)), 31 | expectedError: ErrTruncatedTLV, 32 | }, 33 | { 34 | name: "Empty Len TLV", 35 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, 36 | fixtureEmptyLenTLV)...)), 37 | expectedError: ErrTruncatedTLV, 38 | }, 39 | { 40 | name: "Partial Len TLV", 41 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, 42 | fixturePartialLenTLV)...)), 43 | expectedError: ErrTruncatedTLV, 44 | }, 45 | } 46 | 47 | func TestValid0Length(t *testing.T) { 48 | r := bufio.NewReader(bytes.NewReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, []byte{byte(PP2_TYPE_MIN_CUSTOM), 0x00, 0x00})...))) 49 | h, err := Read(r) 50 | if err != nil { 51 | t.Fatalf("unexpected error: %v", err) 52 | } 53 | tlvs, err := h.TLVs() 54 | if err != nil { 55 | t.Fatalf("unexpected error: %v", err) 56 | } 57 | if len(tlvs) != 1 { 58 | t.Fatalf("expected 1 tlv, got %d", len(tlvs)) 59 | } 60 | if len(tlvs[0].Value) != 0 { 61 | t.Fatalf("expected 0 byte tlv value, got %x", tlvs[0].Value) 62 | } 63 | } 64 | 65 | func TestInvalidV2TLV(t *testing.T) { 66 | for _, tc := range invalidTLVTests { 67 | t.Run(tc.name, func(t *testing.T) { 68 | if hdr, err := Read(tc.reader); err != nil { 69 | t.Fatalf("TestInvalidV2TLV %s: unexpected error reading proxy protocol %#v", tc.name, err) 70 | } else if _, err := hdr.TLVs(); err != tc.expectedError { 71 | t.Fatalf("TestInvalidV2TLV %s: expected %#v, actual %#v", tc.name, tc.expectedError, err) 72 | } 73 | }) 74 | } 75 | } 76 | 77 | func TestV2TLVPP2Registered(t *testing.T) { 78 | pp2RegTypes := []PP2Type{ 79 | PP2_TYPE_ALPN, PP2_TYPE_AUTHORITY, PP2_TYPE_CRC32C, PP2_TYPE_NOOP, PP2_TYPE_UNIQUE_ID, 80 | PP2_TYPE_SSL, PP2_SUBTYPE_SSL_VERSION, PP2_SUBTYPE_SSL_CN, 81 | PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG, 82 | PP2_TYPE_NETNS, 83 | } 84 | pp2RegMap := make(map[PP2Type]bool) 85 | for _, p := range pp2RegTypes { 86 | pp2RegMap[p] = true 87 | if !p.Registered() { 88 | t.Fatalf("TestV2TLVPP2Registered: type %x should be registered", p) 89 | } 90 | if !p.Spec() { 91 | t.Fatalf("TestV2TLVPP2Registered: type %x should be in spec", p) 92 | } 93 | if p.App() { 94 | t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly app", p) 95 | } 96 | if p.Experiment() { 97 | t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly experiment", p) 98 | } 99 | if p.Future() { 100 | t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly future", p) 101 | } 102 | } 103 | 104 | lastType := PP2Type(0xFF) 105 | for i := PP2Type(0x00); i < lastType; i++ { 106 | if !pp2RegMap[i] { 107 | if i.Registered() { 108 | t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", i) 109 | } 110 | } 111 | } 112 | 113 | if lastType.Registered() { 114 | t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", lastType) 115 | } 116 | } 117 | 118 | func TestJoinTLVs(t *testing.T) { 119 | tests := []struct { 120 | name string 121 | raw []byte 122 | tlvs []TLV 123 | }{ 124 | { 125 | name: "authority TLV", 126 | raw: append([]byte{byte(PP2_TYPE_AUTHORITY), 0x00, 0x0B}, []byte("example.org")...), 127 | tlvs: []TLV{{ 128 | Type: PP2_TYPE_AUTHORITY, 129 | Value: []byte("example.org"), 130 | }}, 131 | }, 132 | { 133 | name: "empty TLV", 134 | raw: []byte{byte(PP2_TYPE_NOOP), 0x00, 0x00}, 135 | tlvs: []TLV{{ 136 | Type: PP2_TYPE_NOOP, 137 | Value: nil, 138 | }}, 139 | }, 140 | } 141 | for _, tc := range tests { 142 | t.Run(tc.name, func(t *testing.T) { 143 | if raw, err := JoinTLVs(tc.tlvs); err != nil { 144 | t.Fatalf("unexpected error: %v", err) 145 | } else if !bytes.Equal(raw, tc.raw) { 146 | t.Errorf("expected %#v, got %#v", tc.raw, raw) 147 | } 148 | }) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /tlvparse/aws.go: -------------------------------------------------------------------------------- 1 | // Amazon's application extension to TLVs for NLB VPC endpoint services 2 | // https://docs.aws.amazon.com/elasticloadbalancing/latest/network/load-balancer-target-groups.html#proxy-protocol 3 | 4 | package tlvparse 5 | 6 | import ( 7 | "regexp" 8 | 9 | "github.com/pires/go-proxyproto" 10 | ) 11 | 12 | const ( 13 | // Amazon's extension 14 | PP2_TYPE_AWS = 0xEA 15 | PP2_SUBTYPE_AWS_VPCE_ID = 0x01 16 | ) 17 | 18 | var vpceRe = regexp.MustCompile("^[A-Za-z0-9-]*$") 19 | 20 | func IsAWSVPCEndpointID(tlv proxyproto.TLV) bool { 21 | return tlv.Type == PP2_TYPE_AWS && len(tlv.Value) > 0 && tlv.Value[0] == PP2_SUBTYPE_AWS_VPCE_ID 22 | } 23 | 24 | func AWSVPCEndpointID(tlv proxyproto.TLV) (string, error) { 25 | if !IsAWSVPCEndpointID(tlv) { 26 | return "", proxyproto.ErrIncompatibleTLV 27 | } 28 | vpce := string(tlv.Value[1:]) 29 | if !vpceRe.MatchString(vpce) { 30 | return "", proxyproto.ErrMalformedTLV 31 | } 32 | return vpce, nil 33 | } 34 | 35 | // FindAWSVPCEndpointID returns the first AWS VPC ID in the TLV if it exists and is well-formed. 36 | func FindAWSVPCEndpointID(tlvs []proxyproto.TLV) string { 37 | for _, tlv := range tlvs { 38 | if vpc, err := AWSVPCEndpointID(tlv); err == nil && vpc != "" { 39 | return vpc 40 | } 41 | } 42 | return "" 43 | } 44 | -------------------------------------------------------------------------------- /tlvparse/aws_test.go: -------------------------------------------------------------------------------- 1 | package tlvparse 2 | 3 | import ( 4 | "encoding/binary" 5 | "testing" 6 | 7 | "github.com/pires/go-proxyproto" 8 | ) 9 | 10 | var awsTestCases = []struct { 11 | name string 12 | raw []byte 13 | types []proxyproto.PP2Type 14 | valid func(*testing.T, string, []proxyproto.TLV) 15 | }{ 16 | { 17 | name: "VPCE example", 18 | // https://github.com/aws/elastic-load-balancing-tools/blob/c8eee30ab991ab4c57dc37d1c58f09f67bd534aa/proprot/tst/com/amazonaws/proprot/Compatibility_AwsNetworkLoadBalancerTest.java#L41..L67 19 | raw: []byte{ 20 | 0x0d, 0x0a, 0x0d, 0x0a, /* Start of Sig */ 21 | 0x00, 0x0d, 0x0a, 0x51, 22 | 0x55, 0x49, 0x54, 0x0a, /* End of Sig */ 23 | 0x21, 0x11, 0x00, 0x54, /* ver_cmd, fam and len */ 24 | 0xac, 0x1f, 0x07, 0x71, /* Caller src ip */ 25 | 0xac, 0x1f, 0x0a, 0x1f, /* Endpoint dst ip */ 26 | 0xc8, 0xf2, 0x00, 0x50, /* Proxy src port & dst port */ 27 | 0x03, 0x00, 0x04, 0xe8, /* CRC TLV start */ 28 | 0xd6, 0x89, 0x2d, 0xea, /* CRC TLV cont, VPCE id TLV start */ 29 | 0x00, 0x17, 0x01, 0x76, 30 | 0x70, 0x63, 0x65, 0x2d, 31 | 0x30, 0x38, 0x64, 0x32, 32 | 0x62, 0x66, 0x31, 0x35, 33 | 0x66, 0x61, 0x63, 0x35, 34 | 0x30, 0x30, 0x31, 0x63, 35 | 0x39, 0x04, 0x00, 0x24, /* VPCE id TLV end, NOOP TLV start*/ 36 | 0x00, 0x00, 0x00, 0x00, 37 | 0x00, 0x00, 0x00, 0x00, 38 | 0x00, 0x00, 0x00, 0x00, 39 | 0x00, 0x00, 0x00, 0x00, 40 | 0x00, 0x00, 0x00, 0x00, 41 | 0x00, 0x00, 0x00, 0x00, 42 | 0x00, 0x00, 0x00, 0x00, 43 | 0x00, 0x00, 0x00, 0x00, 44 | 0x00, 0x00, 0x00, 0x00, /* NOOP TLV end */ 45 | }, 46 | types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, 47 | valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { 48 | if !IsAWSVPCEndpointID(tlvs[1]) { 49 | t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWSVPCEndpointID type", name) 50 | } 51 | 52 | vpce := "vpce-08d2bf15fac5001c9" 53 | if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { 54 | t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWSVPCEndpointID", name) 55 | } else if vpca != vpce { 56 | t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) 57 | } 58 | 59 | if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { 60 | t.Fatalf("TestParseV2TLV %s: Expected to find AWSVPCEndpointID %#v in TLVs", name, vpce) 61 | } else if vpca != vpce { 62 | t.Fatalf("TestParseV2TLV %s: Unexpected AWSVPCEndpointID from header expected %#v, actual %#v", name, vpce, vpca) 63 | } 64 | 65 | }, 66 | }, 67 | { 68 | name: "VPCE capture", 69 | raw: []byte{ 70 | 0x0d, 0x0a, 0x0d, 0x0a, 71 | 0x00, 0x0d, 0x0a, 0x51, 72 | 0x55, 0x49, 0x54, 0x0a, 73 | 0x21, 0x11, 0x00, 0x54, 74 | 0xc0, 0xa8, 0x2c, 0x0a, 75 | 0xc0, 0xa8, 0x2c, 0x07, 76 | 0xcc, 0x3e, 0x24, 0x1b, 77 | 0x03, 0x00, 0x04, 0xb9, 78 | 0x28, 0x6f, 0xa6, 0xea, 79 | 0x00, 0x17, 0x01, 0x76, 80 | 0x70, 0x63, 0x65, 0x2d, 81 | 0x30, 0x30, 0x65, 0x61, 82 | 0x66, 0x63, 0x34, 0x35, 83 | 0x38, 0x65, 0x63, 0x39, 84 | 0x37, 0x62, 0x38, 0x33, 85 | 0x33, 0x04, 0x00, 0x24, 86 | 0x00, 0x00, 0x00, 0x00, 87 | 0x00, 0x00, 0x00, 0x00, 88 | 0x00, 0x00, 0x00, 0x00, 89 | 0x00, 0x00, 0x00, 0x00, 90 | 0x00, 0x00, 0x00, 0x00, 91 | 0x00, 0x00, 0x00, 0x00, 92 | 0x00, 0x00, 0x00, 0x00, 93 | 0x00, 0x00, 0x00, 0x00, 94 | 0x00, 0x00, 0x00, 0x00, 95 | }, 96 | types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, 97 | valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { 98 | if !IsAWSVPCEndpointID(tlvs[1]) { 99 | t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWS VPC endpoint ID type", name) 100 | } 101 | 102 | vpce := "vpce-00eafc458ec97b833" 103 | if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { 104 | t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWS VPC ID", name) 105 | } else if vpca != vpce { 106 | t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) 107 | } 108 | 109 | if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { 110 | t.Fatalf("TestParseV2TLV %s: Expected to find VPC ID %#v in TLVs", name, vpce) 111 | } else if vpca != vpce { 112 | t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from header expected %#v, actual %#v", name, vpce, vpca) 113 | } 114 | 115 | }, 116 | }, 117 | } 118 | 119 | func TestV2TLVAWSVPCEBadChars(t *testing.T) { 120 | badVPCE := "vcpe-!?***&&&&&&&" 121 | rawTLVs := vpceTLV(badVPCE) 122 | tlvs, err := proxyproto.SplitTLVs(rawTLVs) 123 | if len(tlvs) != 1 { 124 | t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) 125 | } 126 | if err != nil { 127 | t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) 128 | } 129 | 130 | _, err = AWSVPCEndpointID(tlvs[0]) 131 | if err != proxyproto.ErrMalformedTLV { 132 | t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) 133 | } 134 | 135 | if FindAWSVPCEndpointID(tlvs) != "" { 136 | t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPCEndpointID unexpectedly found") 137 | } 138 | 139 | rawTLVs = vpceTLV("") 140 | tlvs, err = proxyproto.SplitTLVs(rawTLVs) 141 | if len(tlvs) != 1 { 142 | t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) 143 | } 144 | if err != nil { 145 | t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) 146 | } 147 | 148 | parsedVPCE, err := AWSVPCEndpointID(tlvs[0]) 149 | if err != nil { 150 | t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) 151 | } 152 | 153 | if parsedVPCE != "" { 154 | t.Fatalf("TestV2TLVAWSVPCEBadChars: found non-empty vpce, actual: %#v", parsedVPCE) 155 | } 156 | 157 | parsedVPCE = FindAWSVPCEndpointID(tlvs) 158 | if parsedVPCE != "" { 159 | t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPECID unexpectedly found") 160 | } 161 | } 162 | 163 | func TestParseAWSVPCEndpointIDTLVs(t *testing.T) { 164 | for _, tc := range awsTestCases { 165 | t.Run(tc.name, func(t *testing.T) { 166 | tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) 167 | tc.valid(t, tc.name, tlvs) 168 | }) 169 | } 170 | } 171 | 172 | func TestV2TLVAWSUnknownSubtype(t *testing.T) { 173 | vpce := "vpce-abc1234" 174 | 175 | rawTLVs := vpceTLV(vpce) 176 | tlvs, err := proxyproto.SplitTLVs(rawTLVs) 177 | if len(tlvs) != 1 { 178 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) 179 | } 180 | if err != nil { 181 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) 182 | } 183 | 184 | avpce, err := AWSVPCEndpointID(tlvs[0]) 185 | if err != nil { 186 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error actual: %#v", err) 187 | } 188 | if avpce != vpce { 189 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected vpce value expected: %#v, actual: %#v", vpce, avpce) 190 | } 191 | avpce = FindAWSVPCEndpointID(tlvs) 192 | if avpce == "" { 193 | t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly missing") 194 | } 195 | if avpce != vpce { 196 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID value expected: %#v, actual: %#v", vpce, avpce) 197 | } 198 | 199 | subtypeIndex := 3 200 | // Sanity check 201 | if rawTLVs[subtypeIndex] != PP2_SUBTYPE_AWS_VPCE_ID { 202 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected subtype expected %x, actual %x", PP2_SUBTYPE_AWS_VPCE_ID, rawTLVs[subtypeIndex]) 203 | } 204 | 205 | rawTLVs[subtypeIndex] = PP2_SUBTYPE_AWS_VPCE_ID + 1 206 | 207 | tlvs, err = proxyproto.SplitTLVs(rawTLVs) 208 | if len(tlvs) != 1 { 209 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) 210 | } 211 | if err != nil { 212 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) 213 | } 214 | 215 | if IsAWSVPCEndpointID(tlvs[0]) { 216 | t.Fatalf("TestV2TLVAWSUnknownSubtype: AWSVPCEType() unexpectedly true after changing subtype") 217 | } 218 | 219 | _, err = AWSVPCEndpointID(tlvs[0]) 220 | if err != proxyproto.ErrIncompatibleTLV { 221 | t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error expected %#v, actual: %#v", proxyproto.ErrIncompatibleTLV, err) 222 | } 223 | 224 | if FindAWSVPCEndpointID(tlvs) != "" { 225 | t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly exists despite invalid subtype") 226 | } 227 | } 228 | 229 | func vpceTLV(vpce string) []byte { 230 | tlv := []byte{ 231 | PP2_TYPE_AWS, 0x00, 0x00, PP2_SUBTYPE_AWS_VPCE_ID, 232 | } 233 | binary.BigEndian.PutUint16(tlv[1:3], uint16(len(vpce)+1)) // +1 for subtype 234 | return append(tlv, []byte(vpce)...) 235 | } 236 | -------------------------------------------------------------------------------- /tlvparse/azure.go: -------------------------------------------------------------------------------- 1 | // Azure's application extension to TLVs for Private Link Services 2 | // https://docs.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2 3 | 4 | package tlvparse 5 | 6 | import ( 7 | "encoding/binary" 8 | 9 | "github.com/pires/go-proxyproto" 10 | ) 11 | 12 | const ( 13 | // Azure's extension 14 | PP2_TYPE_AZURE = 0xEE 15 | PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID = 0x01 16 | ) 17 | 18 | // IsAzurePrivateEndpointLinkID returns true if given TLV matches Azure Private Endpoint LinkID format 19 | func isAzurePrivateEndpointLinkID(tlv proxyproto.TLV) bool { 20 | return tlv.Type == PP2_TYPE_AZURE && len(tlv.Value) == 5 && tlv.Value[0] == PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID 21 | } 22 | 23 | // AzurePrivateEndpointLinkID returns linkID if given TLV matches Azure Private Endpoint LinkID format 24 | // 25 | // Format description: 26 | // Field Length (Octets) Description 27 | // Type 1 PP2_TYPE_AZURE (0xEE) 28 | // Length 2 Length of value 29 | // Value 1 PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID (0x01) 30 | // 4 UINT32 (4 bytes) representing the LINKID of the private endpoint. Encoded in little endian format. 31 | func azurePrivateEndpointLinkID(tlv proxyproto.TLV) (uint32, error) { 32 | if !isAzurePrivateEndpointLinkID(tlv) { 33 | return 0, proxyproto.ErrIncompatibleTLV 34 | } 35 | linkID := binary.LittleEndian.Uint32(tlv.Value[1:]) 36 | return linkID, nil 37 | } 38 | 39 | // FindAzurePrivateEndpointLinkID returns the first Azure Private Endpoint LinkID if it exists in the TLV collection 40 | // and a boolean indicating if it was found. 41 | func FindAzurePrivateEndpointLinkID(tlvs []proxyproto.TLV) (uint32, bool) { 42 | for _, tlv := range tlvs { 43 | if linkID, err := azurePrivateEndpointLinkID(tlv); err == nil { 44 | return linkID, true 45 | } 46 | } 47 | return 0, false 48 | } 49 | -------------------------------------------------------------------------------- /tlvparse/azure_test.go: -------------------------------------------------------------------------------- 1 | package tlvparse 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pires/go-proxyproto" 7 | ) 8 | 9 | func TestFindAzurePrivateEndpointLinkID(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | tlvs []proxyproto.TLV 13 | wantLinkID uint32 14 | wantFound bool 15 | }{ 16 | { 17 | name: "nil TLVs", 18 | tlvs: nil, 19 | wantLinkID: 0, 20 | wantFound: false, 21 | }, 22 | { 23 | name: "empty TLVs", 24 | tlvs: []proxyproto.TLV{}, 25 | wantLinkID: 0, 26 | wantFound: false, 27 | }, 28 | { 29 | name: "AWS VPC endpoint ID", 30 | tlvs: []proxyproto.TLV{ 31 | { 32 | Type: 0xEA, 33 | Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, 34 | }, 35 | }, 36 | wantLinkID: 0, 37 | wantFound: false, 38 | }, 39 | { 40 | name: "Azure but wrong subtype", 41 | tlvs: []proxyproto.TLV{ 42 | { 43 | Type: 0xEE, 44 | Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, 45 | }, 46 | }, 47 | wantLinkID: 0, 48 | wantFound: false, 49 | }, 50 | { 51 | name: "Azure but wrong length", 52 | tlvs: []proxyproto.TLV{ 53 | { 54 | Type: 0xEE, 55 | Value: []byte{0x02, 0x01, 0x01}, 56 | }, 57 | }, 58 | wantLinkID: 0, 59 | wantFound: false, 60 | }, 61 | { 62 | name: "Azure link ID", 63 | tlvs: []proxyproto.TLV{ 64 | { 65 | Type: 0xEE, 66 | Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, 67 | }, 68 | }, 69 | wantLinkID: 0x210045c1, 70 | wantFound: true, 71 | }, 72 | { 73 | name: "Multiple TLVs", 74 | tlvs: []proxyproto.TLV{ 75 | { // AWS 76 | Type: 0xEA, 77 | Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, 78 | }, 79 | { // Azure but wrong subtype 80 | Type: 0xEE, 81 | Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, 82 | }, 83 | { // Azure but wrong length 84 | Type: 0xEE, 85 | Value: []byte{0x02, 0x01, 0x01}, 86 | }, 87 | { // Correct 88 | Type: 0xEE, 89 | Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, 90 | }, 91 | { // Also correct, but second in line 92 | Type: 0xEE, 93 | Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x22}, 94 | }, 95 | }, 96 | wantLinkID: 0x210045c1, 97 | wantFound: true, 98 | }, 99 | } 100 | for _, tt := range tests { 101 | t.Run(tt.name, func(t *testing.T) { 102 | gotLinkID, gotFound := FindAzurePrivateEndpointLinkID(tt.tlvs) 103 | if gotFound != tt.wantFound { 104 | t.Errorf("FindAzurePrivateEndpointLinkID() got1 = %v, want %v", gotFound, tt.wantFound) 105 | } 106 | if gotLinkID != tt.wantLinkID { 107 | t.Errorf("FindAzurePrivateEndpointLinkID() got = %v, want %v", gotLinkID, tt.wantLinkID) 108 | } 109 | }) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /tlvparse/gcp.go: -------------------------------------------------------------------------------- 1 | package tlvparse 2 | 3 | import ( 4 | "encoding/binary" 5 | 6 | "github.com/pires/go-proxyproto" 7 | ) 8 | 9 | const ( 10 | // PP2_TYPE_GCP indicates a Google Cloud Platform header 11 | PP2_TYPE_GCP proxyproto.PP2Type = 0xE0 12 | ) 13 | 14 | // ExtractPSCConnectionID returns the first PSC Connection ID in the TLV if it exists and is well-formed and 15 | // a bool indicating one was found. 16 | func ExtractPSCConnectionID(tlvs []proxyproto.TLV) (uint64, bool) { 17 | for _, tlv := range tlvs { 18 | if linkID, err := pscConnectionID(tlv); err == nil { 19 | return linkID, true 20 | } 21 | } 22 | return 0, false 23 | } 24 | 25 | // pscConnectionID returns the ID of a GCP PSC extension TLV or errors with ErrIncompatibleTLV or 26 | // ErrMalformedTLV if it's the wrong TLV type or is malformed. 27 | // 28 | // Field Length (bytes) Description 29 | // Type 1 PP2_TYPE_GCP (0xE0) 30 | // Length 2 Length of value (always 0x0008) 31 | // Value 8 The 8-byte PSC Connection ID (decode to uint64; big endian) 32 | // 33 | // For example proxyproto.TLV{Type:0xea, Length:8, Value:[]byte{0xff, 0xff, 0xff, 0xff, 0xc0, 0xa8, 0x64, 0x02}} 34 | // will be decoded as 18446744072646845442. 35 | // 36 | // See https://cloud.google.com/vpc/docs/configure-private-service-connect-producer 37 | func pscConnectionID(t proxyproto.TLV) (uint64, error) { 38 | if !isPSCConnectionID(t) { 39 | return 0, proxyproto.ErrIncompatibleTLV 40 | } 41 | linkID := binary.BigEndian.Uint64(t.Value) 42 | return linkID, nil 43 | } 44 | 45 | func isPSCConnectionID(t proxyproto.TLV) bool { 46 | return t.Type == PP2_TYPE_GCP && len(t.Value) == 8 47 | } 48 | -------------------------------------------------------------------------------- /tlvparse/gcp_test.go: -------------------------------------------------------------------------------- 1 | package tlvparse 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/pires/go-proxyproto" 7 | ) 8 | 9 | func TestExtractPSCConnectionID(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | tlvs []proxyproto.TLV 13 | wantPSCConnectionID uint64 14 | wantFound bool 15 | }{ 16 | { 17 | name: "nil TLVs", 18 | tlvs: nil, 19 | wantFound: false, 20 | }, 21 | { 22 | name: "empty TLVs", 23 | tlvs: []proxyproto.TLV{}, 24 | wantFound: false, 25 | }, 26 | { 27 | name: "AWS VPC endpoint ID", 28 | tlvs: []proxyproto.TLV{ 29 | { 30 | Type: 0xEA, 31 | Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, 32 | }, 33 | }, 34 | wantFound: false, 35 | }, 36 | { 37 | name: "GCP link ID", 38 | tlvs: []proxyproto.TLV{ 39 | { 40 | Type: PP2_TYPE_GCP, 41 | Value: []byte{'\xff', '\xff', '\xff', '\xff', '\xc0', '\xa8', '\x64', '\x02'}, 42 | }, 43 | }, 44 | wantPSCConnectionID: 18446744072646845442, 45 | wantFound: true, 46 | }, 47 | { 48 | name: "Multiple TLVs", 49 | tlvs: []proxyproto.TLV{ 50 | { // AWS 51 | Type: 0xEA, 52 | Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, 53 | }, 54 | { // Azure 55 | Type: 0xEE, 56 | Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, 57 | }, 58 | { // GCP but wrong length 59 | Type: 0xE0, 60 | Value: []byte{0xff, 0xff, 0xff}, 61 | }, 62 | { // Correct 63 | Type: 0xE0, 64 | Value: []byte{'\xff', '\xff', '\xff', '\xff', '\xc0', '\xa8', '\x64', '\x02'}, 65 | }, 66 | }, 67 | wantPSCConnectionID: 18446744072646845442, 68 | wantFound: true, 69 | }, 70 | } 71 | for _, tt := range tests { 72 | t.Run(tt.name, func(t *testing.T) { 73 | linkID, hasLinkID := ExtractPSCConnectionID(tt.tlvs) 74 | if hasLinkID != tt.wantFound { 75 | t.Errorf("ExtractPSCConnectionID() got1 = %v, want %v", hasLinkID, tt.wantFound) 76 | } 77 | if linkID != tt.wantPSCConnectionID { 78 | t.Errorf("ExtractPSCConnectionID() got = %v, want %v", linkID, tt.wantPSCConnectionID) 79 | } 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /tlvparse/ssl.go: -------------------------------------------------------------------------------- 1 | package tlvparse 2 | 3 | import ( 4 | "encoding/binary" 5 | "unicode" 6 | "unicode/utf8" 7 | 8 | "github.com/pires/go-proxyproto" 9 | ) 10 | 11 | const ( 12 | // pp2_tlv_ssl.client bit fields 13 | PP2_BITFIELD_CLIENT_SSL uint8 = 0x01 14 | PP2_BITFIELD_CLIENT_CERT_CONN uint8 = 0x02 15 | PP2_BITFIELD_CLIENT_CERT_SESS uint8 = 0x04 16 | 17 | tlvSSLMinLen = 5 // len(pp2_tlv_ssl.client) + len(pp2_tlv_ssl.verify) 18 | ) 19 | 20 | // 2.2.5. The PP2_TYPE_SSL type and subtypes 21 | /* 22 | struct pp2_tlv_ssl { 23 | uint8_t client; 24 | uint32_t verify; 25 | struct pp2_tlv sub_tlv[0]; 26 | }; 27 | */ 28 | type PP2SSL struct { 29 | Client uint8 // The field is made of a bit field from the following values, 30 | // indicating which element is present: PP2_BITFIELD_CLIENT_SSL, 31 | // PP2_BITFIELD_CLIENT_CERT_CONN, PP2_BITFIELD_CLIENT_CERT_SESS 32 | Verify uint32 // Verify will be zero if the client presented a certificate 33 | // and it was successfully verified, and non-zero otherwise. 34 | TLV []proxyproto.TLV 35 | } 36 | 37 | // Verified is true if the client presented a certificate and it was successfully verified 38 | func (s PP2SSL) Verified() bool { 39 | return s.Verify == 0 40 | } 41 | 42 | // ClientSSL indicates that the client connected over SSL/TLS. When true, SSLVersion will return the version. 43 | func (s PP2SSL) ClientSSL() bool { 44 | return s.Client&PP2_BITFIELD_CLIENT_SSL == PP2_BITFIELD_CLIENT_SSL 45 | } 46 | 47 | // ClientCertConn indicates that the client provided a certificate over the current connection. 48 | func (s PP2SSL) ClientCertConn() bool { 49 | return s.Client&PP2_BITFIELD_CLIENT_CERT_CONN == PP2_BITFIELD_CLIENT_CERT_CONN 50 | } 51 | 52 | // ClientCertSess indicates that the client provided a certificate at least once over the TLS session this 53 | // connection belongs to. 54 | func (s PP2SSL) ClientCertSess() bool { 55 | return s.Client&PP2_BITFIELD_CLIENT_CERT_SESS == PP2_BITFIELD_CLIENT_CERT_SESS 56 | } 57 | 58 | // SSLVersion returns the US-ASCII string representation of the TLS version and whether that extension exists. 59 | func (s PP2SSL) SSLVersion() (string, bool) { 60 | for _, tlv := range s.TLV { 61 | if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_VERSION { 62 | return string(tlv.Value), true 63 | } 64 | } 65 | return "", false 66 | } 67 | 68 | // SSLCipher returns the US-ASCII string representation of the used TLS cipher and whether that extension exists. 69 | func (s PP2SSL) SSLCipher() (string, bool) { 70 | for _, tlv := range s.TLV { 71 | if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CIPHER { 72 | return string(tlv.Value), true 73 | } 74 | } 75 | return "", false 76 | } 77 | 78 | // Marshal formats the PP2SSL structure as a TLV. 79 | func (s PP2SSL) Marshal() (proxyproto.TLV, error) { 80 | v := make([]byte, 5) 81 | v[0] = s.Client 82 | binary.BigEndian.PutUint32(v[1:5], s.Verify) 83 | 84 | tlvs, err := proxyproto.JoinTLVs(s.TLV) 85 | if err != nil { 86 | return proxyproto.TLV{}, err 87 | } 88 | v = append(v, tlvs...) 89 | 90 | return proxyproto.TLV{ 91 | Type: proxyproto.PP2_TYPE_SSL, 92 | Value: v, 93 | }, nil 94 | } 95 | 96 | // ClientCN returns the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3) of the client 97 | // certificate's Distinguished Name and whether that extension exists. 98 | func (s PP2SSL) ClientCN() (string, bool) { 99 | for _, tlv := range s.TLV { 100 | if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CN { 101 | return string(tlv.Value), true 102 | } 103 | } 104 | return "", false 105 | } 106 | 107 | // SSLType is true if the TLV is type SSL 108 | func IsSSL(t proxyproto.TLV) bool { 109 | return t.Type == proxyproto.PP2_TYPE_SSL && len(t.Value) >= tlvSSLMinLen 110 | } 111 | 112 | // SSL returns the pp2_tlv_ssl from section 2.2.5 or errors with ErrIncompatibleTLV or ErrMalformedTLV 113 | func SSL(t proxyproto.TLV) (PP2SSL, error) { 114 | ssl := PP2SSL{} 115 | if !IsSSL(t) { 116 | return ssl, proxyproto.ErrIncompatibleTLV 117 | } 118 | if len(t.Value) < tlvSSLMinLen { 119 | return ssl, proxyproto.ErrMalformedTLV 120 | } 121 | ssl.Client = t.Value[0] 122 | ssl.Verify = binary.BigEndian.Uint32(t.Value[1:5]) 123 | var err error 124 | ssl.TLV, err = proxyproto.SplitTLVs(t.Value[5:]) 125 | if err != nil { 126 | return PP2SSL{}, err 127 | } 128 | versionFound := !ssl.ClientSSL() 129 | for _, tlv := range ssl.TLV { 130 | switch tlv.Type { 131 | case proxyproto.PP2_SUBTYPE_SSL_VERSION: 132 | /* 133 | The PP2_CLIENT_SSL flag indicates that the client connected over SSL/TLS. When 134 | this field is present, the US-ASCII string representation of the TLS version is 135 | appended at the end of the field in the TLV format using the type 136 | PP2_SUBTYPE_SSL_VERSION. 137 | */ 138 | if len(tlv.Value) == 0 || !isASCII(tlv.Value) { 139 | return PP2SSL{}, proxyproto.ErrMalformedTLV 140 | } 141 | versionFound = true 142 | case proxyproto.PP2_SUBTYPE_SSL_CN: 143 | /* 144 | In all cases, the string representation (in UTF8) of the Common Name field 145 | (OID: 2.5.4.3) of the client certificate's Distinguished Name, is appended 146 | using the TLV format and the type PP2_SUBTYPE_SSL_CN. E.g. "example.com". 147 | */ 148 | if len(tlv.Value) == 0 || !utf8.Valid(tlv.Value) { 149 | return PP2SSL{}, proxyproto.ErrMalformedTLV 150 | } 151 | case proxyproto.PP2_SUBTYPE_SSL_CIPHER: 152 | /* 153 | The second level TLV PP2_SUBTYPE_SSL_CIPHER provides the US-ASCII string name 154 | of the used cipher, for example "ECDHE-RSA-AES128-GCM-SHA256". 155 | */ 156 | if len(tlv.Value) == 0 || !isASCII(tlv.Value) { 157 | return PP2SSL{}, proxyproto.ErrMalformedTLV 158 | } 159 | } 160 | } 161 | if !versionFound { 162 | return PP2SSL{}, proxyproto.ErrMalformedTLV 163 | } 164 | return ssl, nil 165 | } 166 | 167 | // SSL returns the first PP2SSL if it exists and is well formed as well as bool indicating if it was found. 168 | func FindSSL(tlvs []proxyproto.TLV) (PP2SSL, bool) { 169 | for _, t := range tlvs { 170 | if ssl, err := SSL(t); err == nil { 171 | return ssl, true 172 | } 173 | } 174 | return PP2SSL{}, false 175 | } 176 | 177 | // isASCII checks whether a byte slice has all characters that fit in the ascii character set, including the null byte. 178 | func isASCII(b []byte) bool { 179 | for _, c := range b { 180 | if c > unicode.MaxASCII { 181 | return false 182 | } 183 | } 184 | return true 185 | } 186 | -------------------------------------------------------------------------------- /tlvparse/ssl_test.go: -------------------------------------------------------------------------------- 1 | package tlvparse 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/pires/go-proxyproto" 8 | ) 9 | 10 | var testCases = []struct { 11 | name string 12 | raw []byte 13 | types []proxyproto.PP2Type 14 | valid func(*testing.T, string, []proxyproto.TLV) 15 | }{ 16 | { 17 | name: "SSL haproxy cn", 18 | raw: []byte{ 19 | 0x0d, 0x0a, 0x0d, 0x0a, 20 | 0x00, 0x0d, 0x0a, 0x51, 21 | 0x55, 0x49, 0x54, 0x0a, 22 | 0x21, 0x11, 0x00, 0x40, 23 | 0x7f, 0x00, 0x00, 0x01, 24 | 0x7f, 0x00, 0x00, 0x01, 25 | 0xcc, 0x8a, 0x23, 0x2e, 26 | 0x20, 0x00, 0x31, 0x07, 27 | 0x00, 0x00, 0x00, 0x00, 28 | 0x21, 0x00, 0x07, 0x54, 29 | 0x4c, 0x53, 0x76, 0x31, 30 | 0x2e, 0x33, 0x22, 0x00, 31 | 0x1f, 0x45, 0x78, 0x61, 32 | 0x6d, 0x70, 0x6c, 0x65, 33 | 0x20, 0x43, 0x6f, 0x6d, 34 | 0x6d, 0x6f, 0x6e, 0x20, 35 | 0x4e, 0x61, 0x6d, 0x65, 36 | 0x20, 0x43, 0x6c, 0x69, 37 | 0x65, 0x6e, 0x74, 0x20, 38 | 0x43, 0x65, 0x72, 0x74, 39 | }, 40 | types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_SSL}, 41 | valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { 42 | if !IsSSL(tlvs[0]) { 43 | t.Fatalf("TestParseV2TLV %s: Expected tlvs[0] to be the SSL type", name) 44 | } 45 | 46 | ssl, err := SSL(tlvs[0]) 47 | if err != nil { 48 | t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing SSL %#v", name, err) 49 | } 50 | 51 | if !ssl.ClientSSL() { 52 | t.Fatalf("TestParseV2TLV %s: Expected ClientSSL() to be true", name) 53 | } 54 | 55 | if !ssl.ClientCertConn() { 56 | t.Fatalf("TestParseV2TLV %s: Expected ClientCertConn() to be true", name) 57 | } 58 | 59 | if !ssl.ClientCertSess() { 60 | t.Fatalf("TestParseV2TLV %s: Expected ClientCertSess() to be true", name) 61 | } 62 | 63 | ecn := "Example Common Name Client Cert" 64 | if acn, ok := ssl.ClientCN(); !ok { 65 | t.Fatalf("TestParseV2TLV %s: Expected ClientCN to exist", name) 66 | } else if acn != ecn { 67 | t.Fatalf("TestParseV2TLV %s: Unexpected ClientCN expected %#v, actual %#v", name, ecn, acn) 68 | } 69 | 70 | esslVer := "TLSv1.3" 71 | if asslVer, ok := ssl.SSLVersion(); !ok { 72 | t.Fatalf("TestParseV2TLV %s: Expected SSLVersion to exist", name) 73 | } else if asslVer != esslVer { 74 | t.Fatalf("TestParseV2TLV %s: Unexpected SSLVersion expected %#v, actual %#v", name, esslVer, asslVer) 75 | } 76 | 77 | if _, ok := ssl.SSLCipher(); ok { 78 | t.Fatalf("TestParseV2TLV %s: Unexpected SSLCipher", name) 79 | } 80 | 81 | if !ssl.Verified() { 82 | t.Fatalf("TestParseV2TLV %s: Expected Verified to be true", name) 83 | } 84 | }, 85 | }, 86 | { 87 | name: "SSL haproxy cipher", 88 | raw: []byte{ 89 | 0x0d, 0x0a, 0x0d, 0x0a, 90 | 0x00, 0x0d, 0x0a, 0x51, 91 | 0x55, 0x49, 0x54, 0x0a, 92 | 0x21, 0x21, 0x00, 0x4f, 93 | 0x00, 0x00, 0x00, 0x00, 94 | 0x00, 0x00, 0x00, 0x00, 95 | 0x00, 0x00, 0xff, 0xff, 96 | 0x0a, 0x01, 0x5b, 0x0e, 97 | 0x00, 0x00, 0x00, 0x00, 98 | 0x00, 0x00, 0x00, 0x00, 99 | 0x00, 0x00, 0xff, 0xff, 100 | 0x0a, 0x01, 0x01, 0x9f, 101 | 0xf4, 0x7c, 0x01, 0xbb, 102 | 0x20, 0x00, 0x28, 0x01, 103 | 0x00, 0x00, 0x00, 0x00, 104 | 0x21, 0x00, 0x07, 0x54, 105 | 0x4c, 0x53, 0x76, 0x31, 106 | 0x2e, 0x33, 0x23, 0x00, 107 | 0x16, 0x54, 0x4c, 0x53, 108 | 0x5f, 0x41, 0x45, 0x53, 109 | 0x5f, 0x32, 0x35, 0x36, 110 | 0x5f, 0x47, 0x43, 0x4d, 111 | 0x5f, 0x53, 0x48, 0x41, 112 | 0x33, 0x38, 0x34, 113 | }, 114 | types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_SSL}, 115 | valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { 116 | if !IsSSL(tlvs[0]) { 117 | t.Fatalf("TestParseV2TLV %s: Expected tlvs[0] to be the SSL type", name) 118 | } 119 | 120 | ssl, err := SSL(tlvs[0]) 121 | if err != nil { 122 | t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing SSL %#v", name, err) 123 | } 124 | 125 | if !ssl.ClientSSL() { 126 | t.Fatalf("TestParseV2TLV %s: Expected ClientSSL() to be true", name) 127 | } 128 | 129 | if ssl.ClientCertConn() { 130 | t.Fatalf("TestParseV2TLV %s: Expected ClientCertConn() to be false", name) 131 | } 132 | 133 | if ssl.ClientCertSess() { 134 | t.Fatalf("TestParseV2TLV %s: Expected ClientCertSess() to be false", name) 135 | } 136 | 137 | if _, ok := ssl.ClientCN(); ok { 138 | t.Fatalf("TestParseV2TLV %s: Expected ClientCN to not exist", name) 139 | } 140 | 141 | esslVer := "TLSv1.3" 142 | if asslVer, ok := ssl.SSLVersion(); !ok { 143 | t.Fatalf("TestParseV2TLV %s: Expected SSLVersion to exist", name) 144 | } else if asslVer != esslVer { 145 | t.Fatalf("TestParseV2TLV %s: Unexpected SSLVersion expected %#v, actual %#v", name, esslVer, asslVer) 146 | } 147 | 148 | esslCipher := "TLS_AES_256_GCM_SHA384" 149 | if asslCipher, ok := ssl.SSLCipher(); !ok { 150 | t.Fatalf("TestParseV2TLV %s: Expected SSLCipher to exist", name) 151 | } else if asslCipher != esslCipher { 152 | t.Fatalf("TestParseV2TLV %s: Unexpected SSLCipher expected %#v, actual %#v", name, esslCipher, asslCipher) 153 | } 154 | }, 155 | }, 156 | } 157 | 158 | func TestParseV2TLV(t *testing.T) { 159 | for _, tc := range testCases { 160 | t.Run(tc.name, func(t *testing.T) { 161 | tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) 162 | tc.valid(t, tc.name, tlvs) 163 | }) 164 | } 165 | } 166 | 167 | func TestPP2SSLMarshal(t *testing.T) { 168 | ver := "TLSv1.3" 169 | cn := "example.org" 170 | pp2 := PP2SSL{ 171 | Client: PP2_BITFIELD_CLIENT_SSL, 172 | Verify: 0, 173 | TLV: []proxyproto.TLV{ 174 | { 175 | Type: proxyproto.PP2_SUBTYPE_SSL_VERSION, 176 | Value: []byte(ver), 177 | }, 178 | { 179 | Type: proxyproto.PP2_SUBTYPE_SSL_CN, 180 | Value: []byte(cn), 181 | }, 182 | }, 183 | } 184 | 185 | raw := []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x7, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x22, 0x0, 0xb, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x6f, 0x72, 0x67} 186 | want := proxyproto.TLV{ 187 | Type: proxyproto.PP2_TYPE_SSL, 188 | Value: raw, 189 | } 190 | 191 | tlv, err := pp2.Marshal() 192 | if err != nil { 193 | t.Fatalf("PP2SSL.Marshal() = %v", err) 194 | } 195 | 196 | if !reflect.DeepEqual(tlv, want) { 197 | t.Errorf("PP2SSL.Marshal() = %#v, want %#v", tlv, want) 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /tlvparse/test.go: -------------------------------------------------------------------------------- 1 | package tlvparse 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "testing" 7 | 8 | "github.com/pires/go-proxyproto" 9 | ) 10 | 11 | func checkTLVs(t *testing.T, name string, raw []byte, expected []proxyproto.PP2Type) []proxyproto.TLV { 12 | header, err := proxyproto.Read(bufio.NewReader(bytes.NewReader(raw))) 13 | if err != nil { 14 | t.Fatalf("%s: Unexpected error reading header %#v", name, err) 15 | } 16 | 17 | tlvs, err := header.TLVs() 18 | if err != nil { 19 | t.Fatalf("%s: Unexpected error splitting TLVS %#v", name, err) 20 | } 21 | 22 | if len(tlvs) != len(expected) { 23 | t.Fatalf("%s: Expected %d TLVs, actual %d", name, len(expected), len(tlvs)) 24 | } 25 | 26 | for i, et := range expected { 27 | if at := tlvs[i].Type; at != et { 28 | t.Fatalf("%s: Expected type %X, actual %X", name, et, at) 29 | } 30 | } 31 | 32 | return tlvs 33 | } 34 | -------------------------------------------------------------------------------- /v1.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "net" 8 | "net/netip" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | const ( 14 | crlf = "\r\n" 15 | separator = " " 16 | ) 17 | 18 | func initVersion1() *Header { 19 | header := new(Header) 20 | header.Version = 1 21 | // Command doesn't exist in v1 22 | header.Command = PROXY 23 | return header 24 | } 25 | 26 | func parseVersion1(reader *bufio.Reader) (*Header, error) { 27 | //The header cannot be more than 107 bytes long. Per spec: 28 | // 29 | // (...) 30 | // - worst case (optional fields set to 0xff) : 31 | // "PROXY UNKNOWN ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n" 32 | // => 5 + 1 + 7 + 1 + 39 + 1 + 39 + 1 + 5 + 1 + 5 + 2 = 107 chars 33 | // 34 | // So a 108-byte buffer is always enough to store all the line and a 35 | // trailing zero for string processing. 36 | // 37 | // It must also be CRLF terminated, as above. The header does not otherwise 38 | // contain a CR or LF byte. 39 | // 40 | // ISSUE #69 41 | // We can't use Peek here as it will block trying to fill the buffer, which 42 | // will never happen if the header is TCP4 or TCP6 (max. 56 and 104 bytes 43 | // respectively) and the server is expected to speak first. 44 | // 45 | // Similarly, we can't use ReadString or ReadBytes as these will keep reading 46 | // until the delimiter is found; an abusive client could easily disrupt a 47 | // server by sending a large amount of data that do not contain a LF byte. 48 | // Another means of attack would be to start connections and simply not send 49 | // data after the initial PROXY signature bytes, accumulating a large 50 | // number of blocked goroutines on the server. ReadSlice will also block for 51 | // a delimiter when the internal buffer does not fill up. 52 | // 53 | // A plain Read is also problematic since we risk reading past the end of the 54 | // header without being able to easily put the excess bytes back into the reader's 55 | // buffer (with the current implementation's design). 56 | // 57 | // So we use a ReadByte loop, which solves the overflow problem and avoids 58 | // reading beyond the end of the header. However, we need one more trick to harden 59 | // against partial header attacks (slow loris) - per spec: 60 | // 61 | // (..) The sender must always ensure that the header is sent at once, so that 62 | // the transport layer maintains atomicity along the path to the receiver. The 63 | // receiver may be tolerant to partial headers or may simply drop the connection 64 | // when receiving a partial header. Recommendation is to be tolerant, but 65 | // implementation constraints may not always easily permit this. 66 | // 67 | // We are subject to such implementation constraints. So we return an error if 68 | // the header cannot be fully extracted with a single read of the underlying 69 | // reader. 70 | buf := make([]byte, 0, 107) 71 | for { 72 | b, err := reader.ReadByte() 73 | if err != nil { 74 | return nil, fmt.Errorf(ErrCantReadVersion1Header.Error()+": %v", err) 75 | } 76 | buf = append(buf, b) 77 | if b == '\n' { 78 | // End of header found 79 | break 80 | } 81 | if len(buf) == 107 { 82 | // No delimiter in first 107 bytes 83 | return nil, ErrVersion1HeaderTooLong 84 | } 85 | if reader.Buffered() == 0 { 86 | // Header was not buffered in a single read. Since we can't 87 | // differentiate between genuine slow writers and DoS agents, 88 | // we abort. On healthy networks, this should never happen. 89 | return nil, ErrCantReadVersion1Header 90 | } 91 | } 92 | 93 | // Check for CR before LF. 94 | if len(buf) < 2 || buf[len(buf)-2] != '\r' { 95 | return nil, ErrLineMustEndWithCrlf 96 | } 97 | 98 | // Check full signature. 99 | tokens := strings.Split(string(buf[:len(buf)-2]), separator) 100 | 101 | // Expect at least 2 tokens: "PROXY" and the transport protocol. 102 | if len(tokens) < 2 { 103 | return nil, ErrCantReadAddressFamilyAndProtocol 104 | } 105 | 106 | // Read address family and protocol 107 | var transportProtocol AddressFamilyAndProtocol 108 | switch tokens[1] { 109 | case "TCP4": 110 | transportProtocol = TCPv4 111 | case "TCP6": 112 | transportProtocol = TCPv6 113 | case "UNKNOWN": 114 | transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN 115 | default: 116 | return nil, ErrCantReadAddressFamilyAndProtocol 117 | } 118 | 119 | // Expect 6 tokens only when UNKNOWN is not present. 120 | if transportProtocol != UNSPEC && len(tokens) < 6 { 121 | return nil, ErrCantReadAddressFamilyAndProtocol 122 | } 123 | 124 | // When a signature is found, allocate a v1 header with Command set to PROXY. 125 | // Command doesn't exist in v1 but set it for other parts of this library 126 | // to rely on it for determining connection details. 127 | header := initVersion1() 128 | 129 | // Transport protocol has been processed already. 130 | header.TransportProtocol = transportProtocol 131 | 132 | // When UNKNOWN, set the command to LOCAL and return early 133 | if header.TransportProtocol == UNSPEC { 134 | header.Command = LOCAL 135 | return header, nil 136 | } 137 | 138 | // Otherwise, continue to read addresses and ports 139 | sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) 140 | if err != nil { 141 | return nil, err 142 | } 143 | destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) 144 | if err != nil { 145 | return nil, err 146 | } 147 | sourcePort, err := parseV1PortNumber(tokens[4]) 148 | if err != nil { 149 | return nil, err 150 | } 151 | destPort, err := parseV1PortNumber(tokens[5]) 152 | if err != nil { 153 | return nil, err 154 | } 155 | header.SourceAddr = &net.TCPAddr{ 156 | IP: sourceIP, 157 | Port: sourcePort, 158 | } 159 | header.DestinationAddr = &net.TCPAddr{ 160 | IP: destIP, 161 | Port: destPort, 162 | } 163 | 164 | return header, nil 165 | } 166 | 167 | func (header *Header) formatVersion1() ([]byte, error) { 168 | // As of version 1, only "TCP4" ( \x54 \x43 \x50 \x34 ) for TCP over IPv4, 169 | // and "TCP6" ( \x54 \x43 \x50 \x36 ) for TCP over IPv6 are allowed. 170 | var proto string 171 | switch header.TransportProtocol { 172 | case TCPv4: 173 | proto = "TCP4" 174 | case TCPv6: 175 | proto = "TCP6" 176 | default: 177 | // Unknown connection (short form) 178 | return []byte("PROXY UNKNOWN" + crlf), nil 179 | } 180 | 181 | sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) 182 | destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) 183 | if !sourceOK || !destOK { 184 | return nil, ErrInvalidAddress 185 | } 186 | 187 | sourceIP, destIP := sourceAddr.IP, destAddr.IP 188 | switch header.TransportProtocol { 189 | case TCPv4: 190 | sourceIP = sourceIP.To4() 191 | destIP = destIP.To4() 192 | case TCPv6: 193 | sourceIP = sourceIP.To16() 194 | destIP = destIP.To16() 195 | } 196 | if sourceIP == nil || destIP == nil { 197 | return nil, ErrInvalidAddress 198 | } 199 | 200 | buf := bytes.NewBuffer(make([]byte, 0, 108)) 201 | buf.Write(SIGV1) 202 | buf.WriteString(separator) 203 | buf.WriteString(proto) 204 | buf.WriteString(separator) 205 | buf.WriteString(sourceIP.String()) 206 | buf.WriteString(separator) 207 | buf.WriteString(destIP.String()) 208 | buf.WriteString(separator) 209 | buf.WriteString(strconv.Itoa(sourceAddr.Port)) 210 | buf.WriteString(separator) 211 | buf.WriteString(strconv.Itoa(destAddr.Port)) 212 | buf.WriteString(crlf) 213 | 214 | return buf.Bytes(), nil 215 | } 216 | 217 | func parseV1PortNumber(portStr string) (int, error) { 218 | port, err := strconv.Atoi(portStr) 219 | if err != nil || port < 0 || port > 65535 { 220 | return 0, ErrInvalidPortNumber 221 | } 222 | return port, nil 223 | } 224 | 225 | func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) { 226 | addr, err := netip.ParseAddr(addrStr) 227 | if err != nil { 228 | return nil, ErrInvalidAddress 229 | } 230 | 231 | switch protocol { 232 | case TCPv4: 233 | if addr.Is4() { 234 | return net.IP(addr.AsSlice()), nil 235 | } 236 | case TCPv6: 237 | if addr.Is6() || addr.Is4In6() { 238 | return net.IP(addr.AsSlice()), nil 239 | } 240 | } 241 | 242 | return nil, ErrInvalidAddress 243 | } 244 | -------------------------------------------------------------------------------- /v1_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net" 10 | "strconv" 11 | "strings" 12 | "testing" 13 | "time" 14 | ) 15 | 16 | var ( 17 | IPv4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) 18 | IPv4In6AddressesAndPorts = strings.Join([]string{IP4IN6_ADDR, IP4IN6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) 19 | IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator) 20 | IPv6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) 21 | IPv6LongAddressesAndPorts = strings.Join([]string{IP6_LONG_ADDR, IP6_LONG_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) 22 | 23 | fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /" 24 | fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /" 25 | fixtureTCP4IN6V1 = "PROXY TCP6 " + IPv4In6AddressesAndPorts + crlf + "GET /" 26 | 27 | fixtureTCP6V1Overflow = "PROXY TCP6 " + IPv6LongAddressesAndPorts 28 | 29 | fixtureUnknown = "PROXY UNKNOWN" + crlf 30 | fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf 31 | ) 32 | 33 | var invalidParseV1Tests = []struct { 34 | desc string 35 | reader *bufio.Reader 36 | expectedError error 37 | }{ 38 | { 39 | desc: "no signature", 40 | reader: newBufioReader([]byte(NO_PROTOCOL)), 41 | expectedError: ErrNoProxyProtocol, 42 | }, 43 | { 44 | desc: "prox", 45 | reader: newBufioReader([]byte("PROX")), 46 | expectedError: ErrNoProxyProtocol, 47 | }, 48 | { 49 | desc: "proxy lf", 50 | reader: newBufioReader([]byte("PROXY \n")), 51 | expectedError: ErrLineMustEndWithCrlf, 52 | }, 53 | { 54 | desc: "proxy crlf", 55 | reader: newBufioReader([]byte("PROXY " + crlf)), 56 | expectedError: ErrCantReadAddressFamilyAndProtocol, 57 | }, 58 | { 59 | desc: "proxy no space crlf", 60 | reader: newBufioReader([]byte("PROXY" + crlf)), 61 | expectedError: ErrCantReadAddressFamilyAndProtocol, 62 | }, 63 | { 64 | desc: "proxy something crlf", 65 | reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)), 66 | expectedError: ErrCantReadAddressFamilyAndProtocol, 67 | }, 68 | { 69 | desc: "incomplete signature TCP4", 70 | reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)), 71 | expectedError: ErrCantReadVersion1Header, 72 | }, 73 | { 74 | desc: "invalid IP address", 75 | reader: newBufioReader([]byte("PROXY TCP4 invalid invalid 65533 65533" + crlf)), 76 | expectedError: ErrInvalidAddress, 77 | }, 78 | { 79 | desc: "TCP6 with IPv4 addresses", 80 | reader: newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)), 81 | expectedError: ErrInvalidAddress, 82 | }, 83 | { 84 | desc: "TCP4 with IPv6 addresses", 85 | reader: newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)), 86 | expectedError: ErrInvalidAddress, 87 | }, 88 | { 89 | desc: "TCP4 with IPv4 mapped addresses", 90 | reader: newBufioReader([]byte("PROXY TCP4 " + IPv4In6AddressesAndPorts + crlf)), 91 | expectedError: ErrInvalidAddress, 92 | }, 93 | { 94 | desc: "TCP4 with invalid port", 95 | reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)), 96 | expectedError: ErrInvalidPortNumber, 97 | }, 98 | { 99 | desc: "header too long", 100 | reader: newBufioReader([]byte("PROXY UNKNOWN " + IPv6LongAddressesAndPorts + " " + crlf)), 101 | expectedError: ErrVersion1HeaderTooLong, 102 | }, 103 | } 104 | 105 | func TestReadV1Invalid(t *testing.T) { 106 | for _, tt := range invalidParseV1Tests { 107 | t.Run(tt.desc, func(t *testing.T) { 108 | if _, err := Read(tt.reader); err != tt.expectedError { 109 | t.Fatalf("expected %s, actual %v", tt.expectedError, err) 110 | } 111 | }) 112 | } 113 | } 114 | 115 | var validParseAndWriteV1Tests = []struct { 116 | desc string 117 | reader *bufio.Reader 118 | expectedHeader *Header 119 | skipWrite bool 120 | }{ 121 | { 122 | desc: "TCP4", 123 | reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)), 124 | expectedHeader: &Header{ 125 | Version: 1, 126 | Command: PROXY, 127 | TransportProtocol: TCPv4, 128 | SourceAddr: v4addr, 129 | DestinationAddr: v4addr, 130 | }, 131 | }, 132 | { 133 | desc: "TCP6", 134 | reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)), 135 | expectedHeader: &Header{ 136 | Version: 1, 137 | Command: PROXY, 138 | TransportProtocol: TCPv6, 139 | SourceAddr: v6addr, 140 | DestinationAddr: v6addr, 141 | }, 142 | }, 143 | { 144 | desc: "TCP4IN6", 145 | reader: bufio.NewReader(strings.NewReader(fixtureTCP4IN6V1)), 146 | expectedHeader: &Header{ 147 | Version: 1, 148 | Command: PROXY, 149 | TransportProtocol: TCPv6, 150 | SourceAddr: v4addr, 151 | DestinationAddr: v4addr, 152 | }, 153 | // we skip write test because net.ParseIP converts ::ffff:127.0.0.1 to v4 154 | // instead of preserving the v4 in v6 form, so, after serializing the header, 155 | // we end up with v6 protocol and a v4 IP which is invalid 156 | skipWrite: true, 157 | }, 158 | { 159 | desc: "unknown", 160 | reader: bufio.NewReader(strings.NewReader(fixtureUnknown)), 161 | expectedHeader: &Header{ 162 | Version: 1, 163 | Command: LOCAL, 164 | TransportProtocol: UNSPEC, 165 | SourceAddr: nil, 166 | DestinationAddr: nil, 167 | }, 168 | }, 169 | { 170 | desc: "unknown with addresses and ports", 171 | reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)), 172 | expectedHeader: &Header{ 173 | Version: 1, 174 | Command: LOCAL, 175 | TransportProtocol: UNSPEC, 176 | SourceAddr: nil, 177 | DestinationAddr: nil, 178 | }, 179 | }, 180 | } 181 | 182 | func TestParseV1Valid(t *testing.T) { 183 | for _, tt := range validParseAndWriteV1Tests { 184 | t.Run(tt.desc, func(t *testing.T) { 185 | header, err := Read(tt.reader) 186 | if err != nil { 187 | t.Fatal("unexpected error", err.Error()) 188 | } 189 | if !header.EqualsTo(tt.expectedHeader) { 190 | t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) 191 | } 192 | }) 193 | } 194 | } 195 | 196 | func TestWriteV1Valid(t *testing.T) { 197 | for _, tt := range validParseAndWriteV1Tests { 198 | if tt.skipWrite { 199 | continue 200 | } 201 | t.Run(tt.desc, func(t *testing.T) { 202 | var b bytes.Buffer 203 | w := bufio.NewWriter(&b) 204 | if _, err := tt.expectedHeader.WriteTo(w); err != nil { 205 | t.Fatal("unexpected error ", err) 206 | } 207 | w.Flush() 208 | 209 | // Read written bytes to validate written header 210 | r := bufio.NewReader(&b) 211 | newHeader, err := Read(r) 212 | if err != nil { 213 | t.Fatal("unexpected error ", err) 214 | } 215 | 216 | if !newHeader.EqualsTo(tt.expectedHeader) { 217 | t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) 218 | } 219 | }) 220 | } 221 | } 222 | 223 | // Tests for parseVersion1 overflow - issue #69. 224 | 225 | type dataSource struct { 226 | NBytes int 227 | NRead int 228 | } 229 | 230 | func (ds *dataSource) Read(b []byte) (int, error) { 231 | if ds.NRead >= ds.NBytes { 232 | return 0, io.EOF 233 | } 234 | avail := ds.NBytes - ds.NRead 235 | if len(b) < avail { 236 | avail = len(b) 237 | } 238 | for i := 0; i < avail; i++ { 239 | b[i] = 0x20 240 | } 241 | ds.NRead += avail 242 | return avail, nil 243 | } 244 | 245 | func TestParseVersion1Overflow(t *testing.T) { 246 | ds := &dataSource{} 247 | reader := bufio.NewReader(ds) 248 | bufSize := reader.Size() 249 | ds.NBytes = bufSize * 16 250 | _, _ = parseVersion1(reader) 251 | if ds.NRead > bufSize { 252 | t.Fatalf("read: expected max %d bytes, actual %d\n", bufSize, ds.NRead) 253 | } 254 | } 255 | 256 | func listen(t *testing.T) *Listener { 257 | l, err := net.Listen("tcp", "127.0.0.1:0") 258 | if err != nil { 259 | t.Fatalf("listen: %v", err) 260 | } 261 | return &Listener{Listener: l} 262 | } 263 | 264 | func client(t *testing.T, addr, header string, length int, terminate bool, wait time.Duration, done chan struct{}, 265 | result chan error, 266 | ) { 267 | c, err := net.Dial("tcp", addr) 268 | if err != nil { 269 | result <- fmt.Errorf("dial: %w", err) 270 | return 271 | } 272 | defer c.Close() 273 | 274 | if terminate && length < 2 { 275 | length = 2 276 | } 277 | 278 | buf := make([]byte, len(header)+length) 279 | copy(buf, []byte(header)) 280 | for i := 0; i < length-2; i++ { 281 | buf[i+len(header)] = 0x20 282 | } 283 | if terminate { 284 | copy(buf[len(header)+length-2:], []byte(crlf)) 285 | } 286 | 287 | n, err := c.Write(buf) 288 | if err != nil { 289 | result <- fmt.Errorf("write: %w", err) 290 | return 291 | } 292 | if n != len(buf) { 293 | result <- errors.New("write; short write") 294 | return 295 | } 296 | 297 | close(result) 298 | time.Sleep(wait) 299 | close(done) 300 | } 301 | 302 | func TestVersion1Overflow(t *testing.T) { 303 | done := make(chan struct{}) 304 | cliResult := make(chan error) 305 | 306 | l := listen(t) 307 | go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, true, 10*time.Second, done, cliResult) 308 | 309 | c, err := l.Accept() 310 | if err != nil { 311 | t.Fatalf("accept: %v", err) 312 | } 313 | 314 | b := []byte{} 315 | _, err = c.Read(b) 316 | if err == nil { 317 | t.Fatalf("net.Conn: no error reported for oversized header") 318 | } 319 | err = <-cliResult 320 | if err != nil { 321 | t.Fatalf("client error: %v", err) 322 | } 323 | } 324 | 325 | func TestVersion1SlowLoris(t *testing.T) { 326 | done := make(chan struct{}) 327 | cliResult := make(chan error) 328 | timeout := make(chan error) 329 | 330 | l := listen(t) 331 | go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 0, false, 10*time.Second, done, cliResult) 332 | 333 | c, err := l.Accept() 334 | if err != nil { 335 | t.Fatalf("accept: %v", err) 336 | } 337 | 338 | go func() { 339 | b := []byte{} 340 | _, err = c.Read(b) 341 | timeout <- err 342 | }() 343 | 344 | select { 345 | case <-done: 346 | t.Fatalf("net.Conn: reader still blocked after 10 seconds") 347 | case err := <-timeout: 348 | if err == nil { 349 | t.Fatalf("net.Conn: no error reported for incomplete header") 350 | } 351 | } 352 | err = <-cliResult 353 | if err != nil { 354 | t.Fatalf("client error: %v", err) 355 | } 356 | } 357 | 358 | func TestVersion1SlowLorisOverflow(t *testing.T) { 359 | done := make(chan struct{}) 360 | cliResult := make(chan error) 361 | timeout := make(chan error) 362 | 363 | l := listen(t) 364 | go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, false, 10*time.Second, done, cliResult) 365 | 366 | c, err := l.Accept() 367 | if err != nil { 368 | t.Fatalf("accept: %v", err) 369 | } 370 | 371 | go func() { 372 | b := []byte{} 373 | _, err = c.Read(b) 374 | timeout <- err 375 | }() 376 | 377 | select { 378 | case <-done: 379 | t.Fatalf("net.Conn: reader still blocked after 10 seconds") 380 | case err := <-timeout: 381 | if err == nil { 382 | t.Fatalf("net.Conn: no error reported for incomplete and overflowed header") 383 | } 384 | } 385 | err = <-cliResult 386 | if err != nil { 387 | t.Fatalf("client error: %v", err) 388 | } 389 | } 390 | -------------------------------------------------------------------------------- /v2.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "encoding/binary" 7 | "errors" 8 | "io" 9 | "net" 10 | ) 11 | 12 | var ( 13 | lengthUnspec = uint16(0) 14 | lengthV4 = uint16(12) 15 | lengthV6 = uint16(36) 16 | lengthUnix = uint16(216) 17 | lengthUnspecBytes = func() []byte { 18 | a := make([]byte, 2) 19 | binary.BigEndian.PutUint16(a, lengthUnspec) 20 | return a 21 | }() 22 | lengthV4Bytes = func() []byte { 23 | a := make([]byte, 2) 24 | binary.BigEndian.PutUint16(a, lengthV4) 25 | return a 26 | }() 27 | lengthV6Bytes = func() []byte { 28 | a := make([]byte, 2) 29 | binary.BigEndian.PutUint16(a, lengthV6) 30 | return a 31 | }() 32 | lengthUnixBytes = func() []byte { 33 | a := make([]byte, 2) 34 | binary.BigEndian.PutUint16(a, lengthUnix) 35 | return a 36 | }() 37 | errUint16Overflow = errors.New("proxyproto: uint16 overflow") 38 | ) 39 | 40 | type _ports struct { 41 | SrcPort uint16 42 | DstPort uint16 43 | } 44 | 45 | type _addr4 struct { 46 | Src [4]byte 47 | Dst [4]byte 48 | SrcPort uint16 49 | DstPort uint16 50 | } 51 | 52 | type _addr6 struct { 53 | Src [16]byte 54 | Dst [16]byte 55 | _ports 56 | } 57 | 58 | type _addrUnix struct { 59 | Src [108]byte 60 | Dst [108]byte 61 | } 62 | 63 | func parseVersion2(reader *bufio.Reader) (header *Header, err error) { 64 | // Skip first 12 bytes (signature) 65 | for i := 0; i < 12; i++ { 66 | if _, err = reader.ReadByte(); err != nil { 67 | return nil, ErrCantReadProtocolVersionAndCommand 68 | } 69 | } 70 | 71 | header = new(Header) 72 | header.Version = 2 73 | 74 | // Read the 13th byte, protocol version and command 75 | b13, err := reader.ReadByte() 76 | if err != nil { 77 | return nil, ErrCantReadProtocolVersionAndCommand 78 | } 79 | header.Command = ProtocolVersionAndCommand(b13) 80 | if _, ok := supportedCommand[header.Command]; !ok { 81 | return nil, ErrUnsupportedProtocolVersionAndCommand 82 | } 83 | 84 | // Read the 14th byte, address family and protocol 85 | b14, err := reader.ReadByte() 86 | if err != nil { 87 | return nil, ErrCantReadAddressFamilyAndProtocol 88 | } 89 | header.TransportProtocol = AddressFamilyAndProtocol(b14) 90 | // UNSPEC is only supported when LOCAL is set. 91 | if header.TransportProtocol == UNSPEC && header.Command != LOCAL { 92 | return nil, ErrUnsupportedAddressFamilyAndProtocol 93 | } 94 | 95 | // Make sure there are bytes available as specified in length 96 | var length uint16 97 | if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil { 98 | return nil, ErrCantReadLength 99 | } 100 | if !header.validateLength(length) { 101 | return nil, ErrInvalidLength 102 | } 103 | 104 | // Return early if the length is zero, which means that 105 | // there's no address information and TLVs present for UNSPEC. 106 | if length == 0 { 107 | return header, nil 108 | } 109 | 110 | if _, err := reader.Peek(int(length)); err != nil { 111 | return nil, ErrInvalidLength 112 | } 113 | 114 | // Length-limited reader for payload section 115 | payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader) 116 | 117 | // Read addresses and ports for protocols other than UNSPEC. 118 | // Ignore address information for UNSPEC, and skip straight to read TLVs, 119 | // since the length is greater than zero. 120 | if header.TransportProtocol != UNSPEC { 121 | if header.TransportProtocol.IsIPv4() { 122 | var addr _addr4 123 | if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { 124 | return nil, ErrInvalidAddress 125 | } 126 | header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) 127 | header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) 128 | } else if header.TransportProtocol.IsIPv6() { 129 | var addr _addr6 130 | if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { 131 | return nil, ErrInvalidAddress 132 | } 133 | header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) 134 | header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) 135 | } else if header.TransportProtocol.IsUnix() { 136 | var addr _addrUnix 137 | if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { 138 | return nil, ErrInvalidAddress 139 | } 140 | 141 | network := "unix" 142 | if header.TransportProtocol.IsDatagram() { 143 | network = "unixgram" 144 | } 145 | 146 | header.SourceAddr = &net.UnixAddr{ 147 | Net: network, 148 | Name: parseUnixName(addr.Src[:]), 149 | } 150 | header.DestinationAddr = &net.UnixAddr{ 151 | Net: network, 152 | Name: parseUnixName(addr.Dst[:]), 153 | } 154 | } 155 | } 156 | 157 | // Copy bytes for optional Type-Length-Value vector 158 | header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice 159 | if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF { 160 | return nil, err 161 | } 162 | 163 | return header, nil 164 | } 165 | 166 | func (header *Header) formatVersion2() ([]byte, error) { 167 | var buf bytes.Buffer 168 | buf.Write(SIGV2) 169 | buf.WriteByte(header.Command.toByte()) 170 | buf.WriteByte(header.TransportProtocol.toByte()) 171 | if header.TransportProtocol.IsUnspec() { 172 | // For UNSPEC, write no addresses and ports but only TLVs if they are present 173 | hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs)) 174 | if err != nil { 175 | return nil, err 176 | } 177 | buf.Write(hdrLen) 178 | } else { 179 | var addrSrc, addrDst []byte 180 | if header.TransportProtocol.IsIPv4() { 181 | hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs)) 182 | if err != nil { 183 | return nil, err 184 | } 185 | buf.Write(hdrLen) 186 | sourceIP, destIP, _ := header.IPs() 187 | addrSrc = sourceIP.To4() 188 | addrDst = destIP.To4() 189 | } else if header.TransportProtocol.IsIPv6() { 190 | hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs)) 191 | if err != nil { 192 | return nil, err 193 | } 194 | buf.Write(hdrLen) 195 | sourceIP, destIP, _ := header.IPs() 196 | addrSrc = sourceIP.To16() 197 | addrDst = destIP.To16() 198 | } else if header.TransportProtocol.IsUnix() { 199 | buf.Write(lengthUnixBytes) 200 | sourceAddr, destAddr, ok := header.UnixAddrs() 201 | if !ok { 202 | return nil, ErrInvalidAddress 203 | } 204 | addrSrc = formatUnixName(sourceAddr.Name) 205 | addrDst = formatUnixName(destAddr.Name) 206 | } 207 | 208 | if addrSrc == nil || addrDst == nil { 209 | return nil, ErrInvalidAddress 210 | } 211 | buf.Write(addrSrc) 212 | buf.Write(addrDst) 213 | 214 | if sourcePort, destPort, ok := header.Ports(); ok { 215 | portBytes := make([]byte, 2) 216 | 217 | binary.BigEndian.PutUint16(portBytes, uint16(sourcePort)) 218 | buf.Write(portBytes) 219 | 220 | binary.BigEndian.PutUint16(portBytes, uint16(destPort)) 221 | buf.Write(portBytes) 222 | } 223 | } 224 | 225 | if len(header.rawTLVs) > 0 { 226 | buf.Write(header.rawTLVs) 227 | } 228 | 229 | return buf.Bytes(), nil 230 | } 231 | 232 | func (header *Header) validateLength(length uint16) bool { 233 | if header.TransportProtocol.IsIPv4() { 234 | return length >= lengthV4 235 | } else if header.TransportProtocol.IsIPv6() { 236 | return length >= lengthV6 237 | } else if header.TransportProtocol.IsUnix() { 238 | return length >= lengthUnix 239 | } else if header.TransportProtocol.IsUnspec() { 240 | return length >= lengthUnspec 241 | } 242 | return false 243 | } 244 | 245 | // addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow. 246 | func addTLVLen(cur []byte, tlvLen int) ([]byte, error) { 247 | if tlvLen == 0 { 248 | return cur, nil 249 | } 250 | curLen := binary.BigEndian.Uint16(cur) 251 | newLen := int(curLen) + tlvLen 252 | if newLen >= 1<<16 { 253 | return nil, errUint16Overflow 254 | } 255 | a := make([]byte, 2) 256 | binary.BigEndian.PutUint16(a, uint16(newLen)) 257 | return a, nil 258 | } 259 | 260 | func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr { 261 | if transport.IsStream() { 262 | return &net.TCPAddr{IP: ip, Port: int(port)} 263 | } else if transport.IsDatagram() { 264 | return &net.UDPAddr{IP: ip, Port: int(port)} 265 | } else { 266 | return nil 267 | } 268 | } 269 | 270 | func parseUnixName(b []byte) string { 271 | i := bytes.IndexByte(b, 0) 272 | if i < 0 { 273 | return string(b) 274 | } 275 | return string(b[:i]) 276 | } 277 | 278 | func formatUnixName(name string) []byte { 279 | n := int(lengthUnix) / 2 280 | if len(name) >= n { 281 | return []byte(name[:n]) 282 | } 283 | pad := make([]byte, n-len(name)) 284 | return append([]byte(name), pad...) 285 | } 286 | -------------------------------------------------------------------------------- /v2_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | iorand "crypto/rand" 7 | "encoding/binary" 8 | "math/rand" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | var ( 14 | invalidRune = byte('\x99') 15 | 16 | // Lengths to use in tests 17 | lengthPadded = uint16(84) 18 | 19 | lengthEmptyBytes = func() []byte { 20 | a := make([]byte, 2) 21 | binary.BigEndian.PutUint16(a, 0) 22 | return a 23 | }() 24 | lengthPaddedBytes = func() []byte { 25 | a := make([]byte, 2) 26 | binary.BigEndian.PutUint16(a, lengthPadded) 27 | return a 28 | }() 29 | 30 | // If life gives you lemons, make mojitos 31 | portBytes = func() []byte { 32 | a := make([]byte, 2) 33 | binary.BigEndian.PutUint16(a, PORT) 34 | return a 35 | }() 36 | 37 | unixBytes = pad([]byte("socket"), 108) 38 | 39 | // Tests don't care if source and destination addresses and ports are the same 40 | addressesIPv4 = append(v4ip.To4(), v4ip.To4()...) 41 | addressesIPv6 = append(v6ip.To16(), v6ip.To16()...) 42 | ports = append(portBytes, portBytes...) 43 | 44 | // Fixtures to use in tests 45 | fixtureIPv4Address = append(addressesIPv4, ports...) 46 | fixtureIPv4V2 = append(lengthV4Bytes, fixtureIPv4Address...) 47 | fixtureIPv4V2Padded = append(append(lengthPaddedBytes, fixtureIPv4Address...), make([]byte, lengthPadded-lengthV4)...) 48 | fixtureIPv6Address = append(addressesIPv6, ports...) 49 | fixtureIPv6V2 = append(lengthV6Bytes, fixtureIPv6Address...) 50 | fixtureIPv6V2Padded = append(append(lengthPaddedBytes, fixtureIPv6Address...), make([]byte, lengthPadded-lengthV6)...) 51 | fixtureUnixAddress = append(unixBytes, unixBytes...) 52 | fixtureUnixV2 = append(lengthUnixBytes, fixtureUnixAddress...) 53 | fixtureTLV = func() []byte { 54 | tlv := make([]byte, 2+rand.Intn(1<<12)) // Not enough to overflow, at least size two 55 | _, _ = iorand.Read(tlv) 56 | return tlv 57 | }() 58 | fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV) 59 | fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV) 60 | fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV) 61 | 62 | // Arbitrary bytes following proxy bytes 63 | arbitraryTailBytes = []byte{'\x99', '\x97', '\x98'} 64 | ) 65 | 66 | func pad(b []byte, n int) []byte { 67 | padding := make([]byte, n-len(b)) 68 | return append(b, padding...) 69 | } 70 | 71 | var invalidParseV2Tests = []struct { 72 | desc string 73 | reader *bufio.Reader 74 | expectedError error 75 | }{ 76 | { 77 | desc: "no signature", 78 | reader: newBufioReader([]byte(NO_PROTOCOL)), 79 | expectedError: ErrNoProxyProtocol, 80 | }, 81 | { 82 | desc: "truncated v2 signature", 83 | reader: newBufioReader(SIGV2[2:]), 84 | expectedError: ErrNoProxyProtocol, 85 | }, 86 | { 87 | desc: "v2 signature and nothing else", 88 | reader: newBufioReader(SIGV2), 89 | expectedError: ErrCantReadProtocolVersionAndCommand, 90 | }, 91 | { 92 | desc: "v2 signature with invalid command", 93 | reader: newBufioReader(append(SIGV2, invalidRune)), 94 | expectedError: ErrUnsupportedProtocolVersionAndCommand, 95 | }, 96 | { 97 | desc: "v2 signature with command but nothing else", 98 | reader: newBufioReader(append(SIGV2, byte(PROXY))), 99 | expectedError: ErrCantReadAddressFamilyAndProtocol, 100 | }, 101 | { 102 | desc: "command proxy but inet family unspec", 103 | reader: newBufioReader(append(SIGV2, byte(PROXY), byte(UNSPEC))), 104 | expectedError: ErrUnsupportedAddressFamilyAndProtocol, 105 | }, 106 | { 107 | desc: "v2 signature with command and invalid inet family", // translated to UNSPEC 108 | reader: newBufioReader(append(SIGV2, byte(PROXY), invalidRune)), 109 | expectedError: ErrCantReadLength, 110 | }, 111 | { 112 | desc: "TCPv4 but no length", 113 | reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4))), 114 | expectedError: ErrCantReadLength, 115 | }, 116 | { 117 | desc: "TCPv4 but invalid length", 118 | reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4), invalidRune)), 119 | expectedError: ErrCantReadLength, 120 | }, 121 | { 122 | desc: "unspec but no length", 123 | reader: newBufioReader(append(SIGV2, byte(LOCAL), byte(UNSPEC))), 124 | expectedError: ErrCantReadLength, 125 | }, 126 | { 127 | desc: "TCPv4 with mismatching length", 128 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthV4Bytes...)), 129 | expectedError: ErrInvalidLength, 130 | }, 131 | { 132 | desc: "TCPv6 with mismatching length", 133 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...)), 134 | expectedError: ErrInvalidLength, 135 | }, 136 | { 137 | desc: "TCPv4 length zero but with address and ports", 138 | reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthEmptyBytes...), fixtureIPv6Address...)), 139 | expectedError: ErrInvalidLength, 140 | }, 141 | { 142 | desc: "TCPv6 with IPv6 length but IPv4 address and ports", 143 | reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...), fixtureIPv4Address...)), 144 | expectedError: ErrInvalidLength, 145 | }, 146 | { 147 | desc: "unspec length greater than zero but no TLVs", 148 | reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV[:2]...)), 149 | expectedError: ErrInvalidLength, 150 | }, 151 | } 152 | 153 | func TestParseV2Invalid(t *testing.T) { 154 | for _, tt := range invalidParseV2Tests { 155 | t.Run(tt.desc, func(t *testing.T) { 156 | if _, err := Read(tt.reader); err != tt.expectedError { 157 | t.Fatalf("expected %s, actual %s", tt.expectedError, err.Error()) 158 | } 159 | }) 160 | } 161 | } 162 | 163 | var validParseAndWriteV2Tests = []struct { 164 | desc string 165 | reader *bufio.Reader 166 | expectedHeader *Header 167 | }{ 168 | { 169 | desc: "local", 170 | reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(TCPv4)), fixtureIPv4V2...)), 171 | expectedHeader: &Header{ 172 | Version: 2, 173 | Command: LOCAL, 174 | TransportProtocol: TCPv4, 175 | SourceAddr: v4addr, 176 | DestinationAddr: v4addr, 177 | }, 178 | }, 179 | { 180 | desc: "local unspec", 181 | reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), lengthUnspecBytes...)), 182 | expectedHeader: &Header{ 183 | Version: 2, 184 | Command: LOCAL, 185 | TransportProtocol: UNSPEC, 186 | SourceAddr: nil, 187 | DestinationAddr: nil, 188 | }, 189 | }, 190 | { 191 | desc: "proxy TCPv4", 192 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2...)), 193 | expectedHeader: &Header{ 194 | Version: 2, 195 | Command: PROXY, 196 | TransportProtocol: TCPv4, 197 | SourceAddr: v4addr, 198 | DestinationAddr: v4addr, 199 | }, 200 | }, 201 | { 202 | desc: "proxy TCPv6", 203 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2...)), 204 | expectedHeader: &Header{ 205 | Version: 2, 206 | Command: PROXY, 207 | TransportProtocol: TCPv6, 208 | SourceAddr: v6addr, 209 | DestinationAddr: v6addr, 210 | }, 211 | }, 212 | { 213 | desc: "proxy TCPv4 with TLV", 214 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...)), 215 | expectedHeader: &Header{ 216 | Version: 2, 217 | Command: PROXY, 218 | TransportProtocol: TCPv4, 219 | SourceAddr: v4addr, 220 | DestinationAddr: v4addr, 221 | rawTLVs: fixtureTLV, 222 | }, 223 | }, 224 | { 225 | desc: "proxy TCPv6 with TLV", 226 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2TLV...)), 227 | expectedHeader: &Header{ 228 | Version: 2, 229 | Command: PROXY, 230 | TransportProtocol: TCPv6, 231 | SourceAddr: v6addr, 232 | DestinationAddr: v6addr, 233 | rawTLVs: fixtureTLV, 234 | }, 235 | }, 236 | { 237 | desc: "local unspec with TLV", 238 | reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV...)), 239 | expectedHeader: &Header{ 240 | Version: 2, 241 | Command: LOCAL, 242 | TransportProtocol: UNSPEC, 243 | SourceAddr: nil, 244 | DestinationAddr: nil, 245 | rawTLVs: fixtureTLV, 246 | }, 247 | }, 248 | { 249 | desc: "proxy UDPv4", 250 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2...)), 251 | expectedHeader: &Header{ 252 | Version: 2, 253 | Command: PROXY, 254 | TransportProtocol: UDPv4, 255 | SourceAddr: v4UDPAddr, 256 | DestinationAddr: v4UDPAddr, 257 | }, 258 | }, 259 | { 260 | desc: "proxy UDPv6", 261 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2...)), 262 | expectedHeader: &Header{ 263 | Version: 2, 264 | Command: PROXY, 265 | TransportProtocol: UDPv6, 266 | SourceAddr: v6UDPAddr, 267 | DestinationAddr: v6UDPAddr, 268 | }, 269 | }, 270 | { 271 | desc: "proxy unix stream", 272 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixStream)), fixtureUnixV2...)), 273 | expectedHeader: &Header{ 274 | Version: 2, 275 | Command: PROXY, 276 | TransportProtocol: UnixStream, 277 | SourceAddr: unixStreamAddr, 278 | DestinationAddr: unixStreamAddr, 279 | }, 280 | }, 281 | { 282 | desc: "proxy unix datagram", 283 | reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixDatagram)), fixtureUnixV2...)), 284 | expectedHeader: &Header{ 285 | Version: 2, 286 | Command: PROXY, 287 | TransportProtocol: UnixDatagram, 288 | SourceAddr: unixDatagramAddr, 289 | DestinationAddr: unixDatagramAddr, 290 | }, 291 | }, 292 | } 293 | 294 | func TestParseV2Valid(t *testing.T) { 295 | for _, tt := range validParseAndWriteV2Tests { 296 | t.Run(tt.desc, func(t *testing.T) { 297 | header, err := Read(tt.reader) 298 | if err != nil { 299 | t.Fatal("unexpected error", err.Error()) 300 | } 301 | if !header.EqualsTo(tt.expectedHeader) { 302 | t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) 303 | } 304 | }) 305 | } 306 | } 307 | 308 | func TestWriteV2Valid(t *testing.T) { 309 | for _, tt := range validParseAndWriteV2Tests { 310 | t.Run(tt.desc, func(t *testing.T) { 311 | var b bytes.Buffer 312 | w := bufio.NewWriter(&b) 313 | if _, err := tt.expectedHeader.WriteTo(w); err != nil { 314 | t.Fatal("unexpected error ", err) 315 | } 316 | w.Flush() 317 | 318 | // Read written bytes to validate written header 319 | r := bufio.NewReader(&b) 320 | newHeader, err := Read(r) 321 | if err != nil { 322 | t.Fatal("unexpected error ", err) 323 | } 324 | 325 | if !newHeader.EqualsTo(tt.expectedHeader) { 326 | t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) 327 | } 328 | }) 329 | } 330 | } 331 | 332 | var validParseV2PaddedTests = []struct { 333 | desc string 334 | value []byte 335 | expectedHeader *Header 336 | }{ 337 | { 338 | desc: "proxy TCPv4", 339 | value: append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2Padded...), 340 | expectedHeader: &Header{ 341 | Version: 2, 342 | Command: PROXY, 343 | TransportProtocol: TCPv4, 344 | SourceAddr: v4addr, 345 | DestinationAddr: v4addr, 346 | rawTLVs: make([]byte, lengthPadded-lengthV4), 347 | }, 348 | }, 349 | { 350 | desc: "proxy TCPv6", 351 | value: append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2Padded...), 352 | expectedHeader: &Header{ 353 | Version: 2, 354 | Command: PROXY, 355 | TransportProtocol: TCPv6, 356 | SourceAddr: v6addr, 357 | DestinationAddr: v6addr, 358 | rawTLVs: make([]byte, lengthPadded-lengthV6), 359 | }, 360 | }, 361 | { 362 | desc: "proxy UDPv4", 363 | value: append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2Padded...), 364 | expectedHeader: &Header{ 365 | Version: 2, 366 | Command: PROXY, 367 | TransportProtocol: UDPv4, 368 | SourceAddr: v4addr, 369 | DestinationAddr: v4addr, 370 | rawTLVs: make([]byte, lengthPadded-lengthV4), 371 | }, 372 | }, 373 | { 374 | desc: "proxy UDPv6", 375 | value: append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2Padded...), 376 | expectedHeader: &Header{ 377 | Version: 2, 378 | Command: PROXY, 379 | TransportProtocol: UDPv6, 380 | SourceAddr: v6addr, 381 | DestinationAddr: v6addr, 382 | rawTLVs: make([]byte, lengthPadded-lengthV6), 383 | }, 384 | }, 385 | } 386 | 387 | func TestParseV2Padded(t *testing.T) { 388 | for _, tt := range validParseV2PaddedTests { 389 | t.Run(tt.desc, func(t *testing.T) { 390 | reader := newBufioReader(append(tt.value, arbitraryTailBytes...)) 391 | 392 | newHeader, err := Read(reader) 393 | if err != nil { 394 | t.Fatal("unexpected error ", err) 395 | } 396 | if !newHeader.EqualsTo(tt.expectedHeader) { 397 | t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) 398 | } 399 | 400 | // Check that remaining padding bytes have been flushed 401 | nextBytes, err := reader.Peek(len(arbitraryTailBytes)) 402 | if err != nil { 403 | t.Fatal("unexpected error ", err) 404 | } 405 | if !reflect.DeepEqual(nextBytes, arbitraryTailBytes) { 406 | t.Fatalf("expected %#v, actual %#v", arbitraryTailBytes, nextBytes) 407 | } 408 | }) 409 | } 410 | } 411 | 412 | func TestV2EqualsToTLV(t *testing.T) { 413 | eHdr := &Header{ 414 | Version: 2, 415 | Command: PROXY, 416 | TransportProtocol: TCPv4, 417 | SourceAddr: v4addr, 418 | DestinationAddr: v4addr, 419 | } 420 | hdr, err := Read(newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...))) 421 | if err != nil { 422 | t.Fatal("unexpected error ", err) 423 | } 424 | if eHdr.EqualsTo(hdr) { 425 | t.Fatalf("unexpectedly equal created: %#v, parsed: %#v", eHdr, hdr) 426 | } 427 | eHdr.rawTLVs = fixtureTLV[:] 428 | 429 | if !eHdr.EqualsTo(hdr) { 430 | t.Fatalf("unexpectedly unequal after tlv copy created: %#v, parsed: %#v", eHdr, hdr) 431 | } 432 | 433 | eHdr.rawTLVs[0] = eHdr.rawTLVs[0] + 1 434 | if eHdr.EqualsTo(hdr) { 435 | t.Fatalf("unexpectedly equal after changing tlv created: %#v, parsed: %#v", eHdr, hdr) 436 | } 437 | } 438 | 439 | var tlvFormatTests = []struct { 440 | desc string 441 | header *Header 442 | }{ 443 | { 444 | desc: "proxy TCPv4", 445 | header: &Header{ 446 | Version: 2, 447 | Command: PROXY, 448 | TransportProtocol: TCPv4, 449 | SourceAddr: v4addr, 450 | DestinationAddr: v4addr, 451 | rawTLVs: make([]byte, 1<<16), 452 | }, 453 | }, 454 | { 455 | desc: "proxy TCPv6", 456 | header: &Header{ 457 | Version: 2, 458 | Command: PROXY, 459 | TransportProtocol: TCPv6, 460 | SourceAddr: v6addr, 461 | DestinationAddr: v6addr, 462 | rawTLVs: make([]byte, 1<<16), 463 | }, 464 | }, 465 | { 466 | desc: "proxy UDPv4", 467 | header: &Header{ 468 | Version: 2, 469 | Command: PROXY, 470 | TransportProtocol: UDPv4, 471 | SourceAddr: v4addr, 472 | DestinationAddr: v4addr, 473 | rawTLVs: make([]byte, 1<<16), 474 | }, 475 | }, 476 | { 477 | desc: "proxy UDPv6", 478 | header: &Header{ 479 | Version: 2, 480 | Command: PROXY, 481 | TransportProtocol: UDPv6, 482 | SourceAddr: v6addr, 483 | DestinationAddr: v6addr, 484 | rawTLVs: make([]byte, 1<<16), 485 | }, 486 | }, 487 | { 488 | desc: "local unspec", 489 | header: &Header{ 490 | Version: 2, 491 | Command: LOCAL, 492 | TransportProtocol: UNSPEC, 493 | SourceAddr: nil, 494 | DestinationAddr: nil, 495 | rawTLVs: make([]byte, 1<<16), 496 | }, 497 | }, 498 | } 499 | 500 | func TestV2TLVFormatTooLargeTLV(t *testing.T) { 501 | for _, tt := range tlvFormatTests { 502 | t.Run(tt.desc, func(t *testing.T) { 503 | if _, err := tt.header.Format(); err != errUint16Overflow { 504 | t.Fatalf("missing or expected error when formatting too-large TLV %#v", err) 505 | } 506 | }) 507 | 508 | } 509 | } 510 | 511 | func newBufioReader(b []byte) *bufio.Reader { 512 | return bufio.NewReader(bytes.NewReader(b)) 513 | } 514 | 515 | func fixtureWithTLV(cur []byte, addr []byte, tlv []byte) []byte { 516 | tlen, err := addTLVLen(cur, len(tlv)) 517 | if err != nil { 518 | panic(err) 519 | } 520 | 521 | return append(append(tlen, addr...), tlv...) 522 | } 523 | -------------------------------------------------------------------------------- /version_cmd.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | // ProtocolVersionAndCommand represents the command in proxy protocol v2. 4 | // Command doesn't exist in v1 but it should be set since other parts of 5 | // this library may rely on it for determining connection details. 6 | type ProtocolVersionAndCommand byte 7 | 8 | const ( 9 | // LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1, 10 | // in which case no address information is expected. 11 | LOCAL ProtocolVersionAndCommand = '\x20' 12 | // PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1, 13 | // in which case valid local/remote address and port information is expected. 14 | PROXY ProtocolVersionAndCommand = '\x21' 15 | ) 16 | 17 | var supportedCommand = map[ProtocolVersionAndCommand]bool{ 18 | LOCAL: true, 19 | PROXY: true, 20 | } 21 | 22 | // IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN, 23 | // i.e. when no address information is expected, false otherwise. 24 | func (pvc ProtocolVersionAndCommand) IsLocal() bool { 25 | return LOCAL == pvc 26 | } 27 | 28 | // IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, 29 | // i.e. when valid local/remote address and port information is expected, false otherwise. 30 | func (pvc ProtocolVersionAndCommand) IsProxy() bool { 31 | return PROXY == pvc 32 | } 33 | 34 | // IsUnspec returns true if the command is unspecified, false otherwise. 35 | func (pvc ProtocolVersionAndCommand) IsUnspec() bool { 36 | return !(pvc.IsLocal() || pvc.IsProxy()) 37 | } 38 | 39 | func (pvc ProtocolVersionAndCommand) toByte() byte { 40 | if pvc.IsLocal() { 41 | return byte(LOCAL) 42 | } else if pvc.IsProxy() { 43 | return byte(PROXY) 44 | } 45 | 46 | return byte(LOCAL) 47 | } 48 | -------------------------------------------------------------------------------- /version_cmd_test.go: -------------------------------------------------------------------------------- 1 | package proxyproto 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestLocal(t *testing.T) { 8 | b := byte(LOCAL) 9 | if ProtocolVersionAndCommand(b).IsUnspec() { 10 | t.Fail() 11 | } 12 | if !ProtocolVersionAndCommand(b).IsLocal() { 13 | t.Fail() 14 | } 15 | if ProtocolVersionAndCommand(b).IsProxy() { 16 | t.Fail() 17 | } 18 | if ProtocolVersionAndCommand(b).toByte() != b { 19 | t.Fail() 20 | } 21 | } 22 | 23 | func TestProxy(t *testing.T) { 24 | b := byte(PROXY) 25 | if ProtocolVersionAndCommand(b).IsUnspec() { 26 | t.Fail() 27 | } 28 | if ProtocolVersionAndCommand(b).IsLocal() { 29 | t.Fail() 30 | } 31 | if !ProtocolVersionAndCommand(b).IsProxy() { 32 | t.Fail() 33 | } 34 | if ProtocolVersionAndCommand(b).toByte() != b { 35 | t.Fail() 36 | } 37 | } 38 | 39 | func TestInvalidProtocolVersion(t *testing.T) { 40 | if !ProtocolVersionAndCommand(0x00).IsUnspec() { 41 | t.Fail() 42 | } 43 | } 44 | --------------------------------------------------------------------------------