├── .github ├── dependabot.yml └── workflows │ └── tests.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MAINTAINERS.md ├── README.md ├── async.go ├── async_test.go ├── client.go ├── client_test.go ├── conn.go ├── frisbee.go ├── go.mod ├── go.sum ├── helpers_test.go ├── internal └── dialer │ └── dialer.go ├── options.go ├── options_test.go ├── pkg ├── metadata │ ├── metadata.go │ ├── metadata_test.go │ └── pool.go └── packet │ ├── packet.go │ ├── packet_test.go │ ├── pool.go │ └── pool_test.go ├── server.go ├── server_test.go ├── stream.go ├── stream_test.go ├── sync.go ├── sync_test.go ├── testutil.go └── throughput_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "gomod" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: Tests and Benchmarks 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | tests: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v4 11 | - name: Install Go 12 | uses: actions/setup-go@v5 13 | with: 14 | go-version: "1.22" 15 | check-latest: true 16 | cache: true 17 | - name: Run Tests 18 | run: go test -v ./... 19 | tests-race: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | - name: Install Go 25 | uses: actions/setup-go@v5 26 | with: 27 | go-version: "1.22" 28 | check-latest: true 29 | cache: true 30 | - name: Test with Race Conditions 31 | run: go test -race -v ./... 32 | timeout-minutes: 15 33 | benchmarks: 34 | runs-on: ubuntu-latest 35 | steps: 36 | - name: Checkout 37 | uses: actions/checkout@v4 38 | - name: Install Go 39 | uses: actions/setup-go@v5 40 | with: 41 | go-version: "1.22" 42 | check-latest: true 43 | cache: true 44 | - name: Benchmark 45 | run: go test -run=^$ -bench=. -v ./... 46 | benchmarks-race: 47 | runs-on: ubuntu-latest 48 | steps: 49 | - name: Checkout 50 | uses: actions/checkout@v4 51 | - name: Install Go 52 | uses: actions/setup-go@v5 53 | with: 54 | go-version: "1.22" 55 | check-latest: true 56 | cache: true 57 | - name: Benchmark with Race Conditions 58 | run: go test -run=^$ -bench=. -race -timeout 30m -v ./... 59 | timeout-minutes: 30 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.test 3 | .env 4 | .idea/ 5 | .DS_Store 6 | dist/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Frisbee Community Code of Conduct 2 | 3 | Frisbee follows the [CNCF Code of Conduct](https://github.com/cncf/foundation/blob/main/code-of-conduct.md). 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Frisbee uses GitHub to manage reviews of pull requests. 4 | 5 | - If you have a trivial fix or improvement, go ahead and create a pull request, 6 | addressing (with `@...`) the maintainer of this repository (see 7 | [MAINTAINERS.md](MAINTAINERS.md)) in the description of the pull request. 8 | 9 | - If you plan to do something more involved, first discuss your ideas 10 | on our [discord](https://loopholelabs.io/discord). 11 | This will avoid unnecessary work and surely give you and us a good deal 12 | of inspiration. 13 | 14 | - Relevant coding style guidelines are the [Go Code Review 15 | Comments](https://code.google.com/p/go-wiki/wiki/CodeReviewComments) 16 | and the _Formatting and style_ section of Peter Bourgon's [Go: Best 17 | Practices for Production 18 | Environments](http://peter.bourgon.org/go-in-production/#formatting-and-style). 19 | 20 | - Be sure to sign off on the [DCO](https://github.com/apps/dco) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | - Shivansh Vij @shivanshvij 2 | - Alex Sørlie Glomsaas @supermanifolds 3 | - Felicitas Pojtinger @pojntfx 4 | - Jimmy Moore @jimmyaxod 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Frisbee-Go 2 | 3 | [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://www.apache.org/licenses/LICENSE-2.0) 4 | [![Tests](https://github.com/loopholelabs/frisbee-go/actions/workflows/tests.yaml/badge.svg)](https://github.com/loopholelabs/frisbee-go/actions/workflows/tests.yaml) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/loopholelabs/frisbee-go)](https://goreportcard.com/report/github.com/loopholelabs/frisbee-go) 6 | [![go-doc](https://godoc.org/github.com/loopholelabs/frisbee-go?status.svg)](https://godoc.org/github.com/loopholelabs/frisbee-go) 7 | [![Discord](https://dcbadge.vercel.app/api/server/JYmFhtdPeu?style=flat)](https://loopholelabs.io/discord) 8 | ![Go Version](https://img.shields.io/badge/go%20version-%3E=1.22-61CFDD.svg) 9 | 10 | 11 | This is the [Go](http://golang.org) implementation of **Frisbee**, a bring-your-own 12 | protocol messaging framework designed for performance and 13 | stability. 14 | 15 | ## Usage and Documentation 16 | 17 | Usage instructions and documentation for `frisbee-go` are available 18 | on [GoDoc](https://godoc.org/github.com/loopholelabs/frisbee-go). 19 | 20 | ## Contributing 21 | 22 | Bug reports and pull requests are welcome on GitHub at [https://github.com/loopholelabs/frisbee-go][gitrepo]. For more 23 | contribution information check 24 | out [the contribution guide](https://github.com/loopholelabs/frisbee-go/blob/main/CONTRIBUTING.md). 25 | 26 | ## License 27 | 28 | The Frisbee project is available as open source under the terms of 29 | the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0). 30 | 31 | ## Code of Conduct 32 | 33 | Everyone interacting in the Frisbee project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [CNCF Code of Conduct](https://github.com/cncf/foundation/blob/main/code-of-conduct.md). 34 | 35 | ## Project Managed By: 36 | 37 | [![https://loopholelabs.io][loopholelabs]](https://loopholelabs.io) 38 | 39 | [gitrepo]: https://github.com/loopholelabs/frisbee-go 40 | [loopholelabs]: https://cdn.loopholelabs.io/loopholelabs/LoopholeLabsLogo.svg 41 | [loophomepage]: https://loopholelabs.io 42 | -------------------------------------------------------------------------------- /async.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "bufio" 7 | "context" 8 | "crypto/tls" 9 | "encoding/binary" 10 | "errors" 11 | "fmt" 12 | "net" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | 17 | "github.com/loopholelabs/common/pkg/queue" 18 | "github.com/loopholelabs/logging/loggers/noop" 19 | "github.com/loopholelabs/logging/types" 20 | 21 | "github.com/loopholelabs/frisbee-go/internal/dialer" 22 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 23 | "github.com/loopholelabs/frisbee-go/pkg/packet" 24 | ) 25 | 26 | // Async is the underlying asynchronous frisbee connection which has extremely efficient read and write logic and 27 | // can handle the specific frisbee requirements. This is not meant to be used on its own, and instead is 28 | // meant to be used by frisbee client and server implementations 29 | type Async struct { 30 | sync.Mutex 31 | conn net.Conn 32 | closed atomic.Bool 33 | writer *bufio.Writer 34 | flushCh chan struct{} 35 | closeCh chan struct{} 36 | incoming *queue.Circular[packet.Packet, *packet.Packet] 37 | staleMu sync.Mutex 38 | stale []*packet.Packet 39 | logger types.Logger 40 | wg sync.WaitGroup 41 | errorMu sync.RWMutex 42 | error error 43 | streamsMu sync.Mutex 44 | streams map[uint16]*Stream 45 | newStreamHandlerMu sync.Mutex 46 | newStreamHandler NewStreamHandler 47 | } 48 | 49 | // ConnectAsync creates a new TCP connection (using net.Dial) and wraps it in a frisbee connection 50 | func ConnectAsync(addr string, keepAlive time.Duration, logger types.Logger, TLSConfig *tls.Config, streamHandler ...NewStreamHandler) (*Async, error) { 51 | var conn net.Conn 52 | var err error 53 | 54 | d := dialer.NewRetry() 55 | 56 | if TLSConfig != nil { 57 | conn, err = d.DialTLS("tcp", addr, TLSConfig) 58 | } else { 59 | conn, err = d.Dial("tcp", addr) 60 | if err == nil { 61 | _ = conn.(*net.TCPConn).SetKeepAlive(true) 62 | _ = conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) 63 | } 64 | } 65 | 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | return NewAsync(conn, logger, streamHandler...), nil 71 | } 72 | 73 | // NewAsync takes an existing net.Conn object and wraps it in a frisbee connection 74 | func NewAsync(c net.Conn, logger types.Logger, streamHandler ...NewStreamHandler) (conn *Async) { 75 | conn = &Async{ 76 | conn: c, 77 | writer: bufio.NewWriterSize(c, DefaultBufferSize), 78 | incoming: queue.NewCircular[packet.Packet, *packet.Packet](DefaultBufferSize), 79 | flushCh: make(chan struct{}, 3), 80 | closeCh: make(chan struct{}), 81 | streams: make(map[uint16]*Stream), 82 | logger: logger, 83 | } 84 | 85 | if logger == nil { 86 | conn.logger = noop.New(types.InfoLevel) 87 | } 88 | 89 | if len(streamHandler) > 0 && streamHandler[0] != nil { 90 | conn.newStreamHandler = streamHandler[0] 91 | } 92 | 93 | conn.wg.Add(1) 94 | go conn.flushLoop() 95 | 96 | conn.wg.Add(1) 97 | go conn.readLoop() 98 | 99 | conn.wg.Add(1) 100 | go conn.pingLoop() 101 | 102 | return 103 | } 104 | 105 | // SetDeadline sets the read and write deadline on the underlying net.Conn 106 | func (c *Async) SetDeadline(t time.Time) error { 107 | if c.closed.Load() { 108 | return ConnectionClosed 109 | } 110 | return c.conn.SetDeadline(t) 111 | } 112 | 113 | // SetReadDeadline sets the read deadline on the underlying net.Conn 114 | func (c *Async) SetReadDeadline(t time.Time) error { 115 | if c.closed.Load() { 116 | return ConnectionClosed 117 | } 118 | return c.conn.SetReadDeadline(t) 119 | } 120 | 121 | // SetWriteDeadline sets the write deadline on the underlying net.Conn 122 | func (c *Async) SetWriteDeadline(t time.Time) error { 123 | if c.closed.Load() { 124 | return ConnectionClosed 125 | } 126 | return c.conn.SetWriteDeadline(t) 127 | } 128 | 129 | // ConnectionState returns the tls.ConnectionState of a *tls.Conn 130 | // if the connection is not *tls.Conn then the NotTLSConnectionError is returned 131 | func (c *Async) ConnectionState() (tls.ConnectionState, error) { 132 | if tlsConn, ok := c.conn.(*tls.Conn); ok { 133 | return tlsConn.ConnectionState(), nil 134 | } 135 | return emptyState, NotTLSConnectionError 136 | } 137 | 138 | // Handshake performs the tls.Handshake() of a *tls.Conn 139 | // if the connection is not *tls.Conn then the NotTLSConnectionError is returned 140 | func (c *Async) Handshake() error { 141 | if tlsConn, ok := c.conn.(*tls.Conn); ok { 142 | return tlsConn.Handshake() 143 | } 144 | return NotTLSConnectionError 145 | } 146 | 147 | // HandshakeContext performs the tls.HandshakeContext() of a *tls.Conn 148 | // if the connection is not *tls.Conn then the NotTLSConnectionError is returned 149 | func (c *Async) HandshakeContext(ctx context.Context) error { 150 | if tlsConn, ok := c.conn.(*tls.Conn); ok { 151 | return tlsConn.HandshakeContext(ctx) 152 | } 153 | return NotTLSConnectionError 154 | } 155 | 156 | // LocalAddr returns the local address of the underlying net.Conn 157 | func (c *Async) LocalAddr() net.Addr { 158 | return c.conn.LocalAddr() 159 | } 160 | 161 | // RemoteAddr returns the remote address of the underlying net.Conn 162 | func (c *Async) RemoteAddr() net.Addr { 163 | return c.conn.RemoteAddr() 164 | } 165 | 166 | // CloseChannel returns a channel that can be listened to for a close event on a frisbee connection 167 | func (c *Async) CloseChannel() <-chan struct{} { 168 | return c.closeCh 169 | } 170 | 171 | // WritePacket takes a packet.Packet and queues it up to send asynchronously. 172 | // 173 | // If packet.Metadata.ContentLength == 0, then the content array's length must be 0. Otherwise, it is required that packet.Metadata.ContentLength == len(content). 174 | func (c *Async) WritePacket(p *packet.Packet) error { 175 | if p.Metadata.Operation <= RESERVED9 { 176 | return InvalidOperation 177 | } 178 | return c.writePacket(p, true) 179 | } 180 | 181 | // ReadPacket is a blocking function that will wait until a Frisbee packet is available and then return it (and its content). 182 | // In the event that the connection is closed, ReadPacket will return an error. 183 | func (c *Async) ReadPacket() (*packet.Packet, error) { 184 | if c.closed.Load() { 185 | c.staleMu.Lock() 186 | if len(c.stale) > 0 { 187 | var p *packet.Packet 188 | p, c.stale = c.stale[0], c.stale[1:] 189 | c.staleMu.Unlock() 190 | return p, nil 191 | } 192 | c.staleMu.Unlock() 193 | c.Logger().Debug().Err(ConnectionClosed).Msg("error while popping from packet queue") 194 | return nil, ConnectionClosed 195 | } 196 | 197 | readPacket, err := c.incoming.Pop() 198 | if err != nil { 199 | if c.closed.Load() { 200 | c.staleMu.Lock() 201 | if len(c.stale) > 0 { 202 | var p *packet.Packet 203 | p, c.stale = c.stale[0], c.stale[1:] 204 | c.staleMu.Unlock() 205 | return p, nil 206 | } 207 | c.staleMu.Unlock() 208 | c.Logger().Debug().Err(ConnectionClosed).Msg("error while popping from packet queue") 209 | return nil, ConnectionClosed 210 | } 211 | c.Logger().Debug().Err(err).Msg("error while popping from packet queue") 212 | return nil, err 213 | } 214 | 215 | return readPacket, nil 216 | } 217 | 218 | // Flush allows for synchronous messaging by flushing the write buffer and instantly sending packets 219 | func (c *Async) Flush() error { 220 | err := c.flush() 221 | if err != nil { 222 | return c.closeWithError(err) 223 | } 224 | return nil 225 | } 226 | 227 | // WriteBufferSize returns the size of the underlying write buffer (used for internal packet handling and for heartbeat logic) 228 | func (c *Async) WriteBufferSize() int { 229 | c.Lock() 230 | if c.closed.Load() { 231 | c.Unlock() 232 | return 0 233 | } 234 | i := c.writer.Buffered() 235 | c.Unlock() 236 | return i 237 | } 238 | 239 | // Logger returns the underlying logger of the frisbee connection 240 | func (c *Async) Logger() types.Logger { 241 | return c.logger 242 | } 243 | 244 | // Error returns the error that caused the frisbee.Async connection to close 245 | func (c *Async) Error() error { 246 | c.errorMu.RLock() 247 | defer c.errorMu.RUnlock() 248 | return c.error 249 | } 250 | 251 | // Closed returns whether the frisbee.Async connection is closed 252 | func (c *Async) Closed() bool { 253 | return c.closed.Load() 254 | } 255 | 256 | // Raw shuts off all of frisbee's underlying functionality and converts the frisbee connection into a normal TCP connection (net.Conn) 257 | func (c *Async) Raw() net.Conn { 258 | _ = c.close() 259 | return c.conn 260 | } 261 | 262 | // NewStream returns a new stream that can be used to send and receive packets 263 | func (c *Async) NewStream(id uint16) (stream *Stream) { 264 | c.streamsMu.Lock() 265 | if stream = c.streams[id]; stream == nil { 266 | stream = newStream(id, c) 267 | c.streams[id] = stream 268 | } 269 | c.streamsMu.Unlock() 270 | return 271 | } 272 | 273 | // SetNewStreamHandler sets the callback handler for new streams. 274 | // 275 | // It's important to note that this handler is called for new streams and if it is 276 | // not set then stream packets will be dropped. 277 | // 278 | // It's also important to note that the handler itself is called in its own goroutine to 279 | // avoid blocking the read lop. This means that the handler must be thread-safe.` 280 | func (c *Async) SetNewStreamHandler(handler NewStreamHandler) { 281 | c.newStreamHandlerMu.Lock() 282 | c.newStreamHandler = handler 283 | c.newStreamHandlerMu.Unlock() 284 | } 285 | 286 | // Close closes the frisbee connection gracefully 287 | func (c *Async) Close() error { 288 | err := c.close() 289 | if err != nil && errors.Is(err, ConnectionClosed) { 290 | return nil 291 | } 292 | _ = c.conn.Close() 293 | return err 294 | } 295 | 296 | // write packet is the internal write packet function that does not check for reserved operations. 297 | func (c *Async) writePacket(p *packet.Packet, closeOnErr bool) error { 298 | if int(p.Metadata.ContentLength) != p.Content.Len() { 299 | return InvalidContentLength 300 | } 301 | if DefaultMaxContentLength > 0 && p.Metadata.ContentLength > DefaultMaxContentLength { 302 | return ContentLengthExceeded 303 | } 304 | 305 | encodedMetadata := metadata.GetBuffer() 306 | binary.BigEndian.PutUint16(encodedMetadata[metadata.MagicOffset:metadata.MagicOffset+metadata.MagicSize], metadata.PacketMagicHeader) 307 | binary.BigEndian.PutUint16(encodedMetadata[metadata.IdOffset:metadata.IdOffset+metadata.IdSize], p.Metadata.Id) 308 | binary.BigEndian.PutUint16(encodedMetadata[metadata.OperationOffset:metadata.OperationOffset+metadata.OperationSize], p.Metadata.Operation) 309 | binary.BigEndian.PutUint32(encodedMetadata[metadata.ContentLengthOffset:metadata.ContentLengthOffset+metadata.ContentLengthSize], p.Metadata.ContentLength) 310 | 311 | c.Lock() 312 | if c.closed.Load() { 313 | c.Unlock() 314 | return ConnectionClosed 315 | } 316 | err := c.conn.SetWriteDeadline(time.Now().Add(DefaultDeadline)) 317 | if err != nil { 318 | c.Unlock() 319 | if c.closed.Load() { 320 | c.Logger().Debug().Err(ConnectionClosed).Uint16("Packet ID", p.Metadata.Id).Msg("error while setting write deadline before writing packet") 321 | return ConnectionClosed 322 | } 323 | c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while setting write deadline before writing packet") 324 | if closeOnErr { 325 | return c.closeWithError(err) 326 | } 327 | return err 328 | } 329 | _, err = c.writer.Write(encodedMetadata[:]) 330 | metadata.PutBuffer(encodedMetadata) 331 | if err != nil { 332 | c.Unlock() 333 | if c.closed.Load() { 334 | c.Logger().Debug().Err(ConnectionClosed).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing encoded metadata") 335 | return ConnectionClosed 336 | } 337 | c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing encoded metadata") 338 | if closeOnErr { 339 | return c.closeWithError(err) 340 | } 341 | return err 342 | } 343 | if p.Metadata.ContentLength != 0 { 344 | _, err = c.writer.Write(p.Content.Bytes()[:p.Metadata.ContentLength]) 345 | if err != nil { 346 | c.Unlock() 347 | if c.closed.Load() { 348 | c.Logger().Debug().Err(ConnectionClosed).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing packet content") 349 | return ConnectionClosed 350 | } 351 | c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing packet content") 352 | if closeOnErr { 353 | return c.closeWithError(err) 354 | } 355 | return err 356 | } 357 | } 358 | 359 | if len(c.flushCh) == 0 { 360 | select { 361 | case c.flushCh <- struct{}{}: 362 | default: 363 | } 364 | } 365 | 366 | c.Unlock() 367 | 368 | return nil 369 | } 370 | 371 | // flush is an internal function for flushing data from the write buffer, however 372 | // it is unique in that it does not call closeWithError (and so does not try and close the underlying connection) 373 | // when it encounters an error, and instead leaves that responsibility to its parent caller 374 | func (c *Async) flush() error { 375 | c.Lock() 376 | if c.closed.Load() { 377 | c.Unlock() 378 | return ConnectionClosed 379 | } 380 | if c.writer.Buffered() > 0 { 381 | err := c.conn.SetWriteDeadline(time.Now().Add(DefaultDeadline)) 382 | if err != nil { 383 | c.Unlock() 384 | return err 385 | } 386 | err = c.writer.Flush() 387 | if err != nil { 388 | c.Unlock() 389 | c.Logger().Error().Err(err).Msg("error while flushing data") 390 | return err 391 | } 392 | } 393 | c.Unlock() 394 | return nil 395 | } 396 | 397 | func (c *Async) close() error { 398 | c.staleMu.Lock() 399 | c.streamsMu.Lock() 400 | if c.closed.CompareAndSwap(false, true) { 401 | c.Logger().Debug().Msg("connection close called, killing goroutines") 402 | c.Lock() 403 | c.incoming.Close() 404 | close(c.closeCh) 405 | close(c.flushCh) 406 | c.Unlock() 407 | _ = c.conn.SetDeadline(pastTime) 408 | c.wg.Wait() 409 | _ = c.conn.SetDeadline(emptyTime) 410 | c.stale = c.incoming.Drain() 411 | c.staleMu.Unlock() 412 | for _, stream := range c.streams { 413 | _ = stream.closeSend(false) 414 | } 415 | c.streamsMu.Unlock() 416 | c.Lock() 417 | if c.writer.Buffered() > 0 { 418 | _ = c.conn.SetWriteDeadline(time.Now().Add(DefaultDeadline)) 419 | _ = c.writer.Flush() 420 | _ = c.conn.SetWriteDeadline(emptyTime) 421 | } 422 | c.Unlock() 423 | return nil 424 | } 425 | c.staleMu.Unlock() 426 | c.streamsMu.Unlock() 427 | return ConnectionClosed 428 | } 429 | 430 | func (c *Async) closeWithError(err error) error { 431 | c.errorMu.Lock() 432 | defer c.errorMu.Unlock() 433 | 434 | c.error = err 435 | closeError := c.close() 436 | if closeError != nil { 437 | c.Logger().Debug().Err(closeError).Msgf("attempted to close connection with error `%s`, but got error while closing", err) 438 | c.error = errors.Join(closeError, err) 439 | return c.error 440 | } 441 | _ = c.conn.Close() 442 | return err 443 | } 444 | 445 | func (c *Async) flushLoop() { 446 | var err error 447 | for { 448 | if _, ok := <-c.flushCh; !ok { 449 | c.wg.Done() 450 | return 451 | } 452 | err = c.flush() 453 | if err != nil { 454 | c.wg.Done() 455 | _ = c.closeWithError(err) 456 | return 457 | } 458 | } 459 | } 460 | 461 | func (c *Async) pingLoop() { 462 | ticker := time.NewTicker(DefaultPingInterval) 463 | defer ticker.Stop() 464 | var err error 465 | for { 466 | select { 467 | case <-c.closeCh: 468 | c.wg.Done() 469 | return 470 | case <-ticker.C: 471 | err = c.writePacket(PINGPacket, false) 472 | if err != nil { 473 | c.wg.Done() 474 | _ = c.closeWithError(err) 475 | return 476 | } 477 | } 478 | } 479 | } 480 | 481 | func (c *Async) readLoop() { 482 | buf := make([]byte, DefaultBufferSize) 483 | var index int 484 | var stream *Stream 485 | var isStream bool 486 | var newStreamHandler NewStreamHandler 487 | for { 488 | buf = buf[:cap(buf)] 489 | if len(buf) < metadata.Size { 490 | c.Logger().Debug().Err(InvalidBufferLength).Msg("error during read loop, calling closeWithError") 491 | c.wg.Done() 492 | _ = c.closeWithError(InvalidBufferLength) 493 | return 494 | } 495 | 496 | var n int 497 | var err error 498 | for n < metadata.Size { 499 | var nn int 500 | err = c.conn.SetReadDeadline(time.Now().Add(DefaultDeadline)) 501 | if err != nil { 502 | c.Logger().Debug().Err(err).Msg("error setting read deadline during read loop, calling closeWithError") 503 | c.wg.Done() 504 | _ = c.closeWithError(err) 505 | return 506 | } 507 | nn, err = c.conn.Read(buf[n:]) 508 | n += nn 509 | if err != nil { 510 | if n < metadata.Size { 511 | c.wg.Done() 512 | _ = c.closeWithError(err) 513 | return 514 | } 515 | break 516 | } 517 | } 518 | 519 | index = 0 520 | for index < n { 521 | p := packet.Get() 522 | p.Metadata.Magic = binary.BigEndian.Uint16(buf[index+metadata.MagicOffset : index+metadata.MagicOffset+metadata.MagicSize]) 523 | p.Metadata.Id = binary.BigEndian.Uint16(buf[index+metadata.IdOffset : index+metadata.IdOffset+metadata.IdSize]) 524 | p.Metadata.Operation = binary.BigEndian.Uint16(buf[index+metadata.OperationOffset : index+metadata.OperationOffset+metadata.OperationSize]) 525 | p.Metadata.ContentLength = binary.BigEndian.Uint32(buf[index+metadata.ContentLengthOffset : index+metadata.ContentLengthOffset+metadata.ContentLengthSize]) 526 | index += metadata.Size 527 | 528 | if p.Metadata.Magic != metadata.PacketMagicHeader { 529 | c.Logger().Debug().Str("magic", fmt.Sprintf("0x%04x", p.Metadata.Magic)).Msg("received packet with incorrect magic header") 530 | c.wg.Done() 531 | _ = c.closeWithError(InvalidMagicHeader) 532 | return 533 | } 534 | 535 | if DefaultMaxContentLength > 0 && p.Metadata.ContentLength > DefaultMaxContentLength { 536 | c.Logger().Debug(). 537 | Uint32("content_length", p.Metadata.ContentLength). 538 | Uint32("max_content_length", DefaultMaxContentLength). 539 | Msg("received packet that exceeds max content length") 540 | 541 | c.wg.Done() 542 | _ = c.closeWithError(ContentLengthExceeded) 543 | return 544 | } 545 | 546 | switch p.Metadata.Operation { 547 | case PING: 548 | c.Logger().Trace().Msg("PING Packet received by read loop, sending back PONG packet") 549 | err = c.writePacket(PONGPacket, false) 550 | if err != nil { 551 | c.wg.Done() 552 | _ = c.closeWithError(err) 553 | return 554 | } 555 | packet.Put(p) 556 | case PONG: 557 | c.Logger().Trace().Msg("PONG Packet received by read loop") 558 | packet.Put(p) 559 | case STREAM: 560 | c.Logger().Trace().Msg("STREAM Packet received by read loop") 561 | isStream = true 562 | c.newStreamHandlerMu.Lock() 563 | newStreamHandler = c.newStreamHandler 564 | c.newStreamHandlerMu.Unlock() 565 | if newStreamHandler != nil || p.Metadata.ContentLength == 0 { 566 | c.streamsMu.Lock() 567 | stream = c.streams[p.Metadata.Id] 568 | c.streamsMu.Unlock() 569 | } 570 | fallthrough 571 | default: 572 | if p.Metadata.ContentLength > 0 { 573 | if n-index < int(p.Metadata.ContentLength) { 574 | minSize := int(p.Metadata.ContentLength) - p.Content.Write(buf[index:n]) 575 | n = 0 576 | for cap(buf) < minSize { 577 | buf = append(buf[:cap(buf)], 0) 578 | } 579 | buf = buf[:cap(buf)] 580 | for n < minSize { 581 | var nn int 582 | err = c.conn.SetReadDeadline(time.Now().Add(DefaultDeadline)) 583 | if err != nil { 584 | c.wg.Done() 585 | _ = c.closeWithError(err) 586 | return 587 | } 588 | nn, err = c.conn.Read(buf[n:]) 589 | n += nn 590 | if err != nil { 591 | if n < minSize { 592 | c.wg.Done() 593 | _ = c.closeWithError(err) 594 | return 595 | } 596 | break 597 | } 598 | } 599 | p.Content.Write(buf[:minSize]) 600 | index = minSize 601 | } else { 602 | index += p.Content.Write(buf[index : index+int(p.Metadata.ContentLength)]) 603 | } 604 | } 605 | if !isStream { 606 | err = c.incoming.Push(p) 607 | if err != nil { 608 | c.Logger().Debug().Err(err).Msg("error while pushing to incoming packet queue") 609 | c.wg.Done() 610 | _ = c.closeWithError(err) 611 | return 612 | } 613 | } else { 614 | if p.Metadata.ContentLength == 0 { 615 | if stream != nil { 616 | stream.close() 617 | c.streamsMu.Lock() 618 | delete(c.streams, p.Metadata.Id) 619 | c.streamsMu.Unlock() 620 | } 621 | packet.Put(p) 622 | } else { 623 | if newStreamHandler == nil { 624 | c.Logger().Debug().Msg("STREAM Packet discarded by read loop") 625 | packet.Put(p) 626 | } else { 627 | if stream == nil { 628 | stream = newStream(p.Metadata.Id, c) 629 | c.streamsMu.Lock() 630 | c.streams[p.Metadata.Id] = stream 631 | c.streamsMu.Unlock() 632 | go newStreamHandler(stream) 633 | } 634 | err = stream.queue.Push(p) 635 | if err != nil { 636 | c.Logger().Debug().Err(err).Msg("error while pushing to a stream queue packet queue") 637 | c.wg.Done() 638 | _ = c.closeWithError(err) 639 | return 640 | } 641 | } 642 | } 643 | } 644 | newStreamHandler = nil 645 | stream = nil 646 | isStream = false 647 | } 648 | if n == index { 649 | index = 0 650 | buf = buf[:cap(buf)] 651 | if len(buf) < metadata.Size { 652 | c.wg.Done() 653 | _ = c.closeWithError(InvalidBufferLength) 654 | return 655 | } 656 | n = 0 657 | for n < metadata.Size { 658 | var nn int 659 | err = c.conn.SetReadDeadline(time.Now().Add(DefaultDeadline)) 660 | if err != nil { 661 | c.wg.Done() 662 | _ = c.closeWithError(err) 663 | return 664 | } 665 | nn, err = c.conn.Read(buf[n:]) 666 | n += nn 667 | if err != nil { 668 | if n < metadata.Size { 669 | c.wg.Done() 670 | _ = c.closeWithError(err) 671 | return 672 | } 673 | break 674 | } 675 | } 676 | } else if n-index < metadata.Size { 677 | copy(buf, buf[index:n]) 678 | n -= index 679 | index = n 680 | 681 | buf = buf[:cap(buf)] 682 | minSize := metadata.Size - index 683 | if len(buf) < minSize { 684 | c.wg.Done() 685 | _ = c.closeWithError(InvalidBufferLength) 686 | return 687 | } 688 | n = 0 689 | for n < minSize { 690 | var nn int 691 | err = c.conn.SetReadDeadline(time.Now().Add(DefaultDeadline)) 692 | if err != nil { 693 | c.wg.Done() 694 | _ = c.closeWithError(err) 695 | return 696 | } 697 | nn, err = c.conn.Read(buf[index+n:]) 698 | n += nn 699 | if err != nil { 700 | if n < minSize { 701 | c.wg.Done() 702 | _ = c.closeWithError(err) 703 | return 704 | } 705 | break 706 | } 707 | } 708 | n += index 709 | index = 0 710 | } 711 | } 712 | } 713 | } 714 | -------------------------------------------------------------------------------- /async_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "crypto/rand" 7 | "encoding/binary" 8 | "io" 9 | "net" 10 | "runtime" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/require" 17 | 18 | "github.com/loopholelabs/logging" 19 | "github.com/loopholelabs/polyglot/v2" 20 | "github.com/loopholelabs/testing/conn/pair" 21 | 22 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 23 | "github.com/loopholelabs/frisbee-go/pkg/packet" 24 | ) 25 | 26 | func TestNewAsync(t *testing.T) { 27 | t.Parallel() 28 | 29 | const packetSize = 512 30 | 31 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 32 | 33 | reader, writer := net.Pipe() 34 | 35 | readerConn := NewAsync(reader, emptyLogger) 36 | writerConn := NewAsync(writer, emptyLogger) 37 | 38 | p := packet.Get() 39 | p.Metadata.Id = 64 40 | p.Metadata.Operation = 32 41 | 42 | err := writerConn.WritePacket(p) 43 | require.NoError(t, err) 44 | packet.Put(p) 45 | 46 | p, err = readerConn.ReadPacket() 47 | require.NoError(t, err) 48 | require.NotNil(t, p.Metadata) 49 | assert.Equal(t, uint16(64), p.Metadata.Id) 50 | assert.Equal(t, uint16(32), p.Metadata.Operation) 51 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 52 | assert.Equal(t, 0, p.Content.Len()) 53 | 54 | data := make([]byte, packetSize) 55 | _, _ = rand.Read(data) 56 | 57 | p.Content.Write(data) 58 | p.Metadata.ContentLength = packetSize 59 | 60 | err = writerConn.WritePacket(p) 61 | require.NoError(t, err) 62 | 63 | packet.Put(p) 64 | 65 | p, err = readerConn.ReadPacket() 66 | require.NoError(t, err) 67 | assert.NotNil(t, p.Metadata) 68 | assert.Equal(t, uint16(64), p.Metadata.Id) 69 | assert.Equal(t, uint16(32), p.Metadata.Operation) 70 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 71 | assert.Equal(t, len(data), p.Content.Len()) 72 | expected := polyglot.NewBufferFromBytes(data) 73 | expected.MoveOffset(len(data)) 74 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 75 | 76 | packet.Put(p) 77 | 78 | err = readerConn.Close() 79 | assert.NoError(t, err) 80 | err = writerConn.Close() 81 | assert.NoError(t, err) 82 | } 83 | 84 | func TestAsyncLargeWrite(t *testing.T) { 85 | t.Parallel() 86 | 87 | const testSize = 100000 88 | const packetSize = 512 89 | 90 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 91 | 92 | reader, writer := net.Pipe() 93 | 94 | readerConn := NewAsync(reader, emptyLogger) 95 | writerConn := NewAsync(writer, emptyLogger) 96 | 97 | randomData := make([][]byte, testSize) 98 | p := packet.Get() 99 | p.Metadata.Id = 64 100 | p.Metadata.Operation = 32 101 | p.Metadata.ContentLength = packetSize 102 | 103 | for i := 0; i < testSize; i++ { 104 | randomData[i] = make([]byte, packetSize) 105 | _, _ = rand.Read(randomData[i]) 106 | p.Content.Write(randomData[i]) 107 | err := writerConn.WritePacket(p) 108 | p.Content.Reset() 109 | assert.NoError(t, err) 110 | } 111 | packet.Put(p) 112 | 113 | for i := 0; i < testSize; i++ { 114 | p, err := readerConn.ReadPacket() 115 | assert.NoError(t, err) 116 | assert.NotNil(t, p.Metadata) 117 | assert.Equal(t, uint16(64), p.Metadata.Id) 118 | assert.Equal(t, uint16(32), p.Metadata.Operation) 119 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 120 | assert.Equal(t, len(randomData[i]), p.Content.Len()) 121 | expected := polyglot.NewBufferFromBytes(randomData[i]) 122 | expected.MoveOffset(len(randomData[i])) 123 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 124 | packet.Put(p) 125 | } 126 | 127 | // Verify large writes past max length fails. 128 | big := make([]byte, 2*DefaultMaxContentLength) 129 | 130 | p = packet.Get() 131 | p.Metadata.Id = 64 132 | p.Metadata.Operation = 32 133 | p.Metadata.ContentLength = uint32(len(big)) 134 | p.Content.Write(big) 135 | 136 | err := writerConn.WritePacket(p) 137 | assert.ErrorIs(t, err, ContentLengthExceeded) 138 | p.Content.Reset() 139 | packet.Put(p) 140 | 141 | err = readerConn.Close() 142 | assert.NoError(t, err) 143 | err = writerConn.Close() 144 | assert.NoError(t, err) 145 | } 146 | 147 | func TestAsyncLargeRead(t *testing.T) { 148 | t.Parallel() 149 | 150 | client, server, err := pair.New() 151 | require.NoError(t, err) 152 | 153 | serverConn := NewAsync(server, logging.Test(t, logging.Noop, t.Name())) 154 | t.Cleanup(func() { serverConn.Close() }) 155 | 156 | // Write a large packet that exceeds the maximum limit. 157 | bigData := make([]byte, 2*DefaultMaxContentLength+metadata.Size) 158 | binary.BigEndian.PutUint16(bigData[metadata.MagicOffset:metadata.MagicOffset+metadata.MagicSize], metadata.PacketMagicHeader) 159 | binary.BigEndian.PutUint16(bigData[metadata.IdOffset:metadata.IdOffset+metadata.IdSize], 0xFFFF) 160 | binary.BigEndian.PutUint16(bigData[metadata.OperationOffset:metadata.OperationOffset+metadata.OperationSize], 0xFFFF) 161 | binary.BigEndian.PutUint32(bigData[metadata.ContentLengthOffset:metadata.ContentLengthOffset+metadata.ContentLengthSize], 2*DefaultMaxContentLength) 162 | 163 | _, err = client.Write(bigData) 164 | 165 | // Verify client was disconnected. 166 | var opError *net.OpError 167 | assert.ErrorAs(t, err, &opError) 168 | 169 | // Verify server connection was closed with the expected error. 170 | _, err = serverConn.ReadPacket() 171 | assert.ErrorIs(t, err, ConnectionClosed) 172 | assert.ErrorIs(t, serverConn.Error(), ContentLengthExceeded) 173 | } 174 | 175 | func TestAsyncDisableMaxContentLength(t *testing.T) { 176 | // Don't run in parallel since it modifies DefaultMaxContentLength. 177 | 178 | oldMax := DisableMaxContentLength(t) 179 | logger := logging.Test(t, logging.Noop, t.Name()) 180 | 181 | t.Run("read", func(t *testing.T) { 182 | client, server, err := pair.New() 183 | require.NoError(t, err) 184 | 185 | serverConn := NewAsync(server, logger) 186 | t.Cleanup(func() { serverConn.Close() }) 187 | 188 | // Write a large packet that would exceed the default maximum limit. 189 | bigData := make([]byte, 2*oldMax+metadata.Size) 190 | binary.BigEndian.PutUint16(bigData[metadata.MagicOffset:metadata.MagicOffset+metadata.MagicSize], metadata.PacketMagicHeader) 191 | binary.BigEndian.PutUint16(bigData[metadata.IdOffset:metadata.IdOffset+metadata.IdSize], 0xFFFF) 192 | binary.BigEndian.PutUint16(bigData[metadata.OperationOffset:metadata.OperationOffset+metadata.OperationSize], 0xFFFF) 193 | binary.BigEndian.PutUint32(bigData[metadata.ContentLengthOffset:metadata.ContentLengthOffset+metadata.ContentLengthSize], 2*oldMax) 194 | 195 | n, err := client.Write(bigData) 196 | assert.NoError(t, err) 197 | assert.Equal(t, int(2*oldMax+metadata.Size), n) 198 | 199 | // Verify packet can be read. 200 | p, err := serverConn.ReadPacket() 201 | require.NoError(t, err) 202 | assert.Equal(t, int(2*oldMax), p.Content.Len()) 203 | }) 204 | 205 | t.Run("write", func(t *testing.T) { 206 | client, server, err := pair.New() 207 | require.NoError(t, err) 208 | 209 | clientConn := NewAsync(client, logger) 210 | t.Cleanup(func() { clientConn.Close() }) 211 | 212 | serverConn := NewAsync(server, logger) 213 | t.Cleanup(func() { serverConn.Close() }) 214 | 215 | doneCh := make(chan any) 216 | go func() { 217 | defer close(doneCh) 218 | 219 | p, err := serverConn.ReadPacket() 220 | assert.NoError(t, err) 221 | if err == nil { 222 | assert.Equal(t, int(2*oldMax), p.Content.Len()) 223 | } 224 | }() 225 | 226 | p := packet.Get() 227 | t.Cleanup(func() { 228 | p.Content.Reset() 229 | packet.Put(p) 230 | }) 231 | 232 | // Write a large packet that would exceed the default maximum limit. 233 | content := make([]byte, 2*oldMax) 234 | 235 | p.Metadata.Id = 64 236 | p.Metadata.Operation = 32 237 | p.Metadata.ContentLength = uint32(len(content)) 238 | p.Content.Write(content) 239 | 240 | err = clientConn.WritePacket(p) 241 | require.NoError(t, err) 242 | 243 | <-doneCh 244 | }) 245 | 246 | } 247 | 248 | func TestAsyncInvalidPacket(t *testing.T) { 249 | t.Parallel() 250 | 251 | client, server, err := pair.New() 252 | require.NoError(t, err) 253 | 254 | serverConn := NewAsync(server, logging.Test(t, logging.Noop, t.Name())) 255 | t.Cleanup(func() { serverConn.Close() }) 256 | 257 | // Write invalid data. 258 | httpReq := []byte(`GET / HTTP/1.1 259 | Host: www.example.com 260 | User-Agent: curl/8.9.1 261 | Accept: */*`) 262 | _, err = client.Write(httpReq) 263 | assert.NoError(t, err) 264 | 265 | // Verify server connection was closed with the expected error. 266 | _, err = serverConn.ReadPacket() 267 | assert.ErrorIs(t, err, ConnectionClosed) 268 | assert.ErrorIs(t, serverConn.Error(), InvalidMagicHeader) 269 | } 270 | 271 | func TestAsyncRawConn(t *testing.T) { 272 | t.Parallel() 273 | 274 | const testSize = 100000 275 | const packetSize = 32 276 | 277 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 278 | 279 | reader, writer, err := pair.New() 280 | require.NoError(t, err) 281 | 282 | readerConn := NewAsync(reader, emptyLogger) 283 | writerConn := NewAsync(writer, emptyLogger) 284 | 285 | randomData := make([]byte, packetSize) 286 | _, _ = rand.Read(randomData) 287 | 288 | p := packet.Get() 289 | p.Metadata.Id = 64 290 | p.Metadata.Operation = 32 291 | p.Content.Write(randomData) 292 | p.Metadata.ContentLength = packetSize 293 | 294 | for i := 0; i < testSize; i++ { 295 | err := writerConn.WritePacket(p) 296 | assert.NoError(t, err) 297 | } 298 | 299 | packet.Put(p) 300 | 301 | for i := 0; i < testSize; i++ { 302 | p, err := readerConn.ReadPacket() 303 | assert.NoError(t, err) 304 | assert.NotNil(t, p.Metadata) 305 | assert.Equal(t, uint16(64), p.Metadata.Id) 306 | assert.Equal(t, uint16(32), p.Metadata.Operation) 307 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 308 | assert.Equal(t, packetSize, p.Content.Len()) 309 | expected := polyglot.NewBufferFromBytes(randomData) 310 | expected.MoveOffset(len(randomData)) 311 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 312 | } 313 | 314 | rawReaderConn := readerConn.Raw() 315 | rawWriterConn := writerConn.Raw() 316 | 317 | rawWriteMessage := []byte("TEST CASE MESSAGE") 318 | 319 | written, err := rawReaderConn.Write(rawWriteMessage) 320 | assert.NoError(t, err) 321 | assert.Equal(t, len(rawWriteMessage), written) 322 | rawReadMessage := make([]byte, len(rawWriteMessage)) 323 | read, err := rawWriterConn.Read(rawReadMessage) 324 | assert.NoError(t, err) 325 | assert.Equal(t, len(rawWriteMessage), read) 326 | assert.Equal(t, rawWriteMessage, rawReadMessage) 327 | 328 | err = readerConn.Close() 329 | assert.NoError(t, err) 330 | err = writerConn.Close() 331 | assert.NoError(t, err) 332 | 333 | assert.NoError(t, pair.Cleanup(rawReaderConn, rawWriterConn)) 334 | } 335 | 336 | func TestAsyncReadClose(t *testing.T) { 337 | t.Parallel() 338 | 339 | reader, writer := net.Pipe() 340 | 341 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 342 | 343 | readerConn := NewAsync(reader, emptyLogger) 344 | writerConn := NewAsync(writer, emptyLogger) 345 | 346 | p := packet.Get() 347 | p.Metadata.Id = 64 348 | p.Metadata.Operation = 32 349 | 350 | err := writerConn.WritePacket(p) 351 | require.NoError(t, err) 352 | 353 | packet.Put(p) 354 | 355 | err = writerConn.Flush() 356 | require.NoError(t, err) 357 | 358 | p, err = readerConn.ReadPacket() 359 | require.NoError(t, err) 360 | require.NotNil(t, p.Metadata) 361 | assert.Equal(t, uint16(64), p.Metadata.Id) 362 | assert.Equal(t, uint16(32), p.Metadata.Operation) 363 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 364 | assert.Equal(t, 0, p.Content.Len()) 365 | 366 | err = readerConn.conn.Close() 367 | assert.NoError(t, err) 368 | 369 | time.Sleep(DefaultPingInterval * 2) 370 | 371 | err = writerConn.WritePacket(p) 372 | if err == nil { 373 | err = writerConn.Flush() 374 | assert.Error(t, err) 375 | } 376 | assert.Error(t, writerConn.Error()) 377 | 378 | err = readerConn.Close() 379 | assert.NoError(t, err) 380 | err = writerConn.Close() 381 | assert.NoError(t, err) 382 | } 383 | 384 | func TestAsyncReadAvailableClose(t *testing.T) { 385 | t.Parallel() 386 | 387 | reader, writer := net.Pipe() 388 | 389 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 390 | 391 | readerConn := NewAsync(reader, emptyLogger) 392 | writerConn := NewAsync(writer, emptyLogger) 393 | 394 | p := packet.Get() 395 | p.Metadata.Id = 64 396 | p.Metadata.Operation = 32 397 | 398 | err := writerConn.WritePacket(p) 399 | require.NoError(t, err) 400 | 401 | err = writerConn.WritePacket(p) 402 | require.NoError(t, err) 403 | 404 | packet.Put(p) 405 | 406 | err = writerConn.Close() 407 | require.NoError(t, err) 408 | 409 | p, err = readerConn.ReadPacket() 410 | require.NoError(t, err) 411 | require.NotNil(t, p.Metadata) 412 | assert.Equal(t, uint16(64), p.Metadata.Id) 413 | assert.Equal(t, uint16(32), p.Metadata.Operation) 414 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 415 | assert.Equal(t, 0, p.Content.Len()) 416 | 417 | p, err = readerConn.ReadPacket() 418 | require.NoError(t, err) 419 | require.NotNil(t, p.Metadata) 420 | assert.Equal(t, uint16(64), p.Metadata.Id) 421 | assert.Equal(t, uint16(32), p.Metadata.Operation) 422 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 423 | assert.Equal(t, 0, p.Content.Len()) 424 | 425 | _, err = readerConn.ReadPacket() 426 | require.Error(t, err) 427 | 428 | err = readerConn.Close() 429 | assert.NoError(t, err) 430 | err = writerConn.Close() 431 | assert.NoError(t, err) 432 | } 433 | 434 | func TestAsyncWriteClose(t *testing.T) { 435 | t.Parallel() 436 | 437 | reader, writer := net.Pipe() 438 | 439 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 440 | 441 | readerConn := NewAsync(reader, emptyLogger) 442 | writerConn := NewAsync(writer, emptyLogger) 443 | 444 | p := packet.Get() 445 | p.Metadata.Id = 64 446 | p.Metadata.Operation = 32 447 | 448 | err := writerConn.WritePacket(p) 449 | require.NoError(t, err) 450 | 451 | packet.Put(p) 452 | 453 | err = writerConn.Flush() 454 | assert.NoError(t, err) 455 | 456 | p, err = readerConn.ReadPacket() 457 | assert.NoError(t, err) 458 | assert.NotNil(t, p.Metadata) 459 | assert.Equal(t, uint16(64), p.Metadata.Id) 460 | assert.Equal(t, uint16(32), p.Metadata.Operation) 461 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 462 | assert.Equal(t, 0, p.Content.Len()) 463 | 464 | err = writerConn.WritePacket(p) 465 | assert.NoError(t, err) 466 | 467 | packet.Put(p) 468 | 469 | err = writerConn.conn.Close() 470 | assert.NoError(t, err) 471 | 472 | runtime.Gosched() 473 | time.Sleep(DefaultDeadline * 2) 474 | runtime.Gosched() 475 | 476 | _, err = readerConn.ReadPacket() 477 | assert.ErrorIs(t, err, ConnectionClosed) 478 | assert.ErrorIs(t, readerConn.Error(), io.EOF) 479 | 480 | err = readerConn.Close() 481 | assert.NoError(t, err) 482 | err = writerConn.Close() 483 | assert.NoError(t, err) 484 | } 485 | 486 | func TestAsyncTimeout(t *testing.T) { 487 | t.Parallel() 488 | 489 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 490 | 491 | reader, writer, err := pair.New() 492 | require.NoError(t, err) 493 | 494 | readerConn := NewAsync(reader, emptyLogger) 495 | writerConn := NewAsync(writer, emptyLogger) 496 | 497 | p := packet.Get() 498 | p.Metadata.Id = 64 499 | p.Metadata.Operation = 32 500 | 501 | err = writerConn.WritePacket(p) 502 | assert.NoError(t, err) 503 | 504 | packet.Put(p) 505 | 506 | err = writerConn.Flush() 507 | assert.NoError(t, err) 508 | 509 | p, err = readerConn.ReadPacket() 510 | assert.NoError(t, err) 511 | assert.NotNil(t, p.Metadata) 512 | assert.Equal(t, uint16(64), p.Metadata.Id) 513 | assert.Equal(t, uint16(32), p.Metadata.Operation) 514 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 515 | assert.Equal(t, 0, p.Content.Len()) 516 | 517 | time.Sleep(DefaultDeadline * 2) 518 | 519 | err = writerConn.Error() 520 | require.NoError(t, err) 521 | 522 | err = writerConn.WritePacket(p) 523 | require.NoError(t, err) 524 | 525 | err = writerConn.Flush() 526 | require.NoError(t, err) 527 | 528 | packet.Put(p) 529 | 530 | time.Sleep(DefaultDeadline) 531 | require.Equal(t, 1, readerConn.incoming.Length()) 532 | 533 | err = writerConn.conn.Close() 534 | require.NoError(t, err) 535 | 536 | runtime.Gosched() 537 | time.Sleep(DefaultDeadline * 2) 538 | runtime.Gosched() 539 | 540 | p, err = readerConn.ReadPacket() 541 | require.NoError(t, err) 542 | assert.NotNil(t, p.Metadata) 543 | assert.Equal(t, uint16(64), p.Metadata.Id) 544 | assert.Equal(t, uint16(32), p.Metadata.Operation) 545 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 546 | assert.Equal(t, 0, p.Content.Len()) 547 | 548 | _, err = readerConn.ReadPacket() 549 | require.ErrorIs(t, err, ConnectionClosed) 550 | 551 | err = readerConn.Error() 552 | if err == nil { 553 | runtime.Gosched() 554 | time.Sleep(DefaultDeadline * 3) 555 | runtime.Gosched() 556 | } 557 | require.Error(t, readerConn.Error()) 558 | 559 | err = readerConn.Close() 560 | assert.NoError(t, err) 561 | err = writerConn.Close() 562 | assert.NoError(t, err) 563 | } 564 | 565 | func BenchmarkAsyncThroughputPipe(b *testing.B) { 566 | DisableMaxContentLength(b) 567 | 568 | const testSize = 100 569 | 570 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 571 | 572 | reader, writer := net.Pipe() 573 | 574 | readerConn := NewAsync(reader, emptyLogger) 575 | writerConn := NewAsync(writer, emptyLogger) 576 | 577 | b.Run("32 Bytes", throughputRunner(testSize, 32, readerConn, writerConn)) 578 | b.Run("512 Bytes", throughputRunner(testSize, 512, readerConn, writerConn)) 579 | b.Run("1024 Bytes", throughputRunner(testSize, 1024, readerConn, writerConn)) 580 | b.Run("2048 Bytes", throughputRunner(testSize, 2048, readerConn, writerConn)) 581 | b.Run("4096 Bytes", throughputRunner(testSize, 4096, readerConn, writerConn)) 582 | 583 | _ = readerConn.Close() 584 | _ = writerConn.Close() 585 | } 586 | 587 | func BenchmarkAsyncThroughputNetwork(b *testing.B) { 588 | DisableMaxContentLength(b) 589 | 590 | const testSize = 100 591 | 592 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 593 | 594 | reader, writer, err := pair.New() 595 | if err != nil { 596 | b.Fatal(err) 597 | } 598 | 599 | readerConn := NewAsync(reader, emptyLogger) 600 | writerConn := NewAsync(writer, emptyLogger) 601 | 602 | b.Run("32 Bytes", throughputRunner(testSize, 32, readerConn, writerConn)) 603 | b.Run("512 Bytes", throughputRunner(testSize, 512, readerConn, writerConn)) 604 | b.Run("1024 Bytes", throughputRunner(testSize, 1024, readerConn, writerConn)) 605 | b.Run("2048 Bytes", throughputRunner(testSize, 2048, readerConn, writerConn)) 606 | b.Run("4096 Bytes", throughputRunner(testSize, 4096, readerConn, writerConn)) 607 | 608 | _ = readerConn.Close() 609 | _ = writerConn.Close() 610 | } 611 | 612 | func BenchmarkAsyncThroughputNetworkMultiple(b *testing.B) { 613 | DisableMaxContentLength(b) 614 | 615 | const testSize = 100 616 | 617 | throughputRunner := func(testSize uint32, packetSize uint32, readerConn Conn, writerConn Conn) func(b *testing.B) { 618 | return func(b *testing.B) { 619 | var err error 620 | 621 | randomData := make([]byte, packetSize) 622 | _, _ = rand.Read(randomData) 623 | 624 | p := packet.Get() 625 | p.Metadata.Id = 64 626 | p.Metadata.Operation = 32 627 | p.Content.Write(randomData) 628 | p.Metadata.ContentLength = packetSize 629 | 630 | for i := 0; i < b.N; i++ { 631 | done := make(chan struct{}, 1) 632 | errCh := make(chan error, 1) 633 | go func() { 634 | for i := uint32(0); i < testSize; i++ { 635 | p, err := readerConn.ReadPacket() 636 | if err != nil { 637 | errCh <- err 638 | return 639 | } 640 | packet.Put(p) 641 | } 642 | done <- struct{}{} 643 | }() 644 | for i := uint32(0); i < testSize; i++ { 645 | select { 646 | case err = <-errCh: 647 | b.Fatal(err) 648 | default: 649 | err = writerConn.WritePacket(p) 650 | if err != nil { 651 | b.Fatal(err) 652 | } 653 | } 654 | } 655 | select { 656 | case <-done: 657 | continue 658 | case err = <-errCh: 659 | b.Fatal(err) 660 | } 661 | } 662 | 663 | packet.Put(p) 664 | } 665 | } 666 | 667 | runner := func(numClients int, packetSize uint32) func(b *testing.B) { 668 | return func(b *testing.B) { 669 | var wg sync.WaitGroup 670 | wg.Add(numClients) 671 | b.SetBytes(int64(testSize * packetSize)) 672 | b.ReportAllocs() 673 | for i := 0; i < numClients; i++ { 674 | go func() { 675 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 676 | 677 | reader, writer, err := pair.New() 678 | if err != nil { 679 | b.Error(err) 680 | } 681 | 682 | readerConn := NewAsync(reader, emptyLogger) 683 | writerConn := NewAsync(writer, emptyLogger) 684 | throughputRunner(testSize, packetSize, readerConn, writerConn)(b) 685 | 686 | _ = readerConn.Close() 687 | _ = writerConn.Close() 688 | wg.Done() 689 | }() 690 | } 691 | wg.Wait() 692 | } 693 | } 694 | 695 | b.Run("1 Pair, 32 Bytes", runner(1, 32)) 696 | b.Run("2 Pair, 32 Bytes", runner(2, 32)) 697 | b.Run("5 Pair, 32 Bytes", runner(5, 32)) 698 | b.Run("10 Pair, 32 Bytes", runner(10, 32)) 699 | b.Run("Half CPU Pair, 32 Bytes", runner(runtime.NumCPU()/2, 32)) 700 | b.Run("CPU Pair, 32 Bytes", runner(runtime.NumCPU(), 32)) 701 | b.Run("Double CPU Pair, 32 Bytes", runner(runtime.NumCPU()*2, 32)) 702 | 703 | b.Run("1 Pair, 512 Bytes", runner(1, 512)) 704 | b.Run("2 Pair, 512 Bytes", runner(2, 512)) 705 | b.Run("5 Pair, 512 Bytes", runner(5, 512)) 706 | b.Run("10 Pair, 512 Bytes", runner(10, 512)) 707 | b.Run("Half CPU Pair, 512 Bytes", runner(runtime.NumCPU()/2, 512)) 708 | b.Run("CPU Pair, 512 Bytes", runner(runtime.NumCPU(), 512)) 709 | b.Run("Double CPU Pair, 512 Bytes", runner(runtime.NumCPU()*2, 512)) 710 | 711 | b.Run("1 Pair, 4096 Bytes", runner(1, 4096)) 712 | b.Run("2 Pair, 4096 Bytes", runner(2, 4096)) 713 | b.Run("5 Pair, 4096 Bytes", runner(5, 4096)) 714 | b.Run("10 Pair, 4096 Bytes", runner(10, 4096)) 715 | b.Run("Half CPU Pair, 4096 Bytes", runner(runtime.NumCPU()/2, 4096)) 716 | b.Run("CPU Pair, 4096 Bytes", runner(runtime.NumCPU(), 4096)) 717 | b.Run("Double CPU Pair, 4096 Bytes", runner(runtime.NumCPU()*2, 4096)) 718 | } 719 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "context" 7 | "net" 8 | "sync" 9 | "sync/atomic" 10 | 11 | "github.com/loopholelabs/logging/types" 12 | 13 | "github.com/loopholelabs/frisbee-go/pkg/packet" 14 | ) 15 | 16 | // Client connects to a frisbee Server and can send and receive frisbee packets 17 | type Client struct { 18 | conn *Async 19 | handlerTable HandlerTable 20 | options *Options 21 | closed atomic.Bool 22 | wg sync.WaitGroup 23 | heartbeatChannel chan struct{} 24 | 25 | baseContext context.Context 26 | baseContextCancel context.CancelFunc 27 | 28 | // PacketContext is used to define packet-specific contexts based on the incoming packet 29 | // and is run whenever a new packet arrives 30 | PacketContext func(context.Context, *packet.Packet) context.Context 31 | 32 | // UpdateContext is used to update a handler-specific context whenever the returned 33 | // Action from a handler is UPDATE 34 | UpdateContext func(context.Context, *Async) context.Context 35 | 36 | // StreamContext is used to update a handler-specific context whenever a new stream is created 37 | // and is run whenever a new stream is created 38 | StreamContext func(context.Context, *Stream) context.Context 39 | } 40 | 41 | // NewClient returns an uninitialized frisbee Client with the registered ClientRouter. 42 | // The ConnectAsync method must then be called to dial the server and initialize the connection. 43 | func NewClient(handlerTable HandlerTable, ctx context.Context, opts ...Option) (*Client, error) { 44 | for i := uint16(0); i < RESERVED9; i++ { 45 | if _, ok := handlerTable[i]; ok { 46 | return nil, InvalidHandlerTable 47 | } 48 | } 49 | 50 | options := loadOptions(opts...) 51 | var heartbeatChannel chan struct{} 52 | 53 | baseContext, baseContextCancel := context.WithCancel(ctx) 54 | 55 | return &Client{ 56 | handlerTable: handlerTable, 57 | baseContext: baseContext, 58 | baseContextCancel: baseContextCancel, 59 | options: options, 60 | heartbeatChannel: heartbeatChannel, 61 | }, nil 62 | } 63 | 64 | // Connect actually connects to the given frisbee server, and starts the reactor goroutines 65 | // to receive and handle incoming packets. If this function is called, FromConn should not be called. 66 | func (c *Client) Connect(addr string, streamHandler ...NewStreamHandler) error { 67 | c.Logger().Debug().Msgf("Connecting to %s", addr) 68 | var frisbeeConn *Async 69 | var err error 70 | frisbeeConn, err = ConnectAsync(addr, c.options.KeepAlive, c.Logger(), c.options.TLSConfig, streamHandler...) 71 | if err != nil { 72 | return err 73 | } 74 | c.conn = frisbeeConn 75 | c.Logger().Info().Msgf("Connected to %s", addr) 76 | 77 | c.wg.Add(1) 78 | go c.handleConn() 79 | c.Logger().Debug().Msgf("Connection handler started for %s", addr) 80 | return nil 81 | } 82 | 83 | // FromConn takes a pre-existing connection to a Frisbee server and starts the reactor goroutines 84 | // to receive and handle incoming packets. If this function is called, Connect should not be called. 85 | func (c *Client) FromConn(conn net.Conn, streamHandler ...NewStreamHandler) error { 86 | c.conn = NewAsync(conn, c.Logger(), streamHandler...) 87 | c.wg.Add(1) 88 | go c.handleConn() 89 | c.Logger().Debug().Msgf("Connection handler started for %s", c.conn.RemoteAddr()) 90 | return nil 91 | } 92 | 93 | // Closed checks whether this client has been closed 94 | func (c *Client) Closed() bool { 95 | return c.closed.Load() 96 | } 97 | 98 | // Error checks whether this client has an error 99 | func (c *Client) Error() error { 100 | return c.conn.Error() 101 | } 102 | 103 | // Close closes the frisbee client and kills all the goroutines 104 | func (c *Client) Close() error { 105 | if c.closed.CompareAndSwap(false, true) { 106 | c.baseContextCancel() 107 | err := c.conn.Close() 108 | if err != nil { 109 | return err 110 | } 111 | c.wg.Wait() 112 | return nil 113 | } 114 | return c.conn.Close() 115 | } 116 | 117 | // WritePacket sends a frisbee packet.Packet from the client to the server 118 | func (c *Client) WritePacket(p *packet.Packet) error { 119 | return c.conn.WritePacket(p) 120 | } 121 | 122 | // Flush flushes any queued frisbee Packets from the client to the server 123 | func (c *Client) Flush() error { 124 | return c.conn.Flush() 125 | } 126 | 127 | // CloseChannel returns a channel that can be listened to see if this client has been closed 128 | func (c *Client) CloseChannel() <-chan struct{} { 129 | return c.conn.CloseChannel() 130 | } 131 | 132 | // Raw converts the frisbee client into a normal net.Conn object, and returns it. 133 | // This is especially useful in proxying and streaming scenarios. 134 | func (c *Client) Raw() (net.Conn, error) { 135 | if c.conn == nil { 136 | return nil, ConnectionNotInitialized 137 | } 138 | if c.closed.CompareAndSwap(false, true) { 139 | conn := c.conn.Raw() 140 | c.wg.Wait() 141 | return conn, nil 142 | } 143 | return c.conn.Raw(), nil 144 | } 145 | 146 | // Stream returns a new Stream object that can be used to send and receive frisbee packets 147 | func (c *Client) Stream(id uint16) *Stream { 148 | return c.conn.NewStream(id) 149 | } 150 | 151 | // SetStreamHandler sets the callback handler for new streams. 152 | // 153 | // It's important to note that this handler is called for new streams and if it is 154 | // not set then stream packets will be dropped. 155 | // 156 | // It's also important to note that the handler itself is called in its own goroutine to 157 | // avoid blocking the read loop. This means that the handler must be thread-safe. 158 | func (c *Client) SetStreamHandler(f func(context.Context, *Stream)) { 159 | if f == nil { 160 | c.conn.SetNewStreamHandler(nil) 161 | } 162 | c.conn.SetNewStreamHandler(func(s *Stream) { 163 | streamCtx := c.baseContext 164 | if c.StreamContext != nil { 165 | streamCtx = c.StreamContext(streamCtx, s) 166 | } 167 | f(streamCtx, s) 168 | }) 169 | } 170 | 171 | // Logger returns the client's logger (useful for ClientRouter functions) 172 | func (c *Client) Logger() types.Logger { 173 | return c.options.Logger 174 | } 175 | 176 | func (c *Client) handleConn() { 177 | var p *packet.Packet 178 | var outgoing *packet.Packet 179 | var action Action 180 | var err error 181 | var handlerFunc Handler 182 | for { 183 | if c.closed.Load() { 184 | c.wg.Done() 185 | return 186 | } 187 | p, err = c.conn.ReadPacket() 188 | if err != nil { 189 | c.Logger().Debug().Err(err).Msg("error while getting packet frisbee connection") 190 | c.wg.Done() 191 | _ = c.Close() 192 | return 193 | } 194 | handlerFunc = c.handlerTable[p.Metadata.Operation] 195 | if handlerFunc != nil { 196 | packetCtx := c.baseContext 197 | if c.PacketContext != nil { 198 | packetCtx = c.PacketContext(packetCtx, p) 199 | } 200 | outgoing, action = handlerFunc(packetCtx, p) 201 | if outgoing != nil && outgoing.Metadata.ContentLength == uint32(outgoing.Content.Len()) { 202 | err = c.conn.WritePacket(outgoing) 203 | if outgoing != p { 204 | packet.Put(outgoing) 205 | } 206 | packet.Put(p) 207 | if err != nil { 208 | c.Logger().Error().Err(err).Msg("error while writing to frisbee conn") 209 | c.wg.Done() 210 | _ = c.Close() 211 | return 212 | } 213 | } else { 214 | packet.Put(p) 215 | } 216 | switch action { 217 | case NONE: 218 | case CLOSE: 219 | c.Logger().Debug().Msgf("Closing connection %s because of CLOSE action", c.conn.RemoteAddr()) 220 | c.wg.Done() 221 | _ = c.Close() 222 | return 223 | } 224 | } else { 225 | packet.Put(p) 226 | } 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "context" 7 | "crypto/rand" 8 | "net" 9 | "testing" 10 | 11 | "github.com/loopholelabs/logging" 12 | "github.com/loopholelabs/testing/conn/pair" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | 16 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 17 | "github.com/loopholelabs/frisbee-go/pkg/packet" 18 | ) 19 | 20 | const ( 21 | clientConnContextKey = "conn" 22 | ) 23 | 24 | func TestClientRaw(t *testing.T) { 25 | t.Parallel() 26 | 27 | const testSize = 100 28 | const packetSize = 512 29 | 30 | clientHandlerTable := make(HandlerTable) 31 | serverHandlerTable := make(HandlerTable) 32 | 33 | serverIsRaw := make(chan struct{}, 1) 34 | 35 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 36 | return 37 | } 38 | 39 | var rawServerConn, rawClientConn net.Conn 40 | serverHandlerTable[metadata.PacketProbe] = func(ctx context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 41 | conn := ctx.Value(clientConnContextKey).(*Async) 42 | rawServerConn = conn.Raw() 43 | serverIsRaw <- struct{}{} 44 | return 45 | } 46 | 47 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 48 | return 49 | } 50 | 51 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 52 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 53 | require.NoError(t, err) 54 | 55 | s.SetConcurrency(1) 56 | 57 | s.ConnContext = func(ctx context.Context, c *Async) context.Context { 58 | return context.WithValue(ctx, clientConnContextKey, c) 59 | } 60 | 61 | serverConn, clientConn, err := pair.New() 62 | require.NoError(t, err) 63 | 64 | go s.ServeConn(serverConn) 65 | 66 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 67 | assert.NoError(t, err) 68 | _, err = c.Raw() 69 | assert.ErrorIs(t, ConnectionNotInitialized, err) 70 | 71 | err = c.FromConn(clientConn) 72 | require.NoError(t, err) 73 | 74 | data := make([]byte, packetSize) 75 | _, _ = rand.Read(data) 76 | 77 | p := packet.Get() 78 | p.Metadata.Operation = metadata.PacketPing 79 | p.Content.Write(data) 80 | p.Metadata.ContentLength = packetSize 81 | 82 | for q := 0; q < testSize; q++ { 83 | p.Metadata.Id = uint16(q) 84 | err := c.WritePacket(p) 85 | assert.NoError(t, err) 86 | } 87 | p.Reset() 88 | p.Metadata.Operation = metadata.PacketProbe 89 | 90 | err = c.WritePacket(p) 91 | assert.NoError(t, err) 92 | 93 | rawClientConn, err = c.Raw() 94 | require.NoError(t, err) 95 | 96 | <-serverIsRaw 97 | 98 | clientBytes := []byte("CLIENT WRITE") 99 | 100 | write, err := rawClientConn.Write(clientBytes) 101 | assert.NoError(t, err) 102 | assert.Equal(t, len(clientBytes), write) 103 | 104 | serverBuffer := make([]byte, len(clientBytes)) 105 | read, err := rawServerConn.Read(serverBuffer) 106 | assert.NoError(t, err) 107 | assert.Equal(t, len(clientBytes), read) 108 | 109 | assert.Equal(t, clientBytes, serverBuffer) 110 | 111 | err = c.Close() 112 | assert.NoError(t, err) 113 | err = rawClientConn.Close() 114 | assert.NoError(t, err) 115 | 116 | err = s.Shutdown() 117 | assert.NoError(t, err) 118 | err = rawServerConn.Close() 119 | assert.NoError(t, err) 120 | } 121 | 122 | func TestClientStaleClose(t *testing.T) { 123 | t.Parallel() 124 | 125 | const testSize = 100 126 | const packetSize = 512 127 | 128 | clientHandlerTable := make(HandlerTable) 129 | serverHandlerTable := make(HandlerTable) 130 | 131 | finished := make(chan struct{}, 1) 132 | 133 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 134 | if incoming.Metadata.Id == testSize-1 { 135 | outgoing = incoming 136 | action = CLOSE 137 | } 138 | return 139 | } 140 | 141 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 142 | finished <- struct{}{} 143 | return 144 | } 145 | 146 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 147 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 148 | require.NoError(t, err) 149 | 150 | s.SetConcurrency(1) 151 | 152 | serverConn, clientConn, err := pair.New() 153 | require.NoError(t, err) 154 | 155 | go s.ServeConn(serverConn) 156 | 157 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 158 | assert.NoError(t, err) 159 | _, err = c.Raw() 160 | assert.ErrorIs(t, ConnectionNotInitialized, err) 161 | 162 | err = c.FromConn(clientConn) 163 | require.NoError(t, err) 164 | 165 | data := make([]byte, packetSize) 166 | _, _ = rand.Read(data) 167 | 168 | p := packet.Get() 169 | p.Metadata.Operation = metadata.PacketPing 170 | p.Content.Write(data) 171 | p.Metadata.ContentLength = packetSize 172 | 173 | for q := 0; q < testSize; q++ { 174 | p.Metadata.Id = uint16(q) 175 | err := c.WritePacket(p) 176 | assert.NoError(t, err) 177 | } 178 | packet.Put(p) 179 | <-finished 180 | 181 | _, err = c.conn.ReadPacket() 182 | assert.ErrorIs(t, err, ConnectionClosed) 183 | 184 | err = c.Close() 185 | assert.NoError(t, err) 186 | 187 | err = s.Shutdown() 188 | assert.NoError(t, err) 189 | } 190 | 191 | func BenchmarkThroughputClient(b *testing.B) { 192 | DisableMaxContentLength(b) 193 | 194 | const testSize = 1<<16 - 1 195 | const packetSize = 512 196 | 197 | clientHandlerTable := make(HandlerTable) 198 | serverHandlerTable := make(HandlerTable) 199 | 200 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 201 | return 202 | } 203 | 204 | clientHandlerTable[metadata.PacketPong] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 205 | return 206 | } 207 | 208 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 209 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 210 | if err != nil { 211 | b.Fatal(err) 212 | } 213 | 214 | s.SetConcurrency(1) 215 | 216 | serverConn, clientConn, err := pair.New() 217 | if err != nil { 218 | b.Fatal(err) 219 | } 220 | 221 | go s.ServeConn(serverConn) 222 | 223 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 224 | if err != nil { 225 | b.Fatal(err) 226 | } 227 | err = c.FromConn(clientConn) 228 | if err != nil { 229 | b.Fatal(err) 230 | } 231 | 232 | data := make([]byte, packetSize) 233 | _, _ = rand.Read(data) 234 | p := packet.Get() 235 | 236 | p.Metadata.Operation = metadata.PacketPing 237 | p.Content.Write(data) 238 | p.Metadata.ContentLength = packetSize 239 | 240 | b.Run("test", func(b *testing.B) { 241 | b.SetBytes(testSize * packetSize) 242 | b.ReportAllocs() 243 | b.ResetTimer() 244 | for i := 0; i < b.N; i++ { 245 | for q := 0; q < testSize; q++ { 246 | p.Metadata.Id = uint16(q) 247 | err = c.WritePacket(p) 248 | if err != nil { 249 | b.Fatal(err) 250 | } 251 | } 252 | } 253 | }) 254 | b.StopTimer() 255 | packet.Put(p) 256 | 257 | err = c.Close() 258 | if err != nil { 259 | b.Fatal(err) 260 | } 261 | err = s.Shutdown() 262 | if err != nil { 263 | b.Fatal(err) 264 | } 265 | } 266 | 267 | func BenchmarkThroughputResponseClient(b *testing.B) { 268 | DisableMaxContentLength(b) 269 | 270 | const testSize = 1<<16 - 1 271 | const packetSize = 512 272 | 273 | clientHandlerTable := make(HandlerTable) 274 | serverHandlerTable := make(HandlerTable) 275 | 276 | finished := make(chan struct{}, 1) 277 | 278 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 279 | if incoming.Metadata.Id == testSize-1 { 280 | incoming.Reset() 281 | incoming.Metadata.Id = testSize 282 | incoming.Metadata.Operation = metadata.PacketPong 283 | outgoing = incoming 284 | } 285 | return 286 | } 287 | 288 | clientHandlerTable[metadata.PacketPong] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 289 | if incoming.Metadata.Id == testSize { 290 | finished <- struct{}{} 291 | } 292 | return 293 | } 294 | 295 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 296 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 297 | if err != nil { 298 | b.Fatal(err) 299 | } 300 | 301 | s.SetConcurrency(1) 302 | 303 | serverConn, clientConn, err := pair.New() 304 | if err != nil { 305 | b.Fatal(err) 306 | } 307 | 308 | go s.ServeConn(serverConn) 309 | 310 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 311 | if err != nil { 312 | b.Fatal(err) 313 | } 314 | err = c.FromConn(clientConn) 315 | if err != nil { 316 | b.Fatal(err) 317 | } 318 | 319 | data := make([]byte, packetSize) 320 | _, _ = rand.Read(data) 321 | p := packet.Get() 322 | p.Metadata.Operation = metadata.PacketPing 323 | 324 | p.Content.Write(data) 325 | p.Metadata.ContentLength = packetSize 326 | 327 | b.Run("test", func(b *testing.B) { 328 | b.SetBytes(testSize * packetSize) 329 | b.ReportAllocs() 330 | b.ResetTimer() 331 | for i := 0; i < b.N; i++ { 332 | for q := 0; q < testSize; q++ { 333 | p.Metadata.Id = uint16(q) 334 | err = c.WritePacket(p) 335 | if err != nil { 336 | b.Fatal(err) 337 | } 338 | } 339 | <-finished 340 | } 341 | }) 342 | b.StopTimer() 343 | 344 | packet.Put(p) 345 | 346 | err = c.Close() 347 | if err != nil { 348 | b.Fatal(err) 349 | } 350 | err = s.Shutdown() 351 | if err != nil { 352 | b.Fatal(err) 353 | } 354 | } 355 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "context" 7 | "crypto/tls" 8 | "errors" 9 | "net" 10 | "time" 11 | 12 | "github.com/loopholelabs/logging/types" 13 | 14 | "github.com/loopholelabs/frisbee-go/pkg/packet" 15 | ) 16 | 17 | // DefaultBufferSize is the size of the default buffer 18 | const DefaultBufferSize = 1 << 16 19 | 20 | var ( 21 | DefaultDeadline = time.Second * 5 22 | DefaultPingInterval = time.Millisecond * 500 23 | DefaultMaxContentLength = uint32(5 * 1024 * 1024) // 5 MB 24 | 25 | emptyTime = time.Time{} 26 | pastTime = time.Unix(1, 0) 27 | 28 | emptyState = tls.ConnectionState{} 29 | ) 30 | 31 | var ( 32 | NotTLSConnectionError = errors.New("connection is not of type *tls.Conn") 33 | ) 34 | 35 | type Conn interface { 36 | Close() error 37 | LocalAddr() net.Addr 38 | RemoteAddr() net.Addr 39 | ConnectionState() (tls.ConnectionState, error) 40 | Handshake() error 41 | HandshakeContext(context.Context) error 42 | SetDeadline(time.Time) error 43 | SetReadDeadline(time.Time) error 44 | SetWriteDeadline(time.Time) error 45 | WritePacket(*packet.Packet) error 46 | ReadPacket() (*packet.Packet, error) 47 | Logger() types.Logger 48 | Error() error 49 | Raw() net.Conn 50 | } 51 | -------------------------------------------------------------------------------- /frisbee.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "time" 9 | 10 | "github.com/loopholelabs/polyglot/v2" 11 | 12 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 13 | "github.com/loopholelabs/frisbee-go/pkg/packet" 14 | ) 15 | 16 | // These are various frisbee errors that can be returned by the client or server: 17 | var ( 18 | InvalidContentLength = errors.New("invalid content length") 19 | ConnectionClosed = errors.New("connection closed") 20 | StreamClosed = errors.New("stream closed") 21 | InvalidStreamPacket = errors.New("invalid stream packet") 22 | ConnectionNotInitialized = errors.New("connection not initialized") 23 | InvalidMagicHeader = errors.New("invalid magic header") 24 | InvalidBufferLength = errors.New("invalid buffer length") 25 | InvalidHandlerTable = errors.New("invalid handler table configuration, a reserved value may have been used") 26 | InvalidOperation = errors.New("invalid operation in packet, a reserved value may have been used") 27 | ContentLengthExceeded = errors.New("content length exceeds maximum allowed") 28 | ) 29 | 30 | // Action is an ENUM used to modify the state of the client or server from a Handler function 31 | // 32 | // NONE: used to do nothing (default) 33 | // CLOSE: close the frisbee connection 34 | // SHUTDOWN: shutdown the frisbee client or server 35 | type Action int 36 | 37 | // These are various frisbee actions, used to modify the state of the client or server from a Handler function: 38 | const ( 39 | // NONE is used to do nothing (default) 40 | NONE = Action(iota) 41 | 42 | // CLOSE is used to close the frisbee connection 43 | CLOSE 44 | ) 45 | 46 | // Handler is the handler function called by frisbee for incoming packets of data, depending on the packet's Metadata.Operation field 47 | type Handler func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) 48 | 49 | // HandlerTable is the lookup table for Frisbee handler functions - based on the Metadata.Operation field of a packet, 50 | // Frisbee will look up the correct handler for that packet. 51 | type HandlerTable map[uint16]Handler 52 | 53 | // These are internal reserved packet types, and are the reason you cannot use 0-9 in Handler functions: 54 | const ( 55 | // PING is used to check if a client is still alive 56 | PING = uint16(iota) 57 | 58 | // PONG is used to respond to a PING packets 59 | PONG 60 | 61 | // STREAM is used to request that a new stream be created by the receiver to 62 | // receive packets with the same packet ID until a packet with a ContentLength of 0 is received 63 | STREAM 64 | 65 | RESERVED3 66 | RESERVED4 67 | RESERVED5 68 | RESERVED6 69 | RESERVED7 70 | RESERVED8 71 | RESERVED9 72 | ) 73 | 74 | var ( 75 | // PINGPacket is a pre-allocated Frisbee Packet for PING Packets 76 | PINGPacket = &packet.Packet{ 77 | Metadata: &metadata.Metadata{ 78 | Operation: PING, 79 | }, 80 | Content: polyglot.NewBuffer(), 81 | } 82 | 83 | // PONGPacket is a pre-allocated Frisbee Packet for PONG Packets 84 | PONGPacket = &packet.Packet{ 85 | Metadata: &metadata.Metadata{ 86 | Operation: PONG, 87 | }, 88 | Content: polyglot.NewBuffer(), 89 | } 90 | ) 91 | 92 | // temporary is an interface used to check if an error is recoverable 93 | type temporary interface { 94 | Temporary() bool 95 | } 96 | 97 | const ( 98 | // maxBackoff is the maximum amount ot time to wait before retrying to accept from a listener 99 | maxBackoff = time.Second 100 | 101 | // minBackoff is the minimum amount ot time to wait before retrying to accept from a listener 102 | minBackoff = time.Millisecond * 5 103 | ) 104 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/loopholelabs/frisbee-go 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/loopholelabs/common v0.4.10 7 | github.com/loopholelabs/logging v0.3.2 8 | github.com/loopholelabs/polyglot/v2 v2.0.5 9 | github.com/loopholelabs/testing v0.2.3 10 | github.com/stretchr/testify v1.10.0 11 | go.uber.org/goleak v1.3.0 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/kr/text v0.2.0 // indirect 17 | github.com/mattn/go-colorable v0.1.13 // indirect 18 | github.com/mattn/go-isatty v0.0.20 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/rogpeppe/go-internal v1.13.1 // indirect 21 | github.com/rs/zerolog v1.33.0 // indirect 22 | golang.org/x/sys v0.24.0 // indirect 23 | gopkg.in/yaml.v3 v3.0.1 // indirect 24 | ) 25 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= 2 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= 6 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 7 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 8 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 9 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 10 | github.com/loopholelabs/common v0.4.10 h1:BMJSMwH0PiVtdpOlXNPlW827B9WPJ/Gkb/q20NLeOjw= 11 | github.com/loopholelabs/common v0.4.10/go.mod h1:wc17hLpzZaDbndb7Fh3MXQDnhf4Cmf/JKC+LmXaD6II= 12 | github.com/loopholelabs/logging v0.3.2 h1:JPfQr/YcYoMlEbgwgoslQs/C2He/nPVCAD1FpdNceRo= 13 | github.com/loopholelabs/logging v0.3.2/go.mod h1:9p5/U/hNxghtUoD6k2Rzx7X4tr4ik3QzhyYUu1m7D58= 14 | github.com/loopholelabs/polyglot/v2 v2.0.5 h1:KYxebXBPnsNnJpuI6SLdOi+NLeqceV00vB97MYYsEQ8= 15 | github.com/loopholelabs/polyglot/v2 v2.0.5/go.mod h1:O0J6ScdwAy1nYlRcTKggyieUsIVaLZT6i7IZR6adHcM= 16 | github.com/loopholelabs/testing v0.2.3 h1:4nVuK5ctaE6ua5Z0dYk2l7xTFmcpCYLUeGjRBp8keOA= 17 | github.com/loopholelabs/testing v0.2.3/go.mod h1:gqtGY91soYD1fQoKQt/6kP14OYpS7gcbcIgq5mc9m8Q= 18 | github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= 19 | github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= 20 | github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= 21 | github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 22 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 23 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 24 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 25 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 26 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 27 | github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= 28 | github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= 29 | github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= 30 | github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= 31 | github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= 32 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 33 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 34 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 35 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 36 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 37 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 38 | golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 39 | golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= 40 | golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 41 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 42 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 43 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 44 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 45 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 46 | -------------------------------------------------------------------------------- /helpers_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "crypto/rand" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | "go.uber.org/goleak" 11 | 12 | "github.com/loopholelabs/frisbee-go/pkg/packet" 13 | ) 14 | 15 | func TestMain(m *testing.M) { 16 | goleak.VerifyTestMain(m) 17 | } 18 | 19 | func throughputRunner(testSize, packetSize uint32, readerConn, writerConn Conn) func(b *testing.B) { 20 | return func(b *testing.B) { 21 | b.SetBytes(int64(testSize * packetSize)) 22 | b.ReportAllocs() 23 | 24 | randomData := make([]byte, packetSize) 25 | _, err := rand.Read(randomData) 26 | require.NoError(b, err) 27 | 28 | p := packet.Get() 29 | p.Metadata.Id = 64 30 | p.Metadata.Operation = 32 31 | p.Content.Write(randomData) 32 | p.Metadata.ContentLength = packetSize 33 | 34 | b.ResetTimer() 35 | for i := 0; i < b.N; i++ { 36 | done := make(chan struct{}, 1) 37 | errCh := make(chan error, 1) 38 | go func() { 39 | for i := uint32(0); i < testSize; i++ { 40 | p, err := readerConn.ReadPacket() 41 | if err != nil { 42 | errCh <- err 43 | return 44 | } 45 | packet.Put(p) 46 | } 47 | done <- struct{}{} 48 | }() 49 | for i := uint32(0); i < testSize; i++ { 50 | select { 51 | case err = <-errCh: 52 | b.Fatal(err) 53 | default: 54 | err = writerConn.WritePacket(p) 55 | if err != nil { 56 | b.Fatal(err) 57 | } 58 | } 59 | } 60 | select { 61 | case <-done: 62 | continue 63 | case err = <-errCh: 64 | b.Fatal(err) 65 | } 66 | } 67 | b.StopTimer() 68 | 69 | packet.Put(p) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /internal/dialer/dialer.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package dialer 4 | 5 | import ( 6 | "crypto/tls" 7 | "net" 8 | "time" 9 | ) 10 | 11 | // Retry is a simple net.Dialer that retries dialing NumRetries times. 12 | type Retry struct { 13 | *net.Dialer 14 | NumRetries int 15 | } 16 | 17 | // NewRetry returns a Retry Dialer with default values. 18 | func NewRetry() *Retry { 19 | return &Retry{ 20 | Dialer: &net.Dialer{ 21 | Timeout: time.Second, 22 | KeepAlive: time.Second * 15, 23 | }, 24 | NumRetries: 10, 25 | } 26 | } 27 | 28 | // Dial calls the underlying *net.Dial to dial a net.Conn, but retries on failure 29 | func (r *Retry) Dial(network, address string) (c net.Conn, err error) { 30 | for i := 0; i < r.NumRetries; i++ { 31 | c, err = r.Dialer.Dial(network, address) 32 | if err == nil { 33 | return 34 | } 35 | } 36 | return 37 | } 38 | 39 | // DialTLS creates a new TLS Dialer using the underlying *net.Dial and uses it to dial a net.Conn, but retries on failure 40 | func (r *Retry) DialTLS(network, address string, config *tls.Config) (c net.Conn, err error) { 41 | d := &tls.Dialer{ 42 | NetDialer: r.Dialer, 43 | Config: config, 44 | } 45 | for i := 0; i < r.NumRetries; i++ { 46 | c, err = d.Dial(network, address) 47 | if err == nil { 48 | return 49 | } 50 | } 51 | return 52 | } 53 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "crypto/tls" 7 | "time" 8 | 9 | "github.com/loopholelabs/logging/loggers/noop" 10 | "github.com/loopholelabs/logging/types" 11 | ) 12 | 13 | // Option is used to generate frisbee client and server options internally 14 | type Option func(opts *Options) 15 | 16 | // Options is used to provide the frisbee client and server with configuration options. 17 | // 18 | // Default Values: 19 | // 20 | // options := Options { 21 | // KeepAlive: time.Minute * 3, 22 | // Logger: &DefaultLogger, 23 | // } 24 | type Options struct { 25 | KeepAlive time.Duration 26 | Logger types.Logger 27 | TLSConfig *tls.Config 28 | } 29 | 30 | func loadOptions(options ...Option) *Options { 31 | opts := new(Options) 32 | for _, option := range options { 33 | option(opts) 34 | } 35 | 36 | if opts.Logger == nil { 37 | opts.Logger = noop.New(types.InfoLevel) 38 | } 39 | 40 | if opts.KeepAlive == 0 { 41 | opts.KeepAlive = time.Minute * 3 42 | } 43 | 44 | return opts 45 | } 46 | 47 | // WithOptions allows users to pass in an Options struct to configure a frisbee client or server 48 | func WithOptions(options Options) Option { 49 | return func(opts *Options) { 50 | *opts = options 51 | } 52 | } 53 | 54 | // WithKeepAlive allows users to define TCP keepalive options for the frisbee client or server (use -1 to disable) 55 | func WithKeepAlive(keepAlive time.Duration) Option { 56 | return func(opts *Options) { 57 | opts.KeepAlive = keepAlive 58 | } 59 | } 60 | 61 | // WithLogger sets the logger for the frisbee client or server 62 | func WithLogger(logger types.Logger) Option { 63 | return func(opts *Options) { 64 | opts.Logger = logger 65 | } 66 | } 67 | 68 | // WithTLS sets the TLS configuration for Frisbee. By default, no TLS configuration is used, and 69 | // Frisbee will use unencrypted TCP connections. If the Frisbee Server is using TLS, then you must pass in 70 | // a TLS config (even an empty one `&tls.Config{}`) for the Frisbee Client. 71 | func WithTLS(tlsConfig *tls.Config) Option { 72 | return func(opts *Options) { 73 | opts.TLSConfig = tlsConfig 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "crypto/tls" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | 12 | "github.com/loopholelabs/logging" 13 | ) 14 | 15 | func TestWithoutOptions(t *testing.T) { 16 | t.Parallel() 17 | 18 | options := loadOptions() 19 | 20 | assert.Equal(t, time.Minute*3, options.KeepAlive) 21 | assert.NotNil(t, options.Logger) 22 | assert.Nil(t, options.TLSConfig) 23 | } 24 | 25 | func TestWithOptions(t *testing.T) { 26 | t.Parallel() 27 | 28 | option := WithOptions(Options{ 29 | KeepAlive: time.Minute * 6, 30 | Logger: nil, 31 | TLSConfig: &tls.Config{}, 32 | }) 33 | 34 | options := loadOptions(option) 35 | 36 | assert.Equal(t, time.Minute*6, options.KeepAlive) 37 | assert.NotNil(t, options.Logger) 38 | assert.Equal(t, &tls.Config{}, options.TLSConfig) 39 | } 40 | 41 | func TestDisableOptions(t *testing.T) { 42 | t.Parallel() 43 | 44 | option := WithOptions(Options{ 45 | KeepAlive: -1, 46 | }) 47 | 48 | options := loadOptions(option) 49 | 50 | assert.Equal(t, time.Duration(-1), options.KeepAlive) 51 | assert.NotNil(t, options.Logger) 52 | assert.Nil(t, options.TLSConfig) 53 | } 54 | 55 | func TestIndividualOptions(t *testing.T) { 56 | t.Parallel() 57 | 58 | logger := logging.Test(t, logging.Noop, t.Name()) 59 | tlsConfig := &tls.Config{ 60 | InsecureSkipVerify: true, 61 | } 62 | 63 | keepAliveOption := WithKeepAlive(time.Minute * 6) 64 | loggerOption := WithLogger(logger) 65 | TLSOption := WithTLS(tlsConfig) 66 | 67 | options := loadOptions(keepAliveOption, loggerOption, TLSOption) 68 | 69 | assert.Equal(t, time.Minute*6, options.KeepAlive) 70 | assert.Equal(t, logger, options.Logger) 71 | assert.Equal(t, tlsConfig, options.TLSConfig) 72 | } 73 | -------------------------------------------------------------------------------- /pkg/metadata/metadata.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package metadata 4 | 5 | import ( 6 | "encoding/binary" 7 | "errors" 8 | "unsafe" 9 | ) 10 | 11 | var ( 12 | EncodingErr = errors.New("error while encoding metadata") 13 | DecodingErr = errors.New("error while decoding metadata") 14 | InvalidBufferLengthErr = errors.New("invalid buffer length") 15 | ) 16 | 17 | var ( 18 | PacketMagicHeader = uint16(0x0F) 19 | ) 20 | 21 | const ( 22 | PacketPing = uint16(10) // PING 23 | PacketPong = uint16(11) // PONG 24 | PacketProbe = uint16(12) // PACKET 25 | ) 26 | 27 | const ( 28 | MagicOffset = 0 // 0 29 | MagicSize = 2 30 | 31 | IdOffset = MagicOffset + MagicSize // 2 32 | IdSize = 2 33 | 34 | OperationOffset = IdOffset + IdSize // 4 35 | OperationSize = 2 36 | 37 | ContentLengthOffset = OperationOffset + OperationSize // 6 38 | ContentLengthSize = 4 39 | 40 | Size = ContentLengthOffset + ContentLengthSize // 10 41 | ) 42 | 43 | // Metadata is 8 bytes in length 44 | type Metadata struct { 45 | Magic uint16 // 2 bytes 46 | Id uint16 // 2 Bytes 47 | Operation uint16 // 2 Bytes 48 | ContentLength uint32 // 4 Bytes 49 | } 50 | 51 | // Encode Metadata 52 | func (fm *Metadata) Encode() (b *Buffer, err error) { 53 | defer func() { 54 | if recoveredErr := recover(); recoveredErr != nil { 55 | err = errors.Join(recoveredErr.(error), EncodingErr) 56 | } 57 | }() 58 | 59 | b = NewBuffer() 60 | binary.BigEndian.PutUint16(b[MagicOffset:MagicOffset+MagicSize], fm.Magic) 61 | binary.BigEndian.PutUint16(b[IdOffset:IdOffset+IdSize], fm.Id) 62 | binary.BigEndian.PutUint16(b[OperationOffset:OperationOffset+OperationSize], fm.Operation) 63 | binary.BigEndian.PutUint32(b[ContentLengthOffset:ContentLengthOffset+ContentLengthSize], fm.ContentLength) 64 | 65 | return 66 | } 67 | 68 | // Decode Metadata 69 | func (fm *Metadata) Decode(buf *Buffer) (err error) { 70 | defer func() { 71 | if recoveredErr := recover(); recoveredErr != nil { 72 | err = errors.Join(recoveredErr.(error), DecodingErr) 73 | } 74 | }() 75 | 76 | fm.Magic = binary.BigEndian.Uint16(buf[MagicOffset : MagicOffset+MagicSize]) 77 | fm.Id = binary.BigEndian.Uint16(buf[IdOffset : IdOffset+IdSize]) 78 | fm.Operation = binary.BigEndian.Uint16(buf[OperationOffset : OperationOffset+OperationSize]) 79 | fm.ContentLength = binary.BigEndian.Uint32(buf[ContentLengthOffset : ContentLengthOffset+ContentLengthSize]) 80 | 81 | return nil 82 | } 83 | 84 | func Encode(id, operation uint16, contentLength uint32) (*Buffer, error) { 85 | metadata := Metadata{ 86 | Magic: PacketMagicHeader, 87 | Id: id, 88 | Operation: operation, 89 | ContentLength: contentLength, 90 | } 91 | 92 | return metadata.Encode() 93 | } 94 | 95 | func Decode(buf []byte) (*Metadata, error) { 96 | if len(buf) < Size { 97 | return nil, InvalidBufferLengthErr 98 | } 99 | 100 | m := new(Metadata) 101 | return m, m.Decode((*Buffer)(unsafe.Pointer(&buf[0]))) 102 | } 103 | -------------------------------------------------------------------------------- /pkg/metadata/metadata_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package metadata 4 | 5 | import ( 6 | "encoding/binary" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestMessageEncodeDecode(t *testing.T) { 14 | t.Parallel() 15 | 16 | message := &Metadata{ 17 | Magic: PacketMagicHeader, 18 | Id: uint16(64), 19 | Operation: PacketProbe, 20 | ContentLength: uint32(0), 21 | } 22 | 23 | correct := NewBuffer() 24 | 25 | binary.BigEndian.PutUint16(correct[MagicOffset:MagicOffset+MagicSize], PacketMagicHeader) 26 | binary.BigEndian.PutUint16(correct[IdOffset:IdOffset+IdSize], uint16(64)) 27 | binary.BigEndian.PutUint16(correct[OperationOffset:OperationOffset+OperationSize], PacketProbe) 28 | binary.BigEndian.PutUint32(correct[ContentLengthOffset:ContentLengthOffset+ContentLengthSize], uint32(0)) 29 | 30 | encoded, err := message.Encode() 31 | require.NoError(t, err) 32 | assert.Equal(t, correct, encoded) 33 | 34 | decoderMessage := &Metadata{} 35 | 36 | err = decoderMessage.Decode(encoded) 37 | require.NoError(t, err) 38 | assert.Equal(t, message, decoderMessage) 39 | } 40 | 41 | func TestEncodeDecode(t *testing.T) { 42 | t.Parallel() 43 | 44 | encodedBytes, err := Encode(64, PacketPong, 512) 45 | assert.Equal(t, nil, err) 46 | 47 | message, err := Decode(encodedBytes[:]) 48 | require.NoError(t, err) 49 | assert.Equal(t, PacketMagicHeader, message.Magic) 50 | assert.Equal(t, uint32(512), message.ContentLength) 51 | assert.Equal(t, uint16(64), message.Id) 52 | assert.Equal(t, PacketPong, message.Operation) 53 | 54 | emptyEncodedBytes, err := Encode(64, PacketPing, 0) 55 | assert.Equal(t, nil, err) 56 | 57 | emptyMessage, err := Decode(emptyEncodedBytes[:]) 58 | require.NoError(t, err) 59 | assert.Equal(t, PacketMagicHeader, message.Magic) 60 | assert.Equal(t, uint32(0), emptyMessage.ContentLength) 61 | assert.Equal(t, uint16(64), emptyMessage.Id) 62 | assert.Equal(t, PacketPing, emptyMessage.Operation) 63 | 64 | invalidMessage, err := Decode(emptyEncodedBytes[1:]) 65 | require.Error(t, err) 66 | assert.ErrorIs(t, InvalidBufferLengthErr, err) 67 | assert.Nil(t, invalidMessage) 68 | } 69 | 70 | func BenchmarkEncode(b *testing.B) { 71 | for i := 0; i < b.N; i++ { 72 | _, _ = Encode(uint16(i), PacketProbe, 512) 73 | } 74 | } 75 | 76 | func BenchmarkDecode(b *testing.B) { 77 | encodedMessage, _ := Encode(0, PacketProbe, 512) 78 | 79 | b.ResetTimer() 80 | for i := 0; i < b.N; i++ { 81 | _, _ = Decode(encodedMessage[:]) 82 | } 83 | } 84 | 85 | func BenchmarkEncodeDecode(b *testing.B) { 86 | b.ResetTimer() 87 | for i := 0; i < b.N; i++ { 88 | encodedMessage, _ := Encode(uint16(i), PacketProbe, 512) 89 | _, _ = Decode(encodedMessage[:]) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /pkg/metadata/pool.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package metadata 4 | 5 | import ( 6 | "github.com/loopholelabs/common/pkg/pool" 7 | ) 8 | 9 | type Buffer [Size]byte 10 | 11 | func NewBuffer() *Buffer { 12 | return new(Buffer) 13 | } 14 | 15 | func (b *Buffer) Reset() {} 16 | 17 | var ( 18 | bufferPool = pool.NewPool[Buffer, *Buffer](NewBuffer) 19 | ) 20 | 21 | func GetBuffer() *Buffer { 22 | return bufferPool.Get() 23 | } 24 | 25 | func PutBuffer(b *Buffer) { 26 | bufferPool.Put(b) 27 | } 28 | -------------------------------------------------------------------------------- /pkg/packet/packet.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package packet 4 | 5 | import ( 6 | "github.com/loopholelabs/polyglot/v2" 7 | 8 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 9 | ) 10 | 11 | // Packet is the structured frisbee data packet, and contains the following: 12 | // 13 | // type Packet struct { 14 | // Metadata struct { 15 | // Magic uint16 // 2 Bytes 16 | // Id uint16 // 2 Bytes 17 | // Operation uint16 // 2 Bytes 18 | // ContentLength uint32 // 4 Bytes 19 | // } 20 | // Content *content.Content 21 | // } 22 | // 23 | // The ID field can be used however the user sees fit, however ContentLength must match the length of the content being 24 | // delivered with the frisbee packet (see the Async.WritePacket function for more details), and the Operation field must be greater than uint16(9). 25 | type Packet struct { 26 | Metadata *metadata.Metadata 27 | Content *polyglot.Buffer 28 | } 29 | 30 | func (p *Packet) Reset() { 31 | p.Metadata.Magic = 0 32 | p.Metadata.Id = 0 33 | p.Metadata.Operation = 0 34 | p.Metadata.ContentLength = 0 35 | p.Content.Reset() 36 | } 37 | 38 | func New() *Packet { 39 | return &Packet{ 40 | Metadata: new(metadata.Metadata), 41 | Content: polyglot.NewBuffer(), 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /pkg/packet/packet_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package packet 4 | 5 | import ( 6 | "crypto/rand" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestNew(t *testing.T) { 13 | t.Parallel() 14 | 15 | p := Get() 16 | 17 | assert.IsType(t, new(Packet), p) 18 | assert.NotNil(t, p.Metadata) 19 | assert.Equal(t, uint16(0), p.Metadata.Magic) 20 | assert.Equal(t, uint16(0), p.Metadata.Id) 21 | assert.Equal(t, uint16(0), p.Metadata.Operation) 22 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 23 | assert.Equal(t, Get().Content, p.Content) 24 | 25 | Put(p) 26 | } 27 | 28 | func TestWrite(t *testing.T) { 29 | t.Parallel() 30 | 31 | p := Get() 32 | 33 | b := make([]byte, 32) 34 | _, err := rand.Read(b) 35 | assert.NoError(t, err) 36 | 37 | p.Content.Write(b) 38 | assert.Equal(t, b, p.Content.Bytes()) 39 | 40 | p.Reset() 41 | assert.NotEqual(t, b, p.Content.Bytes()) 42 | assert.Equal(t, 0, p.Content.Len()) 43 | assert.Equal(t, 512, p.Content.Cap()) 44 | 45 | b = make([]byte, 1024) 46 | _, err = rand.Read(b) 47 | assert.NoError(t, err) 48 | 49 | p.Content.Write(b) 50 | 51 | assert.Equal(t, b, p.Content.Bytes()) 52 | assert.Equal(t, 1024, p.Content.Len()) 53 | assert.GreaterOrEqual(t, p.Content.Cap(), 1024) 54 | 55 | } 56 | -------------------------------------------------------------------------------- /pkg/packet/pool.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package packet 4 | 5 | import ( 6 | "github.com/loopholelabs/common/pkg/pool" 7 | ) 8 | 9 | var ( 10 | packetPool = NewPool() 11 | ) 12 | 13 | func NewPool() *pool.Pool[Packet, *Packet] { 14 | return pool.NewPool(New) 15 | } 16 | 17 | func Get() (s *Packet) { 18 | return packetPool.Get() 19 | } 20 | 21 | func Put(p *Packet) { 22 | packetPool.Put(p) 23 | } 24 | -------------------------------------------------------------------------------- /pkg/packet/pool_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package packet 4 | 5 | import ( 6 | "crypto/rand" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestRecycle(t *testing.T) { 13 | pool := NewPool() 14 | 15 | p := pool.Get() 16 | 17 | p.Metadata.Id = 32 18 | p.Metadata.Operation = 64 19 | p.Metadata.ContentLength = 128 20 | 21 | pool.Put(p) 22 | p = pool.Get() 23 | 24 | testData := make([]byte, p.Content.Cap()*2) 25 | _, err := rand.Read(testData) 26 | assert.NoError(t, err) 27 | for { 28 | assert.NotNil(t, p.Metadata) 29 | assert.Equal(t, uint16(0), p.Metadata.Id) 30 | assert.Equal(t, uint16(0), p.Metadata.Operation) 31 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 32 | assert.Equal(t, *pool.Get().Content, *p.Content) 33 | 34 | p.Content.Write(testData) 35 | assert.Equal(t, len(testData), p.Content.Len()) 36 | assert.GreaterOrEqual(t, p.Content.Cap(), len(testData)) 37 | 38 | pool.Put(p) 39 | p = pool.Get() 40 | 41 | assert.NotNil(t, p.Metadata) 42 | assert.Equal(t, uint16(0), p.Metadata.Id) 43 | assert.Equal(t, uint16(0), p.Metadata.Operation) 44 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 45 | 46 | if p.Content.Cap() < len(testData) { 47 | continue 48 | } 49 | assert.Equal(t, 0, p.Content.Len()) 50 | assert.GreaterOrEqual(t, p.Content.Cap(), len(testData)) 51 | break 52 | } 53 | 54 | pool.Put(p) 55 | } 56 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "context" 7 | "crypto/tls" 8 | "errors" 9 | "net" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | "github.com/loopholelabs/frisbee-go/pkg/packet" 15 | "github.com/loopholelabs/logging/types" 16 | ) 17 | 18 | var ( 19 | OnClosedNil = errors.New("OnClosed function cannot be nil") 20 | PreWriteNil = errors.New("PreWrite function cannot be nil") 21 | ListenerNil = errors.New("listener cannot be nil") 22 | ) 23 | 24 | var ( 25 | defaultOnClosed = func(_ *Async, _ error) {} 26 | 27 | defaultPreWrite = func() {} 28 | 29 | defaultStreamHandler = func(stream *Stream) { 30 | _ = stream.Close() 31 | } 32 | ) 33 | 34 | // Server accepts connections from frisbee Clients and can send and receive frisbee Packets 35 | type Server struct { 36 | listener net.Listener 37 | handlerTable HandlerTable 38 | shutdown atomic.Bool 39 | options *Options 40 | wg sync.WaitGroup 41 | connections map[*Async]struct{} 42 | connectionsMu sync.Mutex 43 | startedCh chan struct{} 44 | concurrency uint64 45 | limiter chan struct{} 46 | 47 | baseContext context.Context 48 | baseContextCancel context.CancelFunc 49 | 50 | // onClosed is a function run by the server whenever a connection is closed 51 | onClosed func(*Async, error) 52 | 53 | // preWrite is run by the server before a write happens 54 | preWrite func() 55 | 56 | // streamHandler is used to handle incoming client-initiated streams on the server 57 | streamHandler func(*Stream) 58 | 59 | // ConnContext is used to define a connection-specific context based on the incoming connection 60 | // and is run whenever a new connection is opened 61 | ConnContext func(context.Context, *Async) context.Context 62 | 63 | // StreamContext is used to define a stream-specific context based on the incoming stream 64 | // and is run whenever a new stream is opened 65 | StreamContext func(context.Context, *Stream) context.Context 66 | 67 | // PacketContext is used to define a handler-specific contexts based on the incoming packet 68 | // and is run whenever a new packet arrives 69 | PacketContext func(context.Context, *packet.Packet) context.Context 70 | 71 | // UpdateContext is used to update a handler-specific context whenever the returned 72 | // Action from a handler is UPDATE 73 | UpdateContext func(context.Context, *Async) context.Context 74 | } 75 | 76 | // NewServer returns an uninitialized frisbee Server with the registered HandlerTable. 77 | // The Start method must then be called to start the server and listen for connections. 78 | func NewServer(handlerTable HandlerTable, ctx context.Context, opts ...Option) (*Server, error) { 79 | options := loadOptions(opts...) 80 | 81 | baseContext, baseContextCancel := context.WithCancel(ctx) 82 | 83 | s := &Server{ 84 | options: options, 85 | connections: make(map[*Async]struct{}), 86 | startedCh: make(chan struct{}), 87 | baseContext: baseContext, 88 | baseContextCancel: baseContextCancel, 89 | onClosed: defaultOnClosed, 90 | preWrite: defaultPreWrite, 91 | streamHandler: defaultStreamHandler, 92 | } 93 | 94 | return s, s.SetHandlerTable(handlerTable) 95 | } 96 | 97 | // SetOnClosed sets the onClosed function for the server. If f is nil, it returns an error. 98 | func (s *Server) SetOnClosed(f func(*Async, error)) error { 99 | if f == nil { 100 | return OnClosedNil 101 | } 102 | s.onClosed = f 103 | return nil 104 | } 105 | 106 | // SetPreWrite sets the preWrite function for the server. If f is nil, it returns an error. 107 | func (s *Server) SetPreWrite(f func()) error { 108 | if f == nil { 109 | return PreWriteNil 110 | } 111 | s.preWrite = f 112 | return nil 113 | } 114 | 115 | // SetStreamHandler sets the streamHandler function for the server. 116 | func (s *Server) SetStreamHandler(f func(context.Context, *Stream)) error { 117 | s.streamHandler = func(stream *Stream) { 118 | streamCtx := s.baseContext 119 | if s.StreamContext != nil { 120 | streamCtx = s.StreamContext(streamCtx, stream) 121 | } 122 | f(streamCtx, stream) 123 | } 124 | return nil 125 | } 126 | 127 | // SetHandlerTable sets the handler table for the server. 128 | // 129 | // This function should not be called once the server has started. 130 | func (s *Server) SetHandlerTable(handlerTable HandlerTable) error { 131 | for i := uint16(0); i < RESERVED9; i++ { 132 | if _, ok := handlerTable[i]; ok { 133 | return InvalidHandlerTable 134 | } 135 | } 136 | 137 | s.handlerTable = handlerTable 138 | return nil 139 | } 140 | 141 | // GetHandlerTable gets the handler table for the server. 142 | // 143 | // This function should not be called once the server has started. 144 | func (s *Server) GetHandlerTable() HandlerTable { 145 | return s.handlerTable 146 | } 147 | 148 | // SetConcurrency sets the maximum number of concurrent goroutines that will be created 149 | // by the server to handle incoming packets. 150 | // 151 | // An important caveat of this is that handlers must always thread-safe if they share resources 152 | // between connections. If the concurrency is set to a value != 1, then the handlers 153 | // must also be thread-safe if they share resources per connection. 154 | // 155 | // This function should not be called once the server has started. 156 | func (s *Server) SetConcurrency(concurrency uint64) { 157 | s.concurrency = concurrency 158 | if s.concurrency > 1 { 159 | s.limiter = make(chan struct{}, s.concurrency) 160 | } 161 | } 162 | 163 | // Start will start the frisbee server and its reactor goroutines 164 | // to receive and handle incoming connections. If the baseContext, ConnContext, 165 | // onClosed, OnShutdown, or preWrite functions have not been defined, it will 166 | // use the default functions for these. 167 | func (s *Server) Start(addr string) error { 168 | var listener net.Listener 169 | var err error 170 | if s.options.TLSConfig != nil { 171 | listener, err = tls.Listen("tcp", addr, s.options.TLSConfig) 172 | } else { 173 | listener, err = net.Listen("tcp", addr) 174 | } 175 | if err != nil { 176 | return err 177 | } 178 | return s.StartWithListener(listener) 179 | } 180 | 181 | // StartWithListener will start the frisbee server and its reactor goroutines 182 | // to receive and handle incoming connections with a given net.Listener. If the baseContext, ConnContext, 183 | // onClosed, OnShutdown, or preWrite functions have not been defined, it will 184 | // use the default functions for these. 185 | func (s *Server) StartWithListener(listener net.Listener) error { 186 | if listener == nil { 187 | return ListenerNil 188 | } 189 | s.listener = listener 190 | s.wg.Add(1) 191 | close(s.startedCh) 192 | return s.handleListener() 193 | } 194 | 195 | // started returns a channel that will be closed when the server has successfully started 196 | // 197 | // This is meant to only be used for testing purposes. 198 | func (s *Server) started() <-chan struct{} { 199 | return s.startedCh 200 | } 201 | 202 | func (s *Server) handleListener() error { 203 | var backoff time.Duration 204 | for { 205 | newConn, err := s.listener.Accept() 206 | if err != nil { 207 | if s.shutdown.Load() { 208 | s.wg.Done() 209 | return nil 210 | } 211 | if ne, ok := err.(temporary); ok && ne.Temporary() { 212 | if backoff == 0 { 213 | backoff = minBackoff 214 | } else { 215 | backoff *= 2 216 | } 217 | if backoff > maxBackoff { 218 | backoff = maxBackoff 219 | } 220 | s.Logger().Warn().Err(err).Msgf("Temporary Accept Error, retrying in %s", backoff) 221 | time.Sleep(backoff) 222 | if s.shutdown.Load() { 223 | s.wg.Done() 224 | return nil 225 | } 226 | continue 227 | } 228 | s.wg.Done() 229 | return err 230 | } 231 | backoff = 0 232 | 233 | s.ServeConn(newConn) 234 | } 235 | } 236 | 237 | func (s *Server) createHandler(conn *Async, closed *atomic.Bool, wg *sync.WaitGroup, ctx context.Context, cancel context.CancelFunc) func(*packet.Packet) { 238 | return func(p *packet.Packet) { 239 | handlerFunc := s.handlerTable[p.Metadata.Operation] 240 | if handlerFunc != nil { 241 | packetCtx := ctx 242 | if s.PacketContext != nil { 243 | packetCtx = s.PacketContext(packetCtx, p) 244 | } 245 | outgoing, action := handlerFunc(packetCtx, p) 246 | if outgoing != nil && outgoing.Metadata.ContentLength == uint32(outgoing.Content.Len()) { 247 | s.preWrite() 248 | err := conn.WritePacket(outgoing) 249 | if outgoing != p { 250 | packet.Put(outgoing) 251 | } 252 | packet.Put(p) 253 | if err != nil { 254 | _ = conn.Close() 255 | if closed.CompareAndSwap(false, true) { 256 | s.onClosed(conn, err) 257 | } 258 | cancel() 259 | wg.Done() 260 | return 261 | } 262 | } else { 263 | packet.Put(p) 264 | } 265 | switch action { 266 | case NONE: 267 | case CLOSE: 268 | _ = conn.Close() 269 | if closed.CompareAndSwap(false, true) { 270 | s.onClosed(conn, nil) 271 | } 272 | cancel() 273 | } 274 | } else { 275 | packet.Put(p) 276 | } 277 | wg.Done() 278 | } 279 | } 280 | 281 | func (s *Server) handleSinglePacket(frisbeeConn *Async, connCtx context.Context) { 282 | var p *packet.Packet 283 | var outgoing *packet.Packet 284 | var action Action 285 | var handlerFunc Handler 286 | var err error 287 | p, err = frisbeeConn.ReadPacket() 288 | if err != nil { 289 | _ = frisbeeConn.Close() 290 | s.onClosed(frisbeeConn, err) 291 | return 292 | } 293 | for { 294 | handlerFunc = s.handlerTable[p.Metadata.Operation] 295 | if handlerFunc != nil { 296 | packetCtx := connCtx 297 | if s.PacketContext != nil { 298 | packetCtx = s.PacketContext(packetCtx, p) 299 | } 300 | outgoing, action = handlerFunc(packetCtx, p) 301 | if outgoing != nil && outgoing.Metadata.ContentLength == uint32(outgoing.Content.Len()) { 302 | s.preWrite() 303 | err = frisbeeConn.WritePacket(outgoing) 304 | if outgoing != p { 305 | packet.Put(outgoing) 306 | } 307 | packet.Put(p) 308 | if err != nil { 309 | _ = frisbeeConn.Close() 310 | s.onClosed(frisbeeConn, err) 311 | return 312 | } 313 | } else { 314 | packet.Put(p) 315 | } 316 | switch action { 317 | case NONE: 318 | case CLOSE: 319 | _ = frisbeeConn.Close() 320 | s.onClosed(frisbeeConn, nil) 321 | return 322 | } 323 | } else { 324 | packet.Put(p) 325 | } 326 | p, err = frisbeeConn.ReadPacket() 327 | if err != nil { 328 | _ = frisbeeConn.Close() 329 | s.onClosed(frisbeeConn, err) 330 | return 331 | } 332 | } 333 | } 334 | 335 | func (s *Server) handleUnlimitedPacket(frisbeeConn *Async, connCtx context.Context) { 336 | p, err := frisbeeConn.ReadPacket() 337 | if err != nil { 338 | _ = frisbeeConn.Close() 339 | s.onClosed(frisbeeConn, err) 340 | return 341 | } 342 | wg := new(sync.WaitGroup) 343 | var closed atomic.Bool 344 | connCtx, cancel := context.WithCancel(connCtx) 345 | handle := s.createHandler(frisbeeConn, &closed, wg, connCtx, cancel) 346 | for { 347 | wg.Add(1) 348 | go handle(p) 349 | p, err = frisbeeConn.ReadPacket() 350 | if err != nil { 351 | _ = frisbeeConn.Close() 352 | if closed.CompareAndSwap(false, true) { 353 | s.onClosed(frisbeeConn, err) 354 | } 355 | cancel() 356 | wg.Wait() 357 | return 358 | } 359 | } 360 | } 361 | 362 | func (s *Server) handleLimitedPacket(frisbeeConn *Async, connCtx context.Context) { 363 | p, err := frisbeeConn.ReadPacket() 364 | if err != nil { 365 | _ = frisbeeConn.Close() 366 | s.onClosed(frisbeeConn, err) 367 | return 368 | } 369 | wg := new(sync.WaitGroup) 370 | var closed atomic.Bool 371 | connCtx, cancel := context.WithCancel(connCtx) 372 | handler := s.createHandler(frisbeeConn, &closed, wg, connCtx, cancel) 373 | handle := func(p *packet.Packet) { 374 | handler(p) 375 | <-s.limiter 376 | } 377 | for { 378 | select { 379 | case s.limiter <- struct{}{}: 380 | wg.Add(1) 381 | go handle(p) 382 | p, err = frisbeeConn.ReadPacket() 383 | if err != nil { 384 | _ = frisbeeConn.Close() 385 | if closed.CompareAndSwap(false, true) { 386 | s.onClosed(frisbeeConn, err) 387 | } 388 | cancel() 389 | wg.Wait() 390 | return 391 | } 392 | case <-connCtx.Done(): 393 | _ = frisbeeConn.Close() 394 | if closed.CompareAndSwap(false, true) { 395 | s.onClosed(frisbeeConn, err) 396 | } 397 | wg.Wait() 398 | return 399 | } 400 | } 401 | } 402 | 403 | // ServeConn takes a net.Conn and starts a goroutine to handle it using the Server. 404 | func (s *Server) ServeConn(conn net.Conn) { 405 | s.wg.Add(1) 406 | go s.serveConn(conn) 407 | } 408 | 409 | // serveConn takes a net.Conn and serves it using the Server 410 | // and assumes that the server's wait group has been incremented by 1. 411 | func (s *Server) serveConn(newConn net.Conn) { 412 | var err error 413 | switch v := newConn.(type) { 414 | case *net.TCPConn: 415 | err = v.SetKeepAlive(true) 416 | if err != nil { 417 | s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive") 418 | _ = v.Close() 419 | s.wg.Done() 420 | return 421 | } 422 | err = v.SetKeepAlivePeriod(s.options.KeepAlive) 423 | if err != nil { 424 | s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive Period") 425 | _ = v.Close() 426 | s.wg.Done() 427 | return 428 | } 429 | } 430 | 431 | frisbeeConn := NewAsync(newConn, s.Logger(), s.streamHandler) 432 | connCtx := s.baseContext 433 | s.connectionsMu.Lock() 434 | if s.shutdown.Load() { 435 | s.wg.Done() 436 | return 437 | } 438 | s.connections[frisbeeConn] = struct{}{} 439 | s.connectionsMu.Unlock() 440 | if s.ConnContext != nil { 441 | connCtx = s.ConnContext(connCtx, frisbeeConn) 442 | } 443 | switch s.concurrency { 444 | case 0: 445 | s.handleUnlimitedPacket(frisbeeConn, connCtx) 446 | case 1: 447 | s.handleSinglePacket(frisbeeConn, connCtx) 448 | default: 449 | s.handleLimitedPacket(frisbeeConn, connCtx) 450 | } 451 | s.connectionsMu.Lock() 452 | if !s.shutdown.Load() { 453 | delete(s.connections, frisbeeConn) 454 | } 455 | s.connectionsMu.Unlock() 456 | s.wg.Done() 457 | } 458 | 459 | // Logger returns the server's logger (useful for ServerRouter functions) 460 | func (s *Server) Logger() types.Logger { 461 | return s.options.Logger 462 | } 463 | 464 | // Shutdown shuts down the frisbee server and kills all the goroutines and active connections 465 | func (s *Server) Shutdown() error { 466 | if s.shutdown.CompareAndSwap(false, true) { 467 | s.baseContextCancel() 468 | s.connectionsMu.Lock() 469 | for c := range s.connections { 470 | _ = c.Close() 471 | delete(s.connections, c) 472 | } 473 | s.connectionsMu.Unlock() 474 | defer s.wg.Wait() 475 | if s.listener != nil { 476 | return s.listener.Close() 477 | } 478 | } 479 | return nil 480 | } 481 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "context" 7 | "crypto/rand" 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/http" 12 | "sync" 13 | "sync/atomic" 14 | "testing" 15 | "time" 16 | 17 | "github.com/stretchr/testify/assert" 18 | "github.com/stretchr/testify/require" 19 | 20 | "github.com/loopholelabs/logging" 21 | "github.com/loopholelabs/polyglot/v2" 22 | "github.com/loopholelabs/testing/conn" 23 | "github.com/loopholelabs/testing/conn/pair" 24 | 25 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 26 | "github.com/loopholelabs/frisbee-go/pkg/packet" 27 | ) 28 | 29 | const ( 30 | serverConnContextKey = "conn" 31 | ) 32 | 33 | func TestServerRawSingle(t *testing.T) { 34 | t.Parallel() 35 | 36 | const testSize = 100 37 | const packetSize = 512 38 | clientHandlerTable := make(HandlerTable) 39 | serverHandlerTable := make(HandlerTable) 40 | 41 | serverIsRaw := make(chan struct{}, 1) 42 | 43 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 44 | return 45 | } 46 | 47 | var rawServerConn, rawClientConn net.Conn 48 | serverHandlerTable[metadata.PacketProbe] = func(ctx context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 49 | c := ctx.Value(serverConnContextKey).(*Async) 50 | rawServerConn = c.Raw() 51 | serverIsRaw <- struct{}{} 52 | return 53 | } 54 | 55 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 56 | return 57 | } 58 | 59 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 60 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 61 | require.NoError(t, err) 62 | 63 | s.SetConcurrency(1) 64 | 65 | s.ConnContext = func(ctx context.Context, c *Async) context.Context { 66 | return context.WithValue(ctx, serverConnContextKey, c) 67 | } 68 | 69 | serverConn, clientConn, err := pair.New() 70 | require.NoError(t, err) 71 | 72 | go s.ServeConn(serverConn) 73 | 74 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 75 | assert.NoError(t, err) 76 | 77 | _, err = c.Raw() 78 | assert.ErrorIs(t, ConnectionNotInitialized, err) 79 | 80 | err = c.FromConn(clientConn) 81 | assert.NoError(t, err) 82 | 83 | data := make([]byte, packetSize) 84 | _, _ = rand.Read(data) 85 | p := packet.Get() 86 | p.Content.Write(data) 87 | p.Metadata.ContentLength = packetSize 88 | p.Metadata.Operation = metadata.PacketPing 89 | expected := polyglot.NewBufferFromBytes(data) 90 | expected.MoveOffset(len(data)) 91 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 92 | 93 | for q := 0; q < testSize; q++ { 94 | p.Metadata.Id = uint16(q) 95 | err = c.WritePacket(p) 96 | assert.NoError(t, err) 97 | } 98 | 99 | p.Reset() 100 | assert.Equal(t, 0, p.Content.Len()) 101 | p.Metadata.Operation = metadata.PacketProbe 102 | 103 | err = c.WritePacket(p) 104 | require.NoError(t, err) 105 | 106 | packet.Put(p) 107 | 108 | rawClientConn, err = c.Raw() 109 | require.NoError(t, err) 110 | 111 | <-serverIsRaw 112 | 113 | serverBytes := []byte("SERVER WRITE") 114 | 115 | write, err := rawServerConn.Write(serverBytes) 116 | assert.NoError(t, err) 117 | assert.Equal(t, cap(serverBytes), write) 118 | 119 | clientBuffer := make([]byte, cap(serverBytes)) 120 | read, err := rawClientConn.Read(clientBuffer[:]) 121 | assert.NoError(t, err) 122 | assert.Equal(t, cap(serverBytes), read) 123 | 124 | assert.Equal(t, serverBytes, clientBuffer[:read]) 125 | 126 | err = c.Close() 127 | assert.NoError(t, err) 128 | err = rawClientConn.Close() 129 | assert.NoError(t, err) 130 | 131 | err = s.Shutdown() 132 | assert.NoError(t, err) 133 | err = rawServerConn.Close() 134 | assert.NoError(t, err) 135 | } 136 | 137 | func TestServerStaleCloseSingle(t *testing.T) { 138 | t.Parallel() 139 | 140 | const testSize = 100 141 | const packetSize = 512 142 | clientHandlerTable := make(HandlerTable) 143 | serverHandlerTable := make(HandlerTable) 144 | 145 | finished := make(chan struct{}, 1) 146 | 147 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 148 | if incoming.Metadata.Id == testSize-1 { 149 | outgoing = incoming 150 | action = CLOSE 151 | } 152 | return 153 | } 154 | 155 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 156 | finished <- struct{}{} 157 | return 158 | } 159 | 160 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 161 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 162 | require.NoError(t, err) 163 | 164 | s.SetConcurrency(1) 165 | 166 | serverConn, clientConn, err := pair.New() 167 | require.NoError(t, err) 168 | 169 | go s.ServeConn(serverConn) 170 | 171 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 172 | assert.NoError(t, err) 173 | _, err = c.Raw() 174 | assert.ErrorIs(t, ConnectionNotInitialized, err) 175 | 176 | err = c.FromConn(clientConn) 177 | require.NoError(t, err) 178 | 179 | data := make([]byte, packetSize) 180 | _, _ = rand.Read(data) 181 | p := packet.Get() 182 | p.Content.Write(data) 183 | p.Metadata.ContentLength = packetSize 184 | p.Metadata.Operation = metadata.PacketPing 185 | expected := polyglot.NewBufferFromBytes(data) 186 | expected.MoveOffset(len(data)) 187 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 188 | 189 | for q := 0; q < testSize; q++ { 190 | p.Metadata.Id = uint16(q) 191 | err = c.WritePacket(p) 192 | assert.NoError(t, err) 193 | } 194 | packet.Put(p) 195 | <-finished 196 | 197 | _, err = c.conn.ReadPacket() 198 | assert.ErrorIs(t, err, ConnectionClosed) 199 | 200 | err = c.Close() 201 | assert.NoError(t, err) 202 | 203 | err = s.Shutdown() 204 | assert.NoError(t, err) 205 | } 206 | 207 | func TestServerMultipleConnectionsSingle(t *testing.T) { 208 | t.Parallel() 209 | 210 | const testSize = 100 211 | const packetSize = 512 212 | 213 | runner := func(t *testing.T, num int) { 214 | finished := make([]chan struct{}, num) 215 | clientTables := make([]HandlerTable, num) 216 | for i := 0; i < num; i++ { 217 | idx := i 218 | finished[idx] = make(chan struct{}, 1) 219 | clientTables[i] = make(HandlerTable) 220 | clientTables[i][metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 221 | finished[idx] <- struct{}{} 222 | return 223 | } 224 | } 225 | serverHandlerTable := make(HandlerTable) 226 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 227 | if incoming.Metadata.Id == testSize-1 { 228 | outgoing = incoming 229 | action = CLOSE 230 | } 231 | return 232 | } 233 | 234 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 235 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 236 | require.NoError(t, err) 237 | 238 | s.SetConcurrency(1) 239 | 240 | var wg sync.WaitGroup 241 | 242 | wg.Add(1) 243 | go func() { 244 | err := s.Start(conn.Listen) 245 | require.NoError(t, err) 246 | wg.Done() 247 | }() 248 | 249 | <-s.started() 250 | listenAddr := s.listener.Addr().String() 251 | 252 | clients := make([]*Client, num) 253 | for i := 0; i < num; i++ { 254 | clients[i], err = NewClient(clientTables[i], context.Background(), WithLogger(emptyLogger)) 255 | assert.NoError(t, err) 256 | _, err = clients[i].Raw() 257 | assert.ErrorIs(t, ConnectionNotInitialized, err) 258 | 259 | err = clients[i].Connect(listenAddr) 260 | require.NoError(t, err) 261 | } 262 | 263 | data := make([]byte, packetSize) 264 | _, err = rand.Read(data) 265 | assert.NoError(t, err) 266 | 267 | var clientWg sync.WaitGroup 268 | for i := 0; i < num; i++ { 269 | idx := i 270 | clientWg.Add(1) 271 | go func() { 272 | p := packet.Get() 273 | p.Content.Write(data) 274 | p.Metadata.ContentLength = packetSize 275 | p.Metadata.Operation = metadata.PacketPing 276 | expected := polyglot.NewBufferFromBytes(data) 277 | expected.MoveOffset(len(data)) 278 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 279 | for q := 0; q < testSize; q++ { 280 | p.Metadata.Id = uint16(q) 281 | err := clients[idx].WritePacket(p) 282 | assert.NoError(t, err) 283 | } 284 | <-finished[idx] 285 | err := clients[idx].Close() 286 | assert.NoError(t, err) 287 | clientWg.Done() 288 | packet.Put(p) 289 | }() 290 | } 291 | 292 | clientWg.Wait() 293 | 294 | err = s.Shutdown() 295 | assert.NoError(t, err) 296 | wg.Wait() 297 | 298 | } 299 | 300 | t.Run("1", func(t *testing.T) { runner(t, 1) }) 301 | t.Run("2", func(t *testing.T) { runner(t, 2) }) 302 | t.Run("3", func(t *testing.T) { runner(t, 3) }) 303 | t.Run("5", func(t *testing.T) { runner(t, 5) }) 304 | t.Run("10", func(t *testing.T) { runner(t, 10) }) 305 | t.Run("100", func(t *testing.T) { runner(t, 100) }) 306 | } 307 | 308 | func TestServerRawUnlimited(t *testing.T) { 309 | t.Parallel() 310 | 311 | const testSize = 100 312 | const packetSize = 512 313 | clientHandlerTable := make(HandlerTable) 314 | serverHandlerTable := make(HandlerTable) 315 | 316 | serverIsRaw := make(chan struct{}, 1) 317 | 318 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 319 | return 320 | } 321 | 322 | var rawServerConn, rawClientConn net.Conn 323 | serverHandlerTable[metadata.PacketProbe] = func(ctx context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 324 | c := ctx.Value(serverConnContextKey).(*Async) 325 | rawServerConn = c.Raw() 326 | serverIsRaw <- struct{}{} 327 | return 328 | } 329 | 330 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 331 | return 332 | } 333 | 334 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 335 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 336 | require.NoError(t, err) 337 | 338 | s.SetConcurrency(0) 339 | 340 | s.ConnContext = func(ctx context.Context, c *Async) context.Context { 341 | return context.WithValue(ctx, serverConnContextKey, c) 342 | } 343 | 344 | serverConn, clientConn, err := pair.New() 345 | require.NoError(t, err) 346 | 347 | go s.ServeConn(serverConn) 348 | 349 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 350 | assert.NoError(t, err) 351 | 352 | _, err = c.Raw() 353 | assert.ErrorIs(t, ConnectionNotInitialized, err) 354 | 355 | err = c.FromConn(clientConn) 356 | assert.NoError(t, err) 357 | 358 | data := make([]byte, packetSize) 359 | _, _ = rand.Read(data) 360 | p := packet.Get() 361 | p.Content.Write(data) 362 | p.Metadata.ContentLength = packetSize 363 | p.Metadata.Operation = metadata.PacketPing 364 | expected := polyglot.NewBufferFromBytes(data) 365 | expected.MoveOffset(len(data)) 366 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 367 | 368 | for q := 0; q < testSize; q++ { 369 | p.Metadata.Id = uint16(q) 370 | err = c.WritePacket(p) 371 | assert.NoError(t, err) 372 | } 373 | 374 | p.Reset() 375 | assert.Equal(t, 0, p.Content.Len()) 376 | p.Metadata.Operation = metadata.PacketProbe 377 | 378 | err = c.WritePacket(p) 379 | require.NoError(t, err) 380 | 381 | packet.Put(p) 382 | 383 | rawClientConn, err = c.Raw() 384 | require.NoError(t, err) 385 | 386 | <-serverIsRaw 387 | 388 | serverBytes := []byte("SERVER WRITE") 389 | 390 | write, err := rawServerConn.Write(serverBytes) 391 | assert.NoError(t, err) 392 | assert.Equal(t, len(serverBytes), write) 393 | 394 | clientBuffer := make([]byte, len(serverBytes)) 395 | read, err := rawClientConn.Read(clientBuffer) 396 | assert.NoError(t, err) 397 | assert.Equal(t, len(serverBytes), read) 398 | 399 | assert.Equal(t, serverBytes, clientBuffer) 400 | 401 | err = c.Close() 402 | assert.NoError(t, err) 403 | err = rawClientConn.Close() 404 | assert.NoError(t, err) 405 | 406 | err = s.Shutdown() 407 | assert.NoError(t, err) 408 | err = rawServerConn.Close() 409 | assert.NoError(t, err) 410 | } 411 | 412 | func TestServerStaleCloseUnlimited(t *testing.T) { 413 | t.Parallel() 414 | 415 | const testSize = 100 416 | const packetSize = 512 417 | clientHandlerTable := make(HandlerTable) 418 | serverHandlerTable := make(HandlerTable) 419 | 420 | finished := make(chan struct{}, 1) 421 | 422 | var count atomic.Int32 423 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 424 | if count.Add(1) == testSize-1 { 425 | outgoing = incoming 426 | action = CLOSE 427 | count.Store(0) 428 | } 429 | return 430 | } 431 | 432 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 433 | finished <- struct{}{} 434 | return 435 | } 436 | 437 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 438 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 439 | require.NoError(t, err) 440 | 441 | s.SetConcurrency(0) 442 | 443 | serverConn, clientConn, err := pair.New() 444 | require.NoError(t, err) 445 | 446 | go s.ServeConn(serverConn) 447 | 448 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 449 | assert.NoError(t, err) 450 | _, err = c.Raw() 451 | assert.ErrorIs(t, ConnectionNotInitialized, err) 452 | 453 | err = c.FromConn(clientConn) 454 | require.NoError(t, err) 455 | 456 | data := make([]byte, packetSize) 457 | _, _ = rand.Read(data) 458 | p := packet.Get() 459 | p.Content.Write(data) 460 | p.Metadata.ContentLength = packetSize 461 | p.Metadata.Operation = metadata.PacketPing 462 | expected := polyglot.NewBufferFromBytes(data) 463 | expected.MoveOffset(len(data)) 464 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 465 | 466 | for q := 0; q < testSize; q++ { 467 | p.Metadata.Id = uint16(q) 468 | err = c.WritePacket(p) 469 | assert.NoError(t, err) 470 | } 471 | packet.Put(p) 472 | <-finished 473 | 474 | _, err = c.conn.ReadPacket() 475 | assert.ErrorIs(t, err, ConnectionClosed) 476 | 477 | err = c.Close() 478 | assert.NoError(t, err) 479 | 480 | err = s.Shutdown() 481 | assert.NoError(t, err) 482 | } 483 | 484 | func TestServerMultipleConnectionsUnlimited(t *testing.T) { 485 | t.Parallel() 486 | 487 | const testSize = 100 488 | const packetSize = 512 489 | 490 | runner := func(t *testing.T, num int) { 491 | finished := make([]chan struct{}, num) 492 | clientTables := make([]HandlerTable, num) 493 | for i := 0; i < num; i++ { 494 | idx := i 495 | finished[idx] = make(chan struct{}, 1) 496 | clientTables[i] = make(HandlerTable) 497 | clientTables[i][metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 498 | finished[idx] <- struct{}{} 499 | return 500 | } 501 | } 502 | clientCounts := make([]atomic.Uint32, num) 503 | 504 | serverHandlerTable := make(HandlerTable) 505 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 506 | if clientCounts[incoming.Metadata.Id].Add(1) == testSize-1 { 507 | outgoing = incoming 508 | action = CLOSE 509 | clientCounts[incoming.Metadata.Id].Store(0) 510 | } 511 | return 512 | } 513 | 514 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 515 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 516 | require.NoError(t, err) 517 | 518 | s.SetConcurrency(0) 519 | 520 | var wg sync.WaitGroup 521 | 522 | wg.Add(1) 523 | go func() { 524 | err := s.Start(conn.Listen) 525 | require.NoError(t, err) 526 | wg.Done() 527 | }() 528 | 529 | <-s.started() 530 | listenAddr := s.listener.Addr().String() 531 | 532 | clients := make([]*Client, num) 533 | for i := 0; i < num; i++ { 534 | clients[i], err = NewClient(clientTables[i], context.Background(), WithLogger(emptyLogger)) 535 | assert.NoError(t, err) 536 | _, err = clients[i].Raw() 537 | assert.ErrorIs(t, ConnectionNotInitialized, err) 538 | 539 | err = clients[i].Connect(listenAddr) 540 | require.NoError(t, err) 541 | } 542 | 543 | data := make([]byte, packetSize) 544 | _, err = rand.Read(data) 545 | assert.NoError(t, err) 546 | 547 | var clientWg sync.WaitGroup 548 | for i := 0; i < num; i++ { 549 | idx := i 550 | clientWg.Add(1) 551 | go func() { 552 | p := packet.Get() 553 | p.Content.Write(data) 554 | p.Metadata.ContentLength = packetSize 555 | p.Metadata.Operation = metadata.PacketPing 556 | p.Metadata.Id = uint16(idx) 557 | expected := polyglot.NewBufferFromBytes(data) 558 | expected.MoveOffset(len(data)) 559 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 560 | for q := 0; q < testSize; q++ { 561 | err := clients[idx].WritePacket(p) 562 | assert.NoError(t, err) 563 | } 564 | <-finished[idx] 565 | err := clients[idx].Close() 566 | assert.NoError(t, err) 567 | clientWg.Done() 568 | packet.Put(p) 569 | }() 570 | } 571 | 572 | clientWg.Wait() 573 | 574 | err = s.Shutdown() 575 | assert.NoError(t, err) 576 | wg.Wait() 577 | 578 | } 579 | 580 | t.Run("1", func(t *testing.T) { runner(t, 1) }) 581 | t.Run("2", func(t *testing.T) { runner(t, 2) }) 582 | t.Run("3", func(t *testing.T) { runner(t, 3) }) 583 | t.Run("5", func(t *testing.T) { runner(t, 5) }) 584 | t.Run("10", func(t *testing.T) { runner(t, 10) }) 585 | t.Run("100", func(t *testing.T) { runner(t, 100) }) 586 | } 587 | 588 | func TestServerRawLimited(t *testing.T) { 589 | t.Parallel() 590 | 591 | const testSize = 100 592 | const packetSize = 512 593 | clientHandlerTable := make(HandlerTable) 594 | serverHandlerTable := make(HandlerTable) 595 | 596 | serverIsRaw := make(chan struct{}, 1) 597 | 598 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 599 | return 600 | } 601 | 602 | var rawServerConn, rawClientConn net.Conn 603 | serverHandlerTable[metadata.PacketProbe] = func(ctx context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 604 | c := ctx.Value(serverConnContextKey).(*Async) 605 | rawServerConn = c.Raw() 606 | serverIsRaw <- struct{}{} 607 | return 608 | } 609 | 610 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 611 | return 612 | } 613 | 614 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 615 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 616 | require.NoError(t, err) 617 | 618 | s.SetConcurrency(10) 619 | 620 | s.ConnContext = func(ctx context.Context, c *Async) context.Context { 621 | return context.WithValue(ctx, serverConnContextKey, c) 622 | } 623 | 624 | serverConn, clientConn, err := pair.New() 625 | require.NoError(t, err) 626 | 627 | go s.ServeConn(serverConn) 628 | 629 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 630 | assert.NoError(t, err) 631 | 632 | _, err = c.Raw() 633 | assert.ErrorIs(t, ConnectionNotInitialized, err) 634 | 635 | err = c.FromConn(clientConn) 636 | assert.NoError(t, err) 637 | 638 | data := make([]byte, packetSize) 639 | _, _ = rand.Read(data) 640 | p := packet.Get() 641 | p.Content.Write(data) 642 | p.Metadata.ContentLength = packetSize 643 | p.Metadata.Operation = metadata.PacketPing 644 | expected := polyglot.NewBufferFromBytes(data) 645 | expected.MoveOffset(len(data)) 646 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 647 | 648 | for q := 0; q < testSize; q++ { 649 | p.Metadata.Id = uint16(q) 650 | err = c.WritePacket(p) 651 | assert.NoError(t, err) 652 | } 653 | 654 | p.Reset() 655 | assert.Equal(t, 0, p.Content.Len()) 656 | p.Metadata.Operation = metadata.PacketProbe 657 | 658 | err = c.WritePacket(p) 659 | require.NoError(t, err) 660 | 661 | packet.Put(p) 662 | 663 | rawClientConn, err = c.Raw() 664 | require.NoError(t, err) 665 | 666 | <-serverIsRaw 667 | 668 | serverBytes := []byte("SERVER WRITE") 669 | 670 | write, err := rawServerConn.Write(serverBytes) 671 | assert.NoError(t, err) 672 | assert.Equal(t, len(serverBytes), write) 673 | 674 | clientBuffer := make([]byte, len(serverBytes)) 675 | read, err := rawClientConn.Read(clientBuffer) 676 | assert.NoError(t, err) 677 | assert.Equal(t, len(serverBytes), read) 678 | 679 | assert.Equal(t, serverBytes, clientBuffer) 680 | 681 | err = c.Close() 682 | assert.NoError(t, err) 683 | err = rawClientConn.Close() 684 | assert.NoError(t, err) 685 | 686 | err = s.Shutdown() 687 | assert.NoError(t, err) 688 | err = rawServerConn.Close() 689 | assert.NoError(t, err) 690 | } 691 | 692 | func TestServerStaleCloseLimited(t *testing.T) { 693 | t.Parallel() 694 | 695 | const testSize = 100 696 | const packetSize = 512 697 | clientHandlerTable := make(HandlerTable) 698 | serverHandlerTable := make(HandlerTable) 699 | 700 | finished := make(chan struct{}, 1) 701 | 702 | var count atomic.Int32 703 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 704 | if count.Add(1) == testSize-1 { 705 | outgoing = incoming 706 | action = CLOSE 707 | count.Store(0) 708 | } 709 | return 710 | } 711 | 712 | clientHandlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 713 | finished <- struct{}{} 714 | return 715 | } 716 | 717 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 718 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 719 | require.NoError(t, err) 720 | 721 | s.SetConcurrency(10) 722 | 723 | serverConn, clientConn, err := pair.New() 724 | require.NoError(t, err) 725 | 726 | go s.ServeConn(serverConn) 727 | 728 | c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(emptyLogger)) 729 | assert.NoError(t, err) 730 | _, err = c.Raw() 731 | assert.ErrorIs(t, ConnectionNotInitialized, err) 732 | 733 | err = c.FromConn(clientConn) 734 | require.NoError(t, err) 735 | 736 | data := make([]byte, packetSize) 737 | _, _ = rand.Read(data) 738 | p := packet.Get() 739 | p.Content.Write(data) 740 | p.Metadata.ContentLength = packetSize 741 | p.Metadata.Operation = metadata.PacketPing 742 | expected := polyglot.NewBufferFromBytes(data) 743 | expected.MoveOffset(len(data)) 744 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 745 | 746 | for q := 0; q < testSize; q++ { 747 | p.Metadata.Id = uint16(q) 748 | err = c.WritePacket(p) 749 | assert.NoError(t, err) 750 | } 751 | packet.Put(p) 752 | <-finished 753 | 754 | _, err = c.conn.ReadPacket() 755 | assert.ErrorIs(t, err, ConnectionClosed) 756 | 757 | err = c.Close() 758 | assert.NoError(t, err) 759 | 760 | err = s.Shutdown() 761 | assert.NoError(t, err) 762 | } 763 | 764 | func TestServerMultipleConnectionsLimited(t *testing.T) { 765 | t.Parallel() 766 | 767 | const testSize = 100 768 | const packetSize = 512 769 | 770 | runner := func(t *testing.T, num int) { 771 | finished := make([]chan struct{}, num) 772 | clientTables := make([]HandlerTable, num) 773 | for i := 0; i < num; i++ { 774 | idx := i 775 | finished[idx] = make(chan struct{}, 1) 776 | clientTables[i] = make(HandlerTable) 777 | clientTables[i][metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 778 | finished[idx] <- struct{}{} 779 | return 780 | } 781 | } 782 | 783 | clientCounts := make([]atomic.Uint32, num) 784 | 785 | serverHandlerTable := make(HandlerTable) 786 | serverHandlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 787 | if clientCounts[incoming.Metadata.Id].Add(1) == testSize-1 { 788 | outgoing = incoming 789 | action = CLOSE 790 | clientCounts[incoming.Metadata.Id].Store(0) 791 | } 792 | return 793 | } 794 | 795 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 796 | s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger)) 797 | require.NoError(t, err) 798 | 799 | s.SetConcurrency(10) 800 | 801 | var wg sync.WaitGroup 802 | 803 | wg.Add(1) 804 | go func() { 805 | err := s.Start(conn.Listen) 806 | require.NoError(t, err) 807 | wg.Done() 808 | }() 809 | 810 | <-s.started() 811 | listenAddr := s.listener.Addr().String() 812 | 813 | clients := make([]*Client, num) 814 | for i := 0; i < num; i++ { 815 | clients[i], err = NewClient(clientTables[i], context.Background(), WithLogger(emptyLogger)) 816 | assert.NoError(t, err) 817 | _, err = clients[i].Raw() 818 | assert.ErrorIs(t, ConnectionNotInitialized, err) 819 | 820 | err = clients[i].Connect(listenAddr) 821 | require.NoError(t, err) 822 | } 823 | 824 | data := make([]byte, packetSize) 825 | _, err = rand.Read(data) 826 | assert.NoError(t, err) 827 | 828 | var clientWg sync.WaitGroup 829 | for i := 0; i < num; i++ { 830 | idx := i 831 | clientWg.Add(1) 832 | go func() { 833 | p := packet.Get() 834 | p.Content.Write(data) 835 | p.Metadata.ContentLength = packetSize 836 | p.Metadata.Operation = metadata.PacketPing 837 | p.Metadata.Id = uint16(idx) 838 | expected := polyglot.NewBufferFromBytes(data) 839 | expected.MoveOffset(len(data)) 840 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 841 | for q := 0; q < testSize; q++ { 842 | err := clients[idx].WritePacket(p) 843 | assert.NoError(t, err) 844 | } 845 | <-finished[idx] 846 | err := clients[idx].Close() 847 | assert.NoError(t, err) 848 | clientWg.Done() 849 | packet.Put(p) 850 | }() 851 | } 852 | 853 | clientWg.Wait() 854 | 855 | err = s.Shutdown() 856 | assert.NoError(t, err) 857 | wg.Wait() 858 | 859 | } 860 | 861 | t.Run("1", func(t *testing.T) { runner(t, 1) }) 862 | t.Run("2", func(t *testing.T) { runner(t, 2) }) 863 | t.Run("3", func(t *testing.T) { runner(t, 3) }) 864 | t.Run("5", func(t *testing.T) { runner(t, 5) }) 865 | t.Run("10", func(t *testing.T) { runner(t, 10) }) 866 | t.Run("100", func(t *testing.T) { runner(t, 100) }) 867 | } 868 | 869 | func TestServerInvalidPacket(t *testing.T) { 870 | t.Parallel() 871 | 872 | // Ensure request is rejected promptly. 873 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 874 | t.Cleanup(cancel) 875 | 876 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 877 | s, err := NewServer(nil, context.Background(), WithLogger(emptyLogger)) 878 | require.NoError(t, err) 879 | 880 | ln, err := net.Listen("tcp", ":0") 881 | require.NoError(t, err) 882 | 883 | go s.StartWithListener(ln) 884 | t.Cleanup(func() { s.Shutdown() }) 885 | 886 | url := fmt.Sprintf("http://%s/", ln.Addr()) 887 | req, err := http.NewRequestWithContext(ctx, "GET", url, nil) 888 | require.NoError(t, err) 889 | 890 | _, err = http.DefaultClient.Do(req) 891 | require.ErrorIs(t, err, io.EOF) 892 | } 893 | 894 | func BenchmarkThroughputServerSingle(b *testing.B) { 895 | DisableMaxContentLength(b) 896 | 897 | const testSize = 1<<16 - 1 898 | const packetSize = 512 899 | 900 | handlerTable := make(HandlerTable) 901 | 902 | handlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 903 | return 904 | } 905 | 906 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 907 | server, err := NewServer(handlerTable, context.Background(), WithLogger(emptyLogger)) 908 | if err != nil { 909 | b.Fatal(err) 910 | } 911 | 912 | server.SetConcurrency(1) 913 | 914 | serverConn, clientConn, err := pair.New() 915 | if err != nil { 916 | b.Fatal(err) 917 | } 918 | 919 | go server.ServeConn(serverConn) 920 | 921 | frisbeeConn := NewAsync(clientConn, emptyLogger) 922 | 923 | data := make([]byte, packetSize) 924 | _, _ = rand.Read(data) 925 | p := packet.Get() 926 | p.Metadata.Operation = metadata.PacketPing 927 | 928 | p.Content.Write(data) 929 | p.Metadata.ContentLength = packetSize 930 | 931 | b.Run("test", func(b *testing.B) { 932 | b.SetBytes(testSize * packetSize) 933 | b.ReportAllocs() 934 | b.ResetTimer() 935 | for i := 0; i < b.N; i++ { 936 | for q := 0; q < testSize; q++ { 937 | p.Metadata.Id = uint16(q) 938 | err = frisbeeConn.WritePacket(p) 939 | if err != nil { 940 | b.Fatal(err) 941 | } 942 | } 943 | } 944 | }) 945 | b.StopTimer() 946 | 947 | packet.Put(p) 948 | 949 | err = frisbeeConn.Close() 950 | if err != nil { 951 | b.Fatal(err) 952 | } 953 | err = server.Shutdown() 954 | if err != nil { 955 | b.Fatal(err) 956 | } 957 | } 958 | 959 | func BenchmarkThroughputServerUnlimited(b *testing.B) { 960 | DisableMaxContentLength(b) 961 | 962 | const testSize = 1<<16 - 1 963 | const packetSize = 512 964 | 965 | handlerTable := make(HandlerTable) 966 | 967 | handlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 968 | time.Sleep(time.Millisecond * 50) 969 | return 970 | } 971 | 972 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 973 | server, err := NewServer(handlerTable, context.Background(), WithLogger(emptyLogger)) 974 | if err != nil { 975 | b.Fatal(err) 976 | } 977 | 978 | server.SetConcurrency(0) 979 | 980 | serverConn, clientConn, err := pair.New() 981 | if err != nil { 982 | b.Fatal(err) 983 | } 984 | 985 | go server.ServeConn(serverConn) 986 | 987 | frisbeeConn := NewAsync(clientConn, emptyLogger) 988 | 989 | data := make([]byte, packetSize) 990 | _, _ = rand.Read(data) 991 | p := packet.Get() 992 | p.Metadata.Operation = metadata.PacketPing 993 | 994 | p.Content.Write(data) 995 | p.Metadata.ContentLength = packetSize 996 | 997 | b.Run("test", func(b *testing.B) { 998 | b.SetBytes(testSize * packetSize) 999 | b.ReportAllocs() 1000 | b.ResetTimer() 1001 | for i := 0; i < b.N; i++ { 1002 | for q := 0; q < testSize; q++ { 1003 | p.Metadata.Id = uint16(q) 1004 | err = frisbeeConn.WritePacket(p) 1005 | if err != nil { 1006 | b.Fatal(err) 1007 | } 1008 | } 1009 | } 1010 | }) 1011 | b.StopTimer() 1012 | 1013 | packet.Put(p) 1014 | 1015 | err = frisbeeConn.Close() 1016 | if err != nil { 1017 | b.Fatal(err) 1018 | } 1019 | err = server.Shutdown() 1020 | if err != nil { 1021 | b.Fatal(err) 1022 | } 1023 | } 1024 | 1025 | func BenchmarkThroughputServerLimited(b *testing.B) { 1026 | DisableMaxContentLength(b) 1027 | 1028 | const testSize = 1<<16 - 1 1029 | const packetSize = 512 1030 | 1031 | handlerTable := make(HandlerTable) 1032 | 1033 | handlerTable[metadata.PacketPing] = func(_ context.Context, _ *packet.Packet) (outgoing *packet.Packet, action Action) { 1034 | time.Sleep(time.Millisecond * 50) 1035 | return 1036 | } 1037 | 1038 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 1039 | server, err := NewServer(handlerTable, context.Background(), WithLogger(emptyLogger)) 1040 | if err != nil { 1041 | b.Fatal(err) 1042 | } 1043 | 1044 | server.SetConcurrency(1 << 14) 1045 | 1046 | serverConn, clientConn, err := pair.New() 1047 | if err != nil { 1048 | b.Fatal(err) 1049 | } 1050 | 1051 | go server.ServeConn(serverConn) 1052 | 1053 | frisbeeConn := NewAsync(clientConn, emptyLogger) 1054 | 1055 | data := make([]byte, packetSize) 1056 | _, _ = rand.Read(data) 1057 | p := packet.Get() 1058 | p.Metadata.Operation = metadata.PacketPing 1059 | 1060 | p.Content.Write(data) 1061 | p.Metadata.ContentLength = packetSize 1062 | 1063 | b.Run("test", func(b *testing.B) { 1064 | b.SetBytes(testSize * packetSize) 1065 | b.ReportAllocs() 1066 | b.ResetTimer() 1067 | for i := 0; i < b.N; i++ { 1068 | for q := 0; q < testSize; q++ { 1069 | p.Metadata.Id = uint16(q) 1070 | err = frisbeeConn.WritePacket(p) 1071 | if err != nil { 1072 | b.Fatal(err) 1073 | } 1074 | } 1075 | } 1076 | }) 1077 | b.StopTimer() 1078 | 1079 | packet.Put(p) 1080 | 1081 | err = frisbeeConn.Close() 1082 | if err != nil { 1083 | b.Fatal(err) 1084 | } 1085 | err = server.Shutdown() 1086 | if err != nil { 1087 | b.Fatal(err) 1088 | } 1089 | } 1090 | 1091 | func BenchmarkThroughputResponseServerSingle(b *testing.B) { 1092 | DisableMaxContentLength(b) 1093 | 1094 | const testSize = 1<<16 - 1 1095 | const packetSize = 512 1096 | 1097 | serverConn, clientConn, err := pair.New() 1098 | if err != nil { 1099 | b.Fatal(err) 1100 | } 1101 | 1102 | handlerTable := make(HandlerTable) 1103 | 1104 | handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 1105 | if incoming.Metadata.Id == testSize-1 { 1106 | incoming.Reset() 1107 | incoming.Metadata.Id = testSize 1108 | incoming.Metadata.Operation = metadata.PacketPong 1109 | outgoing = incoming 1110 | } 1111 | return 1112 | } 1113 | 1114 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 1115 | server, err := NewServer(handlerTable, context.Background(), WithLogger(emptyLogger)) 1116 | if err != nil { 1117 | b.Fatal(err) 1118 | } 1119 | 1120 | server.SetConcurrency(1) 1121 | 1122 | go server.ServeConn(serverConn) 1123 | 1124 | frisbeeConn := NewAsync(clientConn, emptyLogger) 1125 | 1126 | data := make([]byte, packetSize) 1127 | _, _ = rand.Read(data) 1128 | 1129 | p := packet.Get() 1130 | p.Metadata.Operation = metadata.PacketPing 1131 | 1132 | p.Content.Write(data) 1133 | p.Metadata.ContentLength = packetSize 1134 | 1135 | b.Run("test", func(b *testing.B) { 1136 | b.SetBytes(testSize * packetSize) 1137 | b.ReportAllocs() 1138 | b.ResetTimer() 1139 | for i := 0; i < b.N; i++ { 1140 | for q := 0; q < testSize; q++ { 1141 | p.Metadata.Id = uint16(q) 1142 | err = frisbeeConn.WritePacket(p) 1143 | if err != nil { 1144 | b.Fatal(err) 1145 | } 1146 | } 1147 | readPacket, err := frisbeeConn.ReadPacket() 1148 | if err != nil { 1149 | b.Fatal(err) 1150 | } 1151 | 1152 | if readPacket.Metadata.Id != testSize { 1153 | b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) 1154 | } 1155 | 1156 | if readPacket.Metadata.Operation != metadata.PacketPong { 1157 | b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) 1158 | } 1159 | packet.Put(readPacket) 1160 | } 1161 | 1162 | }) 1163 | b.StopTimer() 1164 | 1165 | packet.Put(p) 1166 | 1167 | err = frisbeeConn.Close() 1168 | if err != nil { 1169 | b.Fatal(err) 1170 | } 1171 | err = server.Shutdown() 1172 | if err != nil { 1173 | b.Fatal(err) 1174 | } 1175 | } 1176 | 1177 | func BenchmarkThroughputResponseServerSlowSingle(b *testing.B) { 1178 | DisableMaxContentLength(b) 1179 | 1180 | const testSize = 1<<16 - 1 1181 | const packetSize = 512 1182 | 1183 | serverConn, clientConn, err := pair.New() 1184 | if err != nil { 1185 | b.Fatal(err) 1186 | } 1187 | 1188 | handlerTable := make(HandlerTable) 1189 | 1190 | handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 1191 | time.Sleep(time.Microsecond * 50) 1192 | if incoming.Metadata.Id == testSize-1 { 1193 | incoming.Reset() 1194 | incoming.Metadata.Id = testSize 1195 | incoming.Metadata.Operation = metadata.PacketPong 1196 | outgoing = incoming 1197 | } 1198 | return 1199 | } 1200 | 1201 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 1202 | server, err := NewServer(handlerTable, context.Background(), WithLogger(emptyLogger)) 1203 | if err != nil { 1204 | b.Fatal(err) 1205 | } 1206 | 1207 | server.SetConcurrency(1) 1208 | 1209 | go server.ServeConn(serverConn) 1210 | 1211 | frisbeeConn := NewAsync(clientConn, emptyLogger) 1212 | 1213 | data := make([]byte, packetSize) 1214 | _, _ = rand.Read(data) 1215 | 1216 | p := packet.Get() 1217 | p.Metadata.Operation = metadata.PacketPing 1218 | 1219 | p.Content.Write(data) 1220 | p.Metadata.ContentLength = packetSize 1221 | 1222 | b.Run("test", func(b *testing.B) { 1223 | b.SetBytes(testSize * packetSize) 1224 | b.ReportAllocs() 1225 | b.ResetTimer() 1226 | for i := 0; i < b.N; i++ { 1227 | for q := 0; q < testSize; q++ { 1228 | p.Metadata.Id = uint16(q) 1229 | err = frisbeeConn.WritePacket(p) 1230 | if err != nil { 1231 | b.Fatal(err) 1232 | } 1233 | } 1234 | readPacket, err := frisbeeConn.ReadPacket() 1235 | if err != nil { 1236 | b.Fatal(err) 1237 | } 1238 | 1239 | if readPacket.Metadata.Id != testSize { 1240 | b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) 1241 | } 1242 | 1243 | if readPacket.Metadata.Operation != metadata.PacketPong { 1244 | b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) 1245 | } 1246 | packet.Put(readPacket) 1247 | } 1248 | 1249 | }) 1250 | b.StopTimer() 1251 | 1252 | packet.Put(p) 1253 | 1254 | err = frisbeeConn.Close() 1255 | if err != nil { 1256 | b.Fatal(err) 1257 | } 1258 | err = server.Shutdown() 1259 | if err != nil { 1260 | b.Fatal(err) 1261 | } 1262 | } 1263 | 1264 | func BenchmarkThroughputResponseServerSlowUnlimited(b *testing.B) { 1265 | DisableMaxContentLength(b) 1266 | 1267 | const testSize = 1<<16 - 1 1268 | const packetSize = 512 1269 | 1270 | serverConn, clientConn, err := pair.New() 1271 | if err != nil { 1272 | b.Fatal(err) 1273 | } 1274 | 1275 | handlerTable := make(HandlerTable) 1276 | 1277 | var count atomic.Uint64 1278 | handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 1279 | time.Sleep(time.Microsecond * 50) 1280 | if count.Add(1) == testSize-1 { 1281 | incoming.Reset() 1282 | incoming.Metadata.Id = testSize 1283 | incoming.Metadata.Operation = metadata.PacketPong 1284 | outgoing = incoming 1285 | count.Store(0) 1286 | } 1287 | return 1288 | } 1289 | 1290 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 1291 | server, err := NewServer(handlerTable, context.Background(), WithLogger(emptyLogger)) 1292 | if err != nil { 1293 | b.Fatal(err) 1294 | } 1295 | 1296 | server.SetConcurrency(0) 1297 | 1298 | go server.ServeConn(serverConn) 1299 | 1300 | frisbeeConn := NewAsync(clientConn, emptyLogger) 1301 | 1302 | data := make([]byte, packetSize) 1303 | _, _ = rand.Read(data) 1304 | 1305 | p := packet.Get() 1306 | p.Metadata.Operation = metadata.PacketPing 1307 | 1308 | p.Content.Write(data) 1309 | p.Metadata.ContentLength = packetSize 1310 | 1311 | b.Run("test", func(b *testing.B) { 1312 | b.SetBytes(testSize * packetSize) 1313 | b.ReportAllocs() 1314 | b.ResetTimer() 1315 | for i := 0; i < b.N; i++ { 1316 | for q := 0; q < testSize; q++ { 1317 | p.Metadata.Id = uint16(q) 1318 | err = frisbeeConn.WritePacket(p) 1319 | if err != nil { 1320 | b.Fatal(err) 1321 | } 1322 | } 1323 | readPacket, err := frisbeeConn.ReadPacket() 1324 | if err != nil { 1325 | b.Fatal(err) 1326 | } 1327 | 1328 | if readPacket.Metadata.Id != testSize { 1329 | b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) 1330 | } 1331 | 1332 | if readPacket.Metadata.Operation != metadata.PacketPong { 1333 | b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) 1334 | } 1335 | 1336 | packet.Put(readPacket) 1337 | } 1338 | 1339 | }) 1340 | b.StopTimer() 1341 | 1342 | packet.Put(p) 1343 | 1344 | err = frisbeeConn.Close() 1345 | if err != nil { 1346 | b.Fatal(err) 1347 | } 1348 | err = server.Shutdown() 1349 | if err != nil { 1350 | b.Fatal(err) 1351 | } 1352 | } 1353 | 1354 | func BenchmarkThroughputResponseServerSlowLimited(b *testing.B) { 1355 | DisableMaxContentLength(b) 1356 | 1357 | const testSize = 1<<16 - 1 1358 | const packetSize = 512 1359 | 1360 | serverConn, clientConn, err := pair.New() 1361 | if err != nil { 1362 | b.Fatal(err) 1363 | } 1364 | 1365 | handlerTable := make(HandlerTable) 1366 | 1367 | var count atomic.Uint64 1368 | handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { 1369 | time.Sleep(time.Microsecond * 50) 1370 | if count.Add(1) == testSize-1 { 1371 | incoming.Reset() 1372 | incoming.Metadata.Id = testSize 1373 | incoming.Metadata.Operation = metadata.PacketPong 1374 | outgoing = incoming 1375 | count.Store(0) 1376 | } 1377 | return 1378 | } 1379 | 1380 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 1381 | server, err := NewServer(handlerTable, context.Background(), WithLogger(emptyLogger)) 1382 | if err != nil { 1383 | b.Fatal(err) 1384 | } 1385 | 1386 | server.SetConcurrency(100) 1387 | 1388 | go server.ServeConn(serverConn) 1389 | 1390 | frisbeeConn := NewAsync(clientConn, emptyLogger) 1391 | 1392 | data := make([]byte, packetSize) 1393 | _, _ = rand.Read(data) 1394 | 1395 | p := packet.Get() 1396 | p.Metadata.Operation = metadata.PacketPing 1397 | 1398 | p.Content.Write(data) 1399 | p.Metadata.ContentLength = packetSize 1400 | 1401 | b.Run("test", func(b *testing.B) { 1402 | b.SetBytes(testSize * packetSize) 1403 | b.ReportAllocs() 1404 | b.ResetTimer() 1405 | for i := 0; i < b.N; i++ { 1406 | for q := 0; q < testSize; q++ { 1407 | p.Metadata.Id = uint16(q) 1408 | err = frisbeeConn.WritePacket(p) 1409 | if err != nil { 1410 | b.Fatal(err) 1411 | } 1412 | } 1413 | readPacket, err := frisbeeConn.ReadPacket() 1414 | if err != nil { 1415 | b.Fatal(err) 1416 | } 1417 | 1418 | if readPacket.Metadata.Id != testSize { 1419 | b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) 1420 | } 1421 | 1422 | if readPacket.Metadata.Operation != metadata.PacketPong { 1423 | b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) 1424 | } 1425 | 1426 | packet.Put(readPacket) 1427 | } 1428 | 1429 | }) 1430 | b.StopTimer() 1431 | 1432 | packet.Put(p) 1433 | 1434 | err = frisbeeConn.Close() 1435 | if err != nil { 1436 | b.Fatal(err) 1437 | } 1438 | err = server.Shutdown() 1439 | if err != nil { 1440 | b.Fatal(err) 1441 | } 1442 | } 1443 | -------------------------------------------------------------------------------- /stream.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "sync" 7 | "sync/atomic" 8 | 9 | "github.com/loopholelabs/common/pkg/queue" 10 | 11 | "github.com/loopholelabs/frisbee-go/pkg/packet" 12 | ) 13 | 14 | // DefaultStreamBufferSize is the default size of the stream buffer. 15 | const DefaultStreamBufferSize = 1 << 12 16 | 17 | type NewStreamHandler func(*Stream) 18 | 19 | type Stream struct { 20 | id uint16 21 | conn *Async 22 | closed atomic.Bool 23 | queue *queue.Circular[packet.Packet, *packet.Packet] 24 | staleMu sync.Mutex 25 | stale []*packet.Packet 26 | } 27 | 28 | func newStream(id uint16, conn *Async) *Stream { 29 | return &Stream{ 30 | id: id, 31 | conn: conn, 32 | queue: queue.NewCircular[packet.Packet, *packet.Packet](DefaultStreamBufferSize), 33 | } 34 | } 35 | 36 | // ReadPacket is a blocking function that will wait until a Frisbee packet is available and then return it (and its content). 37 | // In the event that the connection is closed, ReadPacket will return an error. 38 | func (s *Stream) ReadPacket() (*packet.Packet, error) { 39 | if s.closed.Load() { 40 | s.staleMu.Lock() 41 | if len(s.stale) > 0 { 42 | var p *packet.Packet 43 | p, s.stale = s.stale[0], s.stale[1:] 44 | s.staleMu.Unlock() 45 | return p, nil 46 | } 47 | s.staleMu.Unlock() 48 | return nil, StreamClosed 49 | } 50 | 51 | readPacket, err := s.queue.Pop() 52 | if err != nil { 53 | if s.closed.Load() { 54 | s.staleMu.Lock() 55 | if len(s.stale) > 0 { 56 | var p *packet.Packet 57 | p, s.stale = s.stale[0], s.stale[1:] 58 | s.staleMu.Unlock() 59 | return p, nil 60 | } 61 | s.staleMu.Unlock() 62 | return nil, StreamClosed 63 | } 64 | return nil, err 65 | } 66 | 67 | return readPacket, nil 68 | } 69 | 70 | // WritePacket will write the given packet to the stream but the ID and Operation will be 71 | // overwritten with the stream's ID and the STREAM operation. Packets send to a stream 72 | // must have a ContentLength greater than 0. 73 | func (s *Stream) WritePacket(p *packet.Packet) error { 74 | if s.closed.Load() { 75 | return StreamClosed 76 | } 77 | if p.Metadata.ContentLength == 0 { 78 | return InvalidStreamPacket 79 | } 80 | p.Metadata.Id = s.id 81 | p.Metadata.Operation = STREAM 82 | return s.conn.writePacket(p, true) 83 | } 84 | 85 | // ID returns the stream's ID. 86 | func (s *Stream) ID() uint16 { 87 | return s.id 88 | } 89 | 90 | // Conn returns the connection that the stream is associated with. 91 | func (s *Stream) Conn() *Async { 92 | return s.conn 93 | } 94 | 95 | // Close will close the stream and prevent any further reads or writes. 96 | func (s *Stream) Close() error { 97 | return s.closeSend(true) 98 | } 99 | 100 | func (s *Stream) closeSend(lock bool) error { 101 | s.staleMu.Lock() 102 | if s.closed.CompareAndSwap(false, true) { 103 | s.queue.Close() 104 | s.stale = s.queue.Drain() 105 | s.staleMu.Unlock() 106 | 107 | p := packet.Get() 108 | p.Metadata.Id = s.id 109 | p.Metadata.Operation = STREAM 110 | err := s.conn.writePacket(p, true) 111 | packet.Put(p) 112 | 113 | if lock { 114 | s.conn.streamsMu.Lock() 115 | delete(s.conn.streams, s.id) 116 | s.conn.streamsMu.Unlock() 117 | } 118 | 119 | return err 120 | } 121 | s.staleMu.Unlock() 122 | return StreamClosed 123 | } 124 | 125 | // close will close the stream and prevent any further reads or writes without sending a stream close packet. 126 | func (s *Stream) close() { 127 | s.staleMu.Lock() 128 | if s.closed.CompareAndSwap(false, true) { 129 | s.queue.Close() 130 | s.stale = s.queue.Drain() 131 | } 132 | s.staleMu.Unlock() 133 | } 134 | -------------------------------------------------------------------------------- /stream_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "crypto/rand" 7 | "net" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/loopholelabs/logging" 15 | 16 | "github.com/loopholelabs/frisbee-go/pkg/packet" 17 | ) 18 | 19 | func TestNewStream(t *testing.T) { 20 | t.Parallel() 21 | 22 | const packetSize = 512 23 | 24 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 25 | reader, writer := net.Pipe() 26 | 27 | readerConn := NewAsync(reader, emptyLogger) 28 | writerConn := NewAsync(writer, emptyLogger) 29 | 30 | writerStream := writerConn.NewStream(0) 31 | 32 | data := make([]byte, packetSize) 33 | _, err := rand.Read(data) 34 | require.NoError(t, err) 35 | 36 | p := packet.Get() 37 | p.Metadata.Id = 64 38 | p.Metadata.Operation = 32 39 | p.Metadata.ContentLength = uint32(packetSize) 40 | p.Content.Write(data) 41 | 42 | readerStreamCh := make(chan *Stream) 43 | var readerStream *Stream 44 | 45 | readerConn.SetNewStreamHandler(func(stream *Stream) { 46 | readerStreamCh <- stream 47 | }) 48 | 49 | err = writerStream.WritePacket(p) 50 | require.NoError(t, err) 51 | packet.Put(p) 52 | 53 | timer := time.NewTimer(DefaultDeadline) 54 | select { 55 | case <-timer.C: 56 | t.Fatal("timed out waiting for reader stream") 57 | case readerStream = <-readerStreamCh: 58 | } 59 | 60 | p, err = readerStream.ReadPacket() 61 | require.NoError(t, err) 62 | require.NotNil(t, p.Metadata) 63 | assert.Equal(t, readerStream.ID(), p.Metadata.Id) 64 | assert.Equal(t, STREAM, p.Metadata.Operation) 65 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 66 | assert.Equal(t, data, p.Content.Bytes()) 67 | 68 | err = readerStream.Close() 69 | require.NoError(t, err) 70 | 71 | time.Sleep(DefaultDeadline) 72 | 73 | err = writerStream.Close() 74 | require.ErrorIs(t, err, StreamClosed) 75 | 76 | err = readerConn.Close() 77 | assert.NoError(t, err) 78 | err = writerConn.Close() 79 | assert.NoError(t, err) 80 | } 81 | 82 | func TestNewStreamStale(t *testing.T) { 83 | t.Parallel() 84 | 85 | const packetSize = 512 86 | 87 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 88 | 89 | reader, writer := net.Pipe() 90 | 91 | readerConn := NewAsync(reader, emptyLogger) 92 | writerConn := NewAsync(writer, emptyLogger) 93 | 94 | writerStream := writerConn.NewStream(0) 95 | 96 | data := make([]byte, packetSize) 97 | _, err := rand.Read(data) 98 | require.NoError(t, err) 99 | 100 | p := packet.Get() 101 | p.Metadata.Id = 64 102 | p.Metadata.Operation = 32 103 | p.Metadata.ContentLength = uint32(packetSize) 104 | p.Content.Write(data) 105 | 106 | readerStreamCh := make(chan *Stream) 107 | var readerStream *Stream 108 | 109 | readerConn.SetNewStreamHandler(func(stream *Stream) { 110 | readerStreamCh <- stream 111 | }) 112 | 113 | err = writerStream.WritePacket(p) 114 | require.NoError(t, err) 115 | 116 | err = writerStream.WritePacket(p) 117 | require.NoError(t, err) 118 | 119 | packet.Put(p) 120 | 121 | timer := time.NewTimer(DefaultDeadline) 122 | select { 123 | case <-timer.C: 124 | t.Fatal("timed out waiting for reader stream") 125 | case readerStream = <-readerStreamCh: 126 | } 127 | 128 | err = writerStream.Close() 129 | require.NoError(t, err) 130 | 131 | time.Sleep(DefaultDeadline) 132 | 133 | p, err = readerStream.ReadPacket() 134 | require.NoError(t, err) 135 | require.NotNil(t, p.Metadata) 136 | assert.Equal(t, readerStream.ID(), p.Metadata.Id) 137 | assert.Equal(t, STREAM, p.Metadata.Operation) 138 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 139 | assert.Equal(t, data, p.Content.Bytes()) 140 | 141 | p, err = readerStream.ReadPacket() 142 | require.NoError(t, err) 143 | require.NotNil(t, p.Metadata) 144 | assert.Equal(t, readerStream.ID(), p.Metadata.Id) 145 | assert.Equal(t, STREAM, p.Metadata.Operation) 146 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 147 | assert.Equal(t, data, p.Content.Bytes()) 148 | 149 | _, err = readerStream.ReadPacket() 150 | require.ErrorIs(t, err, StreamClosed) 151 | 152 | err = readerConn.Close() 153 | assert.NoError(t, err) 154 | } 155 | 156 | func TestNewStreamDualCreate(t *testing.T) { 157 | t.Parallel() 158 | 159 | const packetSize = 512 160 | 161 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 162 | 163 | reader, writer := net.Pipe() 164 | 165 | readerConn := NewAsync(reader, emptyLogger, func(_ *Stream) {}) 166 | writerConn := NewAsync(writer, emptyLogger, func(_ *Stream) {}) 167 | 168 | writerStream := writerConn.NewStream(0) 169 | readerStream := readerConn.NewStream(0) 170 | 171 | data := make([]byte, packetSize) 172 | _, err := rand.Read(data) 173 | require.NoError(t, err) 174 | 175 | p := packet.Get() 176 | p.Metadata.Id = 64 177 | p.Metadata.Operation = 32 178 | p.Metadata.ContentLength = uint32(packetSize) 179 | p.Content.Write(data) 180 | 181 | err = writerStream.WritePacket(p) 182 | require.NoError(t, err) 183 | 184 | err = writerStream.WritePacket(p) 185 | require.NoError(t, err) 186 | 187 | packet.Put(p) 188 | 189 | err = writerStream.Close() 190 | require.NoError(t, err) 191 | 192 | time.Sleep(DefaultDeadline) 193 | 194 | p, err = readerStream.ReadPacket() 195 | require.NoError(t, err) 196 | require.NotNil(t, p.Metadata) 197 | assert.Equal(t, readerStream.ID(), p.Metadata.Id) 198 | assert.Equal(t, STREAM, p.Metadata.Operation) 199 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 200 | assert.Equal(t, data, p.Content.Bytes()) 201 | 202 | p, err = readerStream.ReadPacket() 203 | require.NoError(t, err) 204 | require.NotNil(t, p.Metadata) 205 | assert.Equal(t, readerStream.ID(), p.Metadata.Id) 206 | assert.Equal(t, STREAM, p.Metadata.Operation) 207 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 208 | assert.Equal(t, data, p.Content.Bytes()) 209 | 210 | _, err = readerStream.ReadPacket() 211 | require.ErrorIs(t, err, StreamClosed) 212 | 213 | err = readerConn.Close() 214 | assert.NoError(t, err) 215 | } 216 | 217 | func TestStreamConnClose(t *testing.T) { 218 | t.Parallel() 219 | 220 | const packetSize = 512 221 | 222 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 223 | 224 | reader, writer := net.Pipe() 225 | 226 | readerConn := NewAsync(reader, emptyLogger, func(_ *Stream) {}) 227 | writerConn := NewAsync(writer, emptyLogger, func(_ *Stream) {}) 228 | 229 | writerStream := writerConn.NewStream(0) 230 | readerStream := readerConn.NewStream(0) 231 | 232 | data := make([]byte, packetSize) 233 | _, err := rand.Read(data) 234 | require.NoError(t, err) 235 | 236 | p := packet.Get() 237 | p.Metadata.Id = 64 238 | p.Metadata.Operation = 32 239 | p.Metadata.ContentLength = uint32(packetSize) 240 | p.Content.Write(data) 241 | 242 | err = writerStream.WritePacket(p) 243 | require.NoError(t, err) 244 | 245 | packet.Put(p) 246 | 247 | err = writerConn.Close() 248 | require.NoError(t, err) 249 | 250 | time.Sleep(DefaultDeadline) 251 | 252 | err = writerStream.Close() 253 | require.ErrorIs(t, err, StreamClosed) 254 | 255 | p, err = readerStream.ReadPacket() 256 | require.NoError(t, err) 257 | require.NotNil(t, p.Metadata) 258 | assert.Equal(t, readerStream.ID(), p.Metadata.Id) 259 | assert.Equal(t, STREAM, p.Metadata.Operation) 260 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 261 | assert.Equal(t, data, p.Content.Bytes()) 262 | 263 | _, err = readerStream.ReadPacket() 264 | require.ErrorIs(t, err, StreamClosed) 265 | 266 | err = readerConn.Close() 267 | assert.NoError(t, err) 268 | } 269 | -------------------------------------------------------------------------------- /sync.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "context" 7 | "crypto/tls" 8 | "encoding/binary" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | 17 | "github.com/loopholelabs/logging/loggers/noop" 18 | "github.com/loopholelabs/logging/types" 19 | 20 | "github.com/loopholelabs/frisbee-go/internal/dialer" 21 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 22 | "github.com/loopholelabs/frisbee-go/pkg/packet" 23 | ) 24 | 25 | // Sync is the underlying synchronous frisbee connection which has extremely efficient read and write logic and 26 | // can handle the specific frisbee requirements. This is not meant to be used on its own, and instead is 27 | // meant to be used by frisbee client and server implementations 28 | type Sync struct { 29 | sync.Mutex 30 | conn net.Conn 31 | closed atomic.Bool 32 | logger types.Logger 33 | error atomic.Value 34 | ctxMu sync.RWMutex 35 | ctx context.Context 36 | } 37 | 38 | // ConnectSync creates a new TCP connection (using net.Dial) and wraps it in a frisbee connection 39 | func ConnectSync(addr string, keepAlive time.Duration, logger types.Logger, TLSConfig *tls.Config) (*Sync, error) { 40 | var conn net.Conn 41 | var err error 42 | 43 | d := dialer.NewRetry() 44 | 45 | if TLSConfig != nil { 46 | conn, err = d.DialTLS("tcp", addr, TLSConfig) 47 | } else { 48 | conn, err = d.Dial("tcp", addr) 49 | if err == nil { 50 | _ = conn.(*net.TCPConn).SetKeepAlive(true) 51 | _ = conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) 52 | } 53 | } 54 | 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | return NewSync(conn, logger), nil 60 | } 61 | 62 | // NewSync takes an existing net.Conn object and wraps it in a frisbee connection 63 | func NewSync(c net.Conn, logger types.Logger) (conn *Sync) { 64 | conn = &Sync{ 65 | conn: c, 66 | logger: logger, 67 | } 68 | 69 | if logger == nil { 70 | conn.logger = noop.New(types.InfoLevel) 71 | } 72 | return 73 | } 74 | 75 | // SetDeadline sets the read and write deadline on the underlying net.Conn 76 | func (c *Sync) SetDeadline(t time.Time) error { 77 | return c.conn.SetDeadline(t) 78 | } 79 | 80 | // SetReadDeadline sets the read deadline on the underlying net.Conn 81 | func (c *Sync) SetReadDeadline(t time.Time) error { 82 | return c.conn.SetReadDeadline(t) 83 | } 84 | 85 | // SetWriteDeadline sets the write deadline on the underlying net.Conn 86 | func (c *Sync) SetWriteDeadline(t time.Time) error { 87 | return c.conn.SetWriteDeadline(t) 88 | } 89 | 90 | // ConnectionState returns the tls.ConnectionState of a *tls.Conn 91 | // if the connection is not *tls.Conn then the NotTLSConnectionError is returned 92 | func (c *Sync) ConnectionState() (tls.ConnectionState, error) { 93 | if tlsConn, ok := c.conn.(*tls.Conn); ok { 94 | return tlsConn.ConnectionState(), nil 95 | } 96 | return emptyState, NotTLSConnectionError 97 | } 98 | 99 | // Handshake performs the tls.Handshake() of a *tls.Conn 100 | // if the connection is not *tls.Conn then the NotTLSConnectionError is returned 101 | func (c *Sync) Handshake() error { 102 | if tlsConn, ok := c.conn.(*tls.Conn); ok { 103 | return tlsConn.Handshake() 104 | } 105 | return NotTLSConnectionError 106 | } 107 | 108 | // HandshakeContext performs the tls.HandshakeContext() of a *tls.Conn 109 | // if the connection is not *tls.Conn then the NotTLSConnectionError is returned 110 | func (c *Sync) HandshakeContext(ctx context.Context) error { 111 | if tlsConn, ok := c.conn.(*tls.Conn); ok { 112 | return tlsConn.HandshakeContext(ctx) 113 | } 114 | return NotTLSConnectionError 115 | } 116 | 117 | // LocalAddr returns the local address of the underlying net.Conn 118 | func (c *Sync) LocalAddr() net.Addr { 119 | return c.conn.LocalAddr() 120 | } 121 | 122 | // RemoteAddr returns the remote address of the underlying net.Conn 123 | func (c *Sync) RemoteAddr() net.Addr { 124 | return c.conn.RemoteAddr() 125 | } 126 | 127 | // WritePacket takes a packet.Packet and sends it synchronously. 128 | // 129 | // If packet.Metadata.ContentLength == 0, then the content array must be nil. Otherwise, it is required that packet.Metadata.ContentLength == len(content). 130 | func (c *Sync) WritePacket(p *packet.Packet) error { 131 | if int(p.Metadata.ContentLength) != p.Content.Len() { 132 | return InvalidContentLength 133 | } 134 | if DefaultMaxContentLength > 0 && p.Metadata.ContentLength > DefaultMaxContentLength { 135 | return ContentLengthExceeded 136 | } 137 | 138 | var encodedMetadata [metadata.Size]byte 139 | 140 | binary.BigEndian.PutUint16(encodedMetadata[metadata.MagicOffset:metadata.MagicOffset+metadata.MagicSize], metadata.PacketMagicHeader) 141 | binary.BigEndian.PutUint16(encodedMetadata[metadata.IdOffset:metadata.IdOffset+metadata.IdSize], p.Metadata.Id) 142 | binary.BigEndian.PutUint16(encodedMetadata[metadata.OperationOffset:metadata.OperationOffset+metadata.OperationSize], p.Metadata.Operation) 143 | binary.BigEndian.PutUint32(encodedMetadata[metadata.ContentLengthOffset:metadata.ContentLengthOffset+metadata.ContentLengthSize], p.Metadata.ContentLength) 144 | 145 | c.Lock() 146 | if c.closed.Load() { 147 | c.Unlock() 148 | return ConnectionClosed 149 | } 150 | 151 | _, err := c.conn.Write(encodedMetadata[:]) 152 | if err != nil { 153 | c.Unlock() 154 | if c.closed.Load() { 155 | c.Logger().Debug().Err(ConnectionClosed).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing encoded metadata") 156 | return ConnectionClosed 157 | } 158 | c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing encoded metadata") 159 | return c.closeWithError(err) 160 | } 161 | if p.Metadata.ContentLength != 0 { 162 | _, err = c.conn.Write(p.Content.Bytes()[:p.Metadata.ContentLength]) 163 | if err != nil { 164 | c.Unlock() 165 | if c.closed.Load() { 166 | c.Logger().Debug().Err(ConnectionClosed).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing encoded metadata") 167 | return ConnectionClosed 168 | } 169 | c.Logger().Debug().Err(err).Uint16("Packet ID", p.Metadata.Id).Msg("error while writing encoded metadata") 170 | return c.closeWithError(err) 171 | } 172 | } 173 | 174 | c.Unlock() 175 | return nil 176 | } 177 | 178 | // ReadPacket is a blocking function that will wait until a frisbee packet is available and then return it (and its content). 179 | // In the event that the connection is closed, ReadPacket will return an error. 180 | func (c *Sync) ReadPacket() (*packet.Packet, error) { 181 | if c.closed.Load() { 182 | return nil, ConnectionClosed 183 | } 184 | var encodedPacket [metadata.Size]byte 185 | 186 | _, err := io.ReadAtLeast(c.conn, encodedPacket[:], metadata.Size) 187 | if err != nil { 188 | if c.closed.Load() { 189 | c.Logger().Debug().Err(ConnectionClosed).Msg("error while reading from underlying net.Conn") 190 | return nil, ConnectionClosed 191 | } 192 | c.Logger().Debug().Err(err).Msg("error while reading from underlying net.Conn") 193 | return nil, c.closeWithError(err) 194 | } 195 | p := packet.Get() 196 | 197 | p.Metadata.Magic = binary.BigEndian.Uint16(encodedPacket[metadata.MagicOffset : metadata.MagicOffset+metadata.MagicSize]) 198 | p.Metadata.Id = binary.BigEndian.Uint16(encodedPacket[metadata.IdOffset : metadata.IdOffset+metadata.IdSize]) 199 | p.Metadata.Operation = binary.BigEndian.Uint16(encodedPacket[metadata.OperationOffset : metadata.OperationOffset+metadata.OperationSize]) 200 | p.Metadata.ContentLength = binary.BigEndian.Uint32(encodedPacket[metadata.ContentLengthOffset : metadata.ContentLengthOffset+metadata.ContentLengthSize]) 201 | 202 | if p.Metadata.Magic != metadata.PacketMagicHeader { 203 | c.Logger().Debug().Str("magic", fmt.Sprintf("%x", p.Metadata.Magic)).Msg("received packet with incorrect magic header") 204 | return nil, c.closeWithError(InvalidMagicHeader) 205 | } 206 | 207 | if DefaultMaxContentLength > 0 && p.Metadata.ContentLength > DefaultMaxContentLength { 208 | c.Logger().Debug(). 209 | Uint32("content_length", p.Metadata.ContentLength). 210 | Uint32("max_content_length", DefaultMaxContentLength). 211 | Msg("received packet that exceeds max content length") 212 | return nil, c.closeWithError(ContentLengthExceeded) 213 | } 214 | 215 | if p.Metadata.ContentLength > 0 { 216 | contentLength := int(p.Metadata.ContentLength) 217 | p.Content.Grow(contentLength) 218 | p.Content.MoveOffset(contentLength) 219 | _, err = io.ReadAtLeast(c.conn, p.Content.Bytes(), contentLength) 220 | if err != nil { 221 | if c.closed.Load() { 222 | c.Logger().Debug().Err(ConnectionClosed).Msg("error while reading from underlying net.Conn") 223 | return nil, ConnectionClosed 224 | } 225 | c.Logger().Debug().Err(err).Msg("error while reading from underlying net.Conn") 226 | return nil, c.closeWithError(err) 227 | } 228 | } 229 | 230 | return p, nil 231 | } 232 | 233 | // SetContext allows users to save a context within a connection 234 | func (c *Sync) SetContext(ctx context.Context) { 235 | c.ctxMu.Lock() 236 | c.ctx = ctx 237 | c.ctxMu.Unlock() 238 | } 239 | 240 | // Context returns the saved context within the connection 241 | func (c *Sync) Context() (ctx context.Context) { 242 | c.ctxMu.RLock() 243 | ctx = c.ctx 244 | c.ctxMu.RUnlock() 245 | return 246 | } 247 | 248 | // Logger returns the underlying logger of the frisbee connection 249 | func (c *Sync) Logger() types.Logger { 250 | return c.logger 251 | } 252 | 253 | // Error returns the error that caused the frisbee.Sync to close or go into a paused state 254 | func (c *Sync) Error() error { 255 | err := c.error.Load() 256 | if err == nil { 257 | return nil 258 | } 259 | return err.(error) 260 | } 261 | 262 | // Raw shuts off all of frisbee's underlying functionality and converts the frisbee connection into a normal TCP connection (net.Conn) 263 | func (c *Sync) Raw() net.Conn { 264 | _ = c.close() 265 | return c.conn 266 | } 267 | 268 | // Close closes the frisbee connection gracefully 269 | func (c *Sync) Close() error { 270 | err := c.close() 271 | if errors.Is(err, ConnectionClosed) { 272 | return nil 273 | } 274 | _ = c.conn.Close() 275 | return err 276 | } 277 | 278 | func (c *Sync) close() error { 279 | if c.closed.CompareAndSwap(false, true) { 280 | return nil 281 | } 282 | return ConnectionClosed 283 | } 284 | 285 | func (c *Sync) closeWithError(err error) error { 286 | closeError := c.close() 287 | if errors.Is(closeError, ConnectionClosed) { 288 | c.Logger().Debug().Err(err).Msg("attempted to close connection with error, but connection already closed") 289 | return ConnectionClosed 290 | } else { 291 | c.Logger().Debug().Err(err).Msgf("closing connection with error") 292 | } 293 | c.error.Store(err) 294 | _ = c.conn.Close() 295 | return err 296 | } 297 | -------------------------------------------------------------------------------- /sync_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: Apache-2.0 2 | 3 | package frisbee 4 | 5 | import ( 6 | "crypto/rand" 7 | "encoding/binary" 8 | "io" 9 | "net" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | 15 | "github.com/loopholelabs/logging" 16 | "github.com/loopholelabs/polyglot/v2" 17 | "github.com/loopholelabs/testing/conn/pair" 18 | 19 | "github.com/loopholelabs/frisbee-go/pkg/metadata" 20 | "github.com/loopholelabs/frisbee-go/pkg/packet" 21 | ) 22 | 23 | func TestNewSync(t *testing.T) { 24 | t.Parallel() 25 | const packetSize = 512 26 | 27 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 28 | 29 | reader, writer := net.Pipe() 30 | 31 | readerConn := NewSync(reader, emptyLogger) 32 | writerConn := NewSync(writer, emptyLogger) 33 | 34 | start := make(chan struct{}, 1) 35 | end := make(chan struct{}, 1) 36 | 37 | p := packet.Get() 38 | p.Metadata.Id = 64 39 | p.Metadata.Operation = 32 40 | 41 | go func() { 42 | start <- struct{}{} 43 | p, err := readerConn.ReadPacket() 44 | assert.NoError(t, err) 45 | assert.NotNil(t, p.Metadata) 46 | assert.Equal(t, uint16(64), p.Metadata.Id) 47 | assert.Equal(t, uint16(32), p.Metadata.Operation) 48 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 49 | assert.Equal(t, 0, p.Content.Len()) 50 | end <- struct{}{} 51 | packet.Put(p) 52 | }() 53 | 54 | <-start 55 | err := writerConn.WritePacket(p) 56 | assert.NoError(t, err) 57 | <-end 58 | 59 | data := make([]byte, packetSize) 60 | _, _ = rand.Read(data) 61 | 62 | p.Content.Write(data) 63 | p.Metadata.ContentLength = packetSize 64 | 65 | go func() { 66 | start <- struct{}{} 67 | p, err := readerConn.ReadPacket() 68 | assert.NoError(t, err) 69 | assert.NotNil(t, p.Metadata) 70 | assert.Equal(t, uint16(64), p.Metadata.Id) 71 | assert.Equal(t, uint16(32), p.Metadata.Operation) 72 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 73 | assert.Equal(t, packetSize, p.Content.Len()) 74 | expected := polyglot.NewBufferFromBytes(data) 75 | expected.MoveOffset(len(data)) 76 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 77 | end <- struct{}{} 78 | packet.Put(p) 79 | }() 80 | 81 | <-start 82 | err = writerConn.WritePacket(p) 83 | assert.NoError(t, err) 84 | 85 | packet.Put(p) 86 | <-end 87 | 88 | err = readerConn.Close() 89 | assert.NoError(t, err) 90 | err = writerConn.Close() 91 | assert.NoError(t, err) 92 | } 93 | 94 | func TestSyncLargeWrite(t *testing.T) { 95 | t.Parallel() 96 | 97 | const testSize = 100000 98 | const packetSize = 512 99 | 100 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 101 | 102 | reader, writer := net.Pipe() 103 | 104 | readerConn := NewSync(reader, emptyLogger) 105 | writerConn := NewSync(writer, emptyLogger) 106 | 107 | randomData := make([][]byte, testSize) 108 | 109 | p := packet.Get() 110 | p.Metadata.Id = 64 111 | p.Metadata.Operation = 32 112 | p.Metadata.ContentLength = packetSize 113 | 114 | start := make(chan struct{}, 1) 115 | end := make(chan struct{}, 1) 116 | 117 | go func() { 118 | start <- struct{}{} 119 | for i := 0; i < testSize; i++ { 120 | p, err := readerConn.ReadPacket() 121 | assert.NoError(t, err) 122 | assert.NotNil(t, p.Metadata) 123 | assert.Equal(t, uint16(64), p.Metadata.Id) 124 | assert.Equal(t, uint16(32), p.Metadata.Operation) 125 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 126 | assert.Equal(t, packetSize, p.Content.Len()) 127 | expected := polyglot.NewBufferFromBytes(randomData[i]) 128 | expected.MoveOffset(len(randomData[i])) 129 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 130 | packet.Put(p) 131 | } 132 | end <- struct{}{} 133 | }() 134 | 135 | <-start 136 | for i := 0; i < testSize; i++ { 137 | randomData[i] = make([]byte, packetSize) 138 | _, _ = rand.Read(randomData[i]) 139 | p.Content.Write(randomData[i]) 140 | err := writerConn.WritePacket(p) 141 | p.Content.Reset() 142 | assert.NoError(t, err) 143 | } 144 | <-end 145 | 146 | packet.Put(p) 147 | 148 | // Verify large writes past max length fails. 149 | big := make([]byte, 2*DefaultMaxContentLength) 150 | 151 | p = packet.Get() 152 | p.Metadata.Id = 64 153 | p.Metadata.Operation = 32 154 | p.Metadata.ContentLength = uint32(len(big)) 155 | p.Content.Write(big) 156 | 157 | err := writerConn.WritePacket(p) 158 | assert.ErrorIs(t, err, ContentLengthExceeded) 159 | p.Content.Reset() 160 | packet.Put(p) 161 | 162 | err = readerConn.Close() 163 | assert.NoError(t, err) 164 | err = writerConn.Close() 165 | assert.NoError(t, err) 166 | } 167 | 168 | func TestSyncLargeRead(t *testing.T) { 169 | t.Parallel() 170 | 171 | client, server, err := pair.New() 172 | require.NoError(t, err) 173 | 174 | serverConn := NewSync(server, logging.Test(t, logging.Noop, t.Name())) 175 | t.Cleanup(func() { serverConn.Close() }) 176 | 177 | // Write a large packet that exceeds the maximum limit. 178 | bigData := make([]byte, 2*DefaultMaxContentLength+metadata.Size) 179 | binary.BigEndian.PutUint16(bigData[metadata.MagicOffset:metadata.MagicOffset+metadata.MagicSize], metadata.PacketMagicHeader) 180 | binary.BigEndian.PutUint16(bigData[metadata.IdOffset:metadata.IdOffset+metadata.IdSize], 0xFFFF) 181 | binary.BigEndian.PutUint16(bigData[metadata.OperationOffset:metadata.OperationOffset+metadata.OperationSize], 0xFFFF) 182 | binary.BigEndian.PutUint32(bigData[metadata.ContentLengthOffset:metadata.ContentLengthOffset+metadata.ContentLengthSize], 2*DefaultMaxContentLength) 183 | 184 | doneCh := make(chan any) 185 | go func() { 186 | defer close(doneCh) 187 | 188 | _, err := client.Write(bigData) 189 | 190 | // Verify client was disconnected. 191 | var opError *net.OpError 192 | assert.ErrorAs(t, err, &opError) 193 | }() 194 | 195 | // Read packet and very it failed. 196 | _, err = serverConn.ReadPacket() 197 | assert.ErrorIs(t, err, ContentLengthExceeded) 198 | 199 | <-doneCh 200 | } 201 | 202 | func TestSyncDisableMaxContentLength(t *testing.T) { 203 | // Don't run in parallel since it modifies DefaultMaxContentLength. 204 | 205 | oldMax := DisableMaxContentLength(t) 206 | logger := logging.Test(t, logging.Noop, t.Name()) 207 | 208 | t.Run("read", func(t *testing.T) { 209 | client, server, err := pair.New() 210 | require.NoError(t, err) 211 | 212 | serverConn := NewSync(server, logger) 213 | t.Cleanup(func() { serverConn.Close() }) 214 | 215 | // Write a large packet that would exceed the default maximum limit. 216 | content := make([]byte, 2*oldMax+metadata.Size) 217 | binary.BigEndian.PutUint16(content[metadata.MagicOffset:metadata.MagicOffset+metadata.MagicSize], metadata.PacketMagicHeader) 218 | binary.BigEndian.PutUint16(content[metadata.IdOffset:metadata.IdOffset+metadata.IdSize], 0xFFFF) 219 | binary.BigEndian.PutUint16(content[metadata.OperationOffset:metadata.OperationOffset+metadata.OperationSize], 0xFFFF) 220 | binary.BigEndian.PutUint32(content[metadata.ContentLengthOffset:metadata.ContentLengthOffset+metadata.ContentLengthSize], 2*oldMax) 221 | 222 | doneCh := make(chan any) 223 | go func() { 224 | defer close(doneCh) 225 | 226 | n, err := client.Write(content) 227 | assert.NoError(t, err) 228 | assert.Equal(t, int(2*oldMax+metadata.Size), n) 229 | }() 230 | 231 | // Verify packet can be read. 232 | p, err := serverConn.ReadPacket() 233 | assert.NoError(t, err) 234 | assert.Equal(t, int(2*oldMax), p.Content.Len()) 235 | 236 | <-doneCh 237 | }) 238 | 239 | t.Run("write", func(t *testing.T) { 240 | client, server, err := pair.New() 241 | require.NoError(t, err) 242 | 243 | clientConn := NewSync(client, logger) 244 | t.Cleanup(func() { clientConn.Close() }) 245 | 246 | serverConn := NewSync(server, logger) 247 | t.Cleanup(func() { serverConn.Close() }) 248 | 249 | doneCh := make(chan any) 250 | go func() { 251 | defer close(doneCh) 252 | 253 | p, err := serverConn.ReadPacket() 254 | assert.NoError(t, err) 255 | if err == nil { 256 | assert.Equal(t, int(2*oldMax), p.Content.Len()) 257 | } 258 | }() 259 | 260 | p := packet.Get() 261 | t.Cleanup(func() { 262 | p.Content.Reset() 263 | packet.Put(p) 264 | }) 265 | 266 | // Write a large packet that would exceed the default maximum limit. 267 | content := make([]byte, 2*oldMax) 268 | 269 | p.Metadata.Id = 64 270 | p.Metadata.Operation = 32 271 | p.Metadata.ContentLength = uint32(len(content)) 272 | p.Content.Write(content) 273 | 274 | err = clientConn.WritePacket(p) 275 | require.NoError(t, err) 276 | 277 | <-doneCh 278 | }) 279 | } 280 | 281 | func TestSyncRawConn(t *testing.T) { 282 | t.Parallel() 283 | 284 | const testSize = 100000 285 | const packetSize = 32 286 | 287 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 288 | 289 | reader, writer, err := pair.New() 290 | require.NoError(t, err) 291 | 292 | start := make(chan struct{}, 1) 293 | end := make(chan struct{}, 1) 294 | 295 | readerConn := NewSync(reader, emptyLogger) 296 | writerConn := NewSync(writer, emptyLogger) 297 | 298 | randomData := make([]byte, packetSize) 299 | _, _ = rand.Read(randomData) 300 | 301 | p := packet.Get() 302 | p.Metadata.Id = 64 303 | p.Metadata.Operation = 32 304 | p.Content.Write(randomData) 305 | p.Metadata.ContentLength = packetSize 306 | 307 | go func() { 308 | start <- struct{}{} 309 | for i := 0; i < testSize; i++ { 310 | p, err := readerConn.ReadPacket() 311 | assert.NoError(t, err) 312 | assert.NotNil(t, p.Metadata) 313 | assert.Equal(t, uint16(64), p.Metadata.Id) 314 | assert.Equal(t, uint16(32), p.Metadata.Operation) 315 | assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength) 316 | assert.Equal(t, packetSize, p.Content.Len()) 317 | expected := polyglot.NewBufferFromBytes(randomData) 318 | expected.MoveOffset(len(randomData)) 319 | assert.Equal(t, expected.Bytes(), p.Content.Bytes()) 320 | packet.Put(p) 321 | } 322 | end <- struct{}{} 323 | }() 324 | 325 | <-start 326 | for i := 0; i < testSize; i++ { 327 | err := writerConn.WritePacket(p) 328 | assert.NoError(t, err) 329 | } 330 | <-end 331 | 332 | rawReaderConn := readerConn.Raw() 333 | rawWriterConn := writerConn.Raw() 334 | 335 | rawWriteMessage := []byte("TEST CASE MESSAGE") 336 | 337 | written, err := rawReaderConn.Write(rawWriteMessage) 338 | assert.NoError(t, err) 339 | assert.Equal(t, len(rawWriteMessage), written) 340 | rawReadMessage := make([]byte, len(rawWriteMessage)) 341 | read, err := rawWriterConn.Read(rawReadMessage) 342 | assert.NoError(t, err) 343 | assert.Equal(t, len(rawWriteMessage), read) 344 | assert.Equal(t, rawWriteMessage, rawReadMessage) 345 | 346 | err = readerConn.Close() 347 | assert.NoError(t, err) 348 | err = writerConn.Close() 349 | assert.NoError(t, err) 350 | 351 | err = rawReaderConn.Close() 352 | assert.NoError(t, err) 353 | err = rawWriterConn.Close() 354 | assert.NoError(t, err) 355 | } 356 | 357 | func TestSyncInvalid(t *testing.T) { 358 | t.Parallel() 359 | 360 | client, server, err := pair.New() 361 | require.NoError(t, err) 362 | 363 | serverConn := NewSync(server, logging.Test(t, logging.Noop, t.Name())) 364 | t.Cleanup(func() { serverConn.Close() }) 365 | 366 | httpReq := []byte(`GET / HTTP/1.1 367 | Host: www.example.com 368 | User-Agent: curl/8.9.1 369 | Accept: */*`) 370 | _, err = client.Write(httpReq) 371 | assert.NoError(t, err) 372 | 373 | _, err = serverConn.ReadPacket() 374 | assert.ErrorIs(t, err, InvalidMagicHeader) 375 | assert.ErrorIs(t, serverConn.Error(), InvalidMagicHeader) 376 | } 377 | 378 | func TestSyncReadClose(t *testing.T) { 379 | t.Parallel() 380 | 381 | reader, writer := net.Pipe() 382 | 383 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 384 | 385 | readerConn := NewSync(reader, emptyLogger) 386 | writerConn := NewSync(writer, emptyLogger) 387 | 388 | p := packet.Get() 389 | p.Metadata.Id = 64 390 | p.Metadata.Operation = 32 391 | 392 | start := make(chan struct{}, 1) 393 | end := make(chan struct{}, 1) 394 | 395 | go func() { 396 | start <- struct{}{} 397 | p, err := readerConn.ReadPacket() 398 | assert.NoError(t, err) 399 | assert.NotNil(t, p.Metadata) 400 | assert.Equal(t, uint16(64), p.Metadata.Id) 401 | assert.Equal(t, uint16(32), p.Metadata.Operation) 402 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 403 | assert.Equal(t, 0, p.Content.Len()) 404 | end <- struct{}{} 405 | packet.Put(p) 406 | }() 407 | 408 | <-start 409 | err := writerConn.WritePacket(p) 410 | assert.NoError(t, err) 411 | <-end 412 | 413 | err = readerConn.conn.Close() 414 | assert.NoError(t, err) 415 | 416 | err = writerConn.WritePacket(p) 417 | assert.Error(t, err) 418 | assert.ErrorIs(t, writerConn.Error(), io.ErrClosedPipe) 419 | 420 | packet.Put(p) 421 | 422 | err = readerConn.Close() 423 | assert.NoError(t, err) 424 | err = writerConn.Close() 425 | assert.NoError(t, err) 426 | } 427 | 428 | func TestSyncWriteClose(t *testing.T) { 429 | t.Parallel() 430 | 431 | reader, writer := net.Pipe() 432 | 433 | emptyLogger := logging.Test(t, logging.Noop, t.Name()) 434 | 435 | readerConn := NewSync(reader, emptyLogger) 436 | writerConn := NewSync(writer, emptyLogger) 437 | 438 | p := packet.Get() 439 | p.Metadata.Id = 64 440 | p.Metadata.Operation = 32 441 | 442 | start := make(chan struct{}, 1) 443 | end := make(chan struct{}, 1) 444 | 445 | go func() { 446 | start <- struct{}{} 447 | p, err := readerConn.ReadPacket() 448 | assert.NoError(t, err) 449 | assert.NotNil(t, p.Metadata) 450 | assert.Equal(t, uint16(64), p.Metadata.Id) 451 | assert.Equal(t, uint16(32), p.Metadata.Operation) 452 | assert.Equal(t, uint32(0), p.Metadata.ContentLength) 453 | assert.Equal(t, 0, p.Content.Len()) 454 | packet.Put(p) 455 | end <- struct{}{} 456 | }() 457 | 458 | <-start 459 | err := writerConn.WritePacket(p) 460 | assert.NoError(t, err) 461 | <-end 462 | 463 | packet.Put(p) 464 | 465 | err = writerConn.conn.Close() 466 | assert.NoError(t, err) 467 | 468 | _, err = readerConn.ReadPacket() 469 | assert.ErrorIs(t, err, io.EOF) 470 | assert.ErrorIs(t, readerConn.Error(), io.EOF) 471 | 472 | err = readerConn.Close() 473 | assert.NoError(t, err) 474 | err = writerConn.Close() 475 | assert.NoError(t, err) 476 | } 477 | 478 | func BenchmarkSyncThroughputPipe(b *testing.B) { 479 | DisableMaxContentLength(b) 480 | 481 | const testSize = 100 482 | 483 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 484 | 485 | reader, writer := net.Pipe() 486 | 487 | readerConn := NewSync(reader, emptyLogger) 488 | writerConn := NewSync(writer, emptyLogger) 489 | 490 | b.Run("32 Bytes", throughputRunner(testSize, 32, readerConn, writerConn)) 491 | b.Run("512 Bytes", throughputRunner(testSize, 512, readerConn, writerConn)) 492 | b.Run("1024 Bytes", throughputRunner(testSize, 1024, readerConn, writerConn)) 493 | b.Run("2048 Bytes", throughputRunner(testSize, 2048, readerConn, writerConn)) 494 | b.Run("4096 Bytes", throughputRunner(testSize, 4096, readerConn, writerConn)) 495 | 496 | _ = readerConn.Close() 497 | _ = writerConn.Close() 498 | } 499 | 500 | func BenchmarkSyncThroughputNetwork(b *testing.B) { 501 | DisableMaxContentLength(b) 502 | 503 | const testSize = 100 504 | 505 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 506 | 507 | reader, writer, err := pair.New() 508 | if err != nil { 509 | b.Fatal(err) 510 | } 511 | 512 | readerConn := NewSync(reader, emptyLogger) 513 | writerConn := NewSync(writer, emptyLogger) 514 | 515 | b.Run("32 Bytes", throughputRunner(testSize, 32, readerConn, writerConn)) 516 | b.Run("512 Bytes", throughputRunner(testSize, 512, readerConn, writerConn)) 517 | b.Run("1024 Bytes", throughputRunner(testSize, 1024, readerConn, writerConn)) 518 | b.Run("2048 Bytes", throughputRunner(testSize, 2048, readerConn, writerConn)) 519 | b.Run("4096 Bytes", throughputRunner(testSize, 4096, readerConn, writerConn)) 520 | 521 | _ = readerConn.Close() 522 | _ = writerConn.Close() 523 | } 524 | -------------------------------------------------------------------------------- /testutil.go: -------------------------------------------------------------------------------- 1 | package frisbee 2 | 3 | import "testing" 4 | 5 | func DisableMaxContentLength(tb testing.TB) uint32 { 6 | tb.Helper() 7 | 8 | oldMax := DefaultMaxContentLength 9 | DefaultMaxContentLength = 0 10 | tb.Cleanup(func() { DefaultMaxContentLength = oldMax }) 11 | 12 | return oldMax 13 | } 14 | -------------------------------------------------------------------------------- /throughput_test.go: -------------------------------------------------------------------------------- 1 | //go:build !race 2 | 3 | // SPDX-License-Identifier: Apache-2.0 4 | 5 | package frisbee 6 | 7 | import ( 8 | "bufio" 9 | "io" 10 | "net" 11 | "testing" 12 | "time" 13 | 14 | "github.com/loopholelabs/logging" 15 | "github.com/loopholelabs/testing/conn/pair" 16 | ) 17 | 18 | func BenchmarkAsyncThroughputLarge(b *testing.B) { 19 | DisableMaxContentLength(b) 20 | 21 | const testSize = 100 22 | 23 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 24 | 25 | reader, writer, err := pair.New() 26 | if err != nil { 27 | b.Fatal(err) 28 | } 29 | 30 | readerConn := NewAsync(reader, emptyLogger) 31 | writerConn := NewAsync(writer, emptyLogger) 32 | 33 | b.Run("1MB", throughputRunner(testSize, 1<<20, readerConn, writerConn)) 34 | b.Run("2MB", throughputRunner(testSize, 1<<21, readerConn, writerConn)) 35 | b.Run("4MB", throughputRunner(testSize, 1<<22, readerConn, writerConn)) 36 | b.Run("8MB", throughputRunner(testSize, 1<<23, readerConn, writerConn)) 37 | b.Run("16MB", throughputRunner(testSize, 1<<24, readerConn, writerConn)) 38 | 39 | _ = readerConn.Close() 40 | _ = writerConn.Close() 41 | } 42 | 43 | func BenchmarkSyncThroughputLarge(b *testing.B) { 44 | DisableMaxContentLength(b) 45 | 46 | const testSize = 100 47 | 48 | emptyLogger := logging.Test(b, logging.Noop, b.Name()) 49 | 50 | reader, writer, err := pair.New() 51 | if err != nil { 52 | b.Fatal(err) 53 | } 54 | 55 | readerConn := NewSync(reader, emptyLogger) 56 | writerConn := NewSync(writer, emptyLogger) 57 | 58 | b.Run("1MB", throughputRunner(testSize, 1<<20, readerConn, writerConn)) 59 | b.Run("2MB", throughputRunner(testSize, 1<<21, readerConn, writerConn)) 60 | b.Run("4MB", throughputRunner(testSize, 1<<22, readerConn, writerConn)) 61 | b.Run("8MB", throughputRunner(testSize, 1<<23, readerConn, writerConn)) 62 | b.Run("16MB", throughputRunner(testSize, 1<<24, readerConn, writerConn)) 63 | 64 | _ = readerConn.Close() 65 | _ = writerConn.Close() 66 | } 67 | 68 | func BenchmarkTCPThroughput(b *testing.B) { 69 | DisableMaxContentLength(b) 70 | 71 | const testSize = 100 72 | 73 | reader, writer, err := pair.New() 74 | if err != nil { 75 | b.Fatal(err) 76 | } 77 | 78 | TCPThroughputRunner := func(testSize uint32, packetSize uint32, readerConn net.Conn, writerConn net.Conn) func(*testing.B) { 79 | bufWriter := bufio.NewWriter(writerConn) 80 | bufReader := bufio.NewReader(readerConn) 81 | return func(b *testing.B) { 82 | b.SetBytes(int64(testSize * packetSize)) 83 | b.ReportAllocs() 84 | var err error 85 | 86 | randomData := make([]byte, packetSize) 87 | readData := make([]byte, packetSize) 88 | b.ResetTimer() 89 | for i := 0; i < b.N; i++ { 90 | done := make(chan struct{}, 1) 91 | errCh := make(chan error, 1) 92 | go func() { 93 | for i := uint32(0); i < testSize; i++ { 94 | err := readerConn.SetReadDeadline(time.Now().Add(DefaultDeadline)) 95 | if err != nil { 96 | errCh <- err 97 | return 98 | } 99 | _, err = io.ReadAtLeast(bufReader, readData[0:], int(packetSize)) 100 | if err != nil { 101 | errCh <- err 102 | return 103 | } 104 | } 105 | done <- struct{}{} 106 | }() 107 | for i := uint32(0); i < testSize; i++ { 108 | select { 109 | case err = <-errCh: 110 | b.Fatal(err) 111 | default: 112 | err = writerConn.SetWriteDeadline(time.Now().Add(DefaultDeadline)) 113 | if err != nil { 114 | b.Fatal(err) 115 | } 116 | _, err = bufWriter.Write(randomData) 117 | if err != nil { 118 | b.Fatal(err) 119 | } 120 | } 121 | } 122 | err = writerConn.SetWriteDeadline(time.Now().Add(DefaultDeadline)) 123 | if err != nil { 124 | b.Fatal(err) 125 | } 126 | err = bufWriter.Flush() 127 | if err != nil { 128 | b.Fatal(err) 129 | } 130 | select { 131 | case <-done: 132 | continue 133 | case err := <-errCh: 134 | b.Fatal(err) 135 | } 136 | } 137 | b.StopTimer() 138 | } 139 | } 140 | 141 | b.Run("32 Bytes", TCPThroughputRunner(testSize, 32, reader, writer)) 142 | b.Run("512 Bytes", TCPThroughputRunner(testSize, 512, reader, writer)) 143 | b.Run("1024 Bytes", TCPThroughputRunner(testSize, 1024, reader, writer)) 144 | b.Run("2048 Bytes", TCPThroughputRunner(testSize, 2048, reader, writer)) 145 | b.Run("4096 Bytes", TCPThroughputRunner(testSize, 4096, reader, writer)) 146 | b.Run("1MB", TCPThroughputRunner(testSize, 1<<20, reader, writer)) 147 | b.Run("2MB", TCPThroughputRunner(testSize, 1<<21, reader, writer)) 148 | b.Run("4MB", TCPThroughputRunner(testSize, 1<<22, reader, writer)) 149 | b.Run("8MB", TCPThroughputRunner(testSize, 1<<23, reader, writer)) 150 | b.Run("16MB", TCPThroughputRunner(testSize, 1<<24, reader, writer)) 151 | 152 | _ = reader.Close() 153 | _ = writer.Close() 154 | } 155 | --------------------------------------------------------------------------------