├── .github ├── CODEOWNERS └── workflows │ └── test.yaml ├── .gitignore ├── .go-version ├── LICENSE ├── README.md ├── addr.go ├── bench_test.go ├── const.go ├── const_test.go ├── go.mod ├── go.sum ├── mux.go ├── session.go ├── session_test.go ├── spec.md ├── stream.go ├── testdata ├── README.md ├── cert.pem └── key.pem ├── util.go └── util_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | # More on CODEOWNERS files: https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners 3 | 4 | # Default owner 5 | * @hashicorp/team-ip-compliance @hashicorp/nomad-eng 6 | 7 | # Add override rules below. Each line is a file/folder pattern followed by one or more owners. 8 | # Being an owner means those groups or individuals will be added as reviewers to PRs affecting 9 | # those areas of the code. 10 | # Examples: 11 | # /docs/ @docs-team 12 | # *.js @js-team 13 | # *.go @go-team 14 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI Tests 2 | on: 3 | pull_request: 4 | paths-ignore: 5 | - 'README.md' 6 | push: 7 | branches: 8 | - 'master' 9 | paths-ignore: 10 | - 'README.md' 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | go-fmt-and-vet: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout Code 20 | uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c # v3.3.0 21 | - name: Setup Go 22 | uses: actions/setup-go@6edd4406fa81c3da01a34fa6f6343087c207a568 # v3.5.0 23 | with: 24 | go-version: '1.23' 25 | cache: true 26 | - name: Go Formatting 27 | run: | 28 | files=$(go fmt ./...) 29 | if [ -n "$files" ]; then 30 | echo "The following file(s) do not conform to go fmt:" 31 | echo "$files" 32 | exit 1 33 | fi 34 | go-test: 35 | needs: go-fmt-and-vet 36 | runs-on: ubuntu-latest 37 | steps: 38 | - name: Checkout Code 39 | uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c # v3.3.0 40 | - name: Setup Go 41 | uses: actions/setup-go@6edd4406fa81c3da01a34fa6f6343087c207a568 # v3.5.0 42 | with: 43 | go-version: '1.23' 44 | cache: true 45 | - name: Run golangci-lint 46 | uses: golangci/golangci-lint-action@08e2f20817b15149a52b5b3ebe7de50aff2ba8c5 47 | - name: Run test and generate coverage report 48 | run: | 49 | go test -race -v -coverprofile=coverage.out ./... 50 | - name: Upload Coverage report 51 | uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 52 | with: 53 | path: coverage.out 54 | name: Coverage-report 55 | - name: Display coverage report 56 | run: go tool cover -func=coverage.out 57 | - name: Build Go 58 | run: go build ./... 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | -------------------------------------------------------------------------------- /.go-version: -------------------------------------------------------------------------------- 1 | 1.23 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 HashiCorp, Inc. 2 | 3 | Mozilla Public License, version 2.0 4 | 5 | 1. Definitions 6 | 7 | 1.1. "Contributor" 8 | 9 | means each individual or legal entity that creates, contributes to the 10 | creation of, or owns Covered Software. 11 | 12 | 1.2. "Contributor Version" 13 | 14 | means the combination of the Contributions of others (if any) used by a 15 | Contributor and that particular Contributor's Contribution. 16 | 17 | 1.3. "Contribution" 18 | 19 | means Covered Software of a particular Contributor. 20 | 21 | 1.4. "Covered Software" 22 | 23 | means Source Code Form to which the initial Contributor has attached the 24 | notice in Exhibit A, the Executable Form of such Source Code Form, and 25 | Modifications of such Source Code Form, in each case including portions 26 | thereof. 27 | 28 | 1.5. "Incompatible With Secondary Licenses" 29 | means 30 | 31 | a. that the initial Contributor has attached the notice described in 32 | Exhibit B to the Covered Software; or 33 | 34 | b. that the Covered Software was made available under the terms of 35 | version 1.1 or earlier of the License, but not also under the terms of 36 | a Secondary License. 37 | 38 | 1.6. "Executable Form" 39 | 40 | means any form of the work other than Source Code Form. 41 | 42 | 1.7. "Larger Work" 43 | 44 | means a work that combines Covered Software with other material, in a 45 | separate file or files, that is not Covered Software. 46 | 47 | 1.8. "License" 48 | 49 | means this document. 50 | 51 | 1.9. "Licensable" 52 | 53 | means having the right to grant, to the maximum extent possible, whether 54 | at the time of the initial grant or subsequently, any and all of the 55 | rights conveyed by this License. 56 | 57 | 1.10. "Modifications" 58 | 59 | means any of the following: 60 | 61 | a. any file in Source Code Form that results from an addition to, 62 | deletion from, or modification of the contents of Covered Software; or 63 | 64 | b. any new file in Source Code Form that contains any Covered Software. 65 | 66 | 1.11. "Patent Claims" of a Contributor 67 | 68 | means any patent claim(s), including without limitation, method, 69 | process, and apparatus claims, in any patent Licensable by such 70 | Contributor that would be infringed, but for the grant of the License, 71 | by the making, using, selling, offering for sale, having made, import, 72 | or transfer of either its Contributions or its Contributor Version. 73 | 74 | 1.12. "Secondary License" 75 | 76 | means either the GNU General Public License, Version 2.0, the GNU Lesser 77 | General Public License, Version 2.1, the GNU Affero General Public 78 | License, Version 3.0, or any later versions of those licenses. 79 | 80 | 1.13. "Source Code Form" 81 | 82 | means the form of the work preferred for making modifications. 83 | 84 | 1.14. "You" (or "Your") 85 | 86 | means an individual or a legal entity exercising rights under this 87 | License. For legal entities, "You" includes any entity that controls, is 88 | controlled by, or is under common control with You. For purposes of this 89 | definition, "control" means (a) the power, direct or indirect, to cause 90 | the direction or management of such entity, whether by contract or 91 | otherwise, or (b) ownership of more than fifty percent (50%) of the 92 | outstanding shares or beneficial ownership of such entity. 93 | 94 | 95 | 2. License Grants and Conditions 96 | 97 | 2.1. Grants 98 | 99 | Each Contributor hereby grants You a world-wide, royalty-free, 100 | non-exclusive license: 101 | 102 | a. under intellectual property rights (other than patent or trademark) 103 | Licensable by such Contributor to use, reproduce, make available, 104 | modify, display, perform, distribute, and otherwise exploit its 105 | Contributions, either on an unmodified basis, with Modifications, or 106 | as part of a Larger Work; and 107 | 108 | b. under Patent Claims of such Contributor to make, use, sell, offer for 109 | sale, have made, import, and otherwise transfer either its 110 | Contributions or its Contributor Version. 111 | 112 | 2.2. Effective Date 113 | 114 | The licenses granted in Section 2.1 with respect to any Contribution 115 | become effective for each Contribution on the date the Contributor first 116 | distributes such Contribution. 117 | 118 | 2.3. Limitations on Grant Scope 119 | 120 | The licenses granted in this Section 2 are the only rights granted under 121 | this License. No additional rights or licenses will be implied from the 122 | distribution or licensing of Covered Software under this License. 123 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 124 | Contributor: 125 | 126 | a. for any code that a Contributor has removed from Covered Software; or 127 | 128 | b. for infringements caused by: (i) Your and any other third party's 129 | modifications of Covered Software, or (ii) the combination of its 130 | Contributions with other software (except as part of its Contributor 131 | Version); or 132 | 133 | c. under Patent Claims infringed by Covered Software in the absence of 134 | its Contributions. 135 | 136 | This License does not grant any rights in the trademarks, service marks, 137 | or logos of any Contributor (except as may be necessary to comply with 138 | the notice requirements in Section 3.4). 139 | 140 | 2.4. Subsequent Licenses 141 | 142 | No Contributor makes additional grants as a result of Your choice to 143 | distribute the Covered Software under a subsequent version of this 144 | License (see Section 10.2) or under the terms of a Secondary License (if 145 | permitted under the terms of Section 3.3). 146 | 147 | 2.5. Representation 148 | 149 | Each Contributor represents that the Contributor believes its 150 | Contributions are its original creation(s) or it has sufficient rights to 151 | grant the rights to its Contributions conveyed by this License. 152 | 153 | 2.6. Fair Use 154 | 155 | This License is not intended to limit any rights You have under 156 | applicable copyright doctrines of fair use, fair dealing, or other 157 | equivalents. 158 | 159 | 2.7. Conditions 160 | 161 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in 162 | Section 2.1. 163 | 164 | 165 | 3. Responsibilities 166 | 167 | 3.1. Distribution of Source Form 168 | 169 | All distribution of Covered Software in Source Code Form, including any 170 | Modifications that You create or to which You contribute, must be under 171 | the terms of this License. You must inform recipients that the Source 172 | Code Form of the Covered Software is governed by the terms of this 173 | License, and how they can obtain a copy of this License. You may not 174 | attempt to alter or restrict the recipients' rights in the Source Code 175 | Form. 176 | 177 | 3.2. Distribution of Executable Form 178 | 179 | If You distribute Covered Software in Executable Form then: 180 | 181 | a. such Covered Software must also be made available in Source Code Form, 182 | as described in Section 3.1, and You must inform recipients of the 183 | Executable Form how they can obtain a copy of such Source Code Form by 184 | reasonable means in a timely manner, at a charge no more than the cost 185 | of distribution to the recipient; and 186 | 187 | b. You may distribute such Executable Form under the terms of this 188 | License, or sublicense it under different terms, provided that the 189 | license for the Executable Form does not attempt to limit or alter the 190 | recipients' rights in the Source Code Form under this License. 191 | 192 | 3.3. Distribution of a Larger Work 193 | 194 | You may create and distribute a Larger Work under terms of Your choice, 195 | provided that You also comply with the requirements of this License for 196 | the Covered Software. If the Larger Work is a combination of Covered 197 | Software with a work governed by one or more Secondary Licenses, and the 198 | Covered Software is not Incompatible With Secondary Licenses, this 199 | License permits You to additionally distribute such Covered Software 200 | under the terms of such Secondary License(s), so that the recipient of 201 | the Larger Work may, at their option, further distribute the Covered 202 | Software under the terms of either this License or such Secondary 203 | License(s). 204 | 205 | 3.4. Notices 206 | 207 | You may not remove or alter the substance of any license notices 208 | (including copyright notices, patent notices, disclaimers of warranty, or 209 | limitations of liability) contained within the Source Code Form of the 210 | Covered Software, except that You may alter any license notices to the 211 | extent required to remedy known factual inaccuracies. 212 | 213 | 3.5. Application of Additional Terms 214 | 215 | You may choose to offer, and to charge a fee for, warranty, support, 216 | indemnity or liability obligations to one or more recipients of Covered 217 | Software. However, You may do so only on Your own behalf, and not on 218 | behalf of any Contributor. You must make it absolutely clear that any 219 | such warranty, support, indemnity, or liability obligation is offered by 220 | You alone, and You hereby agree to indemnify every Contributor for any 221 | liability incurred by such Contributor as a result of warranty, support, 222 | indemnity or liability terms You offer. You may include additional 223 | disclaimers of warranty and limitations of liability specific to any 224 | jurisdiction. 225 | 226 | 4. Inability to Comply Due to Statute or Regulation 227 | 228 | If it is impossible for You to comply with any of the terms of this License 229 | with respect to some or all of the Covered Software due to statute, 230 | judicial order, or regulation then You must: (a) comply with the terms of 231 | this License to the maximum extent possible; and (b) describe the 232 | limitations and the code they affect. Such description must be placed in a 233 | text file included with all distributions of the Covered Software under 234 | this License. Except to the extent prohibited by statute or regulation, 235 | such description must be sufficiently detailed for a recipient of ordinary 236 | skill to be able to understand it. 237 | 238 | 5. Termination 239 | 240 | 5.1. The rights granted under this License will terminate automatically if You 241 | fail to comply with any of its terms. However, if You become compliant, 242 | then the rights granted under this License from a particular Contributor 243 | are reinstated (a) provisionally, unless and until such Contributor 244 | explicitly and finally terminates Your grants, and (b) on an ongoing 245 | basis, if such Contributor fails to notify You of the non-compliance by 246 | some reasonable means prior to 60 days after You have come back into 247 | compliance. Moreover, Your grants from a particular Contributor are 248 | reinstated on an ongoing basis if such Contributor notifies You of the 249 | non-compliance by some reasonable means, this is the first time You have 250 | received notice of non-compliance with this License from such 251 | Contributor, and You become compliant prior to 30 days after Your receipt 252 | of the notice. 253 | 254 | 5.2. If You initiate litigation against any entity by asserting a patent 255 | infringement claim (excluding declaratory judgment actions, 256 | counter-claims, and cross-claims) alleging that a Contributor Version 257 | directly or indirectly infringes any patent, then the rights granted to 258 | You by any and all Contributors for the Covered Software under Section 259 | 2.1 of this License shall terminate. 260 | 261 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user 262 | license agreements (excluding distributors and resellers) which have been 263 | validly granted by You or Your distributors under this License prior to 264 | termination shall survive termination. 265 | 266 | 6. Disclaimer of Warranty 267 | 268 | Covered Software is provided under this License on an "as is" basis, 269 | without warranty of any kind, either expressed, implied, or statutory, 270 | including, without limitation, warranties that the Covered Software is free 271 | of defects, merchantable, fit for a particular purpose or non-infringing. 272 | The entire risk as to the quality and performance of the Covered Software 273 | is with You. Should any Covered Software prove defective in any respect, 274 | You (not any Contributor) assume the cost of any necessary servicing, 275 | repair, or correction. This disclaimer of warranty constitutes an essential 276 | part of this License. No use of any Covered Software is authorized under 277 | this License except under this disclaimer. 278 | 279 | 7. Limitation of Liability 280 | 281 | Under no circumstances and under no legal theory, whether tort (including 282 | negligence), contract, or otherwise, shall any Contributor, or anyone who 283 | distributes Covered Software as permitted above, be liable to You for any 284 | direct, indirect, special, incidental, or consequential damages of any 285 | character including, without limitation, damages for lost profits, loss of 286 | goodwill, work stoppage, computer failure or malfunction, or any and all 287 | other commercial damages or losses, even if such party shall have been 288 | informed of the possibility of such damages. This limitation of liability 289 | shall not apply to liability for death or personal injury resulting from 290 | such party's negligence to the extent applicable law prohibits such 291 | limitation. Some jurisdictions do not allow the exclusion or limitation of 292 | incidental or consequential damages, so this exclusion and limitation may 293 | not apply to You. 294 | 295 | 8. Litigation 296 | 297 | Any litigation relating to this License may be brought only in the courts 298 | of a jurisdiction where the defendant maintains its principal place of 299 | business and such litigation shall be governed by laws of that 300 | jurisdiction, without reference to its conflict-of-law provisions. Nothing 301 | in this Section shall prevent a party's ability to bring cross-claims or 302 | counter-claims. 303 | 304 | 9. Miscellaneous 305 | 306 | This License represents the complete agreement concerning the subject 307 | matter hereof. If any provision of this License is held to be 308 | unenforceable, such provision shall be reformed only to the extent 309 | necessary to make it enforceable. Any law or regulation which provides that 310 | the language of a contract shall be construed against the drafter shall not 311 | be used to construe this License against a Contributor. 312 | 313 | 314 | 10. Versions of the License 315 | 316 | 10.1. New Versions 317 | 318 | Mozilla Foundation is the license steward. Except as provided in Section 319 | 10.3, no one other than the license steward has the right to modify or 320 | publish new versions of this License. Each version will be given a 321 | distinguishing version number. 322 | 323 | 10.2. Effect of New Versions 324 | 325 | You may distribute the Covered Software under the terms of the version 326 | of the License under which You originally received the Covered Software, 327 | or under the terms of any subsequent version published by the license 328 | steward. 329 | 330 | 10.3. Modified Versions 331 | 332 | If you create software not governed by this License, and you want to 333 | create a new license for such software, you may create and use a 334 | modified version of this License if you rename the license and remove 335 | any references to the name of the license steward (except to note that 336 | such modified license differs from this License). 337 | 338 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 339 | Licenses If You choose to distribute Source Code Form that is 340 | Incompatible With Secondary Licenses under the terms of this version of 341 | the License, the notice described in Exhibit B of this License must be 342 | attached. 343 | 344 | Exhibit A - Source Code Form License Notice 345 | 346 | This Source Code Form is subject to the 347 | terms of the Mozilla Public License, v. 348 | 2.0. If a copy of the MPL was not 349 | distributed with this file, You can 350 | obtain one at 351 | http://mozilla.org/MPL/2.0/. 352 | 353 | If it is not possible or desirable to put the notice in a particular file, 354 | then You may include the notice in a location (such as a LICENSE file in a 355 | relevant directory) where a recipient would be likely to look for such a 356 | notice. 357 | 358 | You may add additional accurate notices of copyright ownership. 359 | 360 | Exhibit B - "Incompatible With Secondary Licenses" Notice 361 | 362 | This Source Code Form is "Incompatible 363 | With Secondary Licenses", as defined by 364 | the Mozilla Public License, v. 2.0. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yamux 2 | 3 | Yamux (Yet another Multiplexer) is a multiplexing library for Golang. 4 | It relies on an underlying connection to provide reliability 5 | and ordering, such as TCP or Unix domain sockets, and provides 6 | stream-oriented multiplexing. It is inspired by SPDY but is not 7 | interoperable with it. 8 | 9 | Yamux features include: 10 | 11 | * Bi-directional streams 12 | * Streams can be opened by either client or server 13 | * Useful for NAT traversal 14 | * Server-side push support 15 | * Flow control 16 | * Avoid starvation 17 | * Back-pressure to prevent overwhelming a receiver 18 | * Keep Alives 19 | * Enables persistent connections over a load balancer 20 | * Efficient 21 | * Enables thousands of logical streams with low overhead 22 | 23 | ## Documentation 24 | 25 | For complete documentation, see the associated [Godoc](http://godoc.org/github.com/hashicorp/yamux). 26 | 27 | ## Specification 28 | 29 | The full specification for Yamux is provided in the `spec.md` file. 30 | It can be used as a guide to implementors of interoperable libraries. 31 | 32 | ## Usage 33 | 34 | Using Yamux is remarkably simple: 35 | 36 | ```go 37 | 38 | func client() { 39 | // Get a TCP connection 40 | conn, err := net.Dial(...) 41 | if err != nil { 42 | panic(err) 43 | } 44 | 45 | // Setup client side of yamux 46 | session, err := yamux.Client(conn, nil) 47 | if err != nil { 48 | panic(err) 49 | } 50 | 51 | // Open a new stream 52 | stream, err := session.Open() 53 | if err != nil { 54 | panic(err) 55 | } 56 | 57 | // Stream implements net.Conn 58 | stream.Write([]byte("ping")) 59 | } 60 | 61 | func server() { 62 | // Accept a TCP connection 63 | conn, err := listener.Accept() 64 | if err != nil { 65 | panic(err) 66 | } 67 | 68 | // Setup server side of yamux 69 | session, err := yamux.Server(conn, nil) 70 | if err != nil { 71 | panic(err) 72 | } 73 | 74 | // Accept a stream 75 | stream, err := session.Accept() 76 | if err != nil { 77 | panic(err) 78 | } 79 | 80 | // Listen for a message 81 | buf := make([]byte, 4) 82 | stream.Read(buf) 83 | } 84 | 85 | ``` 86 | 87 | -------------------------------------------------------------------------------- /addr.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | ) 7 | 8 | // hasAddr is used to get the address from the underlying connection 9 | type hasAddr interface { 10 | LocalAddr() net.Addr 11 | RemoteAddr() net.Addr 12 | } 13 | 14 | // yamuxAddr is used when we cannot get the underlying address 15 | type yamuxAddr struct { 16 | Addr string 17 | } 18 | 19 | func (*yamuxAddr) Network() string { 20 | return "yamux" 21 | } 22 | 23 | func (y *yamuxAddr) String() string { 24 | return fmt.Sprintf("yamux:%s", y.Addr) 25 | } 26 | 27 | // Addr is used to get the address of the listener. 28 | func (s *Session) Addr() net.Addr { 29 | return s.LocalAddr() 30 | } 31 | 32 | // LocalAddr is used to get the local address of the 33 | // underlying connection. 34 | func (s *Session) LocalAddr() net.Addr { 35 | addr, ok := s.conn.(hasAddr) 36 | if !ok { 37 | return &yamuxAddr{"local"} 38 | } 39 | return addr.LocalAddr() 40 | } 41 | 42 | // RemoteAddr is used to get the address of remote end 43 | // of the underlying connection 44 | func (s *Session) RemoteAddr() net.Addr { 45 | addr, ok := s.conn.(hasAddr) 46 | if !ok { 47 | return &yamuxAddr{"remote"} 48 | } 49 | return addr.RemoteAddr() 50 | } 51 | 52 | // LocalAddr returns the local address 53 | func (s *Stream) LocalAddr() net.Addr { 54 | return s.session.LocalAddr() 55 | } 56 | 57 | // RemoteAddr returns the remote address 58 | func (s *Stream) RemoteAddr() net.Addr { 59 | return s.session.RemoteAddr() 60 | } 61 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "testing" 7 | ) 8 | 9 | func BenchmarkPing(b *testing.B) { 10 | client, _ := testClientServer(b) 11 | 12 | b.ReportAllocs() 13 | b.ResetTimer() 14 | 15 | for i := 0; i < b.N; i++ { 16 | rtt, err := client.Ping() 17 | if err != nil { 18 | b.Fatalf("err: %v", err) 19 | } 20 | if rtt == 0 { 21 | b.Fatalf("bad: %v", rtt) 22 | } 23 | } 24 | } 25 | 26 | func BenchmarkAccept(b *testing.B) { 27 | client, server := testClientServer(b) 28 | 29 | doneCh := make(chan struct{}) 30 | b.ReportAllocs() 31 | b.ResetTimer() 32 | 33 | go func() { 34 | defer close(doneCh) 35 | 36 | for i := 0; i < b.N; i++ { 37 | stream, err := server.AcceptStream() 38 | if err != nil { 39 | return 40 | } 41 | stream.Close() 42 | } 43 | }() 44 | 45 | for i := 0; i < b.N; i++ { 46 | stream, err := client.Open() 47 | if err != nil { 48 | b.Fatalf("err: %v", err) 49 | } 50 | stream.Close() 51 | } 52 | <-doneCh 53 | } 54 | 55 | func BenchmarkSendRecv32(b *testing.B) { 56 | const payloadSize = 32 57 | benchmarkSendRecv(b, payloadSize, payloadSize) 58 | } 59 | 60 | func BenchmarkSendRecv64(b *testing.B) { 61 | const payloadSize = 64 62 | benchmarkSendRecv(b, payloadSize, payloadSize) 63 | } 64 | 65 | func BenchmarkSendRecv128(b *testing.B) { 66 | const payloadSize = 128 67 | benchmarkSendRecv(b, payloadSize, payloadSize) 68 | } 69 | 70 | func BenchmarkSendRecv256(b *testing.B) { 71 | const payloadSize = 256 72 | benchmarkSendRecv(b, payloadSize, payloadSize) 73 | } 74 | 75 | func BenchmarkSendRecv512(b *testing.B) { 76 | const payloadSize = 512 77 | benchmarkSendRecv(b, payloadSize, payloadSize) 78 | } 79 | 80 | func BenchmarkSendRecv1024(b *testing.B) { 81 | const payloadSize = 1024 82 | benchmarkSendRecv(b, payloadSize, payloadSize) 83 | } 84 | 85 | func BenchmarkSendRecv2048(b *testing.B) { 86 | const payloadSize = 2048 87 | benchmarkSendRecv(b, payloadSize, payloadSize) 88 | } 89 | 90 | func BenchmarkSendRecv4096(b *testing.B) { 91 | const payloadSize = 4096 92 | benchmarkSendRecv(b, payloadSize, payloadSize) 93 | } 94 | 95 | func BenchmarkSendRecvLarge(b *testing.B) { 96 | const sendSize = 512 * 1024 * 1024 //512 MB 97 | const recvSize = 4 * 1024 //4 KB 98 | benchmarkSendRecv(b, sendSize, recvSize) 99 | } 100 | 101 | func benchmarkSendRecv(b *testing.B, sendSize, recvSize int) { 102 | client, server := testClientServer(b) 103 | 104 | sendBuf := make([]byte, sendSize) 105 | recvBuf := make([]byte, recvSize) 106 | errCh := make(chan error, 1) 107 | 108 | b.SetBytes(int64(sendSize)) 109 | b.ReportAllocs() 110 | b.ResetTimer() 111 | 112 | go func() { 113 | stream, err := server.AcceptStream() 114 | if err != nil { 115 | errCh <- err 116 | return 117 | } 118 | defer stream.Close() 119 | 120 | switch { 121 | case sendSize == recvSize: 122 | for i := 0; i < b.N; i++ { 123 | if _, err := stream.Read(recvBuf); err != nil { 124 | errCh <- err 125 | return 126 | } 127 | } 128 | 129 | case recvSize > sendSize: 130 | errCh <- fmt.Errorf("bad test case; recvSize was: %d and sendSize was: %d, but recvSize must be <= sendSize!", recvSize, sendSize) 131 | return 132 | 133 | default: 134 | chunks := sendSize / recvSize 135 | for i := 0; i < b.N; i++ { 136 | for j := 0; j < chunks; j++ { 137 | if _, err := stream.Read(recvBuf); err != nil { 138 | errCh <- err 139 | return 140 | } 141 | } 142 | } 143 | } 144 | errCh <- nil 145 | }() 146 | 147 | stream, err := client.Open() 148 | if err != nil { 149 | b.Fatalf("err: %v", err) 150 | } 151 | defer stream.Close() 152 | 153 | for i := 0; i < b.N; i++ { 154 | if _, err := stream.Write(sendBuf); err != nil { 155 | b.Fatalf("err: %v", err) 156 | } 157 | } 158 | 159 | drainErrorsUntil(b, errCh, 1, 0, "") 160 | } 161 | 162 | func BenchmarkSendRecvParallel32(b *testing.B) { 163 | const payloadSize = 32 164 | benchmarkSendRecvParallel(b, payloadSize) 165 | } 166 | 167 | func BenchmarkSendRecvParallel64(b *testing.B) { 168 | const payloadSize = 64 169 | benchmarkSendRecvParallel(b, payloadSize) 170 | } 171 | 172 | func BenchmarkSendRecvParallel128(b *testing.B) { 173 | const payloadSize = 128 174 | benchmarkSendRecvParallel(b, payloadSize) 175 | } 176 | 177 | func BenchmarkSendRecvParallel256(b *testing.B) { 178 | const payloadSize = 256 179 | benchmarkSendRecvParallel(b, payloadSize) 180 | } 181 | 182 | func BenchmarkSendRecvParallel512(b *testing.B) { 183 | const payloadSize = 512 184 | benchmarkSendRecvParallel(b, payloadSize) 185 | } 186 | 187 | func BenchmarkSendRecvParallel1024(b *testing.B) { 188 | const payloadSize = 1024 189 | benchmarkSendRecvParallel(b, payloadSize) 190 | } 191 | 192 | func BenchmarkSendRecvParallel2048(b *testing.B) { 193 | const payloadSize = 2048 194 | benchmarkSendRecvParallel(b, payloadSize) 195 | } 196 | 197 | func BenchmarkSendRecvParallel4096(b *testing.B) { 198 | const payloadSize = 4096 199 | benchmarkSendRecvParallel(b, payloadSize) 200 | } 201 | 202 | func benchmarkSendRecvParallel(b *testing.B, sendSize int) { 203 | client, server := testClientServer(b) 204 | 205 | sendBuf := make([]byte, sendSize) 206 | discarder := io.Discard.(io.ReaderFrom) 207 | b.SetBytes(int64(sendSize)) 208 | b.ReportAllocs() 209 | b.ResetTimer() 210 | 211 | b.RunParallel(func(pb *testing.PB) { 212 | errCh := make(chan error, 1) 213 | go func() { 214 | stream, err := server.AcceptStream() 215 | if err != nil { 216 | errCh <- err 217 | return 218 | } 219 | defer stream.Close() 220 | 221 | if _, err := discarder.ReadFrom(stream); err != nil { 222 | errCh <- err 223 | return 224 | } 225 | errCh <- nil 226 | }() 227 | 228 | stream, err := client.Open() 229 | if err != nil { 230 | b.Fatalf("err: %v", err) 231 | } 232 | 233 | for pb.Next() { 234 | if _, err := stream.Write(sendBuf); err != nil { 235 | b.Fatalf("err: %v", err) 236 | } 237 | } 238 | 239 | stream.Close() 240 | 241 | drainErrorsUntil(b, errCh, 1, 0, "") 242 | }) 243 | } 244 | -------------------------------------------------------------------------------- /const.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | ) 7 | 8 | // NetError implements net.Error 9 | type NetError struct { 10 | err error 11 | timeout bool 12 | temporary bool 13 | } 14 | 15 | func (e *NetError) Error() string { 16 | return e.err.Error() 17 | } 18 | 19 | func (e *NetError) Timeout() bool { 20 | return e.timeout 21 | } 22 | 23 | func (e *NetError) Temporary() bool { 24 | return e.temporary 25 | } 26 | 27 | var ( 28 | // ErrInvalidVersion means we received a frame with an 29 | // invalid version 30 | ErrInvalidVersion = fmt.Errorf("invalid protocol version") 31 | 32 | // ErrInvalidMsgType means we received a frame with an 33 | // invalid message type 34 | ErrInvalidMsgType = fmt.Errorf("invalid msg type") 35 | 36 | // ErrSessionShutdown is used if there is a shutdown during 37 | // an operation 38 | ErrSessionShutdown = fmt.Errorf("session shutdown") 39 | 40 | // ErrStreamsExhausted is returned if we have no more 41 | // stream ids to issue 42 | ErrStreamsExhausted = fmt.Errorf("streams exhausted") 43 | 44 | // ErrDuplicateStream is used if a duplicate stream is 45 | // opened inbound 46 | ErrDuplicateStream = fmt.Errorf("duplicate stream initiated") 47 | 48 | // ErrReceiveWindowExceeded indicates the window was exceeded 49 | ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded") 50 | 51 | // ErrTimeout is used when we reach an IO deadline 52 | ErrTimeout = &NetError{ 53 | err: fmt.Errorf("i/o deadline reached"), 54 | 55 | // Error should meet net.Error interface for timeouts for compatability 56 | // with standard library expectations, such as http servers. 57 | timeout: true, 58 | } 59 | 60 | // ErrStreamClosed is returned when using a closed stream 61 | ErrStreamClosed = fmt.Errorf("stream closed") 62 | 63 | // ErrUnexpectedFlag is set when we get an unexpected flag 64 | ErrUnexpectedFlag = fmt.Errorf("unexpected flag") 65 | 66 | // ErrRemoteGoAway is used when we get a go away from the other side 67 | ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections") 68 | 69 | // ErrConnectionReset is sent if a stream is reset. This can happen 70 | // if the backlog is exceeded, or if there was a remote GoAway. 71 | ErrConnectionReset = fmt.Errorf("connection reset") 72 | 73 | // ErrConnectionWriteTimeout indicates that we hit the "safety valve" 74 | // timeout writing to the underlying stream connection. 75 | ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout") 76 | 77 | // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close 78 | ErrKeepAliveTimeout = fmt.Errorf("keepalive timeout") 79 | ) 80 | 81 | const ( 82 | // protoVersion is the only version we support 83 | protoVersion uint8 = 0 84 | ) 85 | 86 | const ( 87 | // Data is used for data frames. They are followed 88 | // by length bytes worth of payload. 89 | typeData uint8 = iota 90 | 91 | // WindowUpdate is used to change the window of 92 | // a given stream. The length indicates the delta 93 | // update to the window. 94 | typeWindowUpdate 95 | 96 | // Ping is sent as a keep-alive or to measure 97 | // the RTT. The StreamID and Length value are echoed 98 | // back in the response. 99 | typePing 100 | 101 | // GoAway is sent to terminate a session. The StreamID 102 | // should be 0 and the length is an error code. 103 | typeGoAway 104 | ) 105 | 106 | const ( 107 | // SYN is sent to signal a new stream. May 108 | // be sent with a data payload 109 | flagSYN uint16 = 1 << iota 110 | 111 | // ACK is sent to acknowledge a new stream. May 112 | // be sent with a data payload 113 | flagACK 114 | 115 | // FIN is sent to half-close the given stream. 116 | // May be sent with a data payload. 117 | flagFIN 118 | 119 | // RST is used to hard close a given stream. 120 | flagRST 121 | ) 122 | 123 | const ( 124 | // initialStreamWindow is the initial stream window size 125 | initialStreamWindow uint32 = 256 * 1024 126 | ) 127 | 128 | const ( 129 | // goAwayNormal is sent on a normal termination 130 | goAwayNormal uint32 = iota 131 | 132 | // goAwayProtoErr sent on a protocol error 133 | goAwayProtoErr 134 | 135 | // goAwayInternalErr sent on an internal error 136 | goAwayInternalErr 137 | ) 138 | 139 | const ( 140 | sizeOfVersion = 1 141 | sizeOfType = 1 142 | sizeOfFlags = 2 143 | sizeOfStreamID = 4 144 | sizeOfLength = 4 145 | headerSize = sizeOfVersion + sizeOfType + sizeOfFlags + 146 | sizeOfStreamID + sizeOfLength 147 | ) 148 | 149 | type header []byte 150 | 151 | func (h header) Version() uint8 { 152 | return h[0] 153 | } 154 | 155 | func (h header) MsgType() uint8 { 156 | return h[1] 157 | } 158 | 159 | func (h header) Flags() uint16 { 160 | return binary.BigEndian.Uint16(h[2:4]) 161 | } 162 | 163 | func (h header) StreamID() uint32 { 164 | return binary.BigEndian.Uint32(h[4:8]) 165 | } 166 | 167 | func (h header) Length() uint32 { 168 | return binary.BigEndian.Uint32(h[8:12]) 169 | } 170 | 171 | func (h header) String() string { 172 | return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d", 173 | h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length()) 174 | } 175 | 176 | func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) { 177 | h[0] = protoVersion 178 | h[1] = msgType 179 | binary.BigEndian.PutUint16(h[2:4], flags) 180 | binary.BigEndian.PutUint32(h[4:8], streamID) 181 | binary.BigEndian.PutUint32(h[8:12], length) 182 | } 183 | -------------------------------------------------------------------------------- /const_test.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestConst(t *testing.T) { 8 | if protoVersion != 0 { 9 | t.Fatalf("bad: %v", protoVersion) 10 | } 11 | 12 | if typeData != 0 { 13 | t.Fatalf("bad: %v", typeData) 14 | } 15 | if typeWindowUpdate != 1 { 16 | t.Fatalf("bad: %v", typeWindowUpdate) 17 | } 18 | if typePing != 2 { 19 | t.Fatalf("bad: %v", typePing) 20 | } 21 | if typeGoAway != 3 { 22 | t.Fatalf("bad: %v", typeGoAway) 23 | } 24 | 25 | if flagSYN != 1 { 26 | t.Fatalf("bad: %v", flagSYN) 27 | } 28 | if flagACK != 2 { 29 | t.Fatalf("bad: %v", flagACK) 30 | } 31 | if flagFIN != 4 { 32 | t.Fatalf("bad: %v", flagFIN) 33 | } 34 | if flagRST != 8 { 35 | t.Fatalf("bad: %v", flagRST) 36 | } 37 | 38 | if goAwayNormal != 0 { 39 | t.Fatalf("bad: %v", goAwayNormal) 40 | } 41 | if goAwayProtoErr != 1 { 42 | t.Fatalf("bad: %v", goAwayProtoErr) 43 | } 44 | if goAwayInternalErr != 2 { 45 | t.Fatalf("bad: %v", goAwayInternalErr) 46 | } 47 | 48 | if headerSize != 12 { 49 | t.Fatalf("bad header size") 50 | } 51 | } 52 | 53 | func TestEncodeDecode(t *testing.T) { 54 | hdr := header(make([]byte, headerSize)) 55 | hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321) 56 | 57 | if hdr.Version() != protoVersion { 58 | t.Fatalf("bad: %v", hdr) 59 | } 60 | if hdr.MsgType() != typeWindowUpdate { 61 | t.Fatalf("bad: %v", hdr) 62 | } 63 | if hdr.Flags() != flagACK|flagRST { 64 | t.Fatalf("bad: %v", hdr) 65 | } 66 | if hdr.StreamID() != 1234 { 67 | t.Fatalf("bad: %v", hdr) 68 | } 69 | if hdr.Length() != 4321 { 70 | t.Fatalf("bad: %v", hdr) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/yamux 2 | 3 | go 1.23 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hashicorp/yamux/26e720ab6278d3ec725485927eecf8ada66adfe1/go.sum -------------------------------------------------------------------------------- /mux.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "time" 8 | ) 9 | 10 | // Config is used to tune the Yamux session 11 | type Config struct { 12 | // AcceptBacklog is used to limit how many streams may be 13 | // waiting an accept. 14 | AcceptBacklog int 15 | 16 | // EnableKeepalive is used to do a period keep alive 17 | // messages using a ping. 18 | EnableKeepAlive bool 19 | 20 | // KeepAliveInterval is how often to perform the keep alive 21 | KeepAliveInterval time.Duration 22 | 23 | // ConnectionWriteTimeout is meant to be a "safety valve" timeout after 24 | // we which will suspect a problem with the underlying connection and 25 | // close it. This is only applied to writes, where's there's generally 26 | // an expectation that things will move along quickly. 27 | ConnectionWriteTimeout time.Duration 28 | 29 | // MaxStreamWindowSize is used to control the maximum 30 | // window size that we allow for a stream. 31 | MaxStreamWindowSize uint32 32 | 33 | // StreamOpenTimeout is the maximum amount of time that a stream will 34 | // be allowed to remain in pending state while waiting for an ack from the peer. 35 | // Once the timeout is reached the session will be gracefully closed. 36 | // A zero value disables the StreamOpenTimeout allowing unbounded 37 | // blocking on OpenStream calls. 38 | StreamOpenTimeout time.Duration 39 | 40 | // StreamCloseTimeout is the maximum time that a stream will allowed to 41 | // be in a half-closed state when `Close` is called before forcibly 42 | // closing the connection. Forcibly closed connections will empty the 43 | // receive buffer, drop any future packets received for that stream, 44 | // and send a RST to the remote side. 45 | StreamCloseTimeout time.Duration 46 | 47 | // LogOutput is used to control the log destination. Either Logger or 48 | // LogOutput can be set, not both. 49 | LogOutput io.Writer 50 | 51 | // Logger is used to pass in the logger to be used. Either Logger or 52 | // LogOutput can be set, not both. 53 | Logger Logger 54 | } 55 | 56 | func (c *Config) Clone() *Config { 57 | c2 := *c 58 | return &c2 59 | } 60 | 61 | // DefaultConfig is used to return a default configuration 62 | func DefaultConfig() *Config { 63 | return &Config{ 64 | AcceptBacklog: 256, 65 | EnableKeepAlive: true, 66 | KeepAliveInterval: 30 * time.Second, 67 | ConnectionWriteTimeout: 10 * time.Second, 68 | MaxStreamWindowSize: initialStreamWindow, 69 | StreamCloseTimeout: 5 * time.Minute, 70 | StreamOpenTimeout: 75 * time.Second, 71 | LogOutput: os.Stderr, 72 | } 73 | } 74 | 75 | // VerifyConfig is used to verify the sanity of configuration 76 | func VerifyConfig(config *Config) error { 77 | if config.AcceptBacklog <= 0 { 78 | return fmt.Errorf("backlog must be positive") 79 | } 80 | if config.KeepAliveInterval == 0 { 81 | return fmt.Errorf("keep-alive interval must be positive") 82 | } 83 | if config.MaxStreamWindowSize < initialStreamWindow { 84 | return fmt.Errorf("MaxStreamWindowSize must be larger than %d", initialStreamWindow) 85 | } 86 | if config.LogOutput != nil && config.Logger != nil { 87 | return fmt.Errorf("both Logger and LogOutput may not be set, select one") 88 | } else if config.LogOutput == nil && config.Logger == nil { 89 | return fmt.Errorf("one of Logger or LogOutput must be set, select one") 90 | } 91 | return nil 92 | } 93 | 94 | // Server is used to initialize a new server-side connection. 95 | // There must be at most one server-side connection. If a nil config is 96 | // provided, the DefaultConfiguration will be used. 97 | func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) { 98 | if config == nil { 99 | config = DefaultConfig() 100 | } 101 | if err := VerifyConfig(config); err != nil { 102 | return nil, err 103 | } 104 | return newSession(config, conn, false), nil 105 | } 106 | 107 | // Client is used to initialize a new client-side connection. 108 | // There must be at most one client-side connection. 109 | func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) { 110 | if config == nil { 111 | config = DefaultConfig() 112 | } 113 | 114 | if err := VerifyConfig(config); err != nil { 115 | return nil, err 116 | } 117 | return newSession(config, conn, true), nil 118 | } 119 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "fmt" 8 | "io" 9 | "log" 10 | "math" 11 | "net" 12 | "strings" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | ) 17 | 18 | // Session is used to wrap a reliable ordered connection and to 19 | // multiplex it into multiple streams. 20 | type Session struct { 21 | // remoteGoAway indicates the remote side does 22 | // not want futher connections. Must be first for alignment. 23 | remoteGoAway int32 24 | 25 | // localGoAway indicates that we should stop 26 | // accepting futher connections. Must be first for alignment. 27 | localGoAway int32 28 | 29 | // nextStreamID is the next stream we should 30 | // send. This depends if we are a client/server. 31 | nextStreamID uint32 32 | 33 | // config holds our configuration 34 | config *Config 35 | 36 | // logger is used for our logs 37 | logger Logger 38 | 39 | // conn is the underlying connection 40 | conn io.ReadWriteCloser 41 | 42 | // bufRead is a buffered reader 43 | bufRead *bufio.Reader 44 | 45 | // pings is used to track inflight pings 46 | pings map[uint32]chan struct{} 47 | pingID uint32 48 | pingLock sync.Mutex 49 | 50 | // streams maps a stream id to a stream, and inflight has an entry 51 | // for any outgoing stream that has not yet been established. Both are 52 | // protected by streamLock. 53 | streams map[uint32]*Stream 54 | inflight map[uint32]struct{} 55 | streamLock sync.Mutex 56 | 57 | // synCh acts like a semaphore. It is sized to the AcceptBacklog which 58 | // is assumed to be symmetric between the client and server. This allows 59 | // the client to avoid exceeding the backlog and instead blocks the open. 60 | synCh chan struct{} 61 | 62 | // acceptCh is used to pass ready streams to the client 63 | acceptCh chan *Stream 64 | 65 | // sendCh is used to mark a stream as ready to send, 66 | // or to send a header out directly. 67 | sendCh chan *sendReady 68 | 69 | // recvDoneCh is closed when recv() exits to avoid a race 70 | // between stream registration and stream shutdown 71 | recvDoneCh chan struct{} 72 | sendDoneCh chan struct{} 73 | 74 | // shutdown is used to safely close a session 75 | shutdown bool 76 | shutdownErr error 77 | shutdownCh chan struct{} 78 | shutdownLock sync.Mutex 79 | shutdownErrLock sync.Mutex 80 | } 81 | 82 | // sendReady is used to either mark a stream as ready 83 | // or to directly send a header 84 | type sendReady struct { 85 | Hdr []byte 86 | mu sync.Mutex // Protects Body from unsafe reads. 87 | Body []byte 88 | Err chan error 89 | } 90 | 91 | // newSession is used to construct a new session 92 | func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { 93 | logger := config.Logger 94 | if logger == nil { 95 | logger = log.New(config.LogOutput, "", log.LstdFlags) 96 | } 97 | 98 | s := &Session{ 99 | config: config, 100 | logger: logger, 101 | conn: conn, 102 | bufRead: bufio.NewReader(conn), 103 | pings: make(map[uint32]chan struct{}), 104 | streams: make(map[uint32]*Stream), 105 | inflight: make(map[uint32]struct{}), 106 | synCh: make(chan struct{}, config.AcceptBacklog), 107 | acceptCh: make(chan *Stream, config.AcceptBacklog), 108 | sendCh: make(chan *sendReady, 64), 109 | recvDoneCh: make(chan struct{}), 110 | sendDoneCh: make(chan struct{}), 111 | shutdownCh: make(chan struct{}), 112 | } 113 | if client { 114 | s.nextStreamID = 1 115 | } else { 116 | s.nextStreamID = 2 117 | } 118 | go s.recv() 119 | go s.send() 120 | if config.EnableKeepAlive { 121 | go s.keepalive() 122 | } 123 | return s 124 | } 125 | 126 | // IsClosed does a safe check to see if we have shutdown 127 | func (s *Session) IsClosed() bool { 128 | select { 129 | case <-s.shutdownCh: 130 | return true 131 | default: 132 | return false 133 | } 134 | } 135 | 136 | // CloseChan returns a read-only channel which is closed as 137 | // soon as the session is closed. 138 | func (s *Session) CloseChan() <-chan struct{} { 139 | return s.shutdownCh 140 | } 141 | 142 | // NumStreams returns the number of currently open streams 143 | func (s *Session) NumStreams() int { 144 | s.streamLock.Lock() 145 | num := len(s.streams) 146 | s.streamLock.Unlock() 147 | return num 148 | } 149 | 150 | // Open is used to create a new stream as a net.Conn 151 | func (s *Session) Open() (net.Conn, error) { 152 | conn, err := s.OpenStream() 153 | if err != nil { 154 | return nil, err 155 | } 156 | return conn, nil 157 | } 158 | 159 | // OpenStream is used to create a new stream 160 | func (s *Session) OpenStream() (*Stream, error) { 161 | if s.IsClosed() { 162 | return nil, ErrSessionShutdown 163 | } 164 | if atomic.LoadInt32(&s.remoteGoAway) == 1 { 165 | return nil, ErrRemoteGoAway 166 | } 167 | 168 | // Block if we have too many inflight SYNs 169 | select { 170 | case s.synCh <- struct{}{}: 171 | case <-s.shutdownCh: 172 | return nil, ErrSessionShutdown 173 | } 174 | 175 | GET_ID: 176 | // Get an ID, and check for stream exhaustion 177 | id := atomic.LoadUint32(&s.nextStreamID) 178 | if id >= math.MaxUint32-1 { 179 | return nil, ErrStreamsExhausted 180 | } 181 | if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) { 182 | goto GET_ID 183 | } 184 | 185 | // Register the stream 186 | stream := newStream(s, id, streamInit) 187 | s.streamLock.Lock() 188 | s.streams[id] = stream 189 | s.inflight[id] = struct{}{} 190 | s.streamLock.Unlock() 191 | 192 | if s.config.StreamOpenTimeout > 0 { 193 | go s.setOpenTimeout(stream) 194 | } 195 | 196 | // Send the window update to create 197 | if err := stream.sendWindowUpdate(); err != nil { 198 | select { 199 | case <-s.synCh: 200 | default: 201 | s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") 202 | } 203 | return nil, err 204 | } 205 | return stream, nil 206 | } 207 | 208 | // setOpenTimeout implements a timeout for streams that are opened but not established. 209 | // If the StreamOpenTimeout is exceeded we assume the peer is unable to ACK, 210 | // and close the session. 211 | // The number of running timers is bounded by the capacity of the synCh. 212 | func (s *Session) setOpenTimeout(stream *Stream) { 213 | timer := time.NewTimer(s.config.StreamOpenTimeout) 214 | defer timer.Stop() 215 | 216 | select { 217 | case <-stream.establishCh: 218 | return 219 | case <-s.shutdownCh: 220 | return 221 | case <-timer.C: 222 | // Timeout reached while waiting for ACK. 223 | // Close the session to force connection re-establishment. 224 | s.logger.Printf("[ERR] yamux: aborted stream open (destination=%s): %v", s.RemoteAddr().String(), ErrTimeout.err) 225 | s.Close() 226 | } 227 | } 228 | 229 | // Accept is used to block until the next available stream 230 | // is ready to be accepted. 231 | func (s *Session) Accept() (net.Conn, error) { 232 | conn, err := s.AcceptStream() 233 | if err != nil { 234 | return nil, err 235 | } 236 | return conn, err 237 | } 238 | 239 | // AcceptStream is used to block until the next available stream 240 | // is ready to be accepted. 241 | func (s *Session) AcceptStream() (*Stream, error) { 242 | select { 243 | case stream := <-s.acceptCh: 244 | if err := stream.sendWindowUpdate(); err != nil { 245 | return nil, err 246 | } 247 | return stream, nil 248 | case <-s.shutdownCh: 249 | return nil, s.shutdownErr 250 | } 251 | } 252 | 253 | // AcceptStream is used to block until the next available stream 254 | // is ready to be accepted. 255 | func (s *Session) AcceptStreamWithContext(ctx context.Context) (*Stream, error) { 256 | select { 257 | case <-ctx.Done(): 258 | return nil, ctx.Err() 259 | case stream := <-s.acceptCh: 260 | if err := stream.sendWindowUpdate(); err != nil { 261 | return nil, err 262 | } 263 | return stream, nil 264 | case <-s.shutdownCh: 265 | return nil, s.shutdownErr 266 | } 267 | } 268 | 269 | // Close is used to close the session and all streams. 270 | // Attempts to send a GoAway before closing the connection. 271 | func (s *Session) Close() error { 272 | s.shutdownLock.Lock() 273 | defer s.shutdownLock.Unlock() 274 | 275 | if s.shutdown { 276 | return nil 277 | } 278 | s.shutdown = true 279 | 280 | s.shutdownErrLock.Lock() 281 | if s.shutdownErr == nil { 282 | s.shutdownErr = ErrSessionShutdown 283 | } 284 | s.shutdownErrLock.Unlock() 285 | 286 | close(s.shutdownCh) 287 | 288 | s.conn.Close() 289 | <-s.recvDoneCh 290 | 291 | s.streamLock.Lock() 292 | defer s.streamLock.Unlock() 293 | for _, stream := range s.streams { 294 | stream.forceClose() 295 | } 296 | <-s.sendDoneCh 297 | return nil 298 | } 299 | 300 | // exitErr is used to handle an error that is causing the 301 | // session to terminate. 302 | func (s *Session) exitErr(err error) { 303 | s.shutdownErrLock.Lock() 304 | if s.shutdownErr == nil { 305 | s.shutdownErr = err 306 | } 307 | s.shutdownErrLock.Unlock() 308 | s.Close() 309 | } 310 | 311 | // GoAway can be used to prevent accepting further 312 | // connections. It does not close the underlying conn. 313 | func (s *Session) GoAway() error { 314 | return s.waitForSend(s.goAway(goAwayNormal), nil) 315 | } 316 | 317 | // goAway is used to send a goAway message 318 | func (s *Session) goAway(reason uint32) header { 319 | atomic.SwapInt32(&s.localGoAway, 1) 320 | hdr := header(make([]byte, headerSize)) 321 | hdr.encode(typeGoAway, 0, 0, reason) 322 | return hdr 323 | } 324 | 325 | // Ping is used to measure the RTT response time 326 | func (s *Session) Ping() (time.Duration, error) { 327 | // Get a channel for the ping 328 | ch := make(chan struct{}) 329 | 330 | // Get a new ping id, mark as pending 331 | s.pingLock.Lock() 332 | id := s.pingID 333 | s.pingID++ 334 | s.pings[id] = ch 335 | s.pingLock.Unlock() 336 | 337 | // Send the ping request 338 | hdr := header(make([]byte, headerSize)) 339 | hdr.encode(typePing, flagSYN, 0, id) 340 | if err := s.waitForSend(hdr, nil); err != nil { 341 | return 0, err 342 | } 343 | 344 | // Wait for a response 345 | start := time.Now() 346 | select { 347 | case <-ch: 348 | case <-time.After(s.config.ConnectionWriteTimeout): 349 | s.pingLock.Lock() 350 | delete(s.pings, id) // Ignore it if a response comes later. 351 | s.pingLock.Unlock() 352 | return 0, ErrTimeout 353 | case <-s.shutdownCh: 354 | return 0, ErrSessionShutdown 355 | } 356 | 357 | // Compute the RTT 358 | return time.Since(start), nil 359 | } 360 | 361 | // keepalive is a long running goroutine that periodically does 362 | // a ping to keep the connection alive. 363 | func (s *Session) keepalive() { 364 | for { 365 | select { 366 | case <-time.After(s.config.KeepAliveInterval): 367 | _, err := s.Ping() 368 | if err != nil { 369 | if err != ErrSessionShutdown { 370 | s.logger.Printf("[ERR] yamux: keepalive failed: %v", err) 371 | s.exitErr(ErrKeepAliveTimeout) 372 | } 373 | return 374 | } 375 | case <-s.shutdownCh: 376 | return 377 | } 378 | } 379 | } 380 | 381 | // waitForSendErr waits to send a header, checking for a potential shutdown 382 | func (s *Session) waitForSend(hdr header, body []byte) error { 383 | errCh := make(chan error, 1) 384 | return s.waitForSendErr(hdr, body, errCh) 385 | } 386 | 387 | // waitForSendErr waits to send a header with optional data, checking for a 388 | // potential shutdown. Since there's the expectation that sends can happen 389 | // in a timely manner, we enforce the connection write timeout here. 390 | func (s *Session) waitForSendErr(hdr header, body []byte, errCh chan error) error { 391 | t := timerPool.Get() 392 | timer := t.(*time.Timer) 393 | timer.Reset(s.config.ConnectionWriteTimeout) 394 | defer func() { 395 | timer.Stop() 396 | select { 397 | case <-timer.C: 398 | default: 399 | } 400 | timerPool.Put(t) 401 | }() 402 | 403 | ready := &sendReady{Hdr: hdr, Body: body, Err: errCh} 404 | select { 405 | case s.sendCh <- ready: 406 | case <-s.shutdownCh: 407 | return ErrSessionShutdown 408 | case <-timer.C: 409 | return ErrConnectionWriteTimeout 410 | } 411 | 412 | bodyCopy := func() { 413 | if body == nil { 414 | return // A nil body is ignored. 415 | } 416 | 417 | // In the event of session shutdown or connection write timeout, 418 | // we need to prevent `send` from reading the body buffer after 419 | // returning from this function since the caller may re-use the 420 | // underlying array. 421 | ready.mu.Lock() 422 | defer ready.mu.Unlock() 423 | 424 | if ready.Body == nil { 425 | return // Body was already copied in `send`. 426 | } 427 | newBody := make([]byte, len(body)) 428 | copy(newBody, body) 429 | ready.Body = newBody 430 | } 431 | 432 | select { 433 | case err := <-errCh: 434 | return err 435 | case <-s.shutdownCh: 436 | bodyCopy() 437 | return ErrSessionShutdown 438 | case <-timer.C: 439 | bodyCopy() 440 | return ErrConnectionWriteTimeout 441 | } 442 | } 443 | 444 | // sendNoWait does a send without waiting. Since there's the expectation that 445 | // the send happens right here, we enforce the connection write timeout if we 446 | // can't queue the header to be sent. 447 | func (s *Session) sendNoWait(hdr header) error { 448 | t := timerPool.Get() 449 | timer := t.(*time.Timer) 450 | timer.Reset(s.config.ConnectionWriteTimeout) 451 | defer func() { 452 | timer.Stop() 453 | select { 454 | case <-timer.C: 455 | default: 456 | } 457 | timerPool.Put(t) 458 | }() 459 | 460 | select { 461 | case s.sendCh <- &sendReady{Hdr: hdr}: 462 | return nil 463 | case <-s.shutdownCh: 464 | return ErrSessionShutdown 465 | case <-timer.C: 466 | return ErrConnectionWriteTimeout 467 | } 468 | } 469 | 470 | // send is a long running goroutine that sends data 471 | func (s *Session) send() { 472 | if err := s.sendLoop(); err != nil { 473 | s.exitErr(err) 474 | } 475 | } 476 | 477 | func (s *Session) sendLoop() error { 478 | defer close(s.sendDoneCh) 479 | var bodyBuf bytes.Buffer 480 | for { 481 | bodyBuf.Reset() 482 | 483 | select { 484 | case ready := <-s.sendCh: 485 | // Send a header if ready 486 | if ready.Hdr != nil { 487 | _, err := s.conn.Write(ready.Hdr) 488 | if err != nil { 489 | s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) 490 | asyncSendErr(ready.Err, err) 491 | return err 492 | } 493 | } 494 | 495 | ready.mu.Lock() 496 | if ready.Body != nil { 497 | // Copy the body into the buffer to avoid 498 | // holding a mutex lock during the write. 499 | _, err := bodyBuf.Write(ready.Body) 500 | if err != nil { 501 | ready.Body = nil 502 | ready.mu.Unlock() 503 | s.logger.Printf("[ERR] yamux: Failed to copy body into buffer: %v", err) 504 | asyncSendErr(ready.Err, err) 505 | return err 506 | } 507 | ready.Body = nil 508 | } 509 | ready.mu.Unlock() 510 | 511 | if bodyBuf.Len() > 0 { 512 | // Send data from a body if given 513 | _, err := s.conn.Write(bodyBuf.Bytes()) 514 | if err != nil { 515 | s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) 516 | asyncSendErr(ready.Err, err) 517 | return err 518 | } 519 | } 520 | 521 | // No error, successful send 522 | asyncSendErr(ready.Err, nil) 523 | case <-s.shutdownCh: 524 | return nil 525 | } 526 | } 527 | } 528 | 529 | // recv is a long running goroutine that accepts new data 530 | func (s *Session) recv() { 531 | if err := s.recvLoop(); err != nil { 532 | s.exitErr(err) 533 | } 534 | } 535 | 536 | // Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type 537 | var ( 538 | handlers = []func(*Session, header) error{ 539 | typeData: (*Session).handleStreamMessage, 540 | typeWindowUpdate: (*Session).handleStreamMessage, 541 | typePing: (*Session).handlePing, 542 | typeGoAway: (*Session).handleGoAway, 543 | } 544 | ) 545 | 546 | // recvLoop continues to receive data until a fatal error is encountered 547 | func (s *Session) recvLoop() error { 548 | defer close(s.recvDoneCh) 549 | hdr := header(make([]byte, headerSize)) 550 | for { 551 | // Read the header 552 | if _, err := io.ReadFull(s.bufRead, hdr); err != nil { 553 | if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { 554 | s.logger.Printf("[ERR] yamux: Failed to read header: %v", err) 555 | } 556 | return err 557 | } 558 | 559 | // Verify the version 560 | if hdr.Version() != protoVersion { 561 | s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version()) 562 | return ErrInvalidVersion 563 | } 564 | 565 | mt := hdr.MsgType() 566 | if mt < typeData || mt > typeGoAway { 567 | return ErrInvalidMsgType 568 | } 569 | 570 | if err := handlers[mt](s, hdr); err != nil { 571 | return err 572 | } 573 | } 574 | } 575 | 576 | // handleStreamMessage handles either a data or window update frame 577 | func (s *Session) handleStreamMessage(hdr header) error { 578 | // Check for a new stream creation 579 | id := hdr.StreamID() 580 | flags := hdr.Flags() 581 | if flags&flagSYN == flagSYN { 582 | if err := s.incomingStream(id); err != nil { 583 | return err 584 | } 585 | } 586 | 587 | // Get the stream 588 | s.streamLock.Lock() 589 | stream := s.streams[id] 590 | s.streamLock.Unlock() 591 | 592 | // If we do not have a stream, likely we sent a RST 593 | if stream == nil { 594 | // Drain any data on the wire 595 | if hdr.MsgType() == typeData && hdr.Length() > 0 { 596 | s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id) 597 | if _, err := io.CopyN(io.Discard, s.bufRead, int64(hdr.Length())); err != nil { 598 | s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err) 599 | return nil 600 | } 601 | } else { 602 | s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr) 603 | } 604 | return nil 605 | } 606 | 607 | // Check if this is a window update 608 | if hdr.MsgType() == typeWindowUpdate { 609 | if err := stream.incrSendWindow(hdr, flags); err != nil { 610 | if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 611 | s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 612 | } 613 | return err 614 | } 615 | return nil 616 | } 617 | 618 | // Read the new data 619 | if err := stream.readData(hdr, flags, s.bufRead); err != nil { 620 | if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 621 | s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 622 | } 623 | return err 624 | } 625 | return nil 626 | } 627 | 628 | // handlePing is invokde for a typePing frame 629 | func (s *Session) handlePing(hdr header) error { 630 | flags := hdr.Flags() 631 | pingID := hdr.Length() 632 | 633 | // Check if this is a query, respond back in a separate context so we 634 | // don't interfere with the receiving thread blocking for the write. 635 | if flags&flagSYN == flagSYN { 636 | go func() { 637 | hdr := header(make([]byte, headerSize)) 638 | hdr.encode(typePing, flagACK, 0, pingID) 639 | if err := s.sendNoWait(hdr); err != nil { 640 | s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err) 641 | } 642 | }() 643 | return nil 644 | } 645 | 646 | // Handle a response 647 | s.pingLock.Lock() 648 | ch := s.pings[pingID] 649 | if ch != nil { 650 | delete(s.pings, pingID) 651 | close(ch) 652 | } 653 | s.pingLock.Unlock() 654 | return nil 655 | } 656 | 657 | // handleGoAway is invokde for a typeGoAway frame 658 | func (s *Session) handleGoAway(hdr header) error { 659 | code := hdr.Length() 660 | switch code { 661 | case goAwayNormal: 662 | atomic.SwapInt32(&s.remoteGoAway, 1) 663 | case goAwayProtoErr: 664 | s.logger.Printf("[ERR] yamux: received protocol error go away") 665 | return fmt.Errorf("yamux protocol error") 666 | case goAwayInternalErr: 667 | s.logger.Printf("[ERR] yamux: received internal error go away") 668 | return fmt.Errorf("remote yamux internal error") 669 | default: 670 | s.logger.Printf("[ERR] yamux: received unexpected go away") 671 | return fmt.Errorf("unexpected go away received") 672 | } 673 | return nil 674 | } 675 | 676 | // incomingStream is used to create a new incoming stream 677 | func (s *Session) incomingStream(id uint32) error { 678 | // Reject immediately if we are doing a go away 679 | if atomic.LoadInt32(&s.localGoAway) == 1 { 680 | hdr := header(make([]byte, headerSize)) 681 | hdr.encode(typeWindowUpdate, flagRST, id, 0) 682 | return s.sendNoWait(hdr) 683 | } 684 | 685 | // Allocate a new stream 686 | stream := newStream(s, id, streamSYNReceived) 687 | 688 | s.streamLock.Lock() 689 | defer s.streamLock.Unlock() 690 | 691 | // Check if stream already exists 692 | if _, ok := s.streams[id]; ok { 693 | s.logger.Printf("[ERR] yamux: duplicate stream declared") 694 | if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil { 695 | s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) 696 | } 697 | return ErrDuplicateStream 698 | } 699 | 700 | // Register the stream 701 | s.streams[id] = stream 702 | 703 | // Check if we've exceeded the backlog 704 | select { 705 | case s.acceptCh <- stream: 706 | return nil 707 | default: 708 | // Backlog exceeded! RST the stream 709 | s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset") 710 | delete(s.streams, id) 711 | hdr := header(make([]byte, headerSize)) 712 | hdr.encode(typeWindowUpdate, flagRST, id, 0) 713 | return s.sendNoWait(hdr) 714 | } 715 | } 716 | 717 | // closeStream is used to close a stream once both sides have 718 | // issued a close. If there was an in-flight SYN and the stream 719 | // was not yet established, then this will give the credit back. 720 | func (s *Session) closeStream(id uint32) { 721 | s.streamLock.Lock() 722 | if _, ok := s.inflight[id]; ok { 723 | select { 724 | case <-s.synCh: 725 | default: 726 | s.logger.Printf("[ERR] yamux: SYN tracking out of sync") 727 | } 728 | } 729 | delete(s.streams, id) 730 | s.streamLock.Unlock() 731 | } 732 | 733 | // establishStream is used to mark a stream that was in the 734 | // SYN Sent state as established. 735 | func (s *Session) establishStream(id uint32) { 736 | s.streamLock.Lock() 737 | if _, ok := s.inflight[id]; ok { 738 | delete(s.inflight, id) 739 | } else { 740 | s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)") 741 | } 742 | select { 743 | case <-s.synCh: 744 | default: 745 | s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)") 746 | } 747 | s.streamLock.Unlock() 748 | } 749 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/tls" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "log" 11 | "net" 12 | "reflect" 13 | "runtime" 14 | "strings" 15 | "sync" 16 | "sync/atomic" 17 | "testing" 18 | "time" 19 | ) 20 | 21 | type logCapture struct { 22 | mu sync.Mutex 23 | buf *bytes.Buffer 24 | } 25 | 26 | var _ io.Writer = (*logCapture)(nil) 27 | 28 | func (l *logCapture) Write(p []byte) (n int, err error) { 29 | l.mu.Lock() 30 | defer l.mu.Unlock() 31 | if l.buf == nil { 32 | l.buf = &bytes.Buffer{} 33 | } 34 | return l.buf.Write(p) 35 | } 36 | func (l *logCapture) String() string { 37 | l.mu.Lock() 38 | defer l.mu.Unlock() 39 | return l.buf.String() 40 | } 41 | 42 | func (l *logCapture) logs() []string { 43 | return strings.Split(strings.TrimSpace(l.String()), "\n") 44 | } 45 | 46 | func (l *logCapture) match(expect []string) bool { 47 | return reflect.DeepEqual(l.logs(), expect) 48 | } 49 | 50 | type pipeConn struct { 51 | reader *io.PipeReader 52 | writer *io.PipeWriter 53 | writeBlocker sync.Mutex 54 | } 55 | 56 | func (p *pipeConn) Read(b []byte) (int, error) { 57 | return p.reader.Read(b) 58 | } 59 | 60 | func (p *pipeConn) Write(b []byte) (int, error) { 61 | p.writeBlocker.Lock() 62 | defer p.writeBlocker.Unlock() 63 | return p.writer.Write(b) 64 | } 65 | 66 | func (p *pipeConn) Close() error { 67 | p.reader.Close() 68 | return p.writer.Close() 69 | } 70 | 71 | func testConnPipe(testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { 72 | read1, write1 := io.Pipe() 73 | read2, write2 := io.Pipe() 74 | conn1 := &pipeConn{reader: read1, writer: write2} 75 | conn2 := &pipeConn{reader: read2, writer: write1} 76 | return conn1, conn2 77 | } 78 | 79 | func testConnTCP(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { 80 | l, err := net.ListenTCP("tcp", nil) 81 | if err != nil { 82 | t.Fatalf("error creating listener: %v", err) 83 | } 84 | t.Cleanup(func() { _ = l.Close() }) 85 | 86 | network := l.Addr().Network() 87 | addr := l.Addr().String() 88 | 89 | var server net.Conn 90 | errCh := make(chan error, 1) 91 | go func() { 92 | defer close(errCh) 93 | var err error 94 | server, err = l.Accept() 95 | if err != nil { 96 | errCh <- err 97 | return 98 | } 99 | }() 100 | 101 | t.Logf("Connecting to %s: %s", network, addr) 102 | client, err := net.DialTimeout(network, addr, 10*time.Second) 103 | if err != nil { 104 | t.Fatalf("error dialing tls listener: %v", err) 105 | } 106 | t.Cleanup(func() { _ = client.Close() }) 107 | 108 | if err := <-errCh; err != nil { 109 | t.Fatalf("error creating tls server: %v", err) 110 | } 111 | t.Cleanup(func() { _ = server.Close() }) 112 | 113 | return client, server 114 | } 115 | 116 | func testConnTLS(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { 117 | cert, err := tls.LoadX509KeyPair("testdata/cert.pem", "testdata/key.pem") 118 | if err != nil { 119 | t.Fatalf("error loading certificate: %v", err) 120 | } 121 | 122 | l, err := net.ListenTCP("tcp", nil) 123 | if err != nil { 124 | t.Fatalf("error creating listener: %v", err) 125 | } 126 | t.Cleanup(func() { _ = l.Close() }) 127 | 128 | var server net.Conn 129 | errCh := make(chan error, 1) 130 | go func() { 131 | defer close(errCh) 132 | conn, err := l.Accept() 133 | if err != nil { 134 | errCh <- err 135 | return 136 | } 137 | 138 | server = tls.Server(conn, &tls.Config{ 139 | Certificates: []tls.Certificate{cert}, 140 | }) 141 | }() 142 | 143 | t.Logf("Connecting to %s: %s", l.Addr().Network(), l.Addr()) 144 | client, err := net.DialTimeout(l.Addr().Network(), l.Addr().String(), 10*time.Second) 145 | if err != nil { 146 | t.Fatalf("error dialing tls listener: %v", err) 147 | } 148 | t.Cleanup(func() { _ = client.Close() }) 149 | 150 | tlsClient := tls.Client(client, &tls.Config{ 151 | // InsecureSkipVerify is safe to use here since this is only for tests. 152 | InsecureSkipVerify: true, 153 | }) 154 | 155 | if err := <-errCh; err != nil { 156 | t.Fatalf("error creating tls server: %v", err) 157 | } 158 | t.Cleanup(func() { _ = server.Close() }) 159 | 160 | return tlsClient, server 161 | } 162 | 163 | // connTypeFunc is func that returns a client and server connection for testing 164 | // like testConnTLS. 165 | // 166 | // See connTypeTest 167 | type connTypeFunc func(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) 168 | 169 | // connTypeTest is a test case for a specific conn type. 170 | // 171 | // See testConnType 172 | type connTypeTest struct { 173 | Name string 174 | Conns connTypeFunc 175 | } 176 | 177 | // testConnType runs subtests of the given testFunc against multiple connection 178 | // types. 179 | func testConnTypes(t *testing.T, testFunc func(t testing.TB, client, server io.ReadWriteCloser)) { 180 | reverse := func(f connTypeFunc) connTypeFunc { 181 | return func(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { 182 | c, s := f(t) 183 | return s, c 184 | } 185 | } 186 | cases := []connTypeTest{ 187 | { 188 | Name: "Pipes", 189 | Conns: testConnPipe, 190 | }, 191 | { 192 | Name: "TCP", 193 | Conns: testConnTCP, 194 | }, 195 | { 196 | Name: "TCP_Reverse", 197 | Conns: reverse(testConnTCP), 198 | }, 199 | { 200 | Name: "TLS", 201 | Conns: testConnTLS, 202 | }, 203 | { 204 | Name: "TLS_Reverse", 205 | Conns: reverse(testConnTLS), 206 | }, 207 | } 208 | for i := range cases { 209 | tc := cases[i] 210 | t.Run(tc.Name, func(t *testing.T) { 211 | client, server := tc.Conns(t) 212 | testFunc(t, client, server) 213 | }) 214 | } 215 | } 216 | 217 | func testConf() *Config { 218 | conf := DefaultConfig() 219 | conf.AcceptBacklog = 64 220 | conf.KeepAliveInterval = 100 * time.Millisecond 221 | conf.ConnectionWriteTimeout = 250 * time.Millisecond 222 | return conf 223 | } 224 | 225 | func captureLogs(conf *Config) *logCapture { 226 | buf := new(logCapture) 227 | conf.Logger = log.New(buf, "", 0) 228 | conf.LogOutput = nil 229 | return buf 230 | } 231 | 232 | func testConfNoKeepAlive() *Config { 233 | conf := testConf() 234 | conf.EnableKeepAlive = false 235 | return conf 236 | } 237 | 238 | func testClientServer(t testing.TB) (*Session, *Session) { 239 | client, server := testConnTLS(t) 240 | return testClientServerConfig(t, client, server, testConf(), testConf()) 241 | } 242 | 243 | func testClientServerConfig( 244 | t testing.TB, 245 | clientConn, serverConn io.ReadWriteCloser, 246 | clientConf, serverConf *Config, 247 | ) (clientSession *Session, serverSession *Session) { 248 | 249 | var err error 250 | 251 | clientSession, err = Client(clientConn, clientConf) 252 | if err != nil { 253 | t.Fatalf("err: %v", err) 254 | } 255 | t.Cleanup(func() { _ = clientSession.Close() }) 256 | 257 | serverSession, err = Server(serverConn, serverConf) 258 | if err != nil { 259 | t.Fatalf("err: %v", err) 260 | } 261 | t.Cleanup(func() { _ = serverSession.Close() }) 262 | return clientSession, serverSession 263 | } 264 | 265 | func TestPing(t *testing.T) { 266 | client, server := testClientServer(t) 267 | 268 | rtt, err := client.Ping() 269 | if err != nil { 270 | t.Fatalf("err: %v", err) 271 | } 272 | if rtt == 0 { 273 | t.Fatalf("bad: %v", rtt) 274 | } 275 | 276 | rtt, err = server.Ping() 277 | if err != nil { 278 | t.Fatalf("err: %v", err) 279 | } 280 | if rtt == 0 { 281 | t.Fatalf("bad: %v", rtt) 282 | } 283 | } 284 | 285 | func TestPing_Timeout(t *testing.T) { 286 | conf := testConfNoKeepAlive() 287 | clientPipe, serverPipe := testConnPipe(t) 288 | client, server := testClientServerConfig(t, clientPipe, serverPipe, conf.Clone(), conf.Clone()) 289 | 290 | // Prevent the client from responding 291 | clientConn := client.conn.(*pipeConn) 292 | clientConn.writeBlocker.Lock() 293 | 294 | errCh := make(chan error, 1) 295 | go func() { 296 | _, err := server.Ping() // Ping via the server session 297 | errCh <- err 298 | }() 299 | 300 | select { 301 | case err := <-errCh: 302 | if err != ErrTimeout { 303 | t.Fatalf("err: %v", err) 304 | } 305 | case <-time.After(client.config.ConnectionWriteTimeout * 2): 306 | t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout) 307 | } 308 | 309 | // Verify that we recover, even if we gave up 310 | clientConn.writeBlocker.Unlock() 311 | 312 | go func() { 313 | _, err := server.Ping() // Ping via the server session 314 | errCh <- err 315 | }() 316 | 317 | select { 318 | case err := <-errCh: 319 | if err != nil { 320 | t.Fatalf("err: %v", err) 321 | } 322 | case <-time.After(client.config.ConnectionWriteTimeout): 323 | t.Fatalf("timeout") 324 | } 325 | } 326 | 327 | func TestCloseBeforeAck(t *testing.T) { 328 | testConnTypes(t, func(t testing.TB, clientConn, serverConn io.ReadWriteCloser) { 329 | cfg := testConf() 330 | cfg.AcceptBacklog = 8 331 | client, server := testClientServerConfig(t, clientConn, serverConn, cfg, cfg.Clone()) 332 | 333 | for i := 0; i < 8; i++ { 334 | s, err := client.OpenStream() 335 | if err != nil { 336 | t.Fatal(err) 337 | } 338 | s.Close() 339 | } 340 | 341 | for i := 0; i < 8; i++ { 342 | s, err := server.AcceptStream() 343 | if err != nil { 344 | t.Fatal(err) 345 | } 346 | s.Close() 347 | } 348 | 349 | errCh := make(chan error, 1) 350 | go func() { 351 | s, err := client.OpenStream() 352 | if err != nil { 353 | errCh <- err 354 | return 355 | } 356 | s.Close() 357 | errCh <- nil 358 | }() 359 | 360 | drainErrorsUntil(t, errCh, 1, time.Second*5, "timed out trying to open stream") 361 | }) 362 | } 363 | 364 | func TestAccept(t *testing.T) { 365 | client, server := testClientServer(t) 366 | 367 | if client.NumStreams() != 0 { 368 | t.Fatalf("bad") 369 | } 370 | if server.NumStreams() != 0 { 371 | t.Fatalf("bad") 372 | } 373 | 374 | errCh := make(chan error, 4) 375 | acceptOne := func(streamFunc func() (*Stream, error), expectID uint32) { 376 | stream, err := streamFunc() 377 | if err != nil { 378 | errCh <- err 379 | return 380 | } 381 | if id := stream.StreamID(); id != expectID { 382 | errCh <- fmt.Errorf("bad: %v", id) 383 | return 384 | } 385 | if err := stream.Close(); err != nil { 386 | errCh <- err 387 | return 388 | } 389 | errCh <- nil 390 | } 391 | 392 | go acceptOne(server.AcceptStream, 1) 393 | go acceptOne(client.AcceptStream, 2) 394 | go acceptOne(server.OpenStream, 2) 395 | go acceptOne(client.OpenStream, 1) 396 | 397 | drainErrorsUntil(t, errCh, 4, time.Second, "timeout") 398 | } 399 | 400 | func TestOpenStreamTimeout(t *testing.T) { 401 | const timeout = 25 * time.Millisecond 402 | 403 | testConnTypes(t, func(t testing.TB, clientConn, serverConn io.ReadWriteCloser) { 404 | serverConf := testConf() 405 | serverConf.StreamOpenTimeout = timeout 406 | 407 | clientConf := serverConf.Clone() 408 | clientLogs := captureLogs(clientConf) 409 | 410 | client, _ := testClientServerConfig(t, clientConn, serverConn, clientConf, serverConf) 411 | 412 | // Open a single stream without a server to acknowledge it. 413 | s, err := client.OpenStream() 414 | if err != nil { 415 | t.Fatal(err) 416 | } 417 | 418 | // Sleep for longer than the stream open timeout. 419 | // Since no ACKs are received, the stream and session should be closed. 420 | time.Sleep(timeout * 5) 421 | 422 | // Support multiple underlying connection types 423 | var dest string 424 | switch conn := clientConn.(type) { 425 | case net.Conn: 426 | dest = conn.RemoteAddr().String() 427 | case *pipeConn: 428 | dest = "yamux:remote" 429 | default: 430 | t.Fatalf("unsupported connection type %T - please update test", conn) 431 | } 432 | exp := fmt.Sprintf("[ERR] yamux: aborted stream open (destination=%s): i/o deadline reached", dest) 433 | 434 | if !clientLogs.match([]string{exp}) { 435 | t.Fatalf("server log incorect: %v\nexpected: %v", clientLogs.logs(), exp) 436 | } 437 | 438 | s.stateLock.Lock() 439 | state := s.state 440 | s.stateLock.Unlock() 441 | 442 | if state != streamClosed { 443 | t.Fatalf("stream should have been closed") 444 | } 445 | if !client.IsClosed() { 446 | t.Fatalf("session should have been closed") 447 | } 448 | }) 449 | } 450 | 451 | func TestClose_closeTimeout(t *testing.T) { 452 | conf := testConf() 453 | conf.StreamCloseTimeout = 10 * time.Millisecond 454 | clientConn, serverConn := testConnTLS(t) 455 | client, server := testClientServerConfig(t, clientConn, serverConn, conf, conf.Clone()) 456 | 457 | if client.NumStreams() != 0 { 458 | t.Fatalf("bad") 459 | } 460 | if server.NumStreams() != 0 { 461 | t.Fatalf("bad") 462 | } 463 | 464 | errCh := make(chan error, 2) 465 | 466 | // Open a stream on the client but only close it on the server. 467 | // We want to see if the stream ever gets cleaned up on the client. 468 | 469 | var clientStream *Stream 470 | go func() { 471 | var err error 472 | clientStream, err = client.OpenStream() 473 | errCh <- err 474 | }() 475 | 476 | go func() { 477 | stream, err := server.AcceptStream() 478 | if err != nil { 479 | errCh <- err 480 | return 481 | } 482 | if err := stream.Close(); err != nil { 483 | errCh <- err 484 | return 485 | } 486 | errCh <- nil 487 | }() 488 | 489 | drainErrorsUntil(t, errCh, 2, time.Second, "timeout") 490 | 491 | // We should have zero streams after our timeout period 492 | time.Sleep(100 * time.Millisecond) 493 | 494 | if v := server.NumStreams(); v > 0 { 495 | t.Fatalf("should have zero streams: %d", v) 496 | } 497 | if v := client.NumStreams(); v > 0 { 498 | t.Fatalf("should have zero streams: %d", v) 499 | } 500 | 501 | if _, err := clientStream.Write([]byte("hello")); err == nil { 502 | t.Fatal("should error on write") 503 | } else if err.Error() != "connection reset" { 504 | t.Fatalf("expected connection reset, got %q", err) 505 | } 506 | } 507 | 508 | func TestNonNilInterface(t *testing.T) { 509 | _, server := testClientServer(t) 510 | server.Close() 511 | 512 | conn, err := server.Accept() 513 | if err == nil || !errors.Is(err, ErrSessionShutdown) || conn != nil { 514 | t.Fatal("bad: accept should return a shutdown error and a connection of nil value") 515 | } 516 | if err != nil && conn != nil { 517 | t.Error("bad: accept should return a connection of nil value") 518 | } 519 | 520 | conn, err = server.Open() 521 | if err == nil || !errors.Is(err, ErrSessionShutdown) || conn != nil { 522 | t.Fatal("bad: open should return a shutdown error and a connection of nil value") 523 | } 524 | } 525 | 526 | func TestSendData_Small(t *testing.T) { 527 | client, server := testClientServer(t) 528 | 529 | errCh := make(chan error, 2) 530 | 531 | // Accept an incoming client and perform some reads before closing 532 | go func() { 533 | stream, err := server.AcceptStream() 534 | if err != nil { 535 | errCh <- err 536 | return 537 | } 538 | 539 | if server.NumStreams() != 1 { 540 | errCh <- fmt.Errorf("bad") 541 | return 542 | } 543 | 544 | buf := make([]byte, 4) 545 | for i := 0; i < 1000; i++ { 546 | n, err := stream.Read(buf) 547 | if err != nil { 548 | errCh <- err 549 | return 550 | } 551 | if n != 4 { 552 | errCh <- fmt.Errorf("short read: %d", n) 553 | return 554 | } 555 | if string(buf) != "test" { 556 | errCh <- fmt.Errorf("bad: %s", buf) 557 | return 558 | } 559 | } 560 | 561 | if err := stream.Close(); err != nil { 562 | errCh <- err 563 | return 564 | } 565 | errCh <- nil 566 | }() 567 | 568 | // Open a client and perform some writes before closing 569 | go func() { 570 | stream, err := client.Open() 571 | if err != nil { 572 | errCh <- err 573 | return 574 | } 575 | 576 | if client.NumStreams() != 1 { 577 | errCh <- fmt.Errorf("bad") 578 | return 579 | } 580 | 581 | for i := 0; i < 1000; i++ { 582 | n, err := stream.Write([]byte("test")) 583 | if err != nil { 584 | errCh <- err 585 | return 586 | } 587 | if n != 4 { 588 | errCh <- fmt.Errorf("short write %d", n) 589 | return 590 | } 591 | } 592 | 593 | if err := stream.Close(); err != nil { 594 | errCh <- err 595 | return 596 | } 597 | errCh <- nil 598 | }() 599 | 600 | drainErrorsUntil(t, errCh, 2, 5*time.Second, "timeout") 601 | 602 | // Give client and server a second to receive FINs and close streams 603 | time.Sleep(time.Second) 604 | 605 | if n := client.NumStreams(); n != 0 { 606 | t.Errorf("expected 0 client streams but found %d", n) 607 | } 608 | if n := server.NumStreams(); n != 0 { 609 | t.Errorf("expected 0 server streams but found %d", n) 610 | } 611 | } 612 | 613 | func TestSendData_Large(t *testing.T) { 614 | if testing.Short() { 615 | t.Skip("skipping slow test that may time out on the race detector") 616 | } 617 | client, server := testClientServer(t) 618 | 619 | const ( 620 | sendSize = 250 * 1024 * 1024 621 | recvSize = 4 * 1024 622 | ) 623 | 624 | data := make([]byte, sendSize) 625 | for idx := range data { 626 | data[idx] = byte(idx % 256) 627 | } 628 | 629 | errCh := make(chan error, 2) 630 | 631 | go func() { 632 | stream, err := server.AcceptStream() 633 | if err != nil { 634 | errCh <- err 635 | return 636 | } 637 | var sz int 638 | buf := make([]byte, recvSize) 639 | for i := 0; i < sendSize/recvSize; i++ { 640 | n, err := stream.Read(buf) 641 | if err != nil { 642 | errCh <- err 643 | return 644 | } 645 | if n != recvSize { 646 | errCh <- fmt.Errorf("short read: %d", n) 647 | return 648 | } 649 | sz += n 650 | for idx := range buf { 651 | if buf[idx] != byte(idx%256) { 652 | errCh <- fmt.Errorf("bad: %v %v %v", i, idx, buf[idx]) 653 | return 654 | } 655 | } 656 | } 657 | 658 | if err := stream.Close(); err != nil { 659 | errCh <- err 660 | return 661 | } 662 | 663 | t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz) 664 | errCh <- nil 665 | }() 666 | 667 | go func() { 668 | stream, err := client.Open() 669 | if err != nil { 670 | errCh <- err 671 | return 672 | } 673 | 674 | n, err := stream.Write(data) 675 | if err != nil { 676 | errCh <- err 677 | return 678 | } 679 | if n != len(data) { 680 | errCh <- fmt.Errorf("short write %d", n) 681 | return 682 | } 683 | 684 | if err := stream.Close(); err != nil { 685 | errCh <- err 686 | return 687 | } 688 | errCh <- nil 689 | }() 690 | 691 | drainErrorsUntil(t, errCh, 2, 10*time.Second, "timeout") 692 | } 693 | 694 | func TestGoAway(t *testing.T) { 695 | client, server := testClientServer(t) 696 | 697 | if err := server.GoAway(); err != nil { 698 | t.Fatalf("err: %v", err) 699 | } 700 | 701 | // Give the other side time to process the goaway after receiving it. 702 | time.Sleep(100 * time.Millisecond) 703 | 704 | _, err := client.Open() 705 | if err != ErrRemoteGoAway { 706 | t.Fatalf("err: %v", err) 707 | } 708 | } 709 | 710 | func TestManyStreams(t *testing.T) { 711 | client, server := testClientServer(t) 712 | 713 | const streams = 50 714 | 715 | errCh := make(chan error, 2*streams) 716 | 717 | acceptor := func() { 718 | stream, err := server.AcceptStream() 719 | if err != nil { 720 | errCh <- err 721 | return 722 | } 723 | defer stream.Close() 724 | 725 | buf := make([]byte, 512) 726 | for { 727 | n, err := stream.Read(buf) 728 | if err == io.EOF { 729 | errCh <- nil 730 | return 731 | } 732 | if err != nil { 733 | errCh <- err 734 | return 735 | } 736 | if n == 0 { 737 | errCh <- fmt.Errorf("no bytes read") 738 | return 739 | } 740 | } 741 | } 742 | sender := func(id int) { 743 | stream, err := client.Open() 744 | if err != nil { 745 | errCh <- err 746 | return 747 | } 748 | defer stream.Close() 749 | 750 | msg := fmt.Sprintf("%08d", id) 751 | for i := 0; i < 1000; i++ { 752 | n, err := stream.Write([]byte(msg)) 753 | if err != nil { 754 | errCh <- err 755 | return 756 | } 757 | if n != len(msg) { 758 | errCh <- fmt.Errorf("short write %d", n) 759 | return 760 | } 761 | } 762 | errCh <- nil 763 | } 764 | 765 | for i := 0; i < streams; i++ { 766 | go acceptor() 767 | go sender(i) 768 | } 769 | 770 | drainErrorsUntil(t, errCh, 2*streams, 0, "") 771 | } 772 | 773 | func TestManyStreams_PingPong(t *testing.T) { 774 | client, server := testClientServer(t) 775 | 776 | const streams = 50 777 | 778 | errCh := make(chan error, 2*streams) 779 | 780 | ping := []byte("ping") 781 | pong := []byte("pong") 782 | 783 | acceptor := func() { 784 | stream, err := server.AcceptStream() 785 | if err != nil { 786 | errCh <- err 787 | return 788 | } 789 | defer stream.Close() 790 | 791 | buf := make([]byte, 4) 792 | for { 793 | // Read the 'ping' 794 | n, err := stream.Read(buf) 795 | if err == io.EOF { 796 | errCh <- nil 797 | return 798 | } 799 | if err != nil { 800 | errCh <- err 801 | return 802 | } 803 | if n != 4 { 804 | errCh <- fmt.Errorf("short read %d", n) 805 | return 806 | } 807 | if !bytes.Equal(buf, ping) { 808 | errCh <- fmt.Errorf("bad: %s", buf) 809 | return 810 | } 811 | 812 | // Shrink the internal buffer! 813 | stream.Shrink() 814 | 815 | // Write out the 'pong' 816 | n, err = stream.Write(pong) 817 | if err != nil { 818 | errCh <- err 819 | return 820 | } 821 | if n != 4 { 822 | errCh <- fmt.Errorf("short write %d", n) 823 | return 824 | } 825 | } 826 | } 827 | sender := func() { 828 | stream, err := client.OpenStream() 829 | if err != nil { 830 | errCh <- err 831 | return 832 | } 833 | defer stream.Close() 834 | 835 | buf := make([]byte, 4) 836 | for i := 0; i < 1000; i++ { 837 | // Send the 'ping' 838 | n, err := stream.Write(ping) 839 | if err != nil { 840 | errCh <- err 841 | return 842 | } 843 | if n != 4 { 844 | errCh <- fmt.Errorf("short write %d", n) 845 | return 846 | } 847 | 848 | // Read the 'pong' 849 | n, err = stream.Read(buf) 850 | if err != nil { 851 | errCh <- err 852 | return 853 | } 854 | if n != 4 { 855 | errCh <- fmt.Errorf("short read %d", n) 856 | return 857 | } 858 | if !bytes.Equal(buf, pong) { 859 | errCh <- fmt.Errorf("bad: %s", buf) 860 | return 861 | } 862 | 863 | // Shrink the buffer 864 | stream.Shrink() 865 | } 866 | errCh <- nil 867 | } 868 | 869 | for i := 0; i < streams; i++ { 870 | go acceptor() 871 | go sender() 872 | } 873 | 874 | drainErrorsUntil(t, errCh, 2*streams, 0, "") 875 | } 876 | 877 | // TestHalfClose asserts that half closed streams can still read. 878 | func TestHalfClose(t *testing.T) { 879 | testConnTypes(t, func(t testing.TB, clientConn, serverConn io.ReadWriteCloser) { 880 | client, server := testClientServerConfig(t, clientConn, serverConn, testConf(), testConf()) 881 | 882 | clientStream, err := client.Open() 883 | if err != nil { 884 | t.Fatalf("err: %v", err) 885 | } 886 | if _, err = clientStream.Write([]byte("a")); err != nil { 887 | t.Fatalf("err: %v", err) 888 | } 889 | 890 | serverStream, err := server.Accept() 891 | if err != nil { 892 | t.Fatalf("err: %v", err) 893 | } 894 | serverStream.Close() // Half close 895 | 896 | // Server reads 1 byte written by Client 897 | buf := make([]byte, 4) 898 | n, err := serverStream.Read(buf) 899 | if err != nil { 900 | t.Fatalf("err: %v", err) 901 | } 902 | if n != 1 { 903 | t.Fatalf("bad: %v", n) 904 | } 905 | 906 | // Send more 907 | if _, err = clientStream.Write([]byte("bcd")); err != nil { 908 | t.Fatalf("err: %v", err) 909 | } 910 | clientStream.Close() 911 | 912 | // Read after close always returns the bytes written but may or may not 913 | // receive the EOF. 914 | n, err = serverStream.Read(buf) 915 | if err != nil { 916 | t.Fatalf("err: %v", err) 917 | } 918 | if n != 3 { 919 | t.Fatalf("bad: %v", n) 920 | } 921 | 922 | // EOF after close 923 | n, err = serverStream.Read(buf) 924 | if err != io.EOF { 925 | t.Fatalf("err: %v", err) 926 | } 927 | if n != 0 { 928 | t.Fatalf("bad: %v", n) 929 | } 930 | }) 931 | } 932 | 933 | func TestHalfCloseSessionShutdown(t *testing.T) { 934 | client, server := testClientServer(t) 935 | 936 | // dataSize must be large enough to ensure the server will send a window 937 | // update 938 | dataSize := int64(server.config.MaxStreamWindowSize) 939 | 940 | data := make([]byte, dataSize) 941 | for idx := range data { 942 | data[idx] = byte(idx % 256) 943 | } 944 | 945 | stream, err := client.Open() 946 | if err != nil { 947 | t.Fatalf("err: %v", err) 948 | } 949 | if _, err = stream.Write(data); err != nil { 950 | t.Fatalf("err: %v", err) 951 | } 952 | 953 | stream2, err := server.Accept() 954 | if err != nil { 955 | t.Fatalf("err: %v", err) 956 | } 957 | 958 | if err := stream.Close(); err != nil { 959 | t.Fatalf("err: %v", err) 960 | } 961 | 962 | // Shut down the session of the sending side. This should not cause reads 963 | // to fail on the receiving side. 964 | if err := client.Close(); err != nil { 965 | t.Fatalf("err: %v", err) 966 | } 967 | 968 | buf := make([]byte, dataSize) 969 | n, err := stream2.Read(buf) 970 | if err != nil { 971 | t.Fatalf("err: %v", err) 972 | } 973 | if int64(n) != dataSize { 974 | t.Fatalf("bad: %v", n) 975 | } 976 | 977 | // EOF after close 978 | n, err = stream2.Read(buf) 979 | if err != io.EOF { 980 | t.Fatalf("err: %v", err) 981 | } 982 | if n != 0 { 983 | t.Fatalf("bad: %v", n) 984 | } 985 | } 986 | 987 | func TestReadDeadline(t *testing.T) { 988 | client, server := testClientServer(t) 989 | 990 | stream, err := client.Open() 991 | if err != nil { 992 | t.Fatalf("err: %v", err) 993 | } 994 | defer stream.Close() 995 | 996 | stream2, err := server.Accept() 997 | if err != nil { 998 | t.Fatalf("err: %v", err) 999 | } 1000 | defer stream2.Close() 1001 | 1002 | if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil { 1003 | t.Fatalf("err: %v", err) 1004 | } 1005 | 1006 | buf := make([]byte, 4) 1007 | _, err = stream.Read(buf) 1008 | if err != ErrTimeout { 1009 | t.Fatalf("err: %v", err) 1010 | } 1011 | 1012 | // See https://github.com/hashicorp/yamux/issues/90 1013 | // The standard library's http server package will read from connections in 1014 | // the background to detect if they are alive. 1015 | // 1016 | // It sets a read deadline on connections and detect if the returned error 1017 | // is a network timeout error which implements net.Error. 1018 | // 1019 | // The HTTP server will cancel all server requests if it isn't timeout error 1020 | // from the connection. 1021 | // 1022 | // We assert that we return an error meeting the interface to avoid 1023 | // accidently breaking yamux session compatability with the standard 1024 | // library's http server implementation. 1025 | if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { 1026 | t.Fatalf("reading timeout error is expected to implement net.Error and return true when calling Timeout()") 1027 | } 1028 | } 1029 | 1030 | func TestReadDeadline_BlockedRead(t *testing.T) { 1031 | client, server := testClientServer(t) 1032 | 1033 | stream, err := client.Open() 1034 | if err != nil { 1035 | t.Fatalf("err: %v", err) 1036 | } 1037 | defer stream.Close() 1038 | 1039 | stream2, err := server.Accept() 1040 | if err != nil { 1041 | t.Fatalf("err: %v", err) 1042 | } 1043 | defer stream2.Close() 1044 | 1045 | // Start a read that will block 1046 | errCh := make(chan error, 1) 1047 | go func() { 1048 | buf := make([]byte, 4) 1049 | _, err := stream.Read(buf) 1050 | errCh <- err 1051 | close(errCh) 1052 | }() 1053 | 1054 | // Wait to ensure the read has started. 1055 | time.Sleep(5 * time.Millisecond) 1056 | 1057 | // Update the read deadline 1058 | if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil { 1059 | t.Fatalf("err: %v", err) 1060 | } 1061 | 1062 | select { 1063 | case <-time.After(100 * time.Millisecond): 1064 | t.Fatal("expected read timeout") 1065 | case err := <-errCh: 1066 | if err != ErrTimeout { 1067 | t.Fatalf("expected ErrTimeout; got %v", err) 1068 | } 1069 | } 1070 | } 1071 | 1072 | func TestWriteDeadline(t *testing.T) { 1073 | client, server := testClientServer(t) 1074 | 1075 | stream, err := client.Open() 1076 | if err != nil { 1077 | t.Fatalf("err: %v", err) 1078 | } 1079 | defer stream.Close() 1080 | 1081 | stream2, err := server.Accept() 1082 | if err != nil { 1083 | t.Fatalf("err: %v", err) 1084 | } 1085 | defer stream2.Close() 1086 | 1087 | if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil { 1088 | t.Fatalf("err: %v", err) 1089 | } 1090 | 1091 | buf := make([]byte, 512) 1092 | for i := 0; i < int(initialStreamWindow); i++ { 1093 | _, err := stream.Write(buf) 1094 | if err != nil && err == ErrTimeout { 1095 | return 1096 | } else if err != nil { 1097 | t.Fatalf("err: %v", err) 1098 | } 1099 | } 1100 | t.Fatalf("Expected timeout") 1101 | } 1102 | 1103 | func TestWriteDeadline_BlockedWrite(t *testing.T) { 1104 | client, server := testClientServer(t) 1105 | 1106 | stream, err := client.Open() 1107 | if err != nil { 1108 | t.Fatalf("err: %v", err) 1109 | } 1110 | defer stream.Close() 1111 | 1112 | stream2, err := server.Accept() 1113 | if err != nil { 1114 | t.Fatalf("err: %v", err) 1115 | } 1116 | defer stream2.Close() 1117 | 1118 | // Start a goroutine making writes that will block 1119 | errCh := make(chan error, 1) 1120 | go func() { 1121 | buf := make([]byte, 512) 1122 | for i := 0; i < int(initialStreamWindow); i++ { 1123 | _, err := stream.Write(buf) 1124 | if err == nil { 1125 | continue 1126 | } 1127 | 1128 | errCh <- err 1129 | close(errCh) 1130 | return 1131 | } 1132 | 1133 | close(errCh) 1134 | }() 1135 | 1136 | // Wait to ensure the write has started. 1137 | time.Sleep(5 * time.Millisecond) 1138 | 1139 | // Update the write deadline 1140 | if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil { 1141 | t.Fatalf("err: %v", err) 1142 | } 1143 | 1144 | select { 1145 | case <-time.After(1 * time.Second): 1146 | t.Fatal("expected write timeout") 1147 | case err := <-errCh: 1148 | if err != ErrTimeout { 1149 | t.Fatalf("expected ErrTimeout; got %v", err) 1150 | } 1151 | } 1152 | } 1153 | 1154 | func TestBacklogExceeded(t *testing.T) { 1155 | client, server := testClientServer(t) 1156 | 1157 | // Fill the backlog 1158 | max := client.config.AcceptBacklog 1159 | for i := 0; i < max; i++ { 1160 | stream, err := client.Open() 1161 | if err != nil { 1162 | t.Fatalf("err: %v", err) 1163 | } 1164 | defer stream.Close() 1165 | 1166 | if _, err := stream.Write([]byte("foo")); err != nil { 1167 | t.Fatalf("err: %v", err) 1168 | } 1169 | } 1170 | 1171 | // Attempt to open a new stream 1172 | errCh := make(chan error, 1) 1173 | go func() { 1174 | _, err := client.Open() 1175 | errCh <- err 1176 | }() 1177 | 1178 | // Shutdown the server 1179 | go func() { 1180 | time.Sleep(10 * time.Millisecond) 1181 | server.Close() 1182 | }() 1183 | 1184 | select { 1185 | case err := <-errCh: 1186 | if err == nil { 1187 | t.Fatalf("open should fail") 1188 | } 1189 | case <-time.After(time.Second): 1190 | t.Fatalf("timeout") 1191 | } 1192 | } 1193 | 1194 | func TestKeepAlive(t *testing.T) { 1195 | testConnTypes(t, func(t testing.TB, clientConn, serverConn io.ReadWriteCloser) { 1196 | client, server := testClientServerConfig(t, clientConn, serverConn, testConf(), testConf()) 1197 | 1198 | // Give keepalives time to happen 1199 | time.Sleep(200 * time.Millisecond) 1200 | 1201 | // Ping value should increase 1202 | client.pingLock.Lock() 1203 | defer client.pingLock.Unlock() 1204 | if client.pingID == 0 { 1205 | t.Fatalf("should ping") 1206 | } 1207 | 1208 | server.pingLock.Lock() 1209 | defer server.pingLock.Unlock() 1210 | if server.pingID == 0 { 1211 | t.Fatalf("should ping") 1212 | } 1213 | }) 1214 | } 1215 | 1216 | func TestKeepAlive_Timeout(t *testing.T) { 1217 | conn1, conn2 := testConnPipe(t) 1218 | 1219 | clientConf := testConf() 1220 | clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes 1221 | clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom 1222 | _ = captureLogs(clientConf) // Client logs aren't part of the test 1223 | client, err := Client(conn1, clientConf) 1224 | if err != nil { 1225 | t.Fatalf("err: %v", err) 1226 | } 1227 | defer client.Close() 1228 | 1229 | serverConf := testConf() 1230 | serverLogs := captureLogs(serverConf) 1231 | server, err := Server(conn2, serverConf) 1232 | if err != nil { 1233 | t.Fatalf("err: %v", err) 1234 | } 1235 | defer server.Close() 1236 | 1237 | errCh := make(chan error, 1) 1238 | go func() { 1239 | _, err := server.Accept() // Wait until server closes 1240 | errCh <- err 1241 | }() 1242 | 1243 | // Prevent the client from responding 1244 | clientConn := client.conn.(*pipeConn) 1245 | clientConn.writeBlocker.Lock() 1246 | 1247 | select { 1248 | case err := <-errCh: 1249 | if err != ErrKeepAliveTimeout { 1250 | t.Fatalf("unexpected error: %v", err) 1251 | } 1252 | case <-time.After(1 * time.Second): 1253 | t.Fatalf("timeout waiting for timeout") 1254 | } 1255 | 1256 | clientConn.writeBlocker.Unlock() 1257 | 1258 | if !server.IsClosed() { 1259 | t.Fatalf("server should have closed") 1260 | } 1261 | 1262 | if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) { 1263 | t.Fatalf("server log incorect: %v", serverLogs.logs()) 1264 | } 1265 | } 1266 | 1267 | func TestLargeWindow(t *testing.T) { 1268 | conf := DefaultConfig() 1269 | conf.MaxStreamWindowSize *= 2 1270 | 1271 | clientConn, serverConn := testConnTLS(t) 1272 | client, server := testClientServerConfig(t, clientConn, serverConn, conf, conf.Clone()) 1273 | 1274 | stream, err := client.Open() 1275 | if err != nil { 1276 | t.Fatalf("err: %v", err) 1277 | } 1278 | defer stream.Close() 1279 | 1280 | stream2, err := server.Accept() 1281 | if err != nil { 1282 | t.Fatalf("err: %v", err) 1283 | } 1284 | defer stream2.Close() 1285 | 1286 | err = stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) 1287 | if err != nil { 1288 | t.Fatalf("err: %v", err) 1289 | } 1290 | buf := make([]byte, conf.MaxStreamWindowSize) 1291 | n, err := stream.Write(buf) 1292 | if err != nil { 1293 | t.Fatalf("err: %v", err) 1294 | } 1295 | if n != len(buf) { 1296 | t.Fatalf("short write: %d", n) 1297 | } 1298 | } 1299 | 1300 | type UnlimitedReader struct{} 1301 | 1302 | func (u *UnlimitedReader) Read(p []byte) (int, error) { 1303 | runtime.Gosched() 1304 | return len(p), nil 1305 | } 1306 | 1307 | func TestSendData_VeryLarge(t *testing.T) { 1308 | if testing.Short() { 1309 | t.Skip("skipping slow test that may time out on the race detector") 1310 | } 1311 | client, server := testClientServer(t) 1312 | 1313 | var n int64 = 1 * 1024 * 1024 * 1024 1314 | var workers int = 16 1315 | 1316 | errCh := make(chan error, workers*2) 1317 | 1318 | for i := 0; i < workers; i++ { 1319 | go func() { 1320 | stream, err := server.AcceptStream() 1321 | if err != nil { 1322 | errCh <- err 1323 | return 1324 | } 1325 | defer stream.Close() 1326 | 1327 | buf := make([]byte, 4) 1328 | _, err = stream.Read(buf) 1329 | if err != nil { 1330 | errCh <- err 1331 | return 1332 | } 1333 | if !bytes.Equal(buf, []byte{0, 1, 2, 3}) { 1334 | errCh <- errors.New("bad header") 1335 | return 1336 | } 1337 | 1338 | recv, err := io.Copy(io.Discard, stream) 1339 | if err != nil { 1340 | errCh <- err 1341 | return 1342 | } 1343 | if recv != n { 1344 | errCh <- fmt.Errorf("bad: %v", recv) 1345 | return 1346 | } 1347 | 1348 | errCh <- nil 1349 | }() 1350 | } 1351 | for i := 0; i < workers; i++ { 1352 | go func() { 1353 | stream, err := client.Open() 1354 | if err != nil { 1355 | errCh <- err 1356 | return 1357 | } 1358 | defer stream.Close() 1359 | 1360 | _, err = stream.Write([]byte{0, 1, 2, 3}) 1361 | if err != nil { 1362 | errCh <- err 1363 | return 1364 | } 1365 | 1366 | unlimited := &UnlimitedReader{} 1367 | sent, err := io.Copy(stream, io.LimitReader(unlimited, n)) 1368 | if err != nil { 1369 | errCh <- err 1370 | return 1371 | } 1372 | if sent != n { 1373 | errCh <- fmt.Errorf("bad: %v", sent) 1374 | return 1375 | } 1376 | 1377 | errCh <- nil 1378 | }() 1379 | } 1380 | 1381 | drainErrorsUntil(t, errCh, workers*2, 120*time.Second, "timeout") 1382 | } 1383 | 1384 | func TestBacklogExceeded_Accept(t *testing.T) { 1385 | client, server := testClientServer(t) 1386 | 1387 | max := 5 * client.config.AcceptBacklog 1388 | 1389 | errCh := make(chan error, max) 1390 | go func() { 1391 | for i := 0; i < max; i++ { 1392 | stream, err := server.Accept() 1393 | if err != nil { 1394 | errCh <- err 1395 | return 1396 | } 1397 | defer stream.Close() 1398 | errCh <- nil 1399 | } 1400 | }() 1401 | 1402 | // Fill the backlog 1403 | for i := 0; i < max; i++ { 1404 | stream, err := client.Open() 1405 | if err != nil { 1406 | t.Fatalf("err: %v", err) 1407 | } 1408 | defer stream.Close() 1409 | 1410 | if _, err := stream.Write([]byte("foo")); err != nil { 1411 | t.Fatalf("err: %v", err) 1412 | } 1413 | } 1414 | 1415 | drainErrorsUntil(t, errCh, max, 0, "") 1416 | } 1417 | 1418 | func TestSession_WindowUpdateWriteDuringRead(t *testing.T) { 1419 | conf := testConfNoKeepAlive() 1420 | 1421 | clientConn, serverConn := testConnPipe(t) 1422 | client, server := testClientServerConfig(t, clientConn, serverConn, conf, conf.Clone()) 1423 | 1424 | // Choose a huge flood size that we know will result in a window update. 1425 | flood := int64(client.config.MaxStreamWindowSize) - 1 1426 | 1427 | errCh := make(chan error, 2) 1428 | 1429 | // The server will accept a new stream and then flood data to it. 1430 | go func() { 1431 | stream, err := server.AcceptStream() 1432 | if err != nil { 1433 | errCh <- err 1434 | return 1435 | } 1436 | defer stream.Close() 1437 | 1438 | n, err := stream.Write(make([]byte, flood)) 1439 | if err != nil { 1440 | errCh <- err 1441 | return 1442 | } 1443 | if int64(n) != flood { 1444 | errCh <- fmt.Errorf("short write: %d", n) 1445 | } 1446 | 1447 | errCh <- nil 1448 | }() 1449 | 1450 | // The client will open a stream, block outbound writes, and then 1451 | // listen to the flood from the server, which should time out since 1452 | // it won't be able to send the window update. 1453 | go func() { 1454 | stream, err := client.OpenStream() 1455 | if err != nil { 1456 | errCh <- err 1457 | return 1458 | } 1459 | defer stream.Close() 1460 | 1461 | conn := clientConn.(*pipeConn) 1462 | conn.writeBlocker.Lock() 1463 | defer conn.writeBlocker.Unlock() 1464 | 1465 | _, err = stream.Read(make([]byte, flood)) 1466 | if err != ErrConnectionWriteTimeout { 1467 | errCh <- err 1468 | return 1469 | } 1470 | 1471 | errCh <- nil 1472 | }() 1473 | 1474 | drainErrorsUntil(t, errCh, 2, 0, "") 1475 | } 1476 | 1477 | // TestSession_PartialReadWindowUpdate asserts that when a client performs a 1478 | // partial read it updates the server's send window. 1479 | func TestSession_PartialReadWindowUpdate(t *testing.T) { 1480 | testConnTypes(t, func(t testing.TB, clientConn, serverConn io.ReadWriteCloser) { 1481 | conf := testConfNoKeepAlive() 1482 | 1483 | client, server := testClientServerConfig(t, clientConn, serverConn, conf, conf.Clone()) 1484 | 1485 | errCh := make(chan error, 1) 1486 | 1487 | // Choose a huge flood size that we know will result in a window update. 1488 | flood := int64(client.config.MaxStreamWindowSize) 1489 | var wr *Stream 1490 | 1491 | // The server will accept a new stream and then flood data to it. 1492 | go func() { 1493 | var err error 1494 | wr, err = server.AcceptStream() 1495 | if err != nil { 1496 | errCh <- err 1497 | return 1498 | } 1499 | defer wr.Close() 1500 | 1501 | window := atomic.LoadUint32(&wr.sendWindow) 1502 | if window != client.config.MaxStreamWindowSize { 1503 | errCh <- fmt.Errorf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, window) 1504 | return 1505 | } 1506 | 1507 | n, err := wr.Write(make([]byte, flood)) 1508 | if err != nil { 1509 | errCh <- err 1510 | return 1511 | } 1512 | if int64(n) != flood { 1513 | errCh <- fmt.Errorf("short write: %d", n) 1514 | return 1515 | } 1516 | window = atomic.LoadUint32(&wr.sendWindow) 1517 | if window != 0 { 1518 | errCh <- fmt.Errorf("sendWindow: exp=%d, got=%d", 0, window) 1519 | return 1520 | } 1521 | errCh <- err 1522 | }() 1523 | 1524 | stream, err := client.OpenStream() 1525 | if err != nil { 1526 | t.Fatalf("err: %v", err) 1527 | } 1528 | defer stream.Close() 1529 | 1530 | drainErrorsUntil(t, errCh, 1, 0, "") 1531 | 1532 | // Only read part of the flood 1533 | partialReadSize := flood/2 + 1 1534 | _, err = stream.Read(make([]byte, partialReadSize)) 1535 | if err != nil { 1536 | t.Fatalf("err: %v", err) 1537 | } 1538 | 1539 | // Wait for window update to be applied by server. Should be "instant" but CI 1540 | // can be slow. 1541 | time.Sleep(2 * time.Second) 1542 | 1543 | // Assert server received window update 1544 | window := atomic.LoadUint32(&wr.sendWindow) 1545 | if exp := uint32(partialReadSize); window != exp { 1546 | t.Fatalf("sendWindow: exp=%d, got=%d", exp, window) 1547 | } 1548 | }) 1549 | } 1550 | 1551 | func TestSession_sendNoWait_Timeout(t *testing.T) { 1552 | conf := testConfNoKeepAlive() 1553 | 1554 | clientConn, serverConn := testConnPipe(t) 1555 | client, server := testClientServerConfig(t, clientConn, serverConn, conf, conf.Clone()) 1556 | 1557 | errCh := make(chan error, 2) 1558 | 1559 | go func() { 1560 | stream, err := server.AcceptStream() 1561 | if err != nil { 1562 | errCh <- err 1563 | return 1564 | } 1565 | defer stream.Close() 1566 | errCh <- nil 1567 | }() 1568 | 1569 | // The client will open the stream and then block outbound writes, we'll 1570 | // probe sendNoWait once it gets into that state. 1571 | go func() { 1572 | stream, err := client.OpenStream() 1573 | if err != nil { 1574 | errCh <- err 1575 | return 1576 | } 1577 | defer stream.Close() 1578 | 1579 | conn := clientConn.(*pipeConn) 1580 | conn.writeBlocker.Lock() 1581 | defer conn.writeBlocker.Unlock() 1582 | 1583 | hdr := header(make([]byte, headerSize)) 1584 | hdr.encode(typePing, flagACK, 0, 0) 1585 | for { 1586 | err = client.sendNoWait(hdr) 1587 | if err == nil { 1588 | continue 1589 | } else if err == ErrConnectionWriteTimeout { 1590 | break 1591 | } else { 1592 | errCh <- err 1593 | return 1594 | } 1595 | } 1596 | errCh <- nil 1597 | }() 1598 | 1599 | drainErrorsUntil(t, errCh, 2, 0, "") 1600 | } 1601 | 1602 | func TestSession_PingOfDeath(t *testing.T) { 1603 | conf := testConfNoKeepAlive() 1604 | 1605 | clientConn, serverConn := testConnPipe(t) 1606 | client, server := testClientServerConfig(t, clientConn, serverConn, conf, conf.Clone()) 1607 | 1608 | errCh := make(chan error, 2) 1609 | 1610 | var doPingOfDeath sync.Mutex 1611 | doPingOfDeath.Lock() 1612 | 1613 | // This is used later to block outbound writes. 1614 | conn := server.conn.(*pipeConn) 1615 | 1616 | // The server will accept a stream, block outbound writes, and then 1617 | // flood its send channel so that no more headers can be queued. 1618 | go func() { 1619 | stream, err := server.AcceptStream() 1620 | if err != nil { 1621 | errCh <- err 1622 | return 1623 | } 1624 | defer stream.Close() 1625 | 1626 | conn.writeBlocker.Lock() 1627 | for { 1628 | hdr := header(make([]byte, headerSize)) 1629 | hdr.encode(typePing, 0, 0, 0) 1630 | err = server.sendNoWait(hdr) 1631 | if err == nil { 1632 | continue 1633 | } else if err == ErrConnectionWriteTimeout { 1634 | break 1635 | } else { 1636 | errCh <- err 1637 | return 1638 | } 1639 | } 1640 | 1641 | doPingOfDeath.Unlock() 1642 | errCh <- nil 1643 | }() 1644 | 1645 | // The client will open a stream and then send the server a ping once it 1646 | // can no longer write. This makes sure the server doesn't deadlock reads 1647 | // while trying to reply to the ping with no ability to write. 1648 | go func() { 1649 | stream, err := client.OpenStream() 1650 | if err != nil { 1651 | errCh <- err 1652 | return 1653 | } 1654 | defer stream.Close() 1655 | 1656 | // This ping will never unblock because the ping id will never 1657 | // show up in a response. 1658 | doPingOfDeath.Lock() 1659 | go func() { _, _ = client.Ping() }() 1660 | 1661 | // Wait for a while to make sure the previous ping times out, 1662 | // then turn writes back on and make sure a ping works again. 1663 | time.Sleep(2 * server.config.ConnectionWriteTimeout) 1664 | conn.writeBlocker.Unlock() 1665 | if _, err = client.Ping(); err != nil { 1666 | errCh <- err 1667 | return 1668 | } 1669 | 1670 | errCh <- nil 1671 | }() 1672 | 1673 | drainErrorsUntil(t, errCh, 2, 0, "") 1674 | } 1675 | 1676 | func TestSession_ConnectionWriteTimeout(t *testing.T) { 1677 | conf := testConfNoKeepAlive() 1678 | 1679 | clientConn, serverConn := testConnPipe(t) 1680 | client, server := testClientServerConfig(t, clientConn, serverConn, conf, conf.Clone()) 1681 | 1682 | errCh := make(chan error, 2) 1683 | 1684 | go func() { 1685 | stream, err := server.AcceptStream() 1686 | if err != nil { 1687 | errCh <- err 1688 | return 1689 | } 1690 | defer stream.Close() 1691 | errCh <- nil 1692 | }() 1693 | 1694 | // The client will open the stream and then block outbound writes, we'll 1695 | // tee up a write and make sure it eventually times out. 1696 | go func() { 1697 | stream, err := client.OpenStream() 1698 | if err != nil { 1699 | errCh <- err 1700 | return 1701 | } 1702 | defer stream.Close() 1703 | 1704 | conn := clientConn.(*pipeConn) 1705 | conn.writeBlocker.Lock() 1706 | defer conn.writeBlocker.Unlock() 1707 | 1708 | // Since the write goroutine is blocked then this will return a 1709 | // timeout since it can't get feedback about whether the write 1710 | // worked. 1711 | n, err := stream.Write([]byte("hello")) 1712 | if err != ErrConnectionWriteTimeout { 1713 | errCh <- err 1714 | return 1715 | } 1716 | if n != 0 { 1717 | errCh <- fmt.Errorf("lied about writes: %d", n) 1718 | } 1719 | errCh <- nil 1720 | }() 1721 | 1722 | drainErrorsUntil(t, errCh, 2, 0, "") 1723 | } 1724 | 1725 | func TestCancelAccept(t *testing.T) { 1726 | _, server := testClientServer(t) 1727 | 1728 | ctx, cancel := context.WithCancel(context.Background()) 1729 | t.Cleanup(cancel) 1730 | 1731 | errCh := make(chan error, 1) 1732 | 1733 | go func() { 1734 | stream, err := server.AcceptStreamWithContext(ctx) 1735 | if err != context.Canceled { 1736 | errCh <- err 1737 | return 1738 | } 1739 | 1740 | if stream != nil { 1741 | defer stream.Close() 1742 | } 1743 | errCh <- nil 1744 | }() 1745 | 1746 | cancel() 1747 | 1748 | drainErrorsUntil(t, errCh, 1, 0, "") 1749 | } 1750 | 1751 | // drainErrorsUntil receives `expect` errors from errCh within `timeout`. Fails 1752 | // on any non-nil errors. 1753 | func drainErrorsUntil(t testing.TB, errCh chan error, expect int, timeout time.Duration, msg string) { 1754 | t.Helper() 1755 | start := time.Now() 1756 | var timerC <-chan time.Time 1757 | if timeout > 0 { 1758 | timerC = time.After(timeout) 1759 | } 1760 | 1761 | for found := 0; found < expect; { 1762 | select { 1763 | case <-timerC: 1764 | t.Fatalf(msg+" (timeout was %v)", timeout) 1765 | case err := <-errCh: 1766 | if err != nil { 1767 | t.Fatalf("err: %v", err) 1768 | } else { 1769 | found++ 1770 | } 1771 | } 1772 | } 1773 | t.Logf("drain took %v (timeout was %v)", time.Since(start), timeout) 1774 | } 1775 | -------------------------------------------------------------------------------- /spec.md: -------------------------------------------------------------------------------- 1 | # Specification 2 | 3 | We use this document to detail the internal specification of Yamux. 4 | This is used both as a guide for implementing Yamux, but also for 5 | alternative interoperable libraries to be built. 6 | 7 | # Framing 8 | 9 | Yamux uses a streaming connection underneath, but imposes a message 10 | framing so that it can be shared between many logical streams. Each 11 | frame contains a header like: 12 | 13 | * Version (8 bits) 14 | * Type (8 bits) 15 | * Flags (16 bits) 16 | * StreamID (32 bits) 17 | * Length (32 bits) 18 | 19 | This means that each header has a 12 byte overhead. 20 | All fields are encoded in network order (big endian). 21 | Each field is described below: 22 | 23 | ## Version Field 24 | 25 | The version field is used for future backward compatibility. At the 26 | current time, the field is always set to 0, to indicate the initial 27 | version. 28 | 29 | ## Type Field 30 | 31 | The type field is used to switch the frame message type. The following 32 | message types are supported: 33 | 34 | * 0x0 Data - Used to transmit data. May transmit zero length payloads 35 | depending on the flags. 36 | 37 | * 0x1 Window Update - Used to updated the senders receive window size. 38 | This is used to implement per-session flow control. 39 | 40 | * 0x2 Ping - Used to measure RTT. It can also be used to heart-beat 41 | and do keep-alives over TCP. 42 | 43 | * 0x3 Go Away - Used to close a session. 44 | 45 | ## Flag Field 46 | 47 | The flags field is used to provide additional information related 48 | to the message type. The following flags are supported: 49 | 50 | * 0x1 SYN - Signals the start of a new stream. May be sent with a data or 51 | window update message. Also sent with a ping to indicate outbound. 52 | 53 | * 0x2 ACK - Acknowledges the start of a new stream. May be sent with a data 54 | or window update message. Also sent with a ping to indicate response. 55 | 56 | * 0x4 FIN - Performs a half-close of a stream. May be sent with a data 57 | message or window update. 58 | 59 | * 0x8 RST - Reset a stream immediately. May be sent with a data or 60 | window update message. 61 | 62 | ## StreamID Field 63 | 64 | The StreamID field is used to identify the logical stream the frame 65 | is addressing. The client side should use odd ID's, and the server even. 66 | This prevents any collisions. Additionally, the 0 ID is reserved to represent 67 | the session. 68 | 69 | Both Ping and Go Away messages should always use the 0 StreamID. 70 | 71 | ## Length Field 72 | 73 | The meaning of the length field depends on the message type: 74 | 75 | * Data - provides the length of bytes following the header 76 | * Window update - provides a delta update to the window size 77 | * Ping - Contains an opaque value, echoed back 78 | * Go Away - Contains an error code 79 | 80 | # Message Flow 81 | 82 | There is no explicit connection setup, as Yamux relies on an underlying 83 | transport to be provided. However, there is a distinction between client 84 | and server side of the connection. 85 | 86 | ## Opening a stream 87 | 88 | To open a stream, an initial data or window update frame is sent 89 | with a new StreamID. The SYN flag should be set to signal a new stream. 90 | 91 | The receiver must then reply with either a data or window update frame 92 | with the StreamID along with the ACK flag to accept the stream or with 93 | the RST flag to reject the stream. 94 | 95 | Because we are relying on the reliable stream underneath, a connection 96 | can begin sending data once the SYN flag is sent. The corresponding 97 | ACK does not need to be received. This is particularly well suited 98 | for an RPC system where a client wants to open a stream and immediately 99 | fire a request without waiting for the RTT of the ACK. 100 | 101 | This does introduce the possibility of a connection being rejected 102 | after data has been sent already. This is a slight semantic difference 103 | from TCP, where the connection cannot be refused after it is opened. 104 | Clients should be prepared to handle this by checking for an error 105 | that indicates a RST was received. 106 | 107 | ## Closing a stream 108 | 109 | To close a stream, either side sends a data or window update frame 110 | along with the FIN flag. This does a half-close indicating the sender 111 | will send no further data. 112 | 113 | Once both sides have closed the connection, the stream is closed. 114 | 115 | Alternatively, if an error occurs, the RST flag can be used to 116 | hard close a stream immediately. 117 | 118 | ## Flow Control 119 | 120 | When Yamux is initially starts each stream with a 256KB window size. 121 | There is no window size for the session. 122 | 123 | To prevent the streams from stalling, window update frames should be 124 | sent regularly. Yamux can be configured to provide a larger limit for 125 | windows sizes. Both sides assume the initial 256KB window, but can 126 | immediately send a window update as part of the SYN/ACK indicating a 127 | larger window. 128 | 129 | Both sides should track the number of bytes sent in Data frames 130 | only, as only they are tracked as part of the window size. 131 | 132 | ## Session termination 133 | 134 | When a session is being terminated, the Go Away message should 135 | be sent. The Length should be set to one of the following to 136 | provide an error code: 137 | 138 | * 0x0 Normal termination 139 | * 0x1 Protocol error 140 | * 0x2 Internal error 141 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "sync" 8 | "sync/atomic" 9 | "time" 10 | ) 11 | 12 | type streamState int 13 | 14 | const ( 15 | streamInit streamState = iota 16 | streamSYNSent 17 | streamSYNReceived 18 | streamEstablished 19 | streamLocalClose 20 | streamRemoteClose 21 | streamClosed 22 | streamReset 23 | ) 24 | 25 | // Stream is used to represent a logical stream within a session. Methods on 26 | // Stream are safe to call concurrently with one another, but all Read calls 27 | // must be on the same goroutine and all Write calls must be on the same 28 | // goroutine. 29 | type Stream struct { 30 | recvWindow uint32 31 | sendWindow uint32 32 | 33 | id uint32 34 | session *Session 35 | 36 | state streamState 37 | stateLock sync.Mutex 38 | 39 | recvBuf *bytes.Buffer 40 | recvLock sync.Mutex 41 | 42 | controlHdr header 43 | controlErr chan error 44 | controlHdrLock sync.Mutex 45 | 46 | sendHdr header 47 | sendErr chan error 48 | sendLock sync.Mutex 49 | 50 | recvNotifyCh chan struct{} 51 | sendNotifyCh chan struct{} 52 | 53 | readDeadline atomic.Value // time.Time 54 | writeDeadline atomic.Value // time.Time 55 | 56 | // establishCh is notified if the stream is established or being closed. 57 | establishCh chan struct{} 58 | 59 | // closeTimer is set with stateLock held to honor the StreamCloseTimeout 60 | // setting on Session. 61 | closeTimer *time.Timer 62 | } 63 | 64 | // newStream is used to construct a new stream within 65 | // a given session for an ID 66 | func newStream(session *Session, id uint32, state streamState) *Stream { 67 | s := &Stream{ 68 | id: id, 69 | session: session, 70 | state: state, 71 | controlHdr: header(make([]byte, headerSize)), 72 | controlErr: make(chan error, 1), 73 | sendHdr: header(make([]byte, headerSize)), 74 | sendErr: make(chan error, 1), 75 | recvWindow: initialStreamWindow, 76 | sendWindow: initialStreamWindow, 77 | recvNotifyCh: make(chan struct{}, 1), 78 | sendNotifyCh: make(chan struct{}, 1), 79 | establishCh: make(chan struct{}, 1), 80 | } 81 | s.readDeadline.Store(time.Time{}) 82 | s.writeDeadline.Store(time.Time{}) 83 | return s 84 | } 85 | 86 | // Session returns the associated stream session 87 | func (s *Stream) Session() *Session { 88 | return s.session 89 | } 90 | 91 | // StreamID returns the ID of this stream 92 | func (s *Stream) StreamID() uint32 { 93 | return s.id 94 | } 95 | 96 | // Read is used to read from the stream. It is safe to call Write, Read, and/or 97 | // Close concurrently with each other, but calls to Read are not reentrant and 98 | // should not be called from multiple goroutines. Multiple Read goroutines would 99 | // receive different chunks of data from the Stream and be unable to reassemble 100 | // them in order or along message boundaries, and may encounter deadlocks. 101 | func (s *Stream) Read(b []byte) (n int, err error) { 102 | defer asyncNotify(s.recvNotifyCh) 103 | START: 104 | 105 | // If the stream is closed and there's no data buffered, return EOF 106 | s.stateLock.Lock() 107 | switch s.state { 108 | case streamLocalClose: 109 | // LocalClose only prohibits further local writes. Handle reads normally. 110 | case streamRemoteClose: 111 | fallthrough 112 | case streamClosed: 113 | s.recvLock.Lock() 114 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { 115 | s.recvLock.Unlock() 116 | s.stateLock.Unlock() 117 | return 0, io.EOF 118 | } 119 | s.recvLock.Unlock() 120 | case streamReset: 121 | s.stateLock.Unlock() 122 | return 0, ErrConnectionReset 123 | } 124 | s.stateLock.Unlock() 125 | 126 | // If there is no data available, block 127 | s.recvLock.Lock() 128 | if s.recvBuf == nil || s.recvBuf.Len() == 0 { 129 | s.recvLock.Unlock() 130 | goto WAIT 131 | } 132 | 133 | // Read any bytes 134 | n, _ = s.recvBuf.Read(b) 135 | s.recvLock.Unlock() 136 | 137 | // Send a window update potentially 138 | err = s.sendWindowUpdate() 139 | if err == ErrSessionShutdown { 140 | err = nil 141 | } 142 | return n, err 143 | 144 | WAIT: 145 | var timeout <-chan time.Time 146 | var timer *time.Timer 147 | readDeadline := s.readDeadline.Load().(time.Time) 148 | if !readDeadline.IsZero() { 149 | delay := time.Until(readDeadline) 150 | timer = time.NewTimer(delay) 151 | timeout = timer.C 152 | } 153 | select { 154 | case <-s.session.shutdownCh: 155 | case <-s.recvNotifyCh: 156 | case <-timeout: 157 | return 0, ErrTimeout 158 | } 159 | if timer != nil { 160 | if !timer.Stop() { 161 | <-timeout 162 | } 163 | } 164 | goto START 165 | } 166 | 167 | // Write is used to write to the stream. It is safe to call Write, Read, and/or 168 | // Close concurrently with each other, but calls to Write are not reentrant and 169 | // should not be called from multiple goroutines. 170 | func (s *Stream) Write(b []byte) (n int, err error) { 171 | s.sendLock.Lock() 172 | defer s.sendLock.Unlock() 173 | total := 0 174 | for total < len(b) { 175 | n, err := s.write(b[total:]) 176 | total += n 177 | if err != nil { 178 | return total, err 179 | } 180 | } 181 | return total, nil 182 | } 183 | 184 | // write is used to write to the stream, may return on 185 | // a short write. 186 | func (s *Stream) write(b []byte) (n int, err error) { 187 | var flags uint16 188 | var max uint32 189 | var body []byte 190 | START: 191 | s.stateLock.Lock() 192 | switch s.state { 193 | case streamLocalClose: 194 | fallthrough 195 | case streamClosed: 196 | s.stateLock.Unlock() 197 | return 0, ErrStreamClosed 198 | case streamReset: 199 | s.stateLock.Unlock() 200 | return 0, ErrConnectionReset 201 | } 202 | s.stateLock.Unlock() 203 | 204 | // If there is no data available, block 205 | window := atomic.LoadUint32(&s.sendWindow) 206 | if window == 0 { 207 | goto WAIT 208 | } 209 | 210 | // Determine the flags if any 211 | flags = s.sendFlags() 212 | 213 | // Send up to our send window 214 | max = min(window, uint32(len(b))) 215 | body = b[:max] 216 | 217 | // Send the header 218 | s.sendHdr.encode(typeData, flags, s.id, max) 219 | if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil { 220 | if errors.Is(err, ErrSessionShutdown) || errors.Is(err, ErrConnectionWriteTimeout) { 221 | // Message left in ready queue, header re-use is unsafe. 222 | s.sendHdr = header(make([]byte, headerSize)) 223 | } 224 | return 0, err 225 | } 226 | 227 | // Reduce our send window 228 | atomic.AddUint32(&s.sendWindow, ^uint32(max-1)) 229 | 230 | // Unlock 231 | return int(max), err 232 | 233 | WAIT: 234 | var timeout <-chan time.Time 235 | var timer *time.Timer 236 | writeDeadline := s.writeDeadline.Load().(time.Time) 237 | if !writeDeadline.IsZero() { 238 | delay := time.Until(writeDeadline) 239 | timer = time.NewTimer(delay) 240 | timeout = timer.C 241 | } 242 | select { 243 | case <-s.session.shutdownCh: 244 | case <-s.sendNotifyCh: 245 | case <-timeout: 246 | return 0, ErrTimeout 247 | } 248 | if timer != nil { 249 | if !timer.Stop() { 250 | <-timeout 251 | } 252 | } 253 | goto START 254 | } 255 | 256 | // sendFlags determines any flags that are appropriate 257 | // based on the current stream state 258 | func (s *Stream) sendFlags() uint16 { 259 | s.stateLock.Lock() 260 | defer s.stateLock.Unlock() 261 | var flags uint16 262 | switch s.state { 263 | case streamInit: 264 | flags |= flagSYN 265 | s.state = streamSYNSent 266 | case streamSYNReceived: 267 | flags |= flagACK 268 | s.state = streamEstablished 269 | } 270 | return flags 271 | } 272 | 273 | // sendWindowUpdate potentially sends a window update enabling 274 | // further writes to take place. Must be invoked with the lock. 275 | func (s *Stream) sendWindowUpdate() error { 276 | s.controlHdrLock.Lock() 277 | defer s.controlHdrLock.Unlock() 278 | 279 | // Determine the delta update 280 | max := s.session.config.MaxStreamWindowSize 281 | var bufLen uint32 282 | s.recvLock.Lock() 283 | if s.recvBuf != nil { 284 | bufLen = uint32(s.recvBuf.Len()) 285 | } 286 | delta := (max - bufLen) - s.recvWindow 287 | 288 | // Determine the flags if any 289 | flags := s.sendFlags() 290 | 291 | // Check if we can omit the update 292 | if delta < (max/2) && flags == 0 { 293 | s.recvLock.Unlock() 294 | return nil 295 | } 296 | 297 | // Update our window 298 | s.recvWindow += delta 299 | s.recvLock.Unlock() 300 | 301 | // Send the header 302 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) 303 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { 304 | if errors.Is(err, ErrSessionShutdown) || errors.Is(err, ErrConnectionWriteTimeout) { 305 | // Message left in ready queue, header re-use is unsafe. 306 | s.controlHdr = header(make([]byte, headerSize)) 307 | } 308 | return err 309 | } 310 | return nil 311 | } 312 | 313 | // sendClose is used to send a FIN 314 | func (s *Stream) sendClose() error { 315 | s.controlHdrLock.Lock() 316 | defer s.controlHdrLock.Unlock() 317 | 318 | flags := s.sendFlags() 319 | flags |= flagFIN 320 | s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0) 321 | if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil { 322 | if errors.Is(err, ErrSessionShutdown) || errors.Is(err, ErrConnectionWriteTimeout) { 323 | // Message left in ready queue, header re-use is unsafe. 324 | s.controlHdr = header(make([]byte, headerSize)) 325 | } 326 | return err 327 | } 328 | return nil 329 | } 330 | 331 | // Close is used to close the stream. It is safe to call Close concurrently. 332 | func (s *Stream) Close() error { 333 | closeStream := false 334 | s.stateLock.Lock() 335 | switch s.state { 336 | // Opened means we need to signal a close 337 | case streamSYNSent: 338 | fallthrough 339 | case streamSYNReceived: 340 | fallthrough 341 | case streamEstablished: 342 | s.state = streamLocalClose 343 | goto SEND_CLOSE 344 | 345 | case streamLocalClose: 346 | case streamRemoteClose: 347 | s.state = streamClosed 348 | closeStream = true 349 | goto SEND_CLOSE 350 | 351 | case streamClosed: 352 | case streamReset: 353 | default: 354 | panic("unhandled state") 355 | } 356 | s.stateLock.Unlock() 357 | return nil 358 | SEND_CLOSE: 359 | // This shouldn't happen (the more realistic scenario to cancel the 360 | // timer is via processFlags) but just in case this ever happens, we 361 | // cancel the timer to prevent dangling timers. 362 | if s.closeTimer != nil { 363 | s.closeTimer.Stop() 364 | s.closeTimer = nil 365 | } 366 | 367 | // If we have a StreamCloseTimeout set we start the timeout timer. 368 | // We do this only if we're not already closing the stream since that 369 | // means this was a graceful close. 370 | // 371 | // This prevents memory leaks if one side (this side) closes and the 372 | // remote side poorly behaves and never responds with a FIN to complete 373 | // the close. After the specified timeout, we clean our resources up no 374 | // matter what. 375 | if !closeStream && s.session.config.StreamCloseTimeout > 0 { 376 | s.closeTimer = time.AfterFunc( 377 | s.session.config.StreamCloseTimeout, s.closeTimeout) 378 | } 379 | 380 | s.stateLock.Unlock() 381 | s.sendClose() 382 | s.notifyWaiting() 383 | if closeStream { 384 | s.session.closeStream(s.id) 385 | } 386 | return nil 387 | } 388 | 389 | // closeTimeout is called after StreamCloseTimeout during a close to 390 | // close this stream. 391 | func (s *Stream) closeTimeout() { 392 | // Close our side forcibly 393 | s.forceClose() 394 | 395 | // Free the stream from the session map 396 | s.session.closeStream(s.id) 397 | 398 | // Send a RST so the remote side closes too. 399 | s.sendLock.Lock() 400 | defer s.sendLock.Unlock() 401 | hdr := header(make([]byte, headerSize)) 402 | hdr.encode(typeWindowUpdate, flagRST, s.id, 0) 403 | _ = s.session.sendNoWait(hdr) 404 | } 405 | 406 | // forceClose is used for when the session is exiting 407 | func (s *Stream) forceClose() { 408 | s.stateLock.Lock() 409 | s.state = streamClosed 410 | s.stateLock.Unlock() 411 | s.notifyWaiting() 412 | } 413 | 414 | // processFlags is used to update the state of the stream 415 | // based on set flags, if any. Lock must be held 416 | func (s *Stream) processFlags(flags uint16) error { 417 | s.stateLock.Lock() 418 | defer s.stateLock.Unlock() 419 | 420 | // Close the stream without holding the state lock 421 | closeStream := false 422 | defer func() { 423 | if closeStream { 424 | if s.closeTimer != nil { 425 | // Stop our close timeout timer since we gracefully closed 426 | s.closeTimer.Stop() 427 | } 428 | 429 | s.session.closeStream(s.id) 430 | } 431 | }() 432 | 433 | if flags&flagACK == flagACK { 434 | if s.state == streamSYNSent { 435 | s.state = streamEstablished 436 | } 437 | asyncNotify(s.establishCh) 438 | s.session.establishStream(s.id) 439 | } 440 | if flags&flagFIN == flagFIN { 441 | switch s.state { 442 | case streamSYNSent: 443 | fallthrough 444 | case streamSYNReceived: 445 | fallthrough 446 | case streamEstablished: 447 | s.state = streamRemoteClose 448 | s.notifyWaiting() 449 | case streamLocalClose: 450 | s.state = streamClosed 451 | closeStream = true 452 | s.notifyWaiting() 453 | default: 454 | s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state) 455 | return ErrUnexpectedFlag 456 | } 457 | } 458 | if flags&flagRST == flagRST { 459 | s.state = streamReset 460 | closeStream = true 461 | s.notifyWaiting() 462 | } 463 | return nil 464 | } 465 | 466 | // notifyWaiting notifies all the waiting channels 467 | func (s *Stream) notifyWaiting() { 468 | asyncNotify(s.recvNotifyCh) 469 | asyncNotify(s.sendNotifyCh) 470 | asyncNotify(s.establishCh) 471 | } 472 | 473 | // incrSendWindow updates the size of our send window 474 | func (s *Stream) incrSendWindow(hdr header, flags uint16) error { 475 | if err := s.processFlags(flags); err != nil { 476 | return err 477 | } 478 | 479 | // Increase window, unblock a sender 480 | atomic.AddUint32(&s.sendWindow, hdr.Length()) 481 | asyncNotify(s.sendNotifyCh) 482 | return nil 483 | } 484 | 485 | // readData is used to handle a data frame 486 | func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { 487 | if err := s.processFlags(flags); err != nil { 488 | return err 489 | } 490 | 491 | // Check that our recv window is not exceeded 492 | length := hdr.Length() 493 | if length == 0 { 494 | return nil 495 | } 496 | 497 | // Wrap in a limited reader 498 | conn = &io.LimitedReader{R: conn, N: int64(length)} 499 | 500 | // Copy into buffer 501 | s.recvLock.Lock() 502 | 503 | if length > s.recvWindow { 504 | s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) 505 | s.recvLock.Unlock() 506 | return ErrRecvWindowExceeded 507 | } 508 | 509 | if s.recvBuf == nil { 510 | // Allocate the receive buffer just-in-time to fit the full data frame. 511 | // This way we can read in the whole packet without further allocations. 512 | s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) 513 | } 514 | copiedLength, err := io.Copy(s.recvBuf, conn) 515 | if err != nil { 516 | s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) 517 | s.recvLock.Unlock() 518 | return err 519 | } 520 | 521 | // Decrement the receive window 522 | s.recvWindow -= uint32(copiedLength) 523 | s.recvLock.Unlock() 524 | 525 | // Unblock any readers 526 | asyncNotify(s.recvNotifyCh) 527 | return nil 528 | } 529 | 530 | // SetDeadline sets the read and write deadlines 531 | func (s *Stream) SetDeadline(t time.Time) error { 532 | if err := s.SetReadDeadline(t); err != nil { 533 | return err 534 | } 535 | if err := s.SetWriteDeadline(t); err != nil { 536 | return err 537 | } 538 | return nil 539 | } 540 | 541 | // SetReadDeadline sets the deadline for blocked and future Read calls. 542 | func (s *Stream) SetReadDeadline(t time.Time) error { 543 | s.readDeadline.Store(t) 544 | asyncNotify(s.recvNotifyCh) 545 | return nil 546 | } 547 | 548 | // SetWriteDeadline sets the deadline for blocked and future Write calls 549 | func (s *Stream) SetWriteDeadline(t time.Time) error { 550 | s.writeDeadline.Store(t) 551 | asyncNotify(s.sendNotifyCh) 552 | return nil 553 | } 554 | 555 | // Shrink is used to compact the amount of buffers utilized 556 | // This is useful when using Yamux in a connection pool to reduce 557 | // the idle memory utilization. 558 | func (s *Stream) Shrink() { 559 | s.recvLock.Lock() 560 | if s.recvBuf != nil && s.recvBuf.Len() == 0 { 561 | s.recvBuf = nil 562 | } 563 | s.recvLock.Unlock() 564 | } 565 | -------------------------------------------------------------------------------- /testdata/README.md: -------------------------------------------------------------------------------- 1 | Test certificates generated with: 2 | 3 | ``` 4 | go run $(go env GOROOT)/src/crypto/tls/generate_cert.go --host example.com 5 | ``` 6 | 7 | Requires a bash-like shell and Go installed. 8 | -------------------------------------------------------------------------------- /testdata/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC/DCCAeSgAwIBAgIRAI8YOah8fp9JcV8YbON8488wDQYJKoZIhvcNAQELBQAw 3 | EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0yNDA4MTQyMzE0MTJaFw0yNTA4MTQyMzE0 4 | MTJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 5 | ggEKAoIBAQC5cvLaZRsKScA/JBgIlXTs3uJj7KLOLwGuRxk3sUQ7aREoqlC1bGR8 6 | wN5mYu9Yso6dEOWJBSXybtSpH60AtGnepAKAra4IDNfLbNmWi+13lyqD/BgIpWoP 7 | Lww6MDHLxNP1+U4hQ2mcOw/hueaSMUahXhjTTTVHq+cLQ9eBxk7b/mcxzHLS6he+ 8 | 0zS2QsQ7p/R5RN5QTALZNoHTgXh7Wou/ynmCHNVzaAOdGIvDTSi6fBdqNgvWrY0w 9 | WcZePVcTiZ2y/4TKEugXqu6RdO1C3rtAFEWCn9q+RyNl0MCcbxo+n3xljy9y3HTW 10 | 0qdpYg2wZJKuRBRlsr0D72pEg+OTd4gNAgMBAAGjTTBLMA4GA1UdDwEB/wQEAwIF 11 | oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBYGA1UdEQQPMA2C 12 | C2V4YW1wbGUuY29tMA0GCSqGSIb3DQEBCwUAA4IBAQBvjm0Y6uS285qhh9Ae/4+f 13 | /nc9KVECHCt0w4CAVmkoCULncOLPnfDRgfn0S2jveBeD/1916egnRljYdqHaE/1G 14 | /DHo3b45uC77dCGZzCKl7GC50GOUdirHxNiS99xCPM2rWmoada+v5Oe3kcCBXlJ4 15 | KeDffE7EGo8ACzO5ziKMbR8oThaFrOXIPtUYUFInURbu9VKfRzkLzXNGBZ1WgVZ6 16 | i9McZImuKnKLZJ1e3SlX3PcZwoBYbumaIG1XFx0K4FCO+QsZNOtLPIzA+aVdtFii 17 | f5nn4CxXJ/SGhwnjbJE4lS7vH0JlzVIX5rHEYB4jL8d7TApXVgJ0L/0wgFvQ0bCv 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /testdata/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC5cvLaZRsKScA/ 3 | JBgIlXTs3uJj7KLOLwGuRxk3sUQ7aREoqlC1bGR8wN5mYu9Yso6dEOWJBSXybtSp 4 | H60AtGnepAKAra4IDNfLbNmWi+13lyqD/BgIpWoPLww6MDHLxNP1+U4hQ2mcOw/h 5 | ueaSMUahXhjTTTVHq+cLQ9eBxk7b/mcxzHLS6he+0zS2QsQ7p/R5RN5QTALZNoHT 6 | gXh7Wou/ynmCHNVzaAOdGIvDTSi6fBdqNgvWrY0wWcZePVcTiZ2y/4TKEugXqu6R 7 | dO1C3rtAFEWCn9q+RyNl0MCcbxo+n3xljy9y3HTW0qdpYg2wZJKuRBRlsr0D72pE 8 | g+OTd4gNAgMBAAECggEAdVXvppM2KqpDQzAZLMUzt/PGFidRU1eWnqhJol08qMJv 9 | ouUwL7onUm/Nx8ZtXheL+IEKWkmxmtTZJTDvi3SbT81B8Bzz8g/+Ma3rdj+Ovo4c 10 | zmmg40eV9Yl1GRQJTb55xjY5Yv5+QeV0xQOUiYc4Az3AQ2GkhnaTtyLzph7NIo+d 11 | n4jiNJIlCP2wO3hyIaIVtQXyoGAwNWHTrJsShGaNi9C32SEqzSzuNF94lnGJ4U9M 12 | 7yLl1GD+uhz6V0q3eY8CHK6g8HOaj1ukMlIz+qZdsY252qpEgmY21kvoUJ9Awjjn 13 | 3dAtR7aX+YFtMdAD8rCPHS00lmSlGOEMcg8m07EHmQKBgQDEyHIKvls9RGRfbMbR 14 | +GQ0HyQu2ACbHu6sc8baZZpr9h0pQJOObvPhVWkTcMpVZqlfJZ3X08WHhBBmTZkQ 15 | F4K1nAs+Ps7U04cBDph3eGGfe+roQSVcVARmTq4is1SOQMtSpnKyRMS1HLFp9sR0 16 | 03m9a43pmnkT45m/BxFp6kCMxwKBgQDxQV5yH0YBfT5i8T63vDs/f9nvxgRODC5d 17 | 0rdgJBd1R4VPTmI5Cbcfa3IY4H5nMgh90x9T7xu209ywp75TYS/Q119eElTcj5tX 18 | xhDitw052F2ZA90nuCsyXQq+01zLeRuhMQQx3HbmgoVDNNJJBRfERakhM8z4FDVi 19 | a5FDrDQoiwKBgBZg6T85iKy+A2Aqwa2NPvACfp3pKKB7cw8fl4Ssu1P9yDExy9YN 20 | 3iRJD0sLr6bopuhQIdQynCseJLNNrdN7qPy4QzsP73ualqbTHxmvEgMOF5fUGMiY 21 | MWvlFL6TgFExIy5CCZcmZOxn1/FCA/N5PUYCXkArtgtB/fEQf7V402BvAoGBAN3g 22 | oazRaD/cYLD8cBLo0ZCf0955veHNwCLXtYB9EPnycf8y9pDAh6Mk3QVWCcp8sGSP 23 | 82LtKA7oIDJzw03JtwEZ4oKQ120VwediqIrpkQdfHw2oCRALh+bEvSotF02mryt6 24 | +gGlYdCzvz3E6ZTwUyBWdKqtileptkMy7KFRUZLrAoGAVHMhb1CJWIHLOSNWBHU3 25 | U0Aq9kY3sK/iECKRa0e7qe8gK/hSm61q3RyyYsrrsp8nyPhuvvoJ61AyD9sf4bWn 26 | 4lQalf69PpfdM3Kr9wgu3B8UG15RYAgu9mEp4f5ys/lB0kdcoNhXHm/omuEI7xho 27 | 0TzPD2rJfUl/Jce1oLGoPL8= 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // Logger is a abstract of *log.Logger 9 | type Logger interface { 10 | Print(v ...interface{}) 11 | Printf(format string, v ...interface{}) 12 | Println(v ...interface{}) 13 | } 14 | 15 | var ( 16 | timerPool = &sync.Pool{ 17 | New: func() interface{} { 18 | timer := time.NewTimer(time.Hour * 1e6) 19 | timer.Stop() 20 | return timer 21 | }, 22 | } 23 | ) 24 | 25 | // asyncSendErr is used to try an async send of an error 26 | func asyncSendErr(ch chan error, err error) { 27 | if ch == nil { 28 | return 29 | } 30 | select { 31 | case ch <- err: 32 | default: 33 | } 34 | } 35 | 36 | // asyncNotify is used to signal a waiting goroutine 37 | func asyncNotify(ch chan struct{}) { 38 | select { 39 | case ch <- struct{}{}: 40 | default: 41 | } 42 | } 43 | 44 | // min computes the minimum of two values 45 | func min(a, b uint32) uint32 { 46 | if a < b { 47 | return a 48 | } 49 | return b 50 | } 51 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package yamux 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestAsyncSendErr(t *testing.T) { 8 | ch := make(chan error) 9 | asyncSendErr(ch, ErrTimeout) 10 | select { 11 | case <-ch: 12 | t.Fatalf("should not get") 13 | default: 14 | } 15 | 16 | ch = make(chan error, 1) 17 | asyncSendErr(ch, ErrTimeout) 18 | select { 19 | case <-ch: 20 | default: 21 | t.Fatalf("should get") 22 | } 23 | } 24 | 25 | func TestAsyncNotify(t *testing.T) { 26 | ch := make(chan struct{}) 27 | asyncNotify(ch) 28 | select { 29 | case <-ch: 30 | t.Fatalf("should not get") 31 | default: 32 | } 33 | 34 | ch = make(chan struct{}, 1) 35 | asyncNotify(ch) 36 | select { 37 | case <-ch: 38 | default: 39 | t.Fatalf("should get") 40 | } 41 | } 42 | 43 | func TestMin(t *testing.T) { 44 | if min(1, 2) != 1 { 45 | t.Fatalf("bad") 46 | } 47 | if min(2, 1) != 1 { 48 | t.Fatalf("bad") 49 | } 50 | } 51 | --------------------------------------------------------------------------------