├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ ├── daily.yml │ └── static.yml ├── LICENSE.txt ├── Makefile ├── README.md ├── accept.go ├── accept_test.go ├── autobahn_test.go ├── ci ├── bench.sh ├── fmt.sh ├── lint.sh ├── out │ └── .gitignore └── test.sh ├── close.go ├── close_test.go ├── compress.go ├── compress_test.go ├── conn.go ├── conn_test.go ├── dial.go ├── dial_test.go ├── doc.go ├── example_test.go ├── export_test.go ├── frame.go ├── frame_test.go ├── go.mod ├── go.sum ├── hijack.go ├── hijack_go120_test.go ├── internal ├── bpool │ └── bpool.go ├── errd │ └── wrap.go ├── examples │ ├── README.md │ ├── chat │ │ ├── README.md │ │ ├── chat.go │ │ ├── chat_test.go │ │ ├── index.css │ │ ├── index.html │ │ ├── index.js │ │ └── main.go │ ├── echo │ │ ├── README.md │ │ ├── main.go │ │ ├── server.go │ │ └── server_test.go │ ├── go.mod │ └── go.sum ├── test │ ├── assert │ │ └── assert.go │ ├── doc.go │ ├── wstest │ │ ├── echo.go │ │ └── pipe.go │ └── xrand │ │ └── xrand.go ├── thirdparty │ ├── doc.go │ ├── frame_test.go │ ├── gin_test.go │ ├── go.mod │ └── go.sum ├── util │ └── util.go ├── wsjs │ └── wsjs_js.go └── xsync │ ├── go.go │ └── go_test.go ├── main_test.go ├── mask.go ├── mask_amd64.s ├── mask_arm64.s ├── mask_asm.go ├── mask_asm_test.go ├── mask_go.go ├── mask_test.go ├── netconn.go ├── netconn_js.go ├── netconn_notjs.go ├── read.go ├── stringer.go ├── write.go ├── ws_js.go ├── ws_js_test.go └── wsjson ├── wsjson.go └── wsjson_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Track in case we ever add dependencies. 4 | - package-ecosystem: 'gomod' 5 | directory: '/' 6 | schedule: 7 | interval: 'weekly' 8 | commit-message: 9 | prefix: 'chore' 10 | 11 | # Keep example and test/benchmark deps up-to-date. 12 | - package-ecosystem: 'gomod' 13 | directories: 14 | - '/internal/examples' 15 | - '/internal/thirdparty' 16 | schedule: 17 | interval: 'monthly' 18 | commit-message: 19 | prefix: 'chore' 20 | labels: [] 21 | groups: 22 | internal-deps: 23 | patterns: 24 | - '*' 25 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | branches: 8 | - master 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | fmt: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-go@v5 19 | with: 20 | go-version-file: ./go.mod 21 | - run: make fmt 22 | 23 | lint: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | - run: go version 28 | - uses: actions/setup-go@v5 29 | with: 30 | go-version-file: ./go.mod 31 | - run: make lint 32 | 33 | test: 34 | runs-on: ubuntu-latest 35 | steps: 36 | - name: Disable AppArmor 37 | if: runner.os == 'Linux' 38 | run: | 39 | # Disable AppArmor for Ubuntu 23.10+. 40 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 41 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 42 | - uses: actions/checkout@v4 43 | - uses: actions/setup-go@v5 44 | with: 45 | go-version-file: ./go.mod 46 | - run: make test 47 | - uses: actions/upload-artifact@v4 48 | with: 49 | name: coverage.html 50 | path: ./ci/out/coverage.html 51 | 52 | bench: 53 | runs-on: ubuntu-latest 54 | steps: 55 | - uses: actions/checkout@v4 56 | - uses: actions/setup-go@v5 57 | with: 58 | go-version-file: ./go.mod 59 | - run: make bench 60 | -------------------------------------------------------------------------------- /.github/workflows/daily.yml: -------------------------------------------------------------------------------- 1 | name: daily 2 | on: 3 | workflow_dispatch: 4 | schedule: 5 | - cron: '42 0 * * *' # daily at 00:42 6 | concurrency: 7 | group: ${{ github.workflow }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | bench: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-go@v5 16 | with: 17 | go-version-file: ./go.mod 18 | - run: AUTOBAHN=1 make bench 19 | test: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Disable AppArmor 23 | if: runner.os == 'Linux' 24 | run: | 25 | # Disable AppArmor for Ubuntu 23.10+. 26 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 27 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 28 | - uses: actions/checkout@v4 29 | - uses: actions/setup-go@v5 30 | with: 31 | go-version-file: ./go.mod 32 | - run: AUTOBAHN=1 make test 33 | - uses: actions/upload-artifact@v4 34 | with: 35 | name: coverage.html 36 | path: ./ci/out/coverage.html 37 | bench-dev: 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v4 41 | with: 42 | ref: dev 43 | - uses: actions/setup-go@v5 44 | with: 45 | go-version-file: ./go.mod 46 | - run: AUTOBAHN=1 make bench 47 | test-dev: 48 | runs-on: ubuntu-latest 49 | steps: 50 | - name: Disable AppArmor 51 | if: runner.os == 'Linux' 52 | run: | 53 | # Disable AppArmor for Ubuntu 23.10+. 54 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 55 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 56 | - uses: actions/checkout@v4 57 | with: 58 | ref: dev 59 | - uses: actions/setup-go@v5 60 | with: 61 | go-version-file: ./go.mod 62 | - run: AUTOBAHN=1 make test 63 | - uses: actions/upload-artifact@v4 64 | with: 65 | name: coverage-dev.html 66 | path: ./ci/out/coverage.html 67 | -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | name: static 2 | 3 | on: 4 | push: 5 | branches: ['master'] 6 | workflow_dispatch: 7 | 8 | # Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages. 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | concurrency: 15 | group: pages 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | deploy: 20 | environment: 21 | name: github-pages 22 | url: ${{ steps.deployment.outputs.page_url }} 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Disable AppArmor 26 | if: runner.os == 'Linux' 27 | run: | 28 | # Disable AppArmor for Ubuntu 23.10+. 29 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 30 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 31 | - name: Checkout 32 | uses: actions/checkout@v4 33 | - name: Setup Pages 34 | uses: actions/configure-pages@v5 35 | - name: Setup Go 36 | uses: actions/setup-go@v5 37 | with: 38 | go-version-file: ./go.mod 39 | - name: Generate coverage and badge 40 | run: | 41 | make test 42 | mkdir -p ./ci/out/static 43 | cp ./ci/out/coverage.html ./ci/out/static/coverage.html 44 | percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%') 45 | wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success" 46 | - name: Upload artifact 47 | uses: actions/upload-pages-artifact@v3 48 | with: 49 | path: ./ci/out/static/ 50 | - name: Deploy to GitHub Pages 51 | id: deployment 52 | uses: actions/deploy-pages@v4 53 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Coder 2 | 3 | Permission to use, copy, modify, and distribute this software for any 4 | purpose with or without fee is hereby granted, provided that the above 5 | copyright notice and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all 2 | all: fmt lint test 3 | 4 | .PHONY: fmt 5 | fmt: 6 | ./ci/fmt.sh 7 | 8 | .PHONY: lint 9 | lint: 10 | ./ci/lint.sh 11 | 12 | .PHONY: test 13 | test: 14 | ./ci/test.sh 15 | 16 | .PHONY: bench 17 | bench: 18 | ./ci/bench.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # websocket 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket) 4 | [![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html) 5 | 6 | websocket is a minimal and idiomatic WebSocket library for Go. 7 | 8 | ## Install 9 | 10 | ```sh 11 | go get github.com/coder/websocket 12 | ``` 13 | 14 | > [!NOTE] 15 | > Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket). 16 | > We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from 17 | > 2019 to 2024. 18 | 19 | ## Highlights 20 | 21 | - Minimal and idiomatic API 22 | - First class [context.Context](https://blog.golang.org/context) support 23 | - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) 24 | - [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports) 25 | - JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage 26 | - Zero alloc reads and writes 27 | - Concurrent writes 28 | - [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close) 29 | - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper 30 | - [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API 31 | - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression 32 | - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections 33 | - Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm) 34 | 35 | ## Roadmap 36 | 37 | See GitHub issues for minor issues but the major future enhancements are: 38 | 39 | - [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217) 40 | - [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340) 41 | - [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267) 42 | - [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246) 43 | - [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209) 44 | - [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16) 45 | - WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster 46 | - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) 47 | - [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402) 48 | 49 | ## Examples 50 | 51 | For a production quality example that demonstrates the complete API, see the 52 | [echo example](./internal/examples/echo). 53 | 54 | For a full stack example, see the [chat example](./internal/examples/chat). 55 | 56 | ### Server 57 | 58 | ```go 59 | http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { 60 | c, err := websocket.Accept(w, r, nil) 61 | if err != nil { 62 | // ... 63 | } 64 | defer c.CloseNow() 65 | 66 | // Set the context as needed. Use of r.Context() is not recommended 67 | // to avoid surprising behavior (see http.Hijacker). 68 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 69 | defer cancel() 70 | 71 | var v interface{} 72 | err = wsjson.Read(ctx, c, &v) 73 | if err != nil { 74 | // ... 75 | } 76 | 77 | log.Printf("received: %v", v) 78 | 79 | c.Close(websocket.StatusNormalClosure, "") 80 | }) 81 | ``` 82 | 83 | ### Client 84 | 85 | ```go 86 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 87 | defer cancel() 88 | 89 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 90 | if err != nil { 91 | // ... 92 | } 93 | defer c.CloseNow() 94 | 95 | err = wsjson.Write(ctx, c, "hi") 96 | if err != nil { 97 | // ... 98 | } 99 | 100 | c.Close(websocket.StatusNormalClosure, "") 101 | ``` 102 | 103 | ## Comparison 104 | 105 | ### gorilla/websocket 106 | 107 | Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): 108 | 109 | - Mature and widely used 110 | - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) 111 | - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) 112 | - No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection. 113 | - Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411) 114 | 115 | Advantages of github.com/coder/websocket: 116 | 117 | - Minimal and idiomatic API 118 | - Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. 119 | - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper 120 | - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) 121 | - Full [context.Context](https://blog.golang.org/context) support 122 | - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) 123 | - Will enable easy HTTP/2 support in the future 124 | - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. 125 | - Concurrent writes 126 | - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) 127 | - Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API 128 | - Gorilla requires registering a pong callback before sending a Ping 129 | - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) 130 | - Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage 131 | - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go 132 | - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). 133 | Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326) 134 | - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support 135 | - Gorilla only supports no context takeover mode 136 | - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) 137 | 138 | #### golang.org/x/net/websocket 139 | 140 | [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. 141 | See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). 142 | 143 | The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning 144 | to github.com/coder/websocket. 145 | 146 | #### gobwas/ws 147 | 148 | [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used 149 | in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). 150 | 151 | However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws 152 | 153 | When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. 154 | 155 | #### lesismal/nbio 156 | 157 | [lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is 158 | event driven for performance reasons. 159 | 160 | However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio 161 | 162 | When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. 163 | -------------------------------------------------------------------------------- /accept.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "crypto/sha1" 10 | "encoding/base64" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "log" 15 | "net/http" 16 | "net/textproto" 17 | "net/url" 18 | "path" 19 | "strings" 20 | 21 | "github.com/coder/websocket/internal/errd" 22 | ) 23 | 24 | // AcceptOptions represents Accept's options. 25 | type AcceptOptions struct { 26 | // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. 27 | // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to 28 | // reject it, close the connection when c.Subprotocol() == "". 29 | Subprotocols []string 30 | 31 | // InsecureSkipVerify is used to disable Accept's origin verification behaviour. 32 | // 33 | // You probably want to use OriginPatterns instead. 34 | InsecureSkipVerify bool 35 | 36 | // OriginPatterns lists the host patterns for authorized origins. 37 | // The request host is always authorized. 38 | // Use this to enable cross origin WebSockets. 39 | // 40 | // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. 41 | // In such a case, example.com is the origin and chat.example.com is the request host. 42 | // One would set this field to []string{"example.com"} to authorize example.com to connect. 43 | // 44 | // Each pattern is matched case insensitively against the request origin host 45 | // with path.Match. 46 | // See https://golang.org/pkg/path/#Match 47 | // 48 | // Please ensure you understand the ramifications of enabling this. 49 | // If used incorrectly your WebSocket server will be open to CSRF attacks. 50 | // 51 | // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead 52 | // to bring attention to the danger of such a setting. 53 | OriginPatterns []string 54 | 55 | // CompressionMode controls the compression mode. 56 | // Defaults to CompressionDisabled. 57 | // 58 | // See docs on CompressionMode for details. 59 | CompressionMode CompressionMode 60 | 61 | // CompressionThreshold controls the minimum size of a message before compression is applied. 62 | // 63 | // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes 64 | // for CompressionContextTakeover. 65 | CompressionThreshold int 66 | 67 | // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. 68 | // 69 | // The payload contains the application data of the ping frame. 70 | // If the callback returns false, the subsequent pong frame will not be sent. 71 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 72 | OnPingReceived func(ctx context.Context, payload []byte) bool 73 | 74 | // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. 75 | // 76 | // The payload contains the application data of the pong frame. 77 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 78 | // 79 | // Unlike OnPingReceived, this callback does not return a value because a pong frame 80 | // is a response to a ping and does not trigger any further frame transmission. 81 | OnPongReceived func(ctx context.Context, payload []byte) 82 | } 83 | 84 | func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { 85 | var o AcceptOptions 86 | if opts != nil { 87 | o = *opts 88 | } 89 | return &o 90 | } 91 | 92 | // Accept accepts a WebSocket handshake from a client and upgrades the 93 | // the connection to a WebSocket. 94 | // 95 | // Accept will not allow cross origin requests by default. 96 | // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. 97 | // 98 | // Accept will write a response to w on all errors. 99 | // 100 | // Note that using the http.Request Context after Accept returns may lead to 101 | // unexpected behavior (see http.Hijacker). 102 | func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { 103 | return accept(w, r, opts) 104 | } 105 | 106 | func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { 107 | defer errd.Wrap(&err, "failed to accept WebSocket connection") 108 | 109 | errCode, err := verifyClientRequest(w, r) 110 | if err != nil { 111 | http.Error(w, err.Error(), errCode) 112 | return nil, err 113 | } 114 | 115 | opts = opts.cloneWithDefaults() 116 | if !opts.InsecureSkipVerify { 117 | err = authenticateOrigin(r, opts.OriginPatterns) 118 | if err != nil { 119 | if errors.Is(err, path.ErrBadPattern) { 120 | log.Printf("websocket: %v", err) 121 | err = errors.New(http.StatusText(http.StatusForbidden)) 122 | } 123 | http.Error(w, err.Error(), http.StatusForbidden) 124 | return nil, err 125 | } 126 | } 127 | 128 | hj, ok := hijacker(w) 129 | if !ok { 130 | err = errors.New("http.ResponseWriter does not implement http.Hijacker") 131 | http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) 132 | return nil, err 133 | } 134 | 135 | w.Header().Set("Upgrade", "websocket") 136 | w.Header().Set("Connection", "Upgrade") 137 | 138 | key := r.Header.Get("Sec-WebSocket-Key") 139 | w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) 140 | 141 | subproto := selectSubprotocol(r, opts.Subprotocols) 142 | if subproto != "" { 143 | w.Header().Set("Sec-WebSocket-Protocol", subproto) 144 | } 145 | 146 | copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) 147 | if ok { 148 | w.Header().Set("Sec-WebSocket-Extensions", copts.String()) 149 | } 150 | 151 | w.WriteHeader(http.StatusSwitchingProtocols) 152 | // See https://github.com/nhooyr/websocket/issues/166 153 | if ginWriter, ok := w.(interface { 154 | WriteHeaderNow() 155 | }); ok { 156 | ginWriter.WriteHeaderNow() 157 | } 158 | 159 | netConn, brw, err := hj.Hijack() 160 | if err != nil { 161 | err = fmt.Errorf("failed to hijack connection: %w", err) 162 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 163 | return nil, err 164 | } 165 | 166 | // https://github.com/golang/go/issues/32314 167 | b, _ := brw.Reader.Peek(brw.Reader.Buffered()) 168 | brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) 169 | 170 | return newConn(connConfig{ 171 | subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), 172 | rwc: netConn, 173 | client: false, 174 | copts: copts, 175 | flateThreshold: opts.CompressionThreshold, 176 | onPingReceived: opts.OnPingReceived, 177 | onPongReceived: opts.OnPongReceived, 178 | 179 | br: brw.Reader, 180 | bw: brw.Writer, 181 | }), nil 182 | } 183 | 184 | func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { 185 | if !r.ProtoAtLeast(1, 1) { 186 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) 187 | } 188 | 189 | if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { 190 | w.Header().Set("Connection", "Upgrade") 191 | w.Header().Set("Upgrade", "websocket") 192 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) 193 | } 194 | 195 | if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { 196 | w.Header().Set("Connection", "Upgrade") 197 | w.Header().Set("Upgrade", "websocket") 198 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) 199 | } 200 | 201 | if r.Method != "GET" { 202 | return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) 203 | } 204 | 205 | if r.Header.Get("Sec-WebSocket-Version") != "13" { 206 | w.Header().Set("Sec-WebSocket-Version", "13") 207 | return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) 208 | } 209 | 210 | websocketSecKeys := r.Header.Values("Sec-WebSocket-Key") 211 | if len(websocketSecKeys) == 0 { 212 | return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") 213 | } 214 | 215 | if len(websocketSecKeys) > 1 { 216 | return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers") 217 | } 218 | 219 | // The RFC states to remove any leading or trailing whitespace. 220 | websocketSecKey := strings.TrimSpace(websocketSecKeys[0]) 221 | if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 { 222 | return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey) 223 | } 224 | 225 | return 0, nil 226 | } 227 | 228 | func authenticateOrigin(r *http.Request, originHosts []string) error { 229 | origin := r.Header.Get("Origin") 230 | if origin == "" { 231 | return nil 232 | } 233 | 234 | u, err := url.Parse(origin) 235 | if err != nil { 236 | return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) 237 | } 238 | 239 | if strings.EqualFold(r.Host, u.Host) { 240 | return nil 241 | } 242 | 243 | for _, hostPattern := range originHosts { 244 | matched, err := match(hostPattern, u.Host) 245 | if err != nil { 246 | return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err) 247 | } 248 | if matched { 249 | return nil 250 | } 251 | } 252 | if u.Host == "" { 253 | return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) 254 | } 255 | return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) 256 | } 257 | 258 | func match(pattern, s string) (bool, error) { 259 | return path.Match(strings.ToLower(pattern), strings.ToLower(s)) 260 | } 261 | 262 | func selectSubprotocol(r *http.Request, subprotocols []string) string { 263 | cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") 264 | for _, sp := range subprotocols { 265 | for _, cp := range cps { 266 | if strings.EqualFold(sp, cp) { 267 | return cp 268 | } 269 | } 270 | } 271 | return "" 272 | } 273 | 274 | func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { 275 | if mode == CompressionDisabled { 276 | return nil, false 277 | } 278 | for _, ext := range extensions { 279 | switch ext.name { 280 | // We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs... 281 | // See https://github.com/nhooyr/websocket/issues/218 282 | case "permessage-deflate": 283 | copts, ok := acceptDeflate(ext, mode) 284 | if ok { 285 | return copts, true 286 | } 287 | } 288 | } 289 | return nil, false 290 | } 291 | 292 | func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { 293 | copts := mode.opts() 294 | for _, p := range ext.params { 295 | switch p { 296 | case "client_no_context_takeover": 297 | copts.clientNoContextTakeover = true 298 | continue 299 | case "server_no_context_takeover": 300 | copts.serverNoContextTakeover = true 301 | continue 302 | case "client_max_window_bits", 303 | "server_max_window_bits=15": 304 | continue 305 | } 306 | 307 | if strings.HasPrefix(p, "client_max_window_bits=") { 308 | // We can't adjust the deflate window, but decoding with a larger window is acceptable. 309 | continue 310 | } 311 | return nil, false 312 | } 313 | return copts, true 314 | } 315 | 316 | func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { 317 | for _, t := range headerTokens(h, key) { 318 | if strings.EqualFold(t, token) { 319 | return true 320 | } 321 | } 322 | return false 323 | } 324 | 325 | type websocketExtension struct { 326 | name string 327 | params []string 328 | } 329 | 330 | func websocketExtensions(h http.Header) []websocketExtension { 331 | var exts []websocketExtension 332 | extStrs := headerTokens(h, "Sec-WebSocket-Extensions") 333 | for _, extStr := range extStrs { 334 | if extStr == "" { 335 | continue 336 | } 337 | 338 | vals := strings.Split(extStr, ";") 339 | for i := range vals { 340 | vals[i] = strings.TrimSpace(vals[i]) 341 | } 342 | 343 | e := websocketExtension{ 344 | name: vals[0], 345 | params: vals[1:], 346 | } 347 | 348 | exts = append(exts, e) 349 | } 350 | return exts 351 | } 352 | 353 | func headerTokens(h http.Header, key string) []string { 354 | key = textproto.CanonicalMIMEHeaderKey(key) 355 | var tokens []string 356 | for _, v := range h[key] { 357 | v = strings.TrimSpace(v) 358 | for _, t := range strings.Split(v, ",") { 359 | t = strings.TrimSpace(t) 360 | tokens = append(tokens, t) 361 | } 362 | } 363 | return tokens 364 | } 365 | 366 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 367 | 368 | func secWebSocketAccept(secWebSocketKey string) string { 369 | h := sha1.New() 370 | h.Write([]byte(secWebSocketKey)) 371 | h.Write(keyGUID) 372 | 373 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) 374 | } 375 | -------------------------------------------------------------------------------- /autobahn_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket_test 5 | 6 | import ( 7 | "context" 8 | "encoding/json" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "os" 14 | "os/exec" 15 | "strconv" 16 | "strings" 17 | "testing" 18 | "time" 19 | 20 | "github.com/coder/websocket" 21 | "github.com/coder/websocket/internal/errd" 22 | "github.com/coder/websocket/internal/test/assert" 23 | "github.com/coder/websocket/internal/test/wstest" 24 | "github.com/coder/websocket/internal/util" 25 | ) 26 | 27 | var excludedAutobahnCases = []string{ 28 | // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just 29 | // more performance overhead. 30 | "6.*", "7.5.1", 31 | 32 | // We skip the tests related to requestMaxWindowBits as that is unimplemented due 33 | // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 34 | "13.3.*", "13.4.*", "13.5.*", "13.6.*", 35 | } 36 | 37 | var autobahnCases = []string{"*"} 38 | 39 | // Used to run individual test cases. autobahnCases runs only those cases matched 40 | // and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases 41 | // is niled. 42 | var onlyAutobahnCases = []string{} 43 | 44 | func TestAutobahn(t *testing.T) { 45 | t.Parallel() 46 | 47 | if os.Getenv("AUTOBAHN") == "" { 48 | t.SkipNow() 49 | } 50 | 51 | if os.Getenv("AUTOBAHN") == "fast" { 52 | // These are the slow tests. 53 | excludedAutobahnCases = append(excludedAutobahnCases, 54 | "9.*", "12.*", "13.*", 55 | ) 56 | } 57 | 58 | if len(onlyAutobahnCases) > 0 { 59 | excludedAutobahnCases = []string{} 60 | autobahnCases = onlyAutobahnCases 61 | } 62 | 63 | ctx, cancel := context.WithTimeout(context.Background(), time.Hour) 64 | defer cancel() 65 | 66 | wstestURL, closeFn, err := wstestServer(t, ctx) 67 | assert.Success(t, err) 68 | defer func() { 69 | assert.Success(t, closeFn()) 70 | }() 71 | 72 | err = waitWS(ctx, wstestURL) 73 | assert.Success(t, err) 74 | 75 | cases, err := wstestCaseCount(ctx, wstestURL) 76 | assert.Success(t, err) 77 | 78 | t.Run("cases", func(t *testing.T) { 79 | for i := 1; i <= cases; i++ { 80 | i := i 81 | t.Run("", func(t *testing.T) { 82 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) 83 | defer cancel() 84 | 85 | c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{ 86 | CompressionMode: websocket.CompressionContextTakeover, 87 | }) 88 | assert.Success(t, err) 89 | err = wstest.EchoLoop(ctx, c) 90 | t.Logf("echoLoop: %v", err) 91 | }) 92 | } 93 | }) 94 | 95 | c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil) 96 | assert.Success(t, err) 97 | c.Close(websocket.StatusNormalClosure, "") 98 | 99 | checkWSTestIndex(t, "./ci/out/autobahn-report/index.json") 100 | } 101 | 102 | func waitWS(ctx context.Context, url string) error { 103 | ctx, cancel := context.WithTimeout(ctx, time.Second*5) 104 | defer cancel() 105 | 106 | for ctx.Err() == nil { 107 | c, _, err := websocket.Dial(ctx, url, nil) 108 | if err != nil { 109 | continue 110 | } 111 | c.Close(websocket.StatusNormalClosure, "") 112 | return nil 113 | } 114 | 115 | return ctx.Err() 116 | } 117 | 118 | func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) { 119 | defer errd.Wrap(&err, "failed to start autobahn wstest server") 120 | 121 | serverAddr, err := unusedListenAddr() 122 | if err != nil { 123 | return "", nil, err 124 | } 125 | _, serverPort, err := net.SplitHostPort(serverAddr) 126 | if err != nil { 127 | return "", nil, err 128 | } 129 | 130 | url = "ws://" + serverAddr 131 | const outDir = "ci/out/autobahn-report" 132 | 133 | specFile, err := tempJSONFile(map[string]interface{}{ 134 | "url": url, 135 | "outdir": outDir, 136 | "cases": autobahnCases, 137 | "exclude-cases": excludedAutobahnCases, 138 | }) 139 | if err != nil { 140 | return "", nil, fmt.Errorf("failed to write spec: %w", err) 141 | } 142 | 143 | ctx, cancel := context.WithTimeout(ctx, time.Hour) 144 | defer func() { 145 | if err != nil { 146 | cancel() 147 | } 148 | }() 149 | 150 | dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite") 151 | dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) { 152 | tb.Log(string(p)) 153 | return len(p), nil 154 | }) 155 | dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) { 156 | tb.Log(string(p)) 157 | return len(p), nil 158 | }) 159 | tb.Log(dockerPull) 160 | err = dockerPull.Run() 161 | if err != nil { 162 | return "", nil, fmt.Errorf("failed to pull docker image: %w", err) 163 | } 164 | 165 | wd, err := os.Getwd() 166 | if err != nil { 167 | return "", nil, err 168 | } 169 | 170 | var args []string 171 | args = append(args, "run", "-i", "--rm", 172 | "-v", fmt.Sprintf("%s:%[1]s", specFile), 173 | "-v", fmt.Sprintf("%s/ci:/ci", wd), 174 | fmt.Sprintf("-p=%s:%s", serverAddr, serverPort), 175 | "crossbario/autobahn-testsuite", 176 | ) 177 | args = append(args, "wstest", "--mode", "fuzzingserver", "--spec", specFile, 178 | // Disables some server that runs as part of fuzzingserver mode. 179 | // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 180 | "--webport=0", 181 | ) 182 | wstest := exec.CommandContext(ctx, "docker", args...) 183 | wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) { 184 | tb.Log(string(p)) 185 | return len(p), nil 186 | }) 187 | wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) { 188 | tb.Log(string(p)) 189 | return len(p), nil 190 | }) 191 | tb.Log(wstest) 192 | err = wstest.Start() 193 | if err != nil { 194 | return "", nil, fmt.Errorf("failed to start wstest: %w", err) 195 | } 196 | 197 | return url, func() error { 198 | err = wstest.Process.Kill() 199 | if err != nil { 200 | return fmt.Errorf("failed to kill wstest: %w", err) 201 | } 202 | err = wstest.Wait() 203 | var ee *exec.ExitError 204 | if errors.As(err, &ee) && ee.ExitCode() == -1 { 205 | return nil 206 | } 207 | return err 208 | }, nil 209 | } 210 | 211 | func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { 212 | defer errd.Wrap(&err, "failed to get case count") 213 | 214 | c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) 215 | if err != nil { 216 | return 0, err 217 | } 218 | defer c.Close(websocket.StatusInternalError, "") 219 | 220 | _, r, err := c.Reader(ctx) 221 | if err != nil { 222 | return 0, err 223 | } 224 | b, err := io.ReadAll(r) 225 | if err != nil { 226 | return 0, err 227 | } 228 | cases, err = strconv.Atoi(string(b)) 229 | if err != nil { 230 | return 0, err 231 | } 232 | 233 | c.Close(websocket.StatusNormalClosure, "") 234 | 235 | return cases, nil 236 | } 237 | 238 | func checkWSTestIndex(t *testing.T, path string) { 239 | wstestOut, err := os.ReadFile(path) 240 | assert.Success(t, err) 241 | 242 | var indexJSON map[string]map[string]struct { 243 | Behavior string `json:"behavior"` 244 | BehaviorClose string `json:"behaviorClose"` 245 | } 246 | err = json.Unmarshal(wstestOut, &indexJSON) 247 | assert.Success(t, err) 248 | 249 | for _, tests := range indexJSON { 250 | for test, result := range tests { 251 | t.Run(test, func(t *testing.T) { 252 | switch result.BehaviorClose { 253 | case "OK", "INFORMATIONAL": 254 | default: 255 | t.Errorf("bad close behaviour") 256 | } 257 | 258 | switch result.Behavior { 259 | case "OK", "NON-STRICT", "INFORMATIONAL": 260 | default: 261 | t.Errorf("failed") 262 | } 263 | }) 264 | } 265 | } 266 | 267 | if t.Failed() { 268 | htmlPath := strings.Replace(path, ".json", ".html", 1) 269 | t.Errorf("detected autobahn violation, see %q", htmlPath) 270 | } 271 | } 272 | 273 | func unusedListenAddr() (_ string, err error) { 274 | defer errd.Wrap(&err, "failed to get unused listen address") 275 | l, err := net.Listen("tcp", "localhost:0") 276 | if err != nil { 277 | return "", err 278 | } 279 | l.Close() 280 | return l.Addr().String(), nil 281 | } 282 | 283 | func tempJSONFile(v interface{}) (string, error) { 284 | f, err := os.CreateTemp("", "temp.json") 285 | if err != nil { 286 | return "", fmt.Errorf("temp file: %w", err) 287 | } 288 | defer f.Close() 289 | 290 | e := json.NewEncoder(f) 291 | e.SetIndent("", "\t") 292 | err = e.Encode(v) 293 | if err != nil { 294 | return "", fmt.Errorf("json encode: %w", err) 295 | } 296 | 297 | err = f.Close() 298 | if err != nil { 299 | return "", fmt.Errorf("close temp file: %w", err) 300 | } 301 | 302 | return f.Name(), nil 303 | } 304 | -------------------------------------------------------------------------------- /ci/bench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | go test --run=^$ --bench=. --benchmem "$@" ./... 6 | # For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test 7 | ( 8 | cd ./internal/thirdparty 9 | go test --run=^$ --bench=. --benchmem "$@" . 10 | 11 | GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" . 12 | if [ "$#" -eq 0 ]; then 13 | if [ "${CI-}" ]; then 14 | sudo apt-get update 15 | sudo apt-get install -y qemu-user-static 16 | ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 17 | fi 18 | qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem 19 | fi 20 | ) 21 | -------------------------------------------------------------------------------- /ci/fmt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | X_TOOLS_VERSION=v0.31.0 6 | 7 | go mod tidy 8 | (cd ./internal/thirdparty && go mod tidy) 9 | (cd ./internal/examples && go mod tidy) 10 | gofmt -w -s . 11 | go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" . 12 | 13 | git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \ 14 | --check \ 15 | --log-level=warn \ 16 | --print-width=90 \ 17 | --no-semi \ 18 | --single-quote \ 19 | --arrow-parens=avoid 20 | 21 | go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go 22 | 23 | if [ "${CI-}" ]; then 24 | git diff --exit-code 25 | fi 26 | -------------------------------------------------------------------------------- /ci/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | STATICCHECK_VERSION=v0.6.1 6 | GOVULNCHECK_VERSION=v1.1.4 7 | 8 | go vet ./... 9 | GOOS=js GOARCH=wasm go vet ./... 10 | 11 | go install honnef.co/go/tools/cmd/staticcheck@${STATICCHECK_VERSION} 12 | staticcheck ./... 13 | GOOS=js GOARCH=wasm staticcheck ./... 14 | 15 | govulncheck() { 16 | tmpf=$(mktemp) 17 | if ! command govulncheck "$@" >"$tmpf" 2>&1; then 18 | cat "$tmpf" 19 | fi 20 | } 21 | go install golang.org/x/vuln/cmd/govulncheck@${GOVULNCHECK_VERSION} 22 | govulncheck ./... 23 | GOOS=js GOARCH=wasm govulncheck ./... 24 | 25 | ( 26 | cd ./internal/examples 27 | go vet ./... 28 | staticcheck ./... 29 | govulncheck ./... 30 | ) 31 | ( 32 | cd ./internal/thirdparty 33 | go vet ./... 34 | staticcheck ./... 35 | govulncheck ./... 36 | ) 37 | -------------------------------------------------------------------------------- /ci/out/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /ci/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | ( 6 | cd ./internal/examples 7 | go test "$@" ./... 8 | ) 9 | ( 10 | cd ./internal/thirdparty 11 | go test "$@" ./... 12 | ) 13 | 14 | ( 15 | GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" . 16 | if [ "$#" -eq 0 ]; then 17 | if [ "${CI-}" ]; then 18 | sudo apt-get update 19 | sudo apt-get install -y qemu-user-static 20 | ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 21 | fi 22 | qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask 23 | fi 24 | ) 25 | 26 | 27 | go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce 28 | go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... 29 | sed -i.bak '/stringer\.go/d' ci/out/coverage.prof 30 | sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof 31 | sed -i.bak '/examples/d' ci/out/coverage.prof 32 | 33 | # Last line is the total coverage. 34 | go tool cover -func ci/out/coverage.prof | tail -n1 35 | 36 | go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html 37 | -------------------------------------------------------------------------------- /close.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "context" 8 | "encoding/binary" 9 | "errors" 10 | "fmt" 11 | "net" 12 | "time" 13 | 14 | "github.com/coder/websocket/internal/errd" 15 | ) 16 | 17 | // StatusCode represents a WebSocket status code. 18 | // https://tools.ietf.org/html/rfc6455#section-7.4 19 | type StatusCode int 20 | 21 | // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number 22 | // 23 | // These are only the status codes defined by the protocol. 24 | // 25 | // You can define custom codes in the 3000-4999 range. 26 | // The 3000-3999 range is reserved for use by libraries, frameworks and applications. 27 | // The 4000-4999 range is reserved for private use. 28 | const ( 29 | StatusNormalClosure StatusCode = 1000 30 | StatusGoingAway StatusCode = 1001 31 | StatusProtocolError StatusCode = 1002 32 | StatusUnsupportedData StatusCode = 1003 33 | 34 | // 1004 is reserved and so unexported. 35 | statusReserved StatusCode = 1004 36 | 37 | // StatusNoStatusRcvd cannot be sent in a close message. 38 | // It is reserved for when a close message is received without 39 | // a status code. 40 | StatusNoStatusRcvd StatusCode = 1005 41 | 42 | // StatusAbnormalClosure is exported for use only with Wasm. 43 | // In non Wasm Go, the returned error will indicate whether the 44 | // connection was closed abnormally. 45 | StatusAbnormalClosure StatusCode = 1006 46 | 47 | StatusInvalidFramePayloadData StatusCode = 1007 48 | StatusPolicyViolation StatusCode = 1008 49 | StatusMessageTooBig StatusCode = 1009 50 | StatusMandatoryExtension StatusCode = 1010 51 | StatusInternalError StatusCode = 1011 52 | StatusServiceRestart StatusCode = 1012 53 | StatusTryAgainLater StatusCode = 1013 54 | StatusBadGateway StatusCode = 1014 55 | 56 | // StatusTLSHandshake is only exported for use with Wasm. 57 | // In non Wasm Go, the returned error will indicate whether there was 58 | // a TLS handshake failure. 59 | StatusTLSHandshake StatusCode = 1015 60 | ) 61 | 62 | // CloseError is returned when the connection is closed with a status and reason. 63 | // 64 | // Use Go 1.13's errors.As to check for this error. 65 | // Also see the CloseStatus helper. 66 | type CloseError struct { 67 | Code StatusCode 68 | Reason string 69 | } 70 | 71 | func (ce CloseError) Error() string { 72 | return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) 73 | } 74 | 75 | // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab 76 | // the status code from a CloseError. 77 | // 78 | // -1 will be returned if the passed error is nil or not a CloseError. 79 | func CloseStatus(err error) StatusCode { 80 | var ce CloseError 81 | if errors.As(err, &ce) { 82 | return ce.Code 83 | } 84 | return -1 85 | } 86 | 87 | // Close performs the WebSocket close handshake with the given status code and reason. 88 | // 89 | // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for 90 | // the peer to send a close frame. 91 | // All data messages received from the peer during the close handshake will be discarded. 92 | // 93 | // The connection can only be closed once. Additional calls to Close 94 | // are no-ops. 95 | // 96 | // The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason. 97 | // 98 | // Close will unblock all goroutines interacting with the connection once 99 | // complete. 100 | func (c *Conn) Close(code StatusCode, reason string) (err error) { 101 | defer errd.Wrap(&err, "failed to close WebSocket") 102 | 103 | if c.casClosing() { 104 | err = c.waitGoroutines() 105 | if err != nil { 106 | return err 107 | } 108 | return net.ErrClosed 109 | } 110 | defer func() { 111 | if errors.Is(err, net.ErrClosed) { 112 | err = nil 113 | } 114 | }() 115 | 116 | err = c.closeHandshake(code, reason) 117 | 118 | err2 := c.close() 119 | if err == nil && err2 != nil { 120 | err = err2 121 | } 122 | 123 | err2 = c.waitGoroutines() 124 | if err == nil && err2 != nil { 125 | err = err2 126 | } 127 | 128 | return err 129 | } 130 | 131 | // CloseNow closes the WebSocket connection without attempting a close handshake. 132 | // Use when you do not want the overhead of the close handshake. 133 | func (c *Conn) CloseNow() (err error) { 134 | defer errd.Wrap(&err, "failed to immediately close WebSocket") 135 | 136 | if c.casClosing() { 137 | err = c.waitGoroutines() 138 | if err != nil { 139 | return err 140 | } 141 | return net.ErrClosed 142 | } 143 | defer func() { 144 | if errors.Is(err, net.ErrClosed) { 145 | err = nil 146 | } 147 | }() 148 | 149 | err = c.close() 150 | 151 | err2 := c.waitGoroutines() 152 | if err == nil && err2 != nil { 153 | err = err2 154 | } 155 | return err 156 | } 157 | 158 | func (c *Conn) closeHandshake(code StatusCode, reason string) error { 159 | err := c.writeClose(code, reason) 160 | if err != nil { 161 | return err 162 | } 163 | 164 | err = c.waitCloseHandshake() 165 | if CloseStatus(err) != code { 166 | return err 167 | } 168 | return nil 169 | } 170 | 171 | func (c *Conn) writeClose(code StatusCode, reason string) error { 172 | ce := CloseError{ 173 | Code: code, 174 | Reason: reason, 175 | } 176 | 177 | var p []byte 178 | var err error 179 | if ce.Code != StatusNoStatusRcvd { 180 | p, err = ce.bytes() 181 | if err != nil { 182 | return err 183 | } 184 | } 185 | 186 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 187 | defer cancel() 188 | 189 | err = c.writeControl(ctx, opClose, p) 190 | // If the connection closed as we're writing we ignore the error as we might 191 | // have written the close frame, the peer responded and then someone else read it 192 | // and closed the connection. 193 | if err != nil && !errors.Is(err, net.ErrClosed) { 194 | return err 195 | } 196 | return nil 197 | } 198 | 199 | func (c *Conn) waitCloseHandshake() error { 200 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 201 | defer cancel() 202 | 203 | err := c.readMu.lock(ctx) 204 | if err != nil { 205 | return err 206 | } 207 | defer c.readMu.unlock() 208 | 209 | for i := int64(0); i < c.msgReader.payloadLength; i++ { 210 | _, err := c.br.ReadByte() 211 | if err != nil { 212 | return err 213 | } 214 | } 215 | 216 | for { 217 | h, err := c.readLoop(ctx) 218 | if err != nil { 219 | return err 220 | } 221 | 222 | for i := int64(0); i < h.payloadLength; i++ { 223 | _, err := c.br.ReadByte() 224 | if err != nil { 225 | return err 226 | } 227 | } 228 | } 229 | } 230 | 231 | func (c *Conn) waitGoroutines() error { 232 | t := time.NewTimer(time.Second * 15) 233 | defer t.Stop() 234 | 235 | select { 236 | case <-c.timeoutLoopDone: 237 | case <-t.C: 238 | return errors.New("failed to wait for timeoutLoop goroutine to exit") 239 | } 240 | 241 | c.closeReadMu.Lock() 242 | closeRead := c.closeReadCtx != nil 243 | c.closeReadMu.Unlock() 244 | if closeRead { 245 | select { 246 | case <-c.closeReadDone: 247 | case <-t.C: 248 | return errors.New("failed to wait for close read goroutine to exit") 249 | } 250 | } 251 | 252 | select { 253 | case <-c.closed: 254 | case <-t.C: 255 | return errors.New("failed to wait for connection to be closed") 256 | } 257 | 258 | return nil 259 | } 260 | 261 | func parseClosePayload(p []byte) (CloseError, error) { 262 | if len(p) == 0 { 263 | return CloseError{ 264 | Code: StatusNoStatusRcvd, 265 | }, nil 266 | } 267 | 268 | if len(p) < 2 { 269 | return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) 270 | } 271 | 272 | ce := CloseError{ 273 | Code: StatusCode(binary.BigEndian.Uint16(p)), 274 | Reason: string(p[2:]), 275 | } 276 | 277 | if !validWireCloseCode(ce.Code) { 278 | return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) 279 | } 280 | 281 | return ce, nil 282 | } 283 | 284 | // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number 285 | // and https://tools.ietf.org/html/rfc6455#section-7.4.1 286 | func validWireCloseCode(code StatusCode) bool { 287 | switch code { 288 | case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: 289 | return false 290 | } 291 | 292 | if code >= StatusNormalClosure && code <= StatusBadGateway { 293 | return true 294 | } 295 | if code >= 3000 && code <= 4999 { 296 | return true 297 | } 298 | 299 | return false 300 | } 301 | 302 | func (ce CloseError) bytes() ([]byte, error) { 303 | p, err := ce.bytesErr() 304 | if err != nil { 305 | err = fmt.Errorf("failed to marshal close frame: %w", err) 306 | ce = CloseError{ 307 | Code: StatusInternalError, 308 | } 309 | p, _ = ce.bytesErr() 310 | } 311 | return p, err 312 | } 313 | 314 | const maxCloseReason = maxControlPayload - 2 315 | 316 | func (ce CloseError) bytesErr() ([]byte, error) { 317 | if len(ce.Reason) > maxCloseReason { 318 | return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) 319 | } 320 | 321 | if !validWireCloseCode(ce.Code) { 322 | return nil, fmt.Errorf("status code %v cannot be set", ce.Code) 323 | } 324 | 325 | buf := make([]byte, 2+len(ce.Reason)) 326 | binary.BigEndian.PutUint16(buf, uint16(ce.Code)) 327 | copy(buf[2:], ce.Reason) 328 | return buf, nil 329 | } 330 | 331 | func (c *Conn) casClosing() bool { 332 | return c.closing.Swap(true) 333 | } 334 | 335 | func (c *Conn) isClosed() bool { 336 | select { 337 | case <-c.closed: 338 | return true 339 | default: 340 | return false 341 | } 342 | } 343 | -------------------------------------------------------------------------------- /close_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "io" 8 | "math" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/coder/websocket/internal/test/assert" 13 | ) 14 | 15 | func TestCloseError(t *testing.T) { 16 | t.Parallel() 17 | 18 | testCases := []struct { 19 | name string 20 | ce CloseError 21 | success bool 22 | }{ 23 | { 24 | name: "normal", 25 | ce: CloseError{ 26 | Code: StatusNormalClosure, 27 | Reason: strings.Repeat("x", maxCloseReason), 28 | }, 29 | success: true, 30 | }, 31 | { 32 | name: "bigReason", 33 | ce: CloseError{ 34 | Code: StatusNormalClosure, 35 | Reason: strings.Repeat("x", maxCloseReason+1), 36 | }, 37 | success: false, 38 | }, 39 | { 40 | name: "bigCode", 41 | ce: CloseError{ 42 | Code: math.MaxUint16, 43 | Reason: strings.Repeat("x", maxCloseReason), 44 | }, 45 | success: false, 46 | }, 47 | } 48 | 49 | for _, tc := range testCases { 50 | tc := tc 51 | t.Run(tc.name, func(t *testing.T) { 52 | t.Parallel() 53 | 54 | _, err := tc.ce.bytesErr() 55 | if tc.success { 56 | assert.Success(t, err) 57 | } else { 58 | assert.Error(t, err) 59 | } 60 | }) 61 | } 62 | 63 | t.Run("Error", func(t *testing.T) { 64 | exp := `status = StatusInternalError and reason = "meow"` 65 | act := CloseError{ 66 | Code: StatusInternalError, 67 | Reason: "meow", 68 | }.Error() 69 | assert.Equal(t, "CloseError.Error()", exp, act) 70 | }) 71 | } 72 | 73 | func Test_parseClosePayload(t *testing.T) { 74 | t.Parallel() 75 | 76 | testCases := []struct { 77 | name string 78 | p []byte 79 | success bool 80 | ce CloseError 81 | }{ 82 | { 83 | name: "normal", 84 | p: append([]byte{0x3, 0xE8}, []byte("hello")...), 85 | success: true, 86 | ce: CloseError{ 87 | Code: StatusNormalClosure, 88 | Reason: "hello", 89 | }, 90 | }, 91 | { 92 | name: "nothing", 93 | success: true, 94 | ce: CloseError{ 95 | Code: StatusNoStatusRcvd, 96 | }, 97 | }, 98 | { 99 | name: "oneByte", 100 | p: []byte{0}, 101 | success: false, 102 | }, 103 | { 104 | name: "badStatusCode", 105 | p: []byte{0x17, 0x70}, 106 | success: false, 107 | }, 108 | } 109 | 110 | for _, tc := range testCases { 111 | tc := tc 112 | t.Run(tc.name, func(t *testing.T) { 113 | t.Parallel() 114 | 115 | ce, err := parseClosePayload(tc.p) 116 | if tc.success { 117 | assert.Success(t, err) 118 | assert.Equal(t, "close payload", tc.ce, ce) 119 | } else { 120 | assert.Error(t, err) 121 | } 122 | }) 123 | } 124 | } 125 | 126 | func Test_validWireCloseCode(t *testing.T) { 127 | t.Parallel() 128 | 129 | testCases := []struct { 130 | name string 131 | code StatusCode 132 | valid bool 133 | }{ 134 | { 135 | name: "normal", 136 | code: StatusNormalClosure, 137 | valid: true, 138 | }, 139 | { 140 | name: "noStatus", 141 | code: StatusNoStatusRcvd, 142 | valid: false, 143 | }, 144 | { 145 | name: "3000", 146 | code: 3000, 147 | valid: true, 148 | }, 149 | { 150 | name: "4999", 151 | code: 4999, 152 | valid: true, 153 | }, 154 | { 155 | name: "unknown", 156 | code: 5000, 157 | valid: false, 158 | }, 159 | } 160 | 161 | for _, tc := range testCases { 162 | tc := tc 163 | t.Run(tc.name, func(t *testing.T) { 164 | t.Parallel() 165 | 166 | act := validWireCloseCode(tc.code) 167 | assert.Equal(t, "wire close code", tc.valid, act) 168 | }) 169 | } 170 | } 171 | 172 | func TestCloseStatus(t *testing.T) { 173 | t.Parallel() 174 | 175 | testCases := []struct { 176 | name string 177 | in error 178 | exp StatusCode 179 | }{ 180 | { 181 | name: "nil", 182 | in: nil, 183 | exp: -1, 184 | }, 185 | { 186 | name: "io.EOF", 187 | in: io.EOF, 188 | exp: -1, 189 | }, 190 | { 191 | name: "StatusInternalError", 192 | in: CloseError{ 193 | Code: StatusInternalError, 194 | }, 195 | exp: StatusInternalError, 196 | }, 197 | } 198 | 199 | for _, tc := range testCases { 200 | tc := tc 201 | t.Run(tc.name, func(t *testing.T) { 202 | t.Parallel() 203 | 204 | act := CloseStatus(tc.in) 205 | assert.Equal(t, "close status", tc.exp, act) 206 | }) 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /compress.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "compress/flate" 8 | "io" 9 | "sync" 10 | ) 11 | 12 | // CompressionMode represents the modes available to the permessage-deflate extension. 13 | // See https://tools.ietf.org/html/rfc7692 14 | // 15 | // Works in all modern browsers except Safari which does not implement the permessage-deflate extension. 16 | // 17 | // Compression is only used if the peer supports the mode selected. 18 | type CompressionMode int 19 | 20 | const ( 21 | // CompressionDisabled disables the negotiation of the permessage-deflate extension. 22 | // 23 | // This is the default. Do not enable compression without benchmarking for your particular use case first. 24 | CompressionDisabled CompressionMode = iota 25 | 26 | // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from 27 | // previous messages. i.e compression context across messages is preserved. 28 | // 29 | // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient. 30 | // 31 | // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's 32 | // that are used when reading and then returned. 33 | // 34 | // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently. 35 | // 36 | // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover. 37 | CompressionContextTakeover 38 | 39 | // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with 40 | // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from 41 | // a sync.Pool. 42 | // 43 | // This means less efficient compression as the sliding window from previous messages will not be used but the 44 | // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window. 45 | // Especially if the connections are long lived and seldom written to. 46 | // 47 | // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently. 48 | // 49 | // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled. 50 | CompressionNoContextTakeover 51 | ) 52 | 53 | func (m CompressionMode) opts() *compressionOptions { 54 | return &compressionOptions{ 55 | clientNoContextTakeover: m == CompressionNoContextTakeover, 56 | serverNoContextTakeover: m == CompressionNoContextTakeover, 57 | } 58 | } 59 | 60 | type compressionOptions struct { 61 | clientNoContextTakeover bool 62 | serverNoContextTakeover bool 63 | } 64 | 65 | func (copts *compressionOptions) String() string { 66 | s := "permessage-deflate" 67 | if copts.clientNoContextTakeover { 68 | s += "; client_no_context_takeover" 69 | } 70 | if copts.serverNoContextTakeover { 71 | s += "; server_no_context_takeover" 72 | } 73 | return s 74 | } 75 | 76 | // These bytes are required to get flate.Reader to return. 77 | // They are removed when sending to avoid the overhead as 78 | // WebSocket framing tell's when the message has ended but then 79 | // we need to add them back otherwise flate.Reader keeps 80 | // trying to read more bytes. 81 | const deflateMessageTail = "\x00\x00\xff\xff" 82 | 83 | type trimLastFourBytesWriter struct { 84 | w io.Writer 85 | tail []byte 86 | } 87 | 88 | func (tw *trimLastFourBytesWriter) reset() { 89 | if tw != nil && tw.tail != nil { 90 | tw.tail = tw.tail[:0] 91 | } 92 | } 93 | 94 | func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { 95 | if tw.tail == nil { 96 | tw.tail = make([]byte, 0, 4) 97 | } 98 | 99 | extra := len(tw.tail) + len(p) - 4 100 | 101 | if extra <= 0 { 102 | tw.tail = append(tw.tail, p...) 103 | return len(p), nil 104 | } 105 | 106 | // Now we need to write as many extra bytes as we can from the previous tail. 107 | if extra > len(tw.tail) { 108 | extra = len(tw.tail) 109 | } 110 | if extra > 0 { 111 | _, err := tw.w.Write(tw.tail[:extra]) 112 | if err != nil { 113 | return 0, err 114 | } 115 | 116 | // Shift remaining bytes in tail over. 117 | n := copy(tw.tail, tw.tail[extra:]) 118 | tw.tail = tw.tail[:n] 119 | } 120 | 121 | // If p is less than or equal to 4 bytes, 122 | // all of it is is part of the tail. 123 | if len(p) <= 4 { 124 | tw.tail = append(tw.tail, p...) 125 | return len(p), nil 126 | } 127 | 128 | // Otherwise, only the last 4 bytes are. 129 | tw.tail = append(tw.tail, p[len(p)-4:]...) 130 | 131 | p = p[:len(p)-4] 132 | n, err := tw.w.Write(p) 133 | return n + 4, err 134 | } 135 | 136 | var flateReaderPool sync.Pool 137 | 138 | func getFlateReader(r io.Reader, dict []byte) io.Reader { 139 | fr, ok := flateReaderPool.Get().(io.Reader) 140 | if !ok { 141 | return flate.NewReaderDict(r, dict) 142 | } 143 | fr.(flate.Resetter).Reset(r, dict) 144 | return fr 145 | } 146 | 147 | func putFlateReader(fr io.Reader) { 148 | flateReaderPool.Put(fr) 149 | } 150 | 151 | var flateWriterPool sync.Pool 152 | 153 | func getFlateWriter(w io.Writer) *flate.Writer { 154 | fw, ok := flateWriterPool.Get().(*flate.Writer) 155 | if !ok { 156 | fw, _ = flate.NewWriter(w, flate.BestSpeed) 157 | return fw 158 | } 159 | fw.Reset(w) 160 | return fw 161 | } 162 | 163 | func putFlateWriter(w *flate.Writer) { 164 | flateWriterPool.Put(w) 165 | } 166 | 167 | type slidingWindow struct { 168 | buf []byte 169 | } 170 | 171 | var swPoolMu sync.RWMutex 172 | var swPool = map[int]*sync.Pool{} 173 | 174 | func slidingWindowPool(n int) *sync.Pool { 175 | swPoolMu.RLock() 176 | p, ok := swPool[n] 177 | swPoolMu.RUnlock() 178 | if ok { 179 | return p 180 | } 181 | 182 | p = &sync.Pool{} 183 | 184 | swPoolMu.Lock() 185 | swPool[n] = p 186 | swPoolMu.Unlock() 187 | 188 | return p 189 | } 190 | 191 | func (sw *slidingWindow) init(n int) { 192 | if sw.buf != nil { 193 | return 194 | } 195 | 196 | if n == 0 { 197 | n = 32768 198 | } 199 | 200 | p := slidingWindowPool(n) 201 | sw2, ok := p.Get().(*slidingWindow) 202 | if ok { 203 | *sw = *sw2 204 | } else { 205 | sw.buf = make([]byte, 0, n) 206 | } 207 | } 208 | 209 | func (sw *slidingWindow) close() { 210 | sw.buf = sw.buf[:0] 211 | swPoolMu.Lock() 212 | swPool[cap(sw.buf)].Put(sw) 213 | swPoolMu.Unlock() 214 | } 215 | 216 | func (sw *slidingWindow) write(p []byte) { 217 | if len(p) >= cap(sw.buf) { 218 | sw.buf = sw.buf[:cap(sw.buf)] 219 | p = p[len(p)-cap(sw.buf):] 220 | copy(sw.buf, p) 221 | return 222 | } 223 | 224 | left := cap(sw.buf) - len(sw.buf) 225 | if left < len(p) { 226 | // We need to shift spaceNeeded bytes from the end to make room for p at the end. 227 | spaceNeeded := len(p) - left 228 | copy(sw.buf, sw.buf[spaceNeeded:]) 229 | sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] 230 | } 231 | 232 | sw.buf = append(sw.buf, p...) 233 | } 234 | -------------------------------------------------------------------------------- /compress_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bytes" 8 | "compress/flate" 9 | "io" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/coder/websocket/internal/test/assert" 14 | "github.com/coder/websocket/internal/test/xrand" 15 | ) 16 | 17 | func Test_slidingWindow(t *testing.T) { 18 | t.Parallel() 19 | 20 | const testCount = 99 21 | const maxWindow = 99999 22 | for i := 0; i < testCount; i++ { 23 | t.Run("", func(t *testing.T) { 24 | t.Parallel() 25 | 26 | input := xrand.String(maxWindow) 27 | windowLength := xrand.Int(maxWindow) 28 | var sw slidingWindow 29 | sw.init(windowLength) 30 | sw.write([]byte(input)) 31 | 32 | assert.Equal(t, "window length", windowLength, cap(sw.buf)) 33 | if !strings.HasSuffix(input, string(sw.buf)) { 34 | t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) 35 | } 36 | }) 37 | } 38 | } 39 | 40 | func BenchmarkFlateWriter(b *testing.B) { 41 | b.ReportAllocs() 42 | for i := 0; i < b.N; i++ { 43 | w, _ := flate.NewWriter(io.Discard, flate.BestSpeed) 44 | // We have to write a byte to get the writer to allocate to its full extent. 45 | w.Write([]byte{'a'}) 46 | w.Flush() 47 | } 48 | } 49 | 50 | func BenchmarkFlateReader(b *testing.B) { 51 | b.ReportAllocs() 52 | 53 | var buf bytes.Buffer 54 | w, _ := flate.NewWriter(&buf, flate.BestSpeed) 55 | w.Write([]byte{'a'}) 56 | w.Flush() 57 | 58 | for i := 0; i < b.N; i++ { 59 | r := flate.NewReader(bytes.NewReader(buf.Bytes())) 60 | io.ReadAll(r) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "context" 9 | "fmt" 10 | "io" 11 | "net" 12 | "runtime" 13 | "strconv" 14 | "sync" 15 | "sync/atomic" 16 | ) 17 | 18 | // MessageType represents the type of a WebSocket message. 19 | // See https://tools.ietf.org/html/rfc6455#section-5.6 20 | type MessageType int 21 | 22 | // MessageType constants. 23 | const ( 24 | // MessageText is for UTF-8 encoded text messages like JSON. 25 | MessageText MessageType = iota + 1 26 | // MessageBinary is for binary messages like protobufs. 27 | MessageBinary 28 | ) 29 | 30 | // Conn represents a WebSocket connection. 31 | // All methods may be called concurrently except for Reader and Read. 32 | // 33 | // You must always read from the connection. Otherwise control 34 | // frames will not be handled. See Reader and CloseRead. 35 | // 36 | // Be sure to call Close on the connection when you 37 | // are finished with it to release associated resources. 38 | // 39 | // On any error from any method, the connection is closed 40 | // with an appropriate reason. 41 | // 42 | // This applies to context expirations as well unfortunately. 43 | // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 44 | type Conn struct { 45 | noCopy noCopy 46 | 47 | subprotocol string 48 | rwc io.ReadWriteCloser 49 | client bool 50 | copts *compressionOptions 51 | flateThreshold int 52 | br *bufio.Reader 53 | bw *bufio.Writer 54 | 55 | readTimeout chan context.Context 56 | writeTimeout chan context.Context 57 | timeoutLoopDone chan struct{} 58 | 59 | // Read state. 60 | readMu *mu 61 | readHeaderBuf [8]byte 62 | readControlBuf [maxControlPayload]byte 63 | msgReader *msgReader 64 | 65 | // Write state. 66 | msgWriter *msgWriter 67 | writeFrameMu *mu 68 | writeBuf []byte 69 | writeHeaderBuf [8]byte 70 | writeHeader header 71 | 72 | // Close handshake state. 73 | closeStateMu sync.RWMutex 74 | closeReceivedErr error 75 | closeSentErr error 76 | 77 | // CloseRead state. 78 | closeReadMu sync.Mutex 79 | closeReadCtx context.Context 80 | closeReadDone chan struct{} 81 | 82 | closing atomic.Bool 83 | closeMu sync.Mutex // Protects following. 84 | closed chan struct{} 85 | 86 | pingCounter atomic.Int64 87 | activePingsMu sync.Mutex 88 | activePings map[string]chan<- struct{} 89 | onPingReceived func(context.Context, []byte) bool 90 | onPongReceived func(context.Context, []byte) 91 | } 92 | 93 | type connConfig struct { 94 | subprotocol string 95 | rwc io.ReadWriteCloser 96 | client bool 97 | copts *compressionOptions 98 | flateThreshold int 99 | onPingReceived func(context.Context, []byte) bool 100 | onPongReceived func(context.Context, []byte) 101 | 102 | br *bufio.Reader 103 | bw *bufio.Writer 104 | } 105 | 106 | func newConn(cfg connConfig) *Conn { 107 | c := &Conn{ 108 | subprotocol: cfg.subprotocol, 109 | rwc: cfg.rwc, 110 | client: cfg.client, 111 | copts: cfg.copts, 112 | flateThreshold: cfg.flateThreshold, 113 | 114 | br: cfg.br, 115 | bw: cfg.bw, 116 | 117 | readTimeout: make(chan context.Context), 118 | writeTimeout: make(chan context.Context), 119 | timeoutLoopDone: make(chan struct{}), 120 | 121 | closed: make(chan struct{}), 122 | activePings: make(map[string]chan<- struct{}), 123 | onPingReceived: cfg.onPingReceived, 124 | onPongReceived: cfg.onPongReceived, 125 | } 126 | 127 | c.readMu = newMu(c) 128 | c.writeFrameMu = newMu(c) 129 | 130 | c.msgReader = newMsgReader(c) 131 | 132 | c.msgWriter = newMsgWriter(c) 133 | if c.client { 134 | c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) 135 | } 136 | 137 | if c.flate() && c.flateThreshold == 0 { 138 | c.flateThreshold = 128 139 | if !c.msgWriter.flateContextTakeover() { 140 | c.flateThreshold = 512 141 | } 142 | } 143 | 144 | runtime.SetFinalizer(c, func(c *Conn) { 145 | c.close() 146 | }) 147 | 148 | go c.timeoutLoop() 149 | 150 | return c 151 | } 152 | 153 | // Subprotocol returns the negotiated subprotocol. 154 | // An empty string means the default protocol. 155 | func (c *Conn) Subprotocol() string { 156 | return c.subprotocol 157 | } 158 | 159 | func (c *Conn) close() error { 160 | c.closeMu.Lock() 161 | defer c.closeMu.Unlock() 162 | 163 | if c.isClosed() { 164 | return net.ErrClosed 165 | } 166 | runtime.SetFinalizer(c, nil) 167 | close(c.closed) 168 | 169 | // Have to close after c.closed is closed to ensure any goroutine that wakes up 170 | // from the connection being closed also sees that c.closed is closed and returns 171 | // closeErr. 172 | err := c.rwc.Close() 173 | // With the close of rwc, these become safe to close. 174 | c.msgWriter.close() 175 | c.msgReader.close() 176 | return err 177 | } 178 | 179 | func (c *Conn) timeoutLoop() { 180 | defer close(c.timeoutLoopDone) 181 | 182 | readCtx := context.Background() 183 | writeCtx := context.Background() 184 | 185 | for { 186 | select { 187 | case <-c.closed: 188 | return 189 | 190 | case writeCtx = <-c.writeTimeout: 191 | case readCtx = <-c.readTimeout: 192 | 193 | case <-readCtx.Done(): 194 | c.close() 195 | return 196 | case <-writeCtx.Done(): 197 | c.close() 198 | return 199 | } 200 | } 201 | } 202 | 203 | func (c *Conn) flate() bool { 204 | return c.copts != nil 205 | } 206 | 207 | // Ping sends a ping to the peer and waits for a pong. 208 | // Use this to measure latency or ensure the peer is responsive. 209 | // Ping must be called concurrently with Reader as it does 210 | // not read from the connection but instead waits for a Reader call 211 | // to read the pong. 212 | // 213 | // TCP Keepalives should suffice for most use cases. 214 | func (c *Conn) Ping(ctx context.Context) error { 215 | p := c.pingCounter.Add(1) 216 | 217 | err := c.ping(ctx, strconv.FormatInt(p, 10)) 218 | if err != nil { 219 | return fmt.Errorf("failed to ping: %w", err) 220 | } 221 | return nil 222 | } 223 | 224 | func (c *Conn) ping(ctx context.Context, p string) error { 225 | pong := make(chan struct{}, 1) 226 | 227 | c.activePingsMu.Lock() 228 | c.activePings[p] = pong 229 | c.activePingsMu.Unlock() 230 | 231 | defer func() { 232 | c.activePingsMu.Lock() 233 | delete(c.activePings, p) 234 | c.activePingsMu.Unlock() 235 | }() 236 | 237 | err := c.writeControl(ctx, opPing, []byte(p)) 238 | if err != nil { 239 | return err 240 | } 241 | 242 | select { 243 | case <-c.closed: 244 | return net.ErrClosed 245 | case <-ctx.Done(): 246 | return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) 247 | case <-pong: 248 | return nil 249 | } 250 | } 251 | 252 | type mu struct { 253 | c *Conn 254 | ch chan struct{} 255 | } 256 | 257 | func newMu(c *Conn) *mu { 258 | return &mu{ 259 | c: c, 260 | ch: make(chan struct{}, 1), 261 | } 262 | } 263 | 264 | func (m *mu) forceLock() { 265 | m.ch <- struct{}{} 266 | } 267 | 268 | func (m *mu) tryLock() bool { 269 | select { 270 | case m.ch <- struct{}{}: 271 | return true 272 | default: 273 | return false 274 | } 275 | } 276 | 277 | func (m *mu) lock(ctx context.Context) error { 278 | select { 279 | case <-m.c.closed: 280 | return net.ErrClosed 281 | case <-ctx.Done(): 282 | return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) 283 | case m.ch <- struct{}{}: 284 | // To make sure the connection is certainly alive. 285 | // As it's possible the send on m.ch was selected 286 | // over the receive on closed. 287 | select { 288 | case <-m.c.closed: 289 | // Make sure to release. 290 | m.unlock() 291 | return net.ErrClosed 292 | default: 293 | } 294 | return nil 295 | } 296 | } 297 | 298 | func (m *mu) unlock() { 299 | select { 300 | case <-m.ch: 301 | default: 302 | } 303 | } 304 | 305 | type noCopy struct{} 306 | 307 | func (*noCopy) Lock() {} 308 | -------------------------------------------------------------------------------- /dial.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "context" 10 | "crypto/rand" 11 | "encoding/base64" 12 | "fmt" 13 | "io" 14 | "net/http" 15 | "net/url" 16 | "strings" 17 | "sync" 18 | "time" 19 | 20 | "github.com/coder/websocket/internal/errd" 21 | ) 22 | 23 | // DialOptions represents Dial's options. 24 | type DialOptions struct { 25 | // HTTPClient is used for the connection. 26 | // Its Transport must return writable bodies for WebSocket handshakes. 27 | // http.Transport does beginning with Go 1.12. 28 | HTTPClient *http.Client 29 | 30 | // HTTPHeader specifies the HTTP headers included in the handshake request. 31 | HTTPHeader http.Header 32 | 33 | // Host optionally overrides the Host HTTP header to send. If empty, the value 34 | // of URL.Host will be used. 35 | Host string 36 | 37 | // Subprotocols lists the WebSocket subprotocols to negotiate with the server. 38 | Subprotocols []string 39 | 40 | // CompressionMode controls the compression mode. 41 | // Defaults to CompressionDisabled. 42 | // 43 | // See docs on CompressionMode for details. 44 | CompressionMode CompressionMode 45 | 46 | // CompressionThreshold controls the minimum size of a message before compression is applied. 47 | // 48 | // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes 49 | // for CompressionContextTakeover. 50 | CompressionThreshold int 51 | 52 | // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. 53 | // 54 | // The payload contains the application data of the ping frame. 55 | // If the callback returns false, the subsequent pong frame will not be sent. 56 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 57 | OnPingReceived func(ctx context.Context, payload []byte) bool 58 | 59 | // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. 60 | // 61 | // The payload contains the application data of the pong frame. 62 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 63 | // 64 | // Unlike OnPingReceived, this callback does not return a value because a pong frame 65 | // is a response to a ping and does not trigger any further frame transmission. 66 | OnPongReceived func(ctx context.Context, payload []byte) 67 | } 68 | 69 | func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { 70 | var cancel context.CancelFunc 71 | 72 | var o DialOptions 73 | if opts != nil { 74 | o = *opts 75 | } 76 | if o.HTTPClient == nil { 77 | o.HTTPClient = http.DefaultClient 78 | } 79 | if o.HTTPClient.Timeout > 0 { 80 | ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout) 81 | 82 | newClient := *o.HTTPClient 83 | newClient.Timeout = 0 84 | o.HTTPClient = &newClient 85 | } 86 | if o.HTTPHeader == nil { 87 | o.HTTPHeader = http.Header{} 88 | } 89 | newClient := *o.HTTPClient 90 | oldCheckRedirect := o.HTTPClient.CheckRedirect 91 | newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { 92 | switch req.URL.Scheme { 93 | case "ws": 94 | req.URL.Scheme = "http" 95 | case "wss": 96 | req.URL.Scheme = "https" 97 | } 98 | if oldCheckRedirect != nil { 99 | return oldCheckRedirect(req, via) 100 | } 101 | return nil 102 | } 103 | o.HTTPClient = &newClient 104 | 105 | return ctx, cancel, &o 106 | } 107 | 108 | // Dial performs a WebSocket handshake on url. 109 | // 110 | // The response is the WebSocket handshake response from the server. 111 | // You never need to close resp.Body yourself. 112 | // 113 | // If an error occurs, the returned response may be non nil. 114 | // However, you can only read the first 1024 bytes of the body. 115 | // 116 | // This function requires at least Go 1.12 as it uses a new feature 117 | // in net/http to perform WebSocket handshakes. 118 | // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 119 | // 120 | // URLs with http/https schemes will work and are interpreted as ws/wss. 121 | func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { 122 | return dial(ctx, u, opts, nil) 123 | } 124 | 125 | func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { 126 | defer errd.Wrap(&err, "failed to WebSocket dial") 127 | 128 | var cancel context.CancelFunc 129 | ctx, cancel, opts = opts.cloneWithDefaults(ctx) 130 | if cancel != nil { 131 | defer cancel() 132 | } 133 | 134 | secWebSocketKey, err := secWebSocketKey(rand) 135 | if err != nil { 136 | return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) 137 | } 138 | 139 | var copts *compressionOptions 140 | if opts.CompressionMode != CompressionDisabled { 141 | copts = opts.CompressionMode.opts() 142 | } 143 | 144 | resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) 145 | if err != nil { 146 | return nil, resp, err 147 | } 148 | respBody := resp.Body 149 | resp.Body = nil 150 | defer func() { 151 | if err != nil { 152 | // We read a bit of the body for easier debugging. 153 | r := io.LimitReader(respBody, 1024) 154 | 155 | timer := time.AfterFunc(time.Second*3, func() { 156 | respBody.Close() 157 | }) 158 | defer timer.Stop() 159 | 160 | b, _ := io.ReadAll(r) 161 | respBody.Close() 162 | resp.Body = io.NopCloser(bytes.NewReader(b)) 163 | } 164 | }() 165 | 166 | copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) 167 | if err != nil { 168 | return nil, resp, err 169 | } 170 | 171 | rwc, ok := respBody.(io.ReadWriteCloser) 172 | if !ok { 173 | return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) 174 | } 175 | 176 | return newConn(connConfig{ 177 | subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), 178 | rwc: rwc, 179 | client: true, 180 | copts: copts, 181 | flateThreshold: opts.CompressionThreshold, 182 | onPingReceived: opts.OnPingReceived, 183 | onPongReceived: opts.OnPongReceived, 184 | br: getBufioReader(rwc), 185 | bw: getBufioWriter(rwc), 186 | }), resp, nil 187 | } 188 | 189 | func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { 190 | u, err := url.Parse(urls) 191 | if err != nil { 192 | return nil, fmt.Errorf("failed to parse url: %w", err) 193 | } 194 | 195 | switch u.Scheme { 196 | case "ws": 197 | u.Scheme = "http" 198 | case "wss": 199 | u.Scheme = "https" 200 | case "http", "https": 201 | default: 202 | return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) 203 | } 204 | 205 | req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 206 | if err != nil { 207 | return nil, fmt.Errorf("failed to create new http request: %w", err) 208 | } 209 | if len(opts.Host) > 0 { 210 | req.Host = opts.Host 211 | } 212 | req.Header = opts.HTTPHeader.Clone() 213 | req.Header.Set("Connection", "Upgrade") 214 | req.Header.Set("Upgrade", "websocket") 215 | req.Header.Set("Sec-WebSocket-Version", "13") 216 | req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) 217 | if len(opts.Subprotocols) > 0 { 218 | req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) 219 | } 220 | if copts != nil { 221 | req.Header.Set("Sec-WebSocket-Extensions", copts.String()) 222 | } 223 | 224 | resp, err := opts.HTTPClient.Do(req) 225 | if err != nil { 226 | return nil, fmt.Errorf("failed to send handshake request: %w", err) 227 | } 228 | return resp, nil 229 | } 230 | 231 | func secWebSocketKey(rr io.Reader) (string, error) { 232 | if rr == nil { 233 | rr = rand.Reader 234 | } 235 | b := make([]byte, 16) 236 | _, err := io.ReadFull(rr, b) 237 | if err != nil { 238 | return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) 239 | } 240 | return base64.StdEncoding.EncodeToString(b), nil 241 | } 242 | 243 | func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { 244 | if resp.StatusCode != http.StatusSwitchingProtocols { 245 | return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) 246 | } 247 | 248 | if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { 249 | return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) 250 | } 251 | 252 | if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { 253 | return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) 254 | } 255 | 256 | if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { 257 | return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", 258 | resp.Header.Get("Sec-WebSocket-Accept"), 259 | secWebSocketKey, 260 | ) 261 | } 262 | 263 | err := verifySubprotocol(opts.Subprotocols, resp) 264 | if err != nil { 265 | return nil, err 266 | } 267 | 268 | return verifyServerExtensions(copts, resp.Header) 269 | } 270 | 271 | func verifySubprotocol(subprotos []string, resp *http.Response) error { 272 | proto := resp.Header.Get("Sec-WebSocket-Protocol") 273 | if proto == "" { 274 | return nil 275 | } 276 | 277 | for _, sp2 := range subprotos { 278 | if strings.EqualFold(sp2, proto) { 279 | return nil 280 | } 281 | } 282 | 283 | return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) 284 | } 285 | 286 | func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { 287 | exts := websocketExtensions(h) 288 | if len(exts) == 0 { 289 | return nil, nil 290 | } 291 | 292 | ext := exts[0] 293 | if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { 294 | return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) 295 | } 296 | 297 | _copts := *copts 298 | copts = &_copts 299 | 300 | for _, p := range ext.params { 301 | switch p { 302 | case "client_no_context_takeover": 303 | copts.clientNoContextTakeover = true 304 | continue 305 | case "server_no_context_takeover": 306 | copts.serverNoContextTakeover = true 307 | continue 308 | } 309 | if strings.HasPrefix(p, "server_max_window_bits=") { 310 | // We can't adjust the deflate window, but decoding with a larger window is acceptable. 311 | continue 312 | } 313 | 314 | return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) 315 | } 316 | 317 | return copts, nil 318 | } 319 | 320 | var bufioReaderPool sync.Pool 321 | 322 | func getBufioReader(r io.Reader) *bufio.Reader { 323 | br, ok := bufioReaderPool.Get().(*bufio.Reader) 324 | if !ok { 325 | return bufio.NewReader(r) 326 | } 327 | br.Reset(r) 328 | return br 329 | } 330 | 331 | func putBufioReader(br *bufio.Reader) { 332 | bufioReaderPool.Put(br) 333 | } 334 | 335 | var bufioWriterPool sync.Pool 336 | 337 | func getBufioWriter(w io.Writer) *bufio.Writer { 338 | bw, ok := bufioWriterPool.Get().(*bufio.Writer) 339 | if !ok { 340 | return bufio.NewWriter(w) 341 | } 342 | bw.Reset(w) 343 | return bw 344 | } 345 | 346 | func putBufioWriter(bw *bufio.Writer) { 347 | bufioWriterPool.Put(bw) 348 | } 349 | -------------------------------------------------------------------------------- /dial_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket_test 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "crypto/rand" 10 | "io" 11 | "net/http" 12 | "net/http/httptest" 13 | "net/url" 14 | "strings" 15 | "testing" 16 | "time" 17 | 18 | "github.com/coder/websocket" 19 | "github.com/coder/websocket/internal/test/assert" 20 | "github.com/coder/websocket/internal/util" 21 | "github.com/coder/websocket/internal/xsync" 22 | ) 23 | 24 | func TestBadDials(t *testing.T) { 25 | t.Parallel() 26 | 27 | t.Run("badReq", func(t *testing.T) { 28 | t.Parallel() 29 | 30 | testCases := []struct { 31 | name string 32 | url string 33 | opts *websocket.DialOptions 34 | rand util.ReaderFunc 35 | nilCtx bool 36 | }{ 37 | { 38 | name: "badURL", 39 | url: "://noscheme", 40 | }, 41 | { 42 | name: "badURLScheme", 43 | url: "ftp://nhooyr.io", 44 | }, 45 | { 46 | name: "badTLS", 47 | url: "wss://totallyfake.nhooyr.io", 48 | }, 49 | { 50 | name: "badReader", 51 | rand: func(p []byte) (int, error) { 52 | return 0, io.EOF 53 | }, 54 | }, 55 | { 56 | name: "nilContext", 57 | url: "http://localhost", 58 | nilCtx: true, 59 | }, 60 | } 61 | 62 | for _, tc := range testCases { 63 | tc := tc 64 | t.Run(tc.name, func(t *testing.T) { 65 | t.Parallel() 66 | 67 | var ctx context.Context 68 | var cancel func() 69 | if !tc.nilCtx { 70 | ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) 71 | defer cancel() 72 | } 73 | 74 | if tc.rand == nil { 75 | tc.rand = rand.Reader.Read 76 | } 77 | 78 | _, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand) 79 | assert.Error(t, err) 80 | }) 81 | } 82 | }) 83 | 84 | t.Run("badResponse", func(t *testing.T) { 85 | t.Parallel() 86 | 87 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 88 | defer cancel() 89 | 90 | _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 91 | HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { 92 | return &http.Response{ 93 | Body: io.NopCloser(strings.NewReader("hi")), 94 | }, nil 95 | }), 96 | }) 97 | assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") 98 | }) 99 | 100 | t.Run("badBody", func(t *testing.T) { 101 | t.Parallel() 102 | 103 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 104 | defer cancel() 105 | 106 | rt := func(r *http.Request) (*http.Response, error) { 107 | h := http.Header{} 108 | h.Set("Connection", "Upgrade") 109 | h.Set("Upgrade", "websocket") 110 | h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) 111 | 112 | return &http.Response{ 113 | StatusCode: http.StatusSwitchingProtocols, 114 | Header: h, 115 | Body: io.NopCloser(strings.NewReader("hi")), 116 | }, nil 117 | } 118 | 119 | _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 120 | HTTPClient: mockHTTPClient(rt), 121 | }) 122 | assert.Contains(t, err, "response body is not a io.ReadWriteCloser") 123 | }) 124 | } 125 | 126 | func Test_verifyHostOverride(t *testing.T) { 127 | testCases := []struct { 128 | name string 129 | host string 130 | exp string 131 | }{ 132 | { 133 | name: "noOverride", 134 | host: "", 135 | exp: "example.com", 136 | }, 137 | { 138 | name: "hostOverride", 139 | host: "example.net", 140 | exp: "example.net", 141 | }, 142 | } 143 | 144 | for _, tc := range testCases { 145 | tc := tc 146 | t.Run(tc.name, func(t *testing.T) { 147 | t.Parallel() 148 | 149 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 150 | defer cancel() 151 | 152 | rt := func(r *http.Request) (*http.Response, error) { 153 | assert.Equal(t, "Host", tc.exp, r.Host) 154 | 155 | h := http.Header{} 156 | h.Set("Connection", "Upgrade") 157 | h.Set("Upgrade", "websocket") 158 | h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) 159 | 160 | return &http.Response{ 161 | StatusCode: http.StatusSwitchingProtocols, 162 | Header: h, 163 | Body: mockBody{bytes.NewBufferString("hi")}, 164 | }, nil 165 | } 166 | 167 | c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 168 | HTTPClient: mockHTTPClient(rt), 169 | Host: tc.host, 170 | }) 171 | assert.Success(t, err) 172 | c.CloseNow() 173 | }) 174 | } 175 | 176 | } 177 | 178 | type mockBody struct { 179 | *bytes.Buffer 180 | } 181 | 182 | func (mb mockBody) Close() error { 183 | return nil 184 | } 185 | 186 | func Test_verifyServerHandshake(t *testing.T) { 187 | t.Parallel() 188 | 189 | testCases := []struct { 190 | name string 191 | response func(w http.ResponseWriter) 192 | success bool 193 | }{ 194 | { 195 | name: "badStatus", 196 | response: func(w http.ResponseWriter) { 197 | w.WriteHeader(http.StatusOK) 198 | }, 199 | success: false, 200 | }, 201 | { 202 | name: "badConnection", 203 | response: func(w http.ResponseWriter) { 204 | w.Header().Set("Connection", "???") 205 | w.WriteHeader(http.StatusSwitchingProtocols) 206 | }, 207 | success: false, 208 | }, 209 | { 210 | name: "badUpgrade", 211 | response: func(w http.ResponseWriter) { 212 | w.Header().Set("Connection", "Upgrade") 213 | w.Header().Set("Upgrade", "???") 214 | w.WriteHeader(http.StatusSwitchingProtocols) 215 | }, 216 | success: false, 217 | }, 218 | { 219 | name: "badSecWebSocketAccept", 220 | response: func(w http.ResponseWriter) { 221 | w.Header().Set("Connection", "Upgrade") 222 | w.Header().Set("Upgrade", "websocket") 223 | w.Header().Set("Sec-WebSocket-Accept", "xd") 224 | w.WriteHeader(http.StatusSwitchingProtocols) 225 | }, 226 | success: false, 227 | }, 228 | { 229 | name: "badSecWebSocketProtocol", 230 | response: func(w http.ResponseWriter) { 231 | w.Header().Set("Connection", "Upgrade") 232 | w.Header().Set("Upgrade", "websocket") 233 | w.Header().Set("Sec-WebSocket-Protocol", "xd") 234 | w.WriteHeader(http.StatusSwitchingProtocols) 235 | }, 236 | success: false, 237 | }, 238 | { 239 | name: "unsupportedExtension", 240 | response: func(w http.ResponseWriter) { 241 | w.Header().Set("Connection", "Upgrade") 242 | w.Header().Set("Upgrade", "websocket") 243 | w.Header().Set("Sec-WebSocket-Extensions", "meow") 244 | w.WriteHeader(http.StatusSwitchingProtocols) 245 | }, 246 | success: false, 247 | }, 248 | { 249 | name: "unsupportedDeflateParam", 250 | response: func(w http.ResponseWriter) { 251 | w.Header().Set("Connection", "Upgrade") 252 | w.Header().Set("Upgrade", "websocket") 253 | w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") 254 | w.WriteHeader(http.StatusSwitchingProtocols) 255 | }, 256 | success: false, 257 | }, 258 | { 259 | name: "success", 260 | response: func(w http.ResponseWriter) { 261 | w.Header().Set("Connection", "Upgrade") 262 | w.Header().Set("Upgrade", "websocket") 263 | w.WriteHeader(http.StatusSwitchingProtocols) 264 | }, 265 | success: true, 266 | }, 267 | } 268 | 269 | for _, tc := range testCases { 270 | tc := tc 271 | t.Run(tc.name, func(t *testing.T) { 272 | t.Parallel() 273 | 274 | w := httptest.NewRecorder() 275 | tc.response(w) 276 | resp := w.Result() 277 | 278 | r := httptest.NewRequest("GET", "/", nil) 279 | key, err := websocket.SecWebSocketKey(rand.Reader) 280 | assert.Success(t, err) 281 | r.Header.Set("Sec-WebSocket-Key", key) 282 | 283 | if resp.Header.Get("Sec-WebSocket-Accept") == "" { 284 | resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key)) 285 | } 286 | 287 | opts := &websocket.DialOptions{ 288 | Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), 289 | } 290 | _, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp) 291 | if tc.success { 292 | assert.Success(t, err) 293 | } else { 294 | assert.Error(t, err) 295 | } 296 | }) 297 | } 298 | } 299 | 300 | func mockHTTPClient(fn roundTripperFunc) *http.Client { 301 | return &http.Client{ 302 | Transport: fn, 303 | } 304 | } 305 | 306 | type roundTripperFunc func(*http.Request) (*http.Response, error) 307 | 308 | func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { 309 | return f(r) 310 | } 311 | 312 | func TestDialRedirect(t *testing.T) { 313 | t.Parallel() 314 | 315 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 316 | defer cancel() 317 | 318 | _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 319 | HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) { 320 | resp := &http.Response{ 321 | Header: http.Header{}, 322 | } 323 | if r.URL.Scheme != "https" { 324 | resp.Header.Set("Location", "wss://example.com") 325 | resp.StatusCode = http.StatusFound 326 | return resp, nil 327 | } 328 | resp.Header.Set("Connection", "Upgrade") 329 | resp.Header.Set("Upgrade", "meow") 330 | resp.StatusCode = http.StatusSwitchingProtocols 331 | return resp, nil 332 | }), 333 | }) 334 | assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket") 335 | } 336 | 337 | type forwardProxy struct { 338 | hc *http.Client 339 | } 340 | 341 | func newForwardProxy() *forwardProxy { 342 | return &forwardProxy{ 343 | hc: &http.Client{}, 344 | } 345 | } 346 | 347 | func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 348 | ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) 349 | defer cancel() 350 | 351 | r = r.WithContext(ctx) 352 | r.RequestURI = "" 353 | resp, err := fc.hc.Do(r) 354 | if err != nil { 355 | http.Error(w, err.Error(), http.StatusBadRequest) 356 | return 357 | } 358 | defer resp.Body.Close() 359 | 360 | for k, v := range resp.Header { 361 | w.Header()[k] = v 362 | } 363 | w.Header().Set("PROXIED", "true") 364 | w.WriteHeader(resp.StatusCode) 365 | if resprw, ok := resp.Body.(io.ReadWriter); ok { 366 | c, brw, err := w.(http.Hijacker).Hijack() 367 | if err != nil { 368 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 369 | return 370 | } 371 | brw.Flush() 372 | 373 | errc1 := xsync.Go(func() error { 374 | _, err := io.Copy(c, resprw) 375 | return err 376 | }) 377 | errc2 := xsync.Go(func() error { 378 | _, err := io.Copy(resprw, c) 379 | return err 380 | }) 381 | select { 382 | case <-errc1: 383 | case <-errc2: 384 | case <-r.Context().Done(): 385 | } 386 | } else { 387 | io.Copy(w, resp.Body) 388 | } 389 | } 390 | 391 | func TestDialViaProxy(t *testing.T) { 392 | t.Parallel() 393 | 394 | ps := httptest.NewServer(newForwardProxy()) 395 | defer ps.Close() 396 | 397 | s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 398 | err := echoServer(w, r, nil) 399 | assert.Success(t, err) 400 | })) 401 | defer s.Close() 402 | 403 | psu, err := url.Parse(ps.URL) 404 | assert.Success(t, err) 405 | proxyTransport := http.DefaultTransport.(*http.Transport).Clone() 406 | proxyTransport.Proxy = http.ProxyURL(psu) 407 | 408 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 409 | defer cancel() 410 | c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ 411 | HTTPClient: &http.Client{ 412 | Transport: proxyTransport, 413 | }, 414 | }) 415 | assert.Success(t, err) 416 | assert.Equal(t, "", "true", resp.Header.Get("PROXIED")) 417 | 418 | assertEcho(t, ctx, c) 419 | assertClose(t, c) 420 | } 421 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | // Package websocket implements the RFC 6455 WebSocket protocol. 5 | // 6 | // https://tools.ietf.org/html/rfc6455 7 | // 8 | // Use Dial to dial a WebSocket server. 9 | // 10 | // Use Accept to accept a WebSocket client. 11 | // 12 | // Conn represents the resulting WebSocket connection. 13 | // 14 | // The examples are the best way to understand how to correctly use the library. 15 | // 16 | // The wsjson subpackage contain helpers for JSON and protobuf messages. 17 | // 18 | // More documentation at https://github.com/coder/websocket. 19 | // 20 | // # Wasm 21 | // 22 | // The client side supports compiling to Wasm. 23 | // It wraps the WebSocket browser API. 24 | // 25 | // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket 26 | // 27 | // Some important caveats to be aware of: 28 | // 29 | // - Accept always errors out 30 | // - Conn.Ping is no-op 31 | // - Conn.CloseNow is Close(StatusGoingAway, "") 32 | // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op 33 | // - *http.Response from Dial is &http.Response{} with a 101 status code on success 34 | package websocket // import "github.com/coder/websocket" 35 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package websocket_test 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/coder/websocket" 10 | "github.com/coder/websocket/wsjson" 11 | ) 12 | 13 | func ExampleAccept() { 14 | // This handler accepts a WebSocket connection, reads a single JSON 15 | // message from the client and then closes the connection. 16 | 17 | fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | c, err := websocket.Accept(w, r, nil) 19 | if err != nil { 20 | log.Println(err) 21 | return 22 | } 23 | defer c.CloseNow() 24 | 25 | ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) 26 | defer cancel() 27 | 28 | var v interface{} 29 | err = wsjson.Read(ctx, c, &v) 30 | if err != nil { 31 | log.Println(err) 32 | return 33 | } 34 | 35 | c.Close(websocket.StatusNormalClosure, "") 36 | }) 37 | 38 | err := http.ListenAndServe("localhost:8080", fn) 39 | log.Fatal(err) 40 | } 41 | 42 | func ExampleDial() { 43 | // Dials a server, writes a single JSON message and then 44 | // closes the connection. 45 | 46 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 47 | defer cancel() 48 | 49 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | defer c.CloseNow() 54 | 55 | err = wsjson.Write(ctx, c, "hi") 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | 60 | c.Close(websocket.StatusNormalClosure, "") 61 | } 62 | 63 | func ExampleCloseStatus() { 64 | // Dials a server and then expects to be disconnected with status code 65 | // websocket.StatusNormalClosure. 66 | 67 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 68 | defer cancel() 69 | 70 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 71 | if err != nil { 72 | log.Fatal(err) 73 | } 74 | defer c.CloseNow() 75 | 76 | _, _, err = c.Reader(ctx) 77 | if websocket.CloseStatus(err) != websocket.StatusNormalClosure { 78 | log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) 79 | } 80 | } 81 | 82 | func Example_writeOnly() { 83 | // This handler demonstrates how to correctly handle a write only WebSocket connection. 84 | // i.e you only expect to write messages and do not expect to read any messages. 85 | fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 86 | c, err := websocket.Accept(w, r, nil) 87 | if err != nil { 88 | log.Println(err) 89 | return 90 | } 91 | defer c.CloseNow() 92 | 93 | ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) 94 | defer cancel() 95 | 96 | ctx = c.CloseRead(ctx) 97 | 98 | t := time.NewTicker(time.Second * 30) 99 | defer t.Stop() 100 | 101 | for { 102 | select { 103 | case <-ctx.Done(): 104 | c.Close(websocket.StatusNormalClosure, "") 105 | return 106 | case <-t.C: 107 | err = wsjson.Write(ctx, c, "hi") 108 | if err != nil { 109 | log.Println(err) 110 | return 111 | } 112 | } 113 | } 114 | }) 115 | 116 | err := http.ListenAndServe("localhost:8080", fn) 117 | log.Fatal(err) 118 | } 119 | 120 | func Example_crossOrigin() { 121 | // This handler demonstrates how to safely accept cross origin WebSockets 122 | // from the origin example.com. 123 | fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 124 | c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ 125 | OriginPatterns: []string{"example.com"}, 126 | }) 127 | if err != nil { 128 | log.Println(err) 129 | return 130 | } 131 | c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted") 132 | }) 133 | 134 | err := http.ListenAndServe("localhost:8080", fn) 135 | log.Fatal(err) 136 | } 137 | 138 | func ExampleConn_Ping() { 139 | // Dials a server and pings it 5 times. 140 | 141 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 142 | defer cancel() 143 | 144 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 145 | if err != nil { 146 | log.Fatal(err) 147 | } 148 | defer c.CloseNow() 149 | 150 | // Required to read the Pongs from the server. 151 | ctx = c.CloseRead(ctx) 152 | 153 | for i := 0; i < 5; i++ { 154 | err = c.Ping(ctx) 155 | if err != nil { 156 | log.Fatal(err) 157 | } 158 | } 159 | 160 | c.Close(websocket.StatusNormalClosure, "") 161 | } 162 | 163 | // This example demonstrates full stack chat with an automated test. 164 | func Example_fullStackChat() { 165 | // https://github.com/nhooyr/websocket/tree/master/internal/examples/chat 166 | } 167 | 168 | // This example demonstrates a echo server. 169 | func Example_echo() { 170 | // https://github.com/nhooyr/websocket/tree/master/internal/examples/echo 171 | } 172 | -------------------------------------------------------------------------------- /export_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "net" 8 | 9 | "github.com/coder/websocket/internal/util" 10 | ) 11 | 12 | func (c *Conn) RecordBytesWritten() *int { 13 | var bytesWritten int 14 | c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) { 15 | bytesWritten += len(p) 16 | return c.rwc.Write(p) 17 | })) 18 | return &bytesWritten 19 | } 20 | 21 | func (c *Conn) RecordBytesRead() *int { 22 | var bytesRead int 23 | c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) { 24 | n, err := c.rwc.Read(p) 25 | bytesRead += n 26 | return n, err 27 | })) 28 | return &bytesRead 29 | } 30 | 31 | var ErrClosed = net.ErrClosed 32 | 33 | var ExportedDial = dial 34 | var SecWebSocketAccept = secWebSocketAccept 35 | var SecWebSocketKey = secWebSocketKey 36 | var VerifyServerResponse = verifyServerResponse 37 | 38 | var CompressionModeOpts = CompressionMode.opts 39 | -------------------------------------------------------------------------------- /frame.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | 3 | package websocket 4 | 5 | import ( 6 | "bufio" 7 | "encoding/binary" 8 | "fmt" 9 | "io" 10 | "math" 11 | 12 | "github.com/coder/websocket/internal/errd" 13 | ) 14 | 15 | // opcode represents a WebSocket opcode. 16 | type opcode int 17 | 18 | // https://tools.ietf.org/html/rfc6455#section-11.8. 19 | const ( 20 | opContinuation opcode = iota 21 | opText 22 | opBinary 23 | // 3 - 7 are reserved for further non-control frames. 24 | _ 25 | _ 26 | _ 27 | _ 28 | _ 29 | opClose 30 | opPing 31 | opPong 32 | // 11-16 are reserved for further control frames. 33 | ) 34 | 35 | // header represents a WebSocket frame header. 36 | // See https://tools.ietf.org/html/rfc6455#section-5.2. 37 | type header struct { 38 | fin bool 39 | rsv1 bool 40 | rsv2 bool 41 | rsv3 bool 42 | opcode opcode 43 | 44 | payloadLength int64 45 | 46 | masked bool 47 | maskKey uint32 48 | } 49 | 50 | // readFrameHeader reads a header from the reader. 51 | // See https://tools.ietf.org/html/rfc6455#section-5.2. 52 | func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { 53 | defer errd.Wrap(&err, "failed to read frame header") 54 | 55 | b, err := r.ReadByte() 56 | if err != nil { 57 | return header{}, err 58 | } 59 | 60 | h.fin = b&(1<<7) != 0 61 | h.rsv1 = b&(1<<6) != 0 62 | h.rsv2 = b&(1<<5) != 0 63 | h.rsv3 = b&(1<<4) != 0 64 | 65 | h.opcode = opcode(b & 0xf) 66 | 67 | b, err = r.ReadByte() 68 | if err != nil { 69 | return header{}, err 70 | } 71 | 72 | h.masked = b&(1<<7) != 0 73 | 74 | payloadLength := b &^ (1 << 7) 75 | switch { 76 | case payloadLength < 126: 77 | h.payloadLength = int64(payloadLength) 78 | case payloadLength == 126: 79 | _, err = io.ReadFull(r, readBuf[:2]) 80 | h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) 81 | case payloadLength == 127: 82 | _, err = io.ReadFull(r, readBuf) 83 | h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) 84 | } 85 | if err != nil { 86 | return header{}, err 87 | } 88 | 89 | if h.payloadLength < 0 { 90 | return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) 91 | } 92 | 93 | if h.masked { 94 | _, err = io.ReadFull(r, readBuf[:4]) 95 | if err != nil { 96 | return header{}, err 97 | } 98 | h.maskKey = binary.LittleEndian.Uint32(readBuf) 99 | } 100 | 101 | return h, nil 102 | } 103 | 104 | // maxControlPayload is the maximum length of a control frame payload. 105 | // See https://tools.ietf.org/html/rfc6455#section-5.5. 106 | const maxControlPayload = 125 107 | 108 | // writeFrameHeader writes the bytes of the header to w. 109 | // See https://tools.ietf.org/html/rfc6455#section-5.2 110 | func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { 111 | defer errd.Wrap(&err, "failed to write frame header") 112 | 113 | var b byte 114 | if h.fin { 115 | b |= 1 << 7 116 | } 117 | if h.rsv1 { 118 | b |= 1 << 6 119 | } 120 | if h.rsv2 { 121 | b |= 1 << 5 122 | } 123 | if h.rsv3 { 124 | b |= 1 << 4 125 | } 126 | 127 | b |= byte(h.opcode) 128 | 129 | err = w.WriteByte(b) 130 | if err != nil { 131 | return err 132 | } 133 | 134 | lengthByte := byte(0) 135 | if h.masked { 136 | lengthByte |= 1 << 7 137 | } 138 | 139 | switch { 140 | case h.payloadLength > math.MaxUint16: 141 | lengthByte |= 127 142 | case h.payloadLength > 125: 143 | lengthByte |= 126 144 | case h.payloadLength >= 0: 145 | lengthByte |= byte(h.payloadLength) 146 | } 147 | err = w.WriteByte(lengthByte) 148 | if err != nil { 149 | return err 150 | } 151 | 152 | switch { 153 | case h.payloadLength > math.MaxUint16: 154 | binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) 155 | _, err = w.Write(buf) 156 | case h.payloadLength > 125: 157 | binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) 158 | _, err = w.Write(buf[:2]) 159 | } 160 | if err != nil { 161 | return err 162 | } 163 | 164 | if h.masked { 165 | binary.LittleEndian.PutUint32(buf, h.maskKey) 166 | _, err = w.Write(buf[:4]) 167 | if err != nil { 168 | return err 169 | } 170 | } 171 | 172 | return nil 173 | } 174 | -------------------------------------------------------------------------------- /frame_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "encoding/binary" 10 | "math/bits" 11 | "math/rand" 12 | "strconv" 13 | "testing" 14 | "time" 15 | 16 | "github.com/coder/websocket/internal/test/assert" 17 | ) 18 | 19 | func TestHeader(t *testing.T) { 20 | t.Parallel() 21 | 22 | t.Run("lengths", func(t *testing.T) { 23 | t.Parallel() 24 | 25 | lengths := []int{ 26 | 124, 27 | 125, 28 | 126, 29 | 127, 30 | 31 | 65534, 32 | 65535, 33 | 65536, 34 | 65537, 35 | } 36 | 37 | for _, n := range lengths { 38 | n := n 39 | t.Run(strconv.Itoa(n), func(t *testing.T) { 40 | t.Parallel() 41 | 42 | testHeader(t, header{ 43 | payloadLength: int64(n), 44 | }) 45 | }) 46 | } 47 | }) 48 | 49 | t.Run("fuzz", func(t *testing.T) { 50 | t.Parallel() 51 | 52 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 53 | randBool := func() bool { 54 | return r.Intn(2) == 0 55 | } 56 | 57 | for i := 0; i < 10000; i++ { 58 | h := header{ 59 | fin: randBool(), 60 | rsv1: randBool(), 61 | rsv2: randBool(), 62 | rsv3: randBool(), 63 | opcode: opcode(r.Intn(16)), 64 | 65 | masked: randBool(), 66 | payloadLength: r.Int63(), 67 | } 68 | if h.masked { 69 | h.maskKey = r.Uint32() 70 | } 71 | 72 | testHeader(t, h) 73 | } 74 | }) 75 | } 76 | 77 | func testHeader(t *testing.T, h header) { 78 | b := &bytes.Buffer{} 79 | w := bufio.NewWriter(b) 80 | r := bufio.NewReader(b) 81 | 82 | err := writeFrameHeader(h, w, make([]byte, 8)) 83 | assert.Success(t, err) 84 | 85 | err = w.Flush() 86 | assert.Success(t, err) 87 | 88 | h2, err := readFrameHeader(r, make([]byte, 8)) 89 | assert.Success(t, err) 90 | 91 | assert.Equal(t, "read header", h, h2) 92 | } 93 | 94 | func Test_mask(t *testing.T) { 95 | t.Parallel() 96 | 97 | key := []byte{0xa, 0xb, 0xc, 0xff} 98 | key32 := binary.LittleEndian.Uint32(key) 99 | p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} 100 | gotKey32 := mask(p, key32) 101 | 102 | expP := []byte{0, 0, 0, 0x0d, 0x6} 103 | assert.Equal(t, "p", expP, p) 104 | 105 | expKey32 := bits.RotateLeft32(key32, -8) 106 | assert.Equal(t, "key32", expKey32, gotKey32) 107 | } 108 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coder/websocket 2 | 3 | go 1.23 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder/websocket/efb626be44240d7979b57427265d9b6402166b96/go.sum -------------------------------------------------------------------------------- /hijack.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | 3 | package websocket 4 | 5 | import ( 6 | "net/http" 7 | ) 8 | 9 | type rwUnwrapper interface { 10 | Unwrap() http.ResponseWriter 11 | } 12 | 13 | // hijacker returns the Hijacker interface of the http.ResponseWriter. 14 | // It follows the Unwrap method of the http.ResponseWriter if available, 15 | // matching the behavior of http.ResponseController. If the Hijacker 16 | // interface is not found, it returns false. 17 | // 18 | // Since the http.ResponseController is not available in Go 1.19, and 19 | // does not support checking the presence of the Hijacker interface, 20 | // this function is used to provide a consistent way to check for the 21 | // Hijacker interface across Go versions. 22 | func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) { 23 | for { 24 | switch t := rw.(type) { 25 | case http.Hijacker: 26 | return t, true 27 | case rwUnwrapper: 28 | rw = t.Unwrap() 29 | default: 30 | return nil, false 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /hijack_go120_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js && go1.20 2 | 3 | package websocket 4 | 5 | import ( 6 | "bufio" 7 | "errors" 8 | "net" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | 13 | "github.com/coder/websocket/internal/test/assert" 14 | ) 15 | 16 | func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) { 17 | t.Parallel() 18 | 19 | rr := httptest.NewRecorder() 20 | w := mockUnwrapper{ 21 | ResponseWriter: rr, 22 | unwrap: func() http.ResponseWriter { 23 | return mockHijacker{ 24 | ResponseWriter: rr, 25 | hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { 26 | return nil, nil, errors.New("haha") 27 | }, 28 | } 29 | }, 30 | } 31 | 32 | _, _, err := http.NewResponseController(w).Hijack() 33 | assert.Contains(t, err, "haha") 34 | hj, ok := hijacker(w) 35 | assert.Equal(t, "hijacker found", ok, true) 36 | _, _, err = hj.Hijack() 37 | assert.Contains(t, err, "haha") 38 | } 39 | -------------------------------------------------------------------------------- /internal/bpool/bpool.go: -------------------------------------------------------------------------------- 1 | package bpool 2 | 3 | import ( 4 | "bytes" 5 | "sync" 6 | ) 7 | 8 | var bpool = sync.Pool{ 9 | New: func() any { 10 | return &bytes.Buffer{} 11 | }, 12 | } 13 | 14 | // Get returns a buffer from the pool or creates a new one if 15 | // the pool is empty. 16 | func Get() *bytes.Buffer { 17 | b := bpool.Get() 18 | return b.(*bytes.Buffer) 19 | } 20 | 21 | // Put returns a buffer into the pool. 22 | func Put(b *bytes.Buffer) { 23 | b.Reset() 24 | bpool.Put(b) 25 | } 26 | -------------------------------------------------------------------------------- /internal/errd/wrap.go: -------------------------------------------------------------------------------- 1 | package errd 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Wrap wraps err with fmt.Errorf if err is non nil. 8 | // Intended for use with defer and a named error return. 9 | // Inspired by https://github.com/golang/go/issues/32676. 10 | func Wrap(err *error, f string, v ...interface{}) { 11 | if *err != nil { 12 | *err = fmt.Errorf(f+": %w", append(v, *err)...) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /internal/examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains more involved examples unsuitable 4 | for display with godoc. 5 | -------------------------------------------------------------------------------- /internal/examples/chat/README.md: -------------------------------------------------------------------------------- 1 | # Chat Example 2 | 3 | This directory contains a full stack example of a simple chat webapp using github.com/coder/websocket. 4 | 5 | ```bash 6 | $ cd examples/chat 7 | $ go run . localhost:0 8 | listening on ws://127.0.0.1:51055 9 | ``` 10 | 11 | Visit the printed URL to submit and view broadcasted messages in a browser. 12 | 13 | ![Image of Example](https://i.imgur.com/VwJl9Bh.png) 14 | 15 | ## Structure 16 | 17 | The frontend is contained in `index.html`, `index.js` and `index.css`. It sets up the 18 | DOM with a scrollable div at the top that is populated with new messages as they are broadcast. 19 | At the bottom it adds a form to submit messages. 20 | 21 | The messages are received via the WebSocket `/subscribe` endpoint and published via 22 | the HTTP POST `/publish` endpoint. The reason for not publishing messages over the WebSocket 23 | is so that you can easily publish a message with curl. 24 | 25 | The server portion is `main.go` and `chat.go` and implements serving the static frontend 26 | assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoint. 27 | 28 | The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by 29 | `index.html` and then `index.js`. 30 | 31 | There are two automated tests for the server included in `chat_test.go`. The first is a simple one 32 | client echo test. It publishes a single message and ensures it's received. 33 | 34 | The second is a complex concurrency test where 10 clients send 128 unique messages 35 | of max 128 bytes concurrently. The test ensures all messages are seen by every client. 36 | -------------------------------------------------------------------------------- /internal/examples/chat/chat.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "log" 8 | "net" 9 | "net/http" 10 | "sync" 11 | "time" 12 | 13 | "golang.org/x/time/rate" 14 | 15 | "github.com/coder/websocket" 16 | ) 17 | 18 | // chatServer enables broadcasting to a set of subscribers. 19 | type chatServer struct { 20 | // subscriberMessageBuffer controls the max number 21 | // of messages that can be queued for a subscriber 22 | // before it is kicked. 23 | // 24 | // Defaults to 16. 25 | subscriberMessageBuffer int 26 | 27 | // publishLimiter controls the rate limit applied to the publish endpoint. 28 | // 29 | // Defaults to one publish every 100ms with a burst of 8. 30 | publishLimiter *rate.Limiter 31 | 32 | // logf controls where logs are sent. 33 | // Defaults to log.Printf. 34 | logf func(f string, v ...interface{}) 35 | 36 | // serveMux routes the various endpoints to the appropriate handler. 37 | serveMux http.ServeMux 38 | 39 | subscribersMu sync.Mutex 40 | subscribers map[*subscriber]struct{} 41 | } 42 | 43 | // newChatServer constructs a chatServer with the defaults. 44 | func newChatServer() *chatServer { 45 | cs := &chatServer{ 46 | subscriberMessageBuffer: 16, 47 | logf: log.Printf, 48 | subscribers: make(map[*subscriber]struct{}), 49 | publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), 50 | } 51 | cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) 52 | cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) 53 | cs.serveMux.HandleFunc("/publish", cs.publishHandler) 54 | 55 | return cs 56 | } 57 | 58 | // subscriber represents a subscriber. 59 | // Messages are sent on the msgs channel and if the client 60 | // cannot keep up with the messages, closeSlow is called. 61 | type subscriber struct { 62 | msgs chan []byte 63 | closeSlow func() 64 | } 65 | 66 | func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 67 | cs.serveMux.ServeHTTP(w, r) 68 | } 69 | 70 | // subscribeHandler accepts the WebSocket connection and then subscribes 71 | // it to all future messages. 72 | func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { 73 | err := cs.subscribe(w, r) 74 | if errors.Is(err, context.Canceled) { 75 | return 76 | } 77 | if websocket.CloseStatus(err) == websocket.StatusNormalClosure || 78 | websocket.CloseStatus(err) == websocket.StatusGoingAway { 79 | return 80 | } 81 | if err != nil { 82 | cs.logf("%v", err) 83 | return 84 | } 85 | } 86 | 87 | // publishHandler reads the request body with a limit of 8192 bytes and then publishes 88 | // the received message. 89 | func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { 90 | if r.Method != "POST" { 91 | http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) 92 | return 93 | } 94 | body := http.MaxBytesReader(w, r.Body, 8192) 95 | msg, err := io.ReadAll(body) 96 | if err != nil { 97 | http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) 98 | return 99 | } 100 | 101 | cs.publish(msg) 102 | 103 | w.WriteHeader(http.StatusAccepted) 104 | } 105 | 106 | // subscribe subscribes the given WebSocket to all broadcast messages. 107 | // It creates a subscriber with a buffered msgs chan to give some room to slower 108 | // connections and then registers the subscriber. It then listens for all messages 109 | // and writes them to the WebSocket. If the context is cancelled or 110 | // an error occurs, it returns and deletes the subscription. 111 | // 112 | // It uses CloseRead to keep reading from the connection to process control 113 | // messages and cancel the context if the connection drops. 114 | func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error { 115 | var mu sync.Mutex 116 | var c *websocket.Conn 117 | var closed bool 118 | s := &subscriber{ 119 | msgs: make(chan []byte, cs.subscriberMessageBuffer), 120 | closeSlow: func() { 121 | mu.Lock() 122 | defer mu.Unlock() 123 | closed = true 124 | if c != nil { 125 | c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") 126 | } 127 | }, 128 | } 129 | cs.addSubscriber(s) 130 | defer cs.deleteSubscriber(s) 131 | 132 | c2, err := websocket.Accept(w, r, nil) 133 | if err != nil { 134 | return err 135 | } 136 | mu.Lock() 137 | if closed { 138 | mu.Unlock() 139 | return net.ErrClosed 140 | } 141 | c = c2 142 | mu.Unlock() 143 | defer c.CloseNow() 144 | 145 | ctx := c.CloseRead(context.Background()) 146 | 147 | for { 148 | select { 149 | case msg := <-s.msgs: 150 | err := writeTimeout(ctx, time.Second*5, c, msg) 151 | if err != nil { 152 | return err 153 | } 154 | case <-ctx.Done(): 155 | return ctx.Err() 156 | } 157 | } 158 | } 159 | 160 | // publish publishes the msg to all subscribers. 161 | // It never blocks and so messages to slow subscribers 162 | // are dropped. 163 | func (cs *chatServer) publish(msg []byte) { 164 | cs.subscribersMu.Lock() 165 | defer cs.subscribersMu.Unlock() 166 | 167 | cs.publishLimiter.Wait(context.Background()) 168 | 169 | for s := range cs.subscribers { 170 | select { 171 | case s.msgs <- msg: 172 | default: 173 | go s.closeSlow() 174 | } 175 | } 176 | } 177 | 178 | // addSubscriber registers a subscriber. 179 | func (cs *chatServer) addSubscriber(s *subscriber) { 180 | cs.subscribersMu.Lock() 181 | cs.subscribers[s] = struct{}{} 182 | cs.subscribersMu.Unlock() 183 | } 184 | 185 | // deleteSubscriber deletes the given subscriber. 186 | func (cs *chatServer) deleteSubscriber(s *subscriber) { 187 | cs.subscribersMu.Lock() 188 | delete(cs.subscribers, s) 189 | cs.subscribersMu.Unlock() 190 | } 191 | 192 | func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { 193 | ctx, cancel := context.WithTimeout(ctx, timeout) 194 | defer cancel() 195 | 196 | return c.Write(ctx, websocket.MessageText, msg) 197 | } 198 | -------------------------------------------------------------------------------- /internal/examples/chat/chat_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "fmt" 7 | "math/big" 8 | "net/http" 9 | "net/http/httptest" 10 | "strings" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "golang.org/x/time/rate" 16 | 17 | "github.com/coder/websocket" 18 | ) 19 | 20 | func Test_chatServer(t *testing.T) { 21 | t.Parallel() 22 | 23 | // This is a simple echo test with a single client. 24 | // The client sends a message and ensures it receives 25 | // it on its WebSocket. 26 | t.Run("simple", func(t *testing.T) { 27 | t.Parallel() 28 | 29 | url, closeFn := setupTest(t) 30 | defer closeFn() 31 | 32 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 33 | defer cancel() 34 | 35 | cl, err := newClient(ctx, url) 36 | assertSuccess(t, err) 37 | defer cl.Close() 38 | 39 | expMsg := randString(512) 40 | err = cl.publish(ctx, expMsg) 41 | assertSuccess(t, err) 42 | 43 | msg, err := cl.nextMessage() 44 | assertSuccess(t, err) 45 | 46 | if expMsg != msg { 47 | t.Fatalf("expected %v but got %v", expMsg, msg) 48 | } 49 | }) 50 | 51 | // This test is a complex concurrency test. 52 | // 10 clients are started that send 128 different 53 | // messages of max 128 bytes concurrently. 54 | // 55 | // The test verifies that every message is seen by every client 56 | // and no errors occur anywhere. 57 | t.Run("concurrency", func(t *testing.T) { 58 | t.Parallel() 59 | 60 | const nmessages = 128 61 | const maxMessageSize = 128 62 | const nclients = 16 63 | 64 | url, closeFn := setupTest(t) 65 | defer closeFn() 66 | 67 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 68 | defer cancel() 69 | 70 | var clients []*client 71 | var clientMsgs []map[string]struct{} 72 | for i := 0; i < nclients; i++ { 73 | cl, err := newClient(ctx, url) 74 | assertSuccess(t, err) 75 | defer cl.Close() 76 | 77 | clients = append(clients, cl) 78 | clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) 79 | } 80 | 81 | allMessages := make(map[string]struct{}) 82 | for _, msgs := range clientMsgs { 83 | for m := range msgs { 84 | allMessages[m] = struct{}{} 85 | } 86 | } 87 | 88 | var wg sync.WaitGroup 89 | for i, cl := range clients { 90 | i := i 91 | cl := cl 92 | 93 | wg.Add(1) 94 | go func() { 95 | defer wg.Done() 96 | err := cl.publishMsgs(ctx, clientMsgs[i]) 97 | if err != nil { 98 | t.Errorf("client %d failed to publish all messages: %v", i, err) 99 | } 100 | }() 101 | 102 | wg.Add(1) 103 | go func() { 104 | defer wg.Done() 105 | err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) 106 | if err != nil { 107 | t.Errorf("client %d failed to receive all messages: %v", i, err) 108 | } 109 | }() 110 | } 111 | 112 | wg.Wait() 113 | }) 114 | } 115 | 116 | // setupTest sets up chatServer that can be used 117 | // via the returned url. 118 | // 119 | // Defer closeFn to ensure everything is cleaned up at 120 | // the end of the test. 121 | // 122 | // chatServer logs will be logged via t.Logf. 123 | func setupTest(t *testing.T) (url string, closeFn func()) { 124 | cs := newChatServer() 125 | cs.logf = t.Logf 126 | 127 | // To ensure tests run quickly under even -race. 128 | cs.subscriberMessageBuffer = 4096 129 | cs.publishLimiter.SetLimit(rate.Inf) 130 | 131 | s := httptest.NewServer(cs) 132 | return s.URL, func() { 133 | s.Close() 134 | } 135 | } 136 | 137 | // testAllMessagesReceived ensures that after n reads, all msgs in msgs 138 | // have been read. 139 | func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { 140 | msgs = cloneMessages(msgs) 141 | 142 | for i := 0; i < n; i++ { 143 | msg, err := cl.nextMessage() 144 | if err != nil { 145 | return err 146 | } 147 | delete(msgs, msg) 148 | } 149 | 150 | if len(msgs) != 0 { 151 | return fmt.Errorf("did not receive all expected messages: %q", msgs) 152 | } 153 | return nil 154 | } 155 | 156 | func cloneMessages(msgs map[string]struct{}) map[string]struct{} { 157 | msgs2 := make(map[string]struct{}, len(msgs)) 158 | for m := range msgs { 159 | msgs2[m] = struct{}{} 160 | } 161 | return msgs2 162 | } 163 | 164 | func randMessages(n, maxMessageLength int) map[string]struct{} { 165 | msgs := make(map[string]struct{}) 166 | for i := 0; i < n; i++ { 167 | m := randString(randInt(maxMessageLength)) 168 | if _, ok := msgs[m]; ok { 169 | i-- 170 | continue 171 | } 172 | msgs[m] = struct{}{} 173 | } 174 | return msgs 175 | } 176 | 177 | func assertSuccess(t *testing.T, err error) { 178 | t.Helper() 179 | if err != nil { 180 | t.Fatal(err) 181 | } 182 | } 183 | 184 | type client struct { 185 | url string 186 | c *websocket.Conn 187 | } 188 | 189 | func newClient(ctx context.Context, url string) (*client, error) { 190 | c, _, err := websocket.Dial(ctx, url+"/subscribe", nil) 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | cl := &client{ 196 | url: url, 197 | c: c, 198 | } 199 | 200 | return cl, nil 201 | } 202 | 203 | func (cl *client) publish(ctx context.Context, msg string) (err error) { 204 | defer func() { 205 | if err != nil { 206 | cl.c.Close(websocket.StatusInternalError, "publish failed") 207 | } 208 | }() 209 | 210 | req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) 211 | resp, err := http.DefaultClient.Do(req) 212 | if err != nil { 213 | return err 214 | } 215 | defer resp.Body.Close() 216 | if resp.StatusCode != http.StatusAccepted { 217 | return fmt.Errorf("publish request failed: %v", resp.StatusCode) 218 | } 219 | return nil 220 | } 221 | 222 | func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { 223 | for m := range msgs { 224 | err := cl.publish(ctx, m) 225 | if err != nil { 226 | return err 227 | } 228 | } 229 | return nil 230 | } 231 | 232 | func (cl *client) nextMessage() (string, error) { 233 | typ, b, err := cl.c.Read(context.Background()) 234 | if err != nil { 235 | return "", err 236 | } 237 | 238 | if typ != websocket.MessageText { 239 | cl.c.Close(websocket.StatusUnsupportedData, "expected text message") 240 | return "", fmt.Errorf("expected text message but got %v", typ) 241 | } 242 | return string(b), nil 243 | } 244 | 245 | func (cl *client) Close() error { 246 | return cl.c.Close(websocket.StatusNormalClosure, "") 247 | } 248 | 249 | // randString generates a random string with length n. 250 | func randString(n int) string { 251 | b := make([]byte, n) 252 | _, err := rand.Reader.Read(b) 253 | if err != nil { 254 | panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) 255 | } 256 | 257 | s := strings.ToValidUTF8(string(b), "_") 258 | s = strings.ReplaceAll(s, "\x00", "_") 259 | if len(s) > n { 260 | return s[:n] 261 | } 262 | if len(s) < n { 263 | // Pad with = 264 | extra := n - len(s) 265 | return s + strings.Repeat("=", extra) 266 | } 267 | return s 268 | } 269 | 270 | // randInt returns a randomly generated integer between [0, max). 271 | func randInt(max int) int { 272 | x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) 273 | if err != nil { 274 | panic(fmt.Sprintf("failed to get random int: %v", err)) 275 | } 276 | return int(x.Int64()) 277 | } 278 | -------------------------------------------------------------------------------- /internal/examples/chat/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | width: 100vw; 3 | min-width: 320px; 4 | } 5 | 6 | #root { 7 | padding: 40px 20px; 8 | max-width: 600px; 9 | margin: auto; 10 | height: 100vh; 11 | 12 | display: flex; 13 | flex-direction: column; 14 | align-items: center; 15 | justify-content: center; 16 | } 17 | 18 | #root > * + * { 19 | margin: 20px 0 0 0; 20 | } 21 | 22 | /* 100vh on safari does not include the bottom bar. */ 23 | @supports (-webkit-overflow-scrolling: touch) { 24 | #root { 25 | height: 85vh; 26 | } 27 | } 28 | 29 | #message-log { 30 | width: 100%; 31 | flex-grow: 1; 32 | overflow: auto; 33 | } 34 | 35 | #message-log p:first-child { 36 | margin: 0; 37 | } 38 | 39 | #message-log > * + * { 40 | margin: 10px 0 0 0; 41 | } 42 | 43 | #publish-form-container { 44 | width: 100%; 45 | } 46 | 47 | #publish-form { 48 | width: 100%; 49 | display: flex; 50 | height: 40px; 51 | } 52 | 53 | #publish-form > * + * { 54 | margin: 0 0 0 10px; 55 | } 56 | 57 | #publish-form input[type='text'] { 58 | flex-grow: 1; 59 | 60 | -moz-appearance: none; 61 | -webkit-appearance: none; 62 | word-break: normal; 63 | border-radius: 5px; 64 | border: 1px solid #ccc; 65 | } 66 | 67 | #publish-form input[type='submit'] { 68 | color: white; 69 | background-color: black; 70 | border-radius: 5px; 71 | padding: 5px 10px; 72 | border: none; 73 | } 74 | 75 | #publish-form input[type='submit']:hover { 76 | background-color: red; 77 | } 78 | 79 | #publish-form input[type='submit']:active { 80 | background-color: red; 81 | } 82 | -------------------------------------------------------------------------------- /internal/examples/chat/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | github.com/coder/websocket - Chat Example 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 |
16 |
17 |
18 | 19 | 20 |
21 |
22 |
23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /internal/examples/chat/index.js: -------------------------------------------------------------------------------- 1 | ;(() => { 2 | // expectingMessage is set to true 3 | // if the user has just submitted a message 4 | // and so we should scroll the next message into view when received. 5 | let expectingMessage = false 6 | function dial() { 7 | const conn = new WebSocket(`ws://${location.host}/subscribe`) 8 | 9 | conn.addEventListener('close', ev => { 10 | appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) 11 | if (ev.code !== 1001) { 12 | appendLog('Reconnecting in 1s', true) 13 | setTimeout(dial, 1000) 14 | } 15 | }) 16 | conn.addEventListener('open', ev => { 17 | console.info('websocket connected') 18 | }) 19 | 20 | // This is where we handle messages received. 21 | conn.addEventListener('message', ev => { 22 | if (typeof ev.data !== 'string') { 23 | console.error('unexpected message type', typeof ev.data) 24 | return 25 | } 26 | const p = appendLog(ev.data) 27 | if (expectingMessage) { 28 | p.scrollIntoView() 29 | expectingMessage = false 30 | } 31 | }) 32 | } 33 | dial() 34 | 35 | const messageLog = document.getElementById('message-log') 36 | const publishForm = document.getElementById('publish-form') 37 | const messageInput = document.getElementById('message-input') 38 | 39 | // appendLog appends the passed text to messageLog. 40 | function appendLog(text, error) { 41 | const p = document.createElement('p') 42 | // Adding a timestamp to each message makes the log easier to read. 43 | p.innerText = `${new Date().toLocaleTimeString()}: ${text}` 44 | if (error) { 45 | p.style.color = 'red' 46 | p.style.fontStyle = 'bold' 47 | } 48 | messageLog.append(p) 49 | return p 50 | } 51 | appendLog('Submit a message to get started!') 52 | 53 | // onsubmit publishes the message from the user when the form is submitted. 54 | publishForm.onsubmit = async ev => { 55 | ev.preventDefault() 56 | 57 | const msg = messageInput.value 58 | if (msg === '') { 59 | return 60 | } 61 | messageInput.value = '' 62 | 63 | expectingMessage = true 64 | try { 65 | const resp = await fetch('/publish', { 66 | method: 'POST', 67 | body: msg, 68 | }) 69 | if (resp.status !== 202) { 70 | throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`) 71 | } 72 | } catch (err) { 73 | appendLog(`Publish failed: ${err.message}`, true) 74 | } 75 | } 76 | })() 77 | -------------------------------------------------------------------------------- /internal/examples/chat/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log" 7 | "net" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "time" 12 | ) 13 | 14 | func main() { 15 | log.SetFlags(0) 16 | 17 | err := run() 18 | if err != nil { 19 | log.Fatal(err) 20 | } 21 | } 22 | 23 | // run initializes the chatServer and then 24 | // starts a http.Server for the passed in address. 25 | func run() error { 26 | if len(os.Args) < 2 { 27 | return errors.New("please provide an address to listen on as the first argument") 28 | } 29 | 30 | l, err := net.Listen("tcp", os.Args[1]) 31 | if err != nil { 32 | return err 33 | } 34 | log.Printf("listening on ws://%v", l.Addr()) 35 | 36 | cs := newChatServer() 37 | s := &http.Server{ 38 | Handler: cs, 39 | ReadTimeout: time.Second * 10, 40 | WriteTimeout: time.Second * 10, 41 | } 42 | errc := make(chan error, 1) 43 | go func() { 44 | errc <- s.Serve(l) 45 | }() 46 | 47 | sigs := make(chan os.Signal, 1) 48 | signal.Notify(sigs, os.Interrupt) 49 | select { 50 | case err := <-errc: 51 | log.Printf("failed to serve: %v", err) 52 | case sig := <-sigs: 53 | log.Printf("terminating: %v", sig) 54 | } 55 | 56 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 57 | defer cancel() 58 | 59 | return s.Shutdown(ctx) 60 | } 61 | -------------------------------------------------------------------------------- /internal/examples/echo/README.md: -------------------------------------------------------------------------------- 1 | # Echo Example 2 | 3 | This directory contains a echo server example using github.com/coder/websocket. 4 | 5 | ```bash 6 | $ cd examples/echo 7 | $ go run . localhost:0 8 | listening on ws://127.0.0.1:51055 9 | ``` 10 | 11 | You can use a WebSocket client like https://github.com/hashrocket/ws to connect. All messages 12 | written will be echoed back. 13 | 14 | ## Structure 15 | 16 | The server is in `server.go` and is implemented as a `http.HandlerFunc` that accepts the WebSocket 17 | and then reads all messages and writes them exactly as is back to the connection. 18 | 19 | `server_test.go` contains a small unit test to verify it works correctly. 20 | 21 | `main.go` brings it all together so that you can run it and play around with it. 22 | -------------------------------------------------------------------------------- /internal/examples/echo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log" 7 | "net" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "time" 12 | ) 13 | 14 | func main() { 15 | log.SetFlags(0) 16 | 17 | err := run() 18 | if err != nil { 19 | log.Fatal(err) 20 | } 21 | } 22 | 23 | // run starts a http.Server for the passed in address 24 | // with all requests handled by echoServer. 25 | func run() error { 26 | if len(os.Args) < 2 { 27 | return errors.New("please provide an address to listen on as the first argument") 28 | } 29 | 30 | l, err := net.Listen("tcp", os.Args[1]) 31 | if err != nil { 32 | return err 33 | } 34 | log.Printf("listening on ws://%v", l.Addr()) 35 | 36 | s := &http.Server{ 37 | Handler: echoServer{ 38 | logf: log.Printf, 39 | }, 40 | ReadTimeout: time.Second * 10, 41 | WriteTimeout: time.Second * 10, 42 | } 43 | errc := make(chan error, 1) 44 | go func() { 45 | errc <- s.Serve(l) 46 | }() 47 | 48 | sigs := make(chan os.Signal, 1) 49 | signal.Notify(sigs, os.Interrupt) 50 | select { 51 | case err := <-errc: 52 | log.Printf("failed to serve: %v", err) 53 | case sig := <-sigs: 54 | log.Printf("terminating: %v", sig) 55 | } 56 | 57 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 58 | defer cancel() 59 | 60 | return s.Shutdown(ctx) 61 | } 62 | -------------------------------------------------------------------------------- /internal/examples/echo/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "time" 9 | 10 | "golang.org/x/time/rate" 11 | 12 | "github.com/coder/websocket" 13 | ) 14 | 15 | // echoServer is the WebSocket echo server implementation. 16 | // It ensures the client speaks the echo subprotocol and 17 | // only allows one message every 100ms with a 10 message burst. 18 | type echoServer struct { 19 | // logf controls where logs are sent. 20 | logf func(f string, v ...interface{}) 21 | } 22 | 23 | func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 24 | c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ 25 | Subprotocols: []string{"echo"}, 26 | }) 27 | if err != nil { 28 | s.logf("%v", err) 29 | return 30 | } 31 | defer c.CloseNow() 32 | 33 | if c.Subprotocol() != "echo" { 34 | c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") 35 | return 36 | } 37 | 38 | l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) 39 | for { 40 | err = echo(c, l) 41 | if websocket.CloseStatus(err) == websocket.StatusNormalClosure { 42 | return 43 | } 44 | if err != nil { 45 | s.logf("failed to echo with %v: %v", r.RemoteAddr, err) 46 | return 47 | } 48 | } 49 | } 50 | 51 | // echo reads from the WebSocket connection and then writes 52 | // the received message back to it. 53 | // The entire function has 10s to complete. 54 | func echo(c *websocket.Conn, l *rate.Limiter) error { 55 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 56 | defer cancel() 57 | 58 | err := l.Wait(ctx) 59 | if err != nil { 60 | return err 61 | } 62 | 63 | typ, r, err := c.Reader(ctx) 64 | if err != nil { 65 | return err 66 | } 67 | 68 | w, err := c.Writer(ctx, typ) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | _, err = io.Copy(w, r) 74 | if err != nil { 75 | return fmt.Errorf("failed to io.Copy: %w", err) 76 | } 77 | 78 | err = w.Close() 79 | return err 80 | } 81 | -------------------------------------------------------------------------------- /internal/examples/echo/server_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "net/http/httptest" 6 | "testing" 7 | "time" 8 | 9 | "github.com/coder/websocket" 10 | "github.com/coder/websocket/wsjson" 11 | ) 12 | 13 | // Test_echoServer tests the echoServer by sending it 5 different messages 14 | // and ensuring the responses all match. 15 | func Test_echoServer(t *testing.T) { 16 | t.Parallel() 17 | 18 | s := httptest.NewServer(echoServer{ 19 | logf: t.Logf, 20 | }) 21 | defer s.Close() 22 | 23 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 24 | defer cancel() 25 | 26 | c, _, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ 27 | Subprotocols: []string{"echo"}, 28 | }) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | defer c.Close(websocket.StatusInternalError, "the sky is falling") 33 | 34 | for i := 0; i < 5; i++ { 35 | err = wsjson.Write(ctx, c, map[string]int{ 36 | "i": i, 37 | }) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | v := map[string]int{} 43 | err = wsjson.Read(ctx, c, &v) 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | 48 | if v["i"] != i { 49 | t.Fatalf("expected %v but got %v", i, v) 50 | } 51 | } 52 | 53 | c.Close(websocket.StatusNormalClosure, "") 54 | } 55 | -------------------------------------------------------------------------------- /internal/examples/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coder/websocket/examples 2 | 3 | go 1.23 4 | 5 | replace github.com/coder/websocket => ../.. 6 | 7 | require ( 8 | github.com/coder/websocket v0.0.0-00010101000000-000000000000 9 | golang.org/x/time v0.7.0 10 | ) 11 | -------------------------------------------------------------------------------- /internal/examples/go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= 2 | golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 3 | -------------------------------------------------------------------------------- /internal/test/assert/assert.go: -------------------------------------------------------------------------------- 1 | package assert 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | // Equal asserts exp == act. 12 | func Equal(t testing.TB, name string, exp, got interface{}) { 13 | t.Helper() 14 | 15 | if !reflect.DeepEqual(exp, got) { 16 | t.Fatalf("unexpected %v: expected %#v but got %#v", name, exp, got) 17 | } 18 | } 19 | 20 | // Success asserts err == nil. 21 | func Success(t testing.TB, err error) { 22 | t.Helper() 23 | 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | } 28 | 29 | // Error asserts err != nil. 30 | func Error(t testing.TB, err error) { 31 | t.Helper() 32 | 33 | if err == nil { 34 | t.Fatal("expected error") 35 | } 36 | } 37 | 38 | // Contains asserts the fmt.Sprint(v) contains sub. 39 | func Contains(t testing.TB, v interface{}, sub string) { 40 | t.Helper() 41 | 42 | s := fmt.Sprint(v) 43 | if !strings.Contains(s, sub) { 44 | t.Fatalf("expected %q to contain %q", s, sub) 45 | } 46 | } 47 | 48 | // ErrorIs asserts errors.Is(got, exp) 49 | func ErrorIs(t testing.TB, exp, got error) { 50 | t.Helper() 51 | 52 | if !errors.Is(got, exp) { 53 | t.Fatalf("expected %v but got %v", exp, got) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /internal/test/doc.go: -------------------------------------------------------------------------------- 1 | // Package test contains subpackages only used in tests. 2 | package test 3 | -------------------------------------------------------------------------------- /internal/test/wstest/echo.go: -------------------------------------------------------------------------------- 1 | package wstest 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "time" 9 | 10 | "github.com/coder/websocket" 11 | "github.com/coder/websocket/internal/test/xrand" 12 | "github.com/coder/websocket/internal/xsync" 13 | ) 14 | 15 | // EchoLoop echos every msg received from c until an error 16 | // occurs or the context expires. 17 | // The read limit is set to 1 << 30. 18 | func EchoLoop(ctx context.Context, c *websocket.Conn) error { 19 | defer c.Close(websocket.StatusInternalError, "") 20 | 21 | c.SetReadLimit(1 << 30) 22 | 23 | ctx, cancel := context.WithTimeout(ctx, time.Minute*5) 24 | defer cancel() 25 | 26 | b := make([]byte, 32<<10) 27 | for { 28 | typ, r, err := c.Reader(ctx) 29 | if err != nil { 30 | return err 31 | } 32 | 33 | w, err := c.Writer(ctx, typ) 34 | if err != nil { 35 | return err 36 | } 37 | 38 | _, err = io.CopyBuffer(w, r, b) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | err = w.Close() 44 | if err != nil { 45 | return err 46 | } 47 | } 48 | } 49 | 50 | // Echo writes a message and ensures the same is sent back on c. 51 | func Echo(ctx context.Context, c *websocket.Conn, max int) error { 52 | expType := websocket.MessageBinary 53 | if xrand.Bool() { 54 | expType = websocket.MessageText 55 | } 56 | 57 | msg := randMessage(expType, xrand.Int(max)) 58 | 59 | writeErr := xsync.Go(func() error { 60 | return c.Write(ctx, expType, msg) 61 | }) 62 | 63 | actType, act, err := c.Read(ctx) 64 | if err != nil { 65 | return err 66 | } 67 | 68 | err = <-writeErr 69 | if err != nil { 70 | return err 71 | } 72 | 73 | if expType != actType { 74 | return fmt.Errorf("unexpected message typ (%v): %v", expType, actType) 75 | } 76 | 77 | if !bytes.Equal(msg, act) { 78 | return fmt.Errorf("unexpected msg read: %#v", act) 79 | } 80 | 81 | return nil 82 | } 83 | 84 | func randMessage(typ websocket.MessageType, n int) []byte { 85 | if typ == websocket.MessageBinary { 86 | return xrand.Bytes(n) 87 | } 88 | return []byte(xrand.String(n)) 89 | } 90 | -------------------------------------------------------------------------------- /internal/test/wstest/pipe.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package wstest 5 | 6 | import ( 7 | "bufio" 8 | "context" 9 | "net" 10 | "net/http" 11 | "net/http/httptest" 12 | 13 | "github.com/coder/websocket" 14 | ) 15 | 16 | // Pipe is used to create an in memory connection 17 | // between two websockets analogous to net.Pipe. 18 | func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) { 19 | tt := fakeTransport{ 20 | h: func(w http.ResponseWriter, r *http.Request) { 21 | serverConn, _ = websocket.Accept(w, r, acceptOpts) 22 | }, 23 | } 24 | 25 | if dialOpts == nil { 26 | dialOpts = &websocket.DialOptions{} 27 | } 28 | _dialOpts := *dialOpts 29 | dialOpts = &_dialOpts 30 | dialOpts.HTTPClient = &http.Client{ 31 | Transport: tt, 32 | } 33 | 34 | clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts) 35 | return clientConn, serverConn 36 | } 37 | 38 | type fakeTransport struct { 39 | h http.HandlerFunc 40 | } 41 | 42 | func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) { 43 | clientConn, serverConn := net.Pipe() 44 | 45 | hj := testHijacker{ 46 | ResponseRecorder: httptest.NewRecorder(), 47 | serverConn: serverConn, 48 | } 49 | 50 | t.h.ServeHTTP(hj, r) 51 | 52 | resp := hj.ResponseRecorder.Result() 53 | if resp.StatusCode == http.StatusSwitchingProtocols { 54 | resp.Body = clientConn 55 | } 56 | return resp, nil 57 | } 58 | 59 | type testHijacker struct { 60 | *httptest.ResponseRecorder 61 | serverConn net.Conn 62 | } 63 | 64 | var _ http.Hijacker = testHijacker{} 65 | 66 | func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { 67 | return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil 68 | } 69 | -------------------------------------------------------------------------------- /internal/test/xrand/xrand.go: -------------------------------------------------------------------------------- 1 | package xrand 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/base64" 6 | "fmt" 7 | "math/big" 8 | "strings" 9 | ) 10 | 11 | // Bytes generates random bytes with length n. 12 | func Bytes(n int) []byte { 13 | b := make([]byte, n) 14 | _, err := rand.Reader.Read(b) 15 | if err != nil { 16 | panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) 17 | } 18 | return b 19 | } 20 | 21 | // String generates a random string with length n. 22 | func String(n int) string { 23 | s := strings.ToValidUTF8(string(Bytes(n)), "_") 24 | s = strings.ReplaceAll(s, "\x00", "_") 25 | if len(s) > n { 26 | return s[:n] 27 | } 28 | if len(s) < n { 29 | // Pad with = 30 | extra := n - len(s) 31 | return s + strings.Repeat("=", extra) 32 | } 33 | return s 34 | } 35 | 36 | // Bool returns a randomly generated boolean. 37 | func Bool() bool { 38 | return Int(2) == 1 39 | } 40 | 41 | // Int returns a randomly generated integer between [0, max). 42 | func Int(max int) int { 43 | x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) 44 | if err != nil { 45 | panic(fmt.Sprintf("failed to get random int: %v", err)) 46 | } 47 | return int(x.Int64()) 48 | } 49 | 50 | // Base64 returns a randomly generated base64 string of length n. 51 | func Base64(n int) string { 52 | return base64.StdEncoding.EncodeToString(Bytes(n)) 53 | } 54 | -------------------------------------------------------------------------------- /internal/thirdparty/doc.go: -------------------------------------------------------------------------------- 1 | // Package thirdparty contains third party benchmarks and tests. 2 | package thirdparty 3 | -------------------------------------------------------------------------------- /internal/thirdparty/frame_test.go: -------------------------------------------------------------------------------- 1 | package thirdparty 2 | 3 | import ( 4 | "encoding/binary" 5 | "runtime" 6 | "strconv" 7 | "testing" 8 | _ "unsafe" 9 | 10 | "github.com/gobwas/ws" 11 | _ "github.com/gorilla/websocket" 12 | _ "github.com/lesismal/nbio/nbhttp/websocket" 13 | 14 | _ "github.com/coder/websocket" 15 | ) 16 | 17 | func basicMask(b []byte, maskKey [4]byte, pos int) int { 18 | for i := range b { 19 | b[i] ^= maskKey[pos&3] 20 | pos++ 21 | } 22 | return pos & 3 23 | } 24 | 25 | //go:linkname maskGo github.com/coder/websocket.maskGo 26 | func maskGo(b []byte, key32 uint32) int 27 | 28 | //go:linkname maskAsm github.com/coder/websocket.maskAsm 29 | func maskAsm(b *byte, len int, key32 uint32) uint32 30 | 31 | //go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR 32 | func nbioMaskBytes(b, key []byte) int 33 | 34 | //go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes 35 | func gorillaMaskBytes(key [4]byte, pos int, b []byte) int 36 | 37 | func Benchmark_mask(b *testing.B) { 38 | b.Run(runtime.GOARCH, benchmark_mask) 39 | } 40 | 41 | func benchmark_mask(b *testing.B) { 42 | sizes := []int{ 43 | 8, 44 | 16, 45 | 32, 46 | 128, 47 | 256, 48 | 512, 49 | 1024, 50 | 2048, 51 | 4096, 52 | 8192, 53 | 16384, 54 | } 55 | 56 | fns := []struct { 57 | name string 58 | fn func(b *testing.B, key [4]byte, p []byte) 59 | }{ 60 | { 61 | name: "basic", 62 | fn: func(b *testing.B, key [4]byte, p []byte) { 63 | for i := 0; i < b.N; i++ { 64 | basicMask(p, key, 0) 65 | } 66 | }, 67 | }, 68 | 69 | { 70 | name: "nhooyr-go", 71 | fn: func(b *testing.B, key [4]byte, p []byte) { 72 | key32 := binary.LittleEndian.Uint32(key[:]) 73 | b.ResetTimer() 74 | 75 | for i := 0; i < b.N; i++ { 76 | maskGo(p, key32) 77 | } 78 | }, 79 | }, 80 | { 81 | name: "wdvxdr1123-asm", 82 | fn: func(b *testing.B, key [4]byte, p []byte) { 83 | key32 := binary.LittleEndian.Uint32(key[:]) 84 | b.ResetTimer() 85 | 86 | for i := 0; i < b.N; i++ { 87 | maskAsm(&p[0], len(p), key32) 88 | } 89 | }, 90 | }, 91 | 92 | { 93 | name: "gorilla", 94 | fn: func(b *testing.B, key [4]byte, p []byte) { 95 | for i := 0; i < b.N; i++ { 96 | gorillaMaskBytes(key, 0, p) 97 | } 98 | }, 99 | }, 100 | { 101 | name: "gobwas", 102 | fn: func(b *testing.B, key [4]byte, p []byte) { 103 | for i := 0; i < b.N; i++ { 104 | ws.Cipher(p, key, 0) 105 | } 106 | }, 107 | }, 108 | { 109 | name: "nbio", 110 | fn: func(b *testing.B, key [4]byte, p []byte) { 111 | keyb := key[:] 112 | for i := 0; i < b.N; i++ { 113 | nbioMaskBytes(p, keyb) 114 | } 115 | }, 116 | }, 117 | } 118 | 119 | key := [4]byte{1, 2, 3, 4} 120 | 121 | for _, fn := range fns { 122 | b.Run(fn.name, func(b *testing.B) { 123 | for _, size := range sizes { 124 | p := make([]byte, size) 125 | 126 | b.Run(strconv.Itoa(size), func(b *testing.B) { 127 | b.SetBytes(int64(size)) 128 | 129 | fn.fn(b, key, p) 130 | }) 131 | } 132 | }) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /internal/thirdparty/gin_test.go: -------------------------------------------------------------------------------- 1 | package thirdparty 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | "github.com/gin-gonic/gin" 12 | 13 | "github.com/coder/websocket" 14 | "github.com/coder/websocket/internal/errd" 15 | "github.com/coder/websocket/internal/test/assert" 16 | "github.com/coder/websocket/internal/test/wstest" 17 | "github.com/coder/websocket/wsjson" 18 | ) 19 | 20 | func TestGin(t *testing.T) { 21 | t.Parallel() 22 | 23 | gin.SetMode(gin.ReleaseMode) 24 | r := gin.New() 25 | r.GET("/", func(ginCtx *gin.Context) { 26 | err := echoServer(ginCtx.Writer, ginCtx.Request, nil) 27 | if err != nil { 28 | t.Error(err) 29 | } 30 | }) 31 | 32 | s := httptest.NewServer(r) 33 | defer s.Close() 34 | 35 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) 36 | defer cancel() 37 | 38 | c, _, err := websocket.Dial(ctx, s.URL, nil) 39 | assert.Success(t, err) 40 | defer c.Close(websocket.StatusInternalError, "") 41 | 42 | err = wsjson.Write(ctx, c, "hello") 43 | assert.Success(t, err) 44 | 45 | var v interface{} 46 | err = wsjson.Read(ctx, c, &v) 47 | assert.Success(t, err) 48 | assert.Equal(t, "read msg", "hello", v) 49 | 50 | err = c.Close(websocket.StatusNormalClosure, "") 51 | assert.Success(t, err) 52 | } 53 | 54 | func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) { 55 | defer errd.Wrap(&err, "echo server failed") 56 | 57 | c, err := websocket.Accept(w, r, opts) 58 | if err != nil { 59 | return err 60 | } 61 | defer c.Close(websocket.StatusInternalError, "") 62 | 63 | err = wstest.EchoLoop(r.Context(), c) 64 | return assertCloseStatus(websocket.StatusNormalClosure, err) 65 | } 66 | 67 | func assertCloseStatus(exp websocket.StatusCode, err error) error { 68 | if websocket.CloseStatus(err) == -1 { 69 | return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) 70 | } 71 | if websocket.CloseStatus(err) != exp { 72 | return fmt.Errorf("expected close status %v but got %v", exp, err) 73 | } 74 | return nil 75 | } 76 | -------------------------------------------------------------------------------- /internal/thirdparty/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coder/websocket/internal/thirdparty 2 | 3 | go 1.23 4 | 5 | replace github.com/coder/websocket => ../.. 6 | 7 | require ( 8 | github.com/coder/websocket v0.0.0-00010101000000-000000000000 9 | github.com/gin-gonic/gin v1.10.0 10 | github.com/gobwas/ws v1.4.0 11 | github.com/gorilla/websocket v1.5.3 12 | github.com/lesismal/nbio v1.5.12 13 | ) 14 | 15 | require ( 16 | github.com/bytedance/sonic v1.11.6 // indirect 17 | github.com/bytedance/sonic/loader v0.1.1 // indirect 18 | github.com/cloudwego/base64x v0.1.4 // indirect 19 | github.com/cloudwego/iasm v0.2.0 // indirect 20 | github.com/gabriel-vasile/mimetype v1.4.3 // indirect 21 | github.com/gin-contrib/sse v0.1.0 // indirect 22 | github.com/go-playground/locales v0.14.1 // indirect 23 | github.com/go-playground/universal-translator v0.18.1 // indirect 24 | github.com/go-playground/validator/v10 v10.20.0 // indirect 25 | github.com/gobwas/httphead v0.1.0 // indirect 26 | github.com/gobwas/pool v0.2.1 // indirect 27 | github.com/goccy/go-json v0.10.2 // indirect 28 | github.com/json-iterator/go v1.1.12 // indirect 29 | github.com/klauspost/cpuid/v2 v2.2.7 // indirect 30 | github.com/leodido/go-urn v1.4.0 // indirect 31 | github.com/lesismal/llib v1.1.13 // indirect 32 | github.com/mattn/go-isatty v0.0.20 // indirect 33 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 34 | github.com/modern-go/reflect2 v1.0.2 // indirect 35 | github.com/pelletier/go-toml/v2 v2.2.2 // indirect 36 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 37 | github.com/ugorji/go/codec v1.2.12 // indirect 38 | golang.org/x/arch v0.8.0 // indirect 39 | golang.org/x/crypto v0.23.0 // indirect 40 | golang.org/x/net v0.25.0 // indirect 41 | golang.org/x/sys v0.20.0 // indirect 42 | golang.org/x/text v0.15.0 // indirect 43 | google.golang.org/protobuf v1.34.1 // indirect 44 | gopkg.in/yaml.v3 v3.0.1 // indirect 45 | ) 46 | -------------------------------------------------------------------------------- /internal/thirdparty/go.sum: -------------------------------------------------------------------------------- 1 | github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= 2 | github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= 3 | github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= 4 | github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= 5 | github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= 6 | github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= 7 | github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= 8 | github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= 9 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 10 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 11 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 12 | github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= 13 | github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= 14 | github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= 15 | github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= 16 | github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= 17 | github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= 18 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= 19 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 20 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= 21 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= 22 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= 23 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= 24 | github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= 25 | github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= 26 | github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= 27 | github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= 28 | github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= 29 | github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= 30 | github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= 31 | github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= 32 | github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= 33 | github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= 34 | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= 35 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 36 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 37 | github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= 38 | github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 39 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 40 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 41 | github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= 42 | github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= 43 | github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= 44 | github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= 45 | github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= 46 | github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= 47 | github.com/lesismal/llib v1.1.13 h1:+w1+t0PykXpj2dXQck0+p6vdC9/mnbEXHgUy/HXDGfE= 48 | github.com/lesismal/llib v1.1.13/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= 49 | github.com/lesismal/nbio v1.5.12 h1:YcUjjmOvmKEANs6Oo175JogXvHy8CuE7i6ccjM2/tv4= 50 | github.com/lesismal/nbio v1.5.12/go.mod h1:QsxE0fKFe1PioyjuHVDn2y8ktYK7xv9MFbpkoRFj8vI= 51 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 52 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 53 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 54 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 55 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 56 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 57 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 58 | github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= 59 | github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= 60 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 61 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 62 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 63 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 64 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 65 | github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= 66 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 67 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 68 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 69 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 70 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 71 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 72 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 73 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 74 | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= 75 | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= 76 | github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= 77 | github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= 78 | golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 79 | golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= 80 | golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= 81 | golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= 82 | golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= 83 | golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= 84 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 85 | golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 86 | golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= 87 | golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= 88 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 89 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 90 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 91 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 92 | golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= 93 | golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 94 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 95 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 96 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 97 | golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= 98 | golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 99 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 100 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 101 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 102 | google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= 103 | google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 104 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 105 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 106 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 107 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 108 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 109 | nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= 110 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= 111 | -------------------------------------------------------------------------------- /internal/util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | // WriterFunc is used to implement one off io.Writers. 4 | type WriterFunc func(p []byte) (int, error) 5 | 6 | func (f WriterFunc) Write(p []byte) (int, error) { 7 | return f(p) 8 | } 9 | 10 | // ReaderFunc is used to implement one off io.Readers. 11 | type ReaderFunc func(p []byte) (int, error) 12 | 13 | func (f ReaderFunc) Read(p []byte) (int, error) { 14 | return f(p) 15 | } 16 | -------------------------------------------------------------------------------- /internal/wsjs/wsjs_js.go: -------------------------------------------------------------------------------- 1 | //go:build js 2 | // +build js 3 | 4 | // Package wsjs implements typed access to the browser javascript WebSocket API. 5 | // 6 | // https://developer.mozilla.org/en-US/docs/Web/API/WebSocket 7 | package wsjs 8 | 9 | import ( 10 | "syscall/js" 11 | ) 12 | 13 | func handleJSError(err *error, onErr func()) { 14 | r := recover() 15 | 16 | if jsErr, ok := r.(js.Error); ok { 17 | *err = jsErr 18 | 19 | if onErr != nil { 20 | onErr() 21 | } 22 | return 23 | } 24 | 25 | if r != nil { 26 | panic(r) 27 | } 28 | } 29 | 30 | // New is a wrapper around the javascript WebSocket constructor. 31 | func New(url string, protocols []string) (c WebSocket, err error) { 32 | defer handleJSError(&err, func() { 33 | c = WebSocket{} 34 | }) 35 | 36 | jsProtocols := make([]interface{}, len(protocols)) 37 | for i, p := range protocols { 38 | jsProtocols[i] = p 39 | } 40 | 41 | c = WebSocket{ 42 | v: js.Global().Get("WebSocket").New(url, jsProtocols), 43 | } 44 | 45 | c.setBinaryType("arraybuffer") 46 | 47 | return c, nil 48 | } 49 | 50 | // WebSocket is a wrapper around a javascript WebSocket object. 51 | type WebSocket struct { 52 | v js.Value 53 | } 54 | 55 | func (c WebSocket) setBinaryType(typ string) { 56 | c.v.Set("binaryType", string(typ)) 57 | } 58 | 59 | func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() { 60 | f := js.FuncOf(func(this js.Value, args []js.Value) interface{} { 61 | fn(args[0]) 62 | return nil 63 | }) 64 | c.v.Call("addEventListener", eventType, f) 65 | 66 | return func() { 67 | c.v.Call("removeEventListener", eventType, f) 68 | f.Release() 69 | } 70 | } 71 | 72 | // CloseEvent is the type passed to a WebSocket close handler. 73 | type CloseEvent struct { 74 | Code uint16 75 | Reason string 76 | WasClean bool 77 | } 78 | 79 | // OnClose registers a function to be called when the WebSocket is closed. 80 | func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) { 81 | return c.addEventListener("close", func(e js.Value) { 82 | ce := CloseEvent{ 83 | Code: uint16(e.Get("code").Int()), 84 | Reason: e.Get("reason").String(), 85 | WasClean: e.Get("wasClean").Bool(), 86 | } 87 | fn(ce) 88 | }) 89 | } 90 | 91 | // OnError registers a function to be called when there is an error 92 | // with the WebSocket. 93 | func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) { 94 | return c.addEventListener("error", fn) 95 | } 96 | 97 | // MessageEvent is the type passed to a message handler. 98 | type MessageEvent struct { 99 | // string or []byte. 100 | Data interface{} 101 | 102 | // There are more fields to the interface but we don't use them. 103 | // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent 104 | } 105 | 106 | // OnMessage registers a function to be called when the WebSocket receives a message. 107 | func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { 108 | return c.addEventListener("message", func(e js.Value) { 109 | var data interface{} 110 | 111 | arrayBuffer := e.Get("data") 112 | if arrayBuffer.Type() == js.TypeString { 113 | data = arrayBuffer.String() 114 | } else { 115 | data = extractArrayBuffer(arrayBuffer) 116 | } 117 | 118 | me := MessageEvent{ 119 | Data: data, 120 | } 121 | fn(me) 122 | }) 123 | } 124 | 125 | // Subprotocol returns the WebSocket subprotocol in use. 126 | func (c WebSocket) Subprotocol() string { 127 | return c.v.Get("protocol").String() 128 | } 129 | 130 | // OnOpen registers a function to be called when the WebSocket is opened. 131 | func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { 132 | return c.addEventListener("open", fn) 133 | } 134 | 135 | // Close closes the WebSocket with the given code and reason. 136 | func (c WebSocket) Close(code int, reason string) (err error) { 137 | defer handleJSError(&err, nil) 138 | c.v.Call("close", code, reason) 139 | return err 140 | } 141 | 142 | // SendText sends the given string as a text message 143 | // on the WebSocket. 144 | func (c WebSocket) SendText(v string) (err error) { 145 | defer handleJSError(&err, nil) 146 | c.v.Call("send", v) 147 | return err 148 | } 149 | 150 | // SendBytes sends the given message as a binary message 151 | // on the WebSocket. 152 | func (c WebSocket) SendBytes(v []byte) (err error) { 153 | defer handleJSError(&err, nil) 154 | c.v.Call("send", uint8Array(v)) 155 | return err 156 | } 157 | 158 | func extractArrayBuffer(arrayBuffer js.Value) []byte { 159 | uint8Array := js.Global().Get("Uint8Array").New(arrayBuffer) 160 | dst := make([]byte, uint8Array.Length()) 161 | js.CopyBytesToGo(dst, uint8Array) 162 | return dst 163 | } 164 | 165 | func uint8Array(src []byte) js.Value { 166 | uint8Array := js.Global().Get("Uint8Array").New(len(src)) 167 | js.CopyBytesToJS(uint8Array, src) 168 | return uint8Array 169 | } 170 | -------------------------------------------------------------------------------- /internal/xsync/go.go: -------------------------------------------------------------------------------- 1 | package xsync 2 | 3 | import ( 4 | "fmt" 5 | "runtime/debug" 6 | ) 7 | 8 | // Go allows running a function in another goroutine 9 | // and waiting for its error. 10 | func Go(fn func() error) <-chan error { 11 | errs := make(chan error, 1) 12 | go func() { 13 | defer func() { 14 | r := recover() 15 | if r != nil { 16 | select { 17 | case errs <- fmt.Errorf("panic in go fn: %v, %s", r, debug.Stack()): 18 | default: 19 | } 20 | } 21 | }() 22 | errs <- fn() 23 | }() 24 | 25 | return errs 26 | } 27 | -------------------------------------------------------------------------------- /internal/xsync/go_test.go: -------------------------------------------------------------------------------- 1 | package xsync 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/coder/websocket/internal/test/assert" 7 | ) 8 | 9 | func TestGoRecover(t *testing.T) { 10 | t.Parallel() 11 | 12 | errs := Go(func() error { 13 | panic("anmol") 14 | }) 15 | 16 | err := <-errs 17 | assert.Contains(t, err, "anmol") 18 | } 19 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package websocket_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "runtime" 7 | "testing" 8 | ) 9 | 10 | func goroutineStacks() []byte { 11 | buf := make([]byte, 512) 12 | for { 13 | m := runtime.Stack(buf, true) 14 | if m < len(buf) { 15 | return buf[:m] 16 | } 17 | buf = make([]byte, len(buf)*2) 18 | } 19 | } 20 | 21 | func TestMain(m *testing.M) { 22 | code := m.Run() 23 | if runtime.GOOS != "js" && runtime.NumGoroutine() != 1 || 24 | runtime.GOOS == "js" && runtime.NumGoroutine() != 2 { 25 | fmt.Fprintf(os.Stderr, "goroutine leak detected, expected 1 but got %d goroutines\n", runtime.NumGoroutine()) 26 | fmt.Fprintf(os.Stderr, "%s\n", goroutineStacks()) 27 | os.Exit(1) 28 | } 29 | os.Exit(code) 30 | } 31 | -------------------------------------------------------------------------------- /mask.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "encoding/binary" 5 | "math/bits" 6 | ) 7 | 8 | // maskGo applies the WebSocket masking algorithm to p 9 | // with the given key. 10 | // See https://tools.ietf.org/html/rfc6455#section-5.3 11 | // 12 | // The returned value is the correctly rotated key to 13 | // to continue to mask/unmask the message. 14 | // 15 | // It is optimized for LittleEndian and expects the key 16 | // to be in little endian. 17 | // 18 | // See https://github.com/golang/go/issues/31586 19 | func maskGo(b []byte, key uint32) uint32 { 20 | if len(b) >= 8 { 21 | key64 := uint64(key)<<32 | uint64(key) 22 | 23 | // At some point in the future we can clean these unrolled loops up. 24 | // See https://github.com/golang/go/issues/31586#issuecomment-487436401 25 | 26 | // Then we xor until b is less than 128 bytes. 27 | for len(b) >= 128 { 28 | v := binary.LittleEndian.Uint64(b) 29 | binary.LittleEndian.PutUint64(b, v^key64) 30 | v = binary.LittleEndian.Uint64(b[8:16]) 31 | binary.LittleEndian.PutUint64(b[8:16], v^key64) 32 | v = binary.LittleEndian.Uint64(b[16:24]) 33 | binary.LittleEndian.PutUint64(b[16:24], v^key64) 34 | v = binary.LittleEndian.Uint64(b[24:32]) 35 | binary.LittleEndian.PutUint64(b[24:32], v^key64) 36 | v = binary.LittleEndian.Uint64(b[32:40]) 37 | binary.LittleEndian.PutUint64(b[32:40], v^key64) 38 | v = binary.LittleEndian.Uint64(b[40:48]) 39 | binary.LittleEndian.PutUint64(b[40:48], v^key64) 40 | v = binary.LittleEndian.Uint64(b[48:56]) 41 | binary.LittleEndian.PutUint64(b[48:56], v^key64) 42 | v = binary.LittleEndian.Uint64(b[56:64]) 43 | binary.LittleEndian.PutUint64(b[56:64], v^key64) 44 | v = binary.LittleEndian.Uint64(b[64:72]) 45 | binary.LittleEndian.PutUint64(b[64:72], v^key64) 46 | v = binary.LittleEndian.Uint64(b[72:80]) 47 | binary.LittleEndian.PutUint64(b[72:80], v^key64) 48 | v = binary.LittleEndian.Uint64(b[80:88]) 49 | binary.LittleEndian.PutUint64(b[80:88], v^key64) 50 | v = binary.LittleEndian.Uint64(b[88:96]) 51 | binary.LittleEndian.PutUint64(b[88:96], v^key64) 52 | v = binary.LittleEndian.Uint64(b[96:104]) 53 | binary.LittleEndian.PutUint64(b[96:104], v^key64) 54 | v = binary.LittleEndian.Uint64(b[104:112]) 55 | binary.LittleEndian.PutUint64(b[104:112], v^key64) 56 | v = binary.LittleEndian.Uint64(b[112:120]) 57 | binary.LittleEndian.PutUint64(b[112:120], v^key64) 58 | v = binary.LittleEndian.Uint64(b[120:128]) 59 | binary.LittleEndian.PutUint64(b[120:128], v^key64) 60 | b = b[128:] 61 | } 62 | 63 | // Then we xor until b is less than 64 bytes. 64 | for len(b) >= 64 { 65 | v := binary.LittleEndian.Uint64(b) 66 | binary.LittleEndian.PutUint64(b, v^key64) 67 | v = binary.LittleEndian.Uint64(b[8:16]) 68 | binary.LittleEndian.PutUint64(b[8:16], v^key64) 69 | v = binary.LittleEndian.Uint64(b[16:24]) 70 | binary.LittleEndian.PutUint64(b[16:24], v^key64) 71 | v = binary.LittleEndian.Uint64(b[24:32]) 72 | binary.LittleEndian.PutUint64(b[24:32], v^key64) 73 | v = binary.LittleEndian.Uint64(b[32:40]) 74 | binary.LittleEndian.PutUint64(b[32:40], v^key64) 75 | v = binary.LittleEndian.Uint64(b[40:48]) 76 | binary.LittleEndian.PutUint64(b[40:48], v^key64) 77 | v = binary.LittleEndian.Uint64(b[48:56]) 78 | binary.LittleEndian.PutUint64(b[48:56], v^key64) 79 | v = binary.LittleEndian.Uint64(b[56:64]) 80 | binary.LittleEndian.PutUint64(b[56:64], v^key64) 81 | b = b[64:] 82 | } 83 | 84 | // Then we xor until b is less than 32 bytes. 85 | for len(b) >= 32 { 86 | v := binary.LittleEndian.Uint64(b) 87 | binary.LittleEndian.PutUint64(b, v^key64) 88 | v = binary.LittleEndian.Uint64(b[8:16]) 89 | binary.LittleEndian.PutUint64(b[8:16], v^key64) 90 | v = binary.LittleEndian.Uint64(b[16:24]) 91 | binary.LittleEndian.PutUint64(b[16:24], v^key64) 92 | v = binary.LittleEndian.Uint64(b[24:32]) 93 | binary.LittleEndian.PutUint64(b[24:32], v^key64) 94 | b = b[32:] 95 | } 96 | 97 | // Then we xor until b is less than 16 bytes. 98 | for len(b) >= 16 { 99 | v := binary.LittleEndian.Uint64(b) 100 | binary.LittleEndian.PutUint64(b, v^key64) 101 | v = binary.LittleEndian.Uint64(b[8:16]) 102 | binary.LittleEndian.PutUint64(b[8:16], v^key64) 103 | b = b[16:] 104 | } 105 | 106 | // Then we xor until b is less than 8 bytes. 107 | for len(b) >= 8 { 108 | v := binary.LittleEndian.Uint64(b) 109 | binary.LittleEndian.PutUint64(b, v^key64) 110 | b = b[8:] 111 | } 112 | } 113 | 114 | // Then we xor until b is less than 4 bytes. 115 | for len(b) >= 4 { 116 | v := binary.LittleEndian.Uint32(b) 117 | binary.LittleEndian.PutUint32(b, v^key) 118 | b = b[4:] 119 | } 120 | 121 | // xor remaining bytes. 122 | for i := range b { 123 | b[i] ^= byte(key) 124 | key = bits.RotateLeft32(key, -8) 125 | } 126 | 127 | return key 128 | } 129 | -------------------------------------------------------------------------------- /mask_amd64.s: -------------------------------------------------------------------------------- 1 | #include "textflag.h" 2 | 3 | // func maskAsm(b *byte, len int, key uint32) 4 | TEXT ·maskAsm(SB), NOSPLIT, $0-28 5 | // AX = b 6 | // CX = len (left length) 7 | // SI = key (uint32) 8 | // DI = uint64(SI) | uint64(SI)<<32 9 | MOVQ b+0(FP), AX 10 | MOVQ len+8(FP), CX 11 | MOVL key+16(FP), SI 12 | 13 | // calculate the DI 14 | // DI = SI<<32 | SI 15 | MOVL SI, DI 16 | MOVQ DI, DX 17 | SHLQ $32, DI 18 | ORQ DX, DI 19 | 20 | CMPQ CX, $15 21 | JLE less_than_16 22 | CMPQ CX, $63 23 | JLE less_than_64 24 | CMPQ CX, $128 25 | JLE sse 26 | TESTQ $31, AX 27 | JNZ unaligned 28 | 29 | unaligned_loop_1byte: 30 | XORB SI, (AX) 31 | INCQ AX 32 | DECQ CX 33 | ROLL $24, SI 34 | TESTQ $7, AX 35 | JNZ unaligned_loop_1byte 36 | 37 | // calculate DI again since SI was modified 38 | // DI = SI<<32 | SI 39 | MOVL SI, DI 40 | MOVQ DI, DX 41 | SHLQ $32, DI 42 | ORQ DX, DI 43 | 44 | TESTQ $31, AX 45 | JZ sse 46 | 47 | unaligned: 48 | TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b. 49 | JNZ unaligned_loop_1byte 50 | 51 | unaligned_loop: 52 | // we don't need to check the CX since we know it's above 128 53 | XORQ DI, (AX) 54 | ADDQ $8, AX 55 | SUBQ $8, CX 56 | TESTQ $31, AX 57 | JNZ unaligned_loop 58 | JMP sse 59 | 60 | sse: 61 | CMPQ CX, $0x40 62 | JL less_than_64 63 | MOVQ DI, X0 64 | PUNPCKLQDQ X0, X0 65 | 66 | sse_loop: 67 | MOVOU 0*16(AX), X1 68 | MOVOU 1*16(AX), X2 69 | MOVOU 2*16(AX), X3 70 | MOVOU 3*16(AX), X4 71 | PXOR X0, X1 72 | PXOR X0, X2 73 | PXOR X0, X3 74 | PXOR X0, X4 75 | MOVOU X1, 0*16(AX) 76 | MOVOU X2, 1*16(AX) 77 | MOVOU X3, 2*16(AX) 78 | MOVOU X4, 3*16(AX) 79 | ADDQ $0x40, AX 80 | SUBQ $0x40, CX 81 | CMPQ CX, $0x40 82 | JAE sse_loop 83 | 84 | less_than_64: 85 | TESTQ $32, CX 86 | JZ less_than_32 87 | XORQ DI, (AX) 88 | XORQ DI, 8(AX) 89 | XORQ DI, 16(AX) 90 | XORQ DI, 24(AX) 91 | ADDQ $32, AX 92 | 93 | less_than_32: 94 | TESTQ $16, CX 95 | JZ less_than_16 96 | XORQ DI, (AX) 97 | XORQ DI, 8(AX) 98 | ADDQ $16, AX 99 | 100 | less_than_16: 101 | TESTQ $8, CX 102 | JZ less_than_8 103 | XORQ DI, (AX) 104 | ADDQ $8, AX 105 | 106 | less_than_8: 107 | TESTQ $4, CX 108 | JZ less_than_4 109 | XORL SI, (AX) 110 | ADDQ $4, AX 111 | 112 | less_than_4: 113 | TESTQ $2, CX 114 | JZ less_than_2 115 | XORW SI, (AX) 116 | ROLL $16, SI 117 | ADDQ $2, AX 118 | 119 | less_than_2: 120 | TESTQ $1, CX 121 | JZ done 122 | XORB SI, (AX) 123 | ROLL $24, SI 124 | 125 | done: 126 | MOVL SI, ret+24(FP) 127 | RET 128 | -------------------------------------------------------------------------------- /mask_arm64.s: -------------------------------------------------------------------------------- 1 | #include "textflag.h" 2 | 3 | // func maskAsm(b *byte, len int, key uint32) 4 | TEXT ·maskAsm(SB), NOSPLIT, $0-28 5 | // R0 = b 6 | // R1 = len 7 | // R3 = key (uint32) 8 | // R2 = uint64(key)<<32 | uint64(key) 9 | MOVD b_ptr+0(FP), R0 10 | MOVD b_len+8(FP), R1 11 | MOVWU key+16(FP), R3 12 | MOVD R3, R2 13 | ORR R2<<32, R2, R2 14 | VDUP R2, V0.D2 15 | CMP $64, R1 16 | BLT less_than_64 17 | 18 | loop_64: 19 | VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16] 20 | VEOR V1.B16, V0.B16, V1.B16 21 | VEOR V2.B16, V0.B16, V2.B16 22 | VEOR V3.B16, V0.B16, V3.B16 23 | VEOR V4.B16, V0.B16, V4.B16 24 | VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0) 25 | SUBS $64, R1 26 | CMP $64, R1 27 | BGE loop_64 28 | 29 | less_than_64: 30 | CBZ R1, end 31 | TBZ $5, R1, less_than_32 32 | VLD1 (R0), [V1.B16, V2.B16] 33 | VEOR V1.B16, V0.B16, V1.B16 34 | VEOR V2.B16, V0.B16, V2.B16 35 | VST1.P [V1.B16, V2.B16], 32(R0) 36 | 37 | less_than_32: 38 | TBZ $4, R1, less_than_16 39 | LDP (R0), (R11, R12) 40 | EOR R11, R2, R11 41 | EOR R12, R2, R12 42 | STP.P (R11, R12), 16(R0) 43 | 44 | less_than_16: 45 | TBZ $3, R1, less_than_8 46 | MOVD (R0), R11 47 | EOR R2, R11, R11 48 | MOVD.P R11, 8(R0) 49 | 50 | less_than_8: 51 | TBZ $2, R1, less_than_4 52 | MOVWU (R0), R11 53 | EORW R2, R11, R11 54 | MOVWU.P R11, 4(R0) 55 | 56 | less_than_4: 57 | TBZ $1, R1, less_than_2 58 | MOVHU (R0), R11 59 | EORW R3, R11, R11 60 | MOVHU.P R11, 2(R0) 61 | RORW $16, R3 62 | 63 | less_than_2: 64 | TBZ $0, R1, end 65 | MOVBU (R0), R11 66 | EORW R3, R11, R11 67 | MOVBU.P R11, 1(R0) 68 | RORW $8, R3 69 | 70 | end: 71 | MOVWU R3, ret+24(FP) 72 | RET 73 | -------------------------------------------------------------------------------- /mask_asm.go: -------------------------------------------------------------------------------- 1 | //go:build amd64 || arm64 2 | 3 | package websocket 4 | 5 | func mask(b []byte, key uint32) uint32 { 6 | // TODO: Will enable in v1.9.0. 7 | return maskGo(b, key) 8 | /* 9 | if len(b) > 0 { 10 | return maskAsm(&b[0], len(b), key) 11 | } 12 | return key 13 | */ 14 | } 15 | 16 | // @nhooyr: I am not confident that the amd64 or the arm64 implementations of this 17 | // function are perfect. There are almost certainly missing optimizations or 18 | // opportunities for simplification. I'm confident there are no bugs though. 19 | // For example, the arm64 implementation doesn't align memory like the amd64. 20 | // Or the amd64 implementation could use AVX512 instead of just AVX2. 21 | // The AVX2 code I had to disable anyway as it wasn't performing as expected. 22 | // See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049 23 | // 24 | //go:noescape 25 | //lint:ignore U1000 disabled till v1.9.0 26 | func maskAsm(b *byte, len int, key uint32) uint32 27 | -------------------------------------------------------------------------------- /mask_asm_test.go: -------------------------------------------------------------------------------- 1 | //go:build amd64 || arm64 2 | 3 | package websocket 4 | 5 | import "testing" 6 | 7 | func TestMaskASM(t *testing.T) { 8 | t.Parallel() 9 | 10 | testMask(t, "maskASM", mask) 11 | } 12 | -------------------------------------------------------------------------------- /mask_go.go: -------------------------------------------------------------------------------- 1 | //go:build !amd64 && !arm64 && !js 2 | 3 | package websocket 4 | 5 | func mask(b []byte, key uint32) uint32 { 6 | return maskGo(b, key) 7 | } 8 | -------------------------------------------------------------------------------- /mask_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "encoding/binary" 7 | "math/big" 8 | "math/bits" 9 | "testing" 10 | 11 | "github.com/coder/websocket/internal/test/assert" 12 | ) 13 | 14 | func basicMask(b []byte, key uint32) uint32 { 15 | for i := range b { 16 | b[i] ^= byte(key) 17 | key = bits.RotateLeft32(key, -8) 18 | } 19 | return key 20 | } 21 | 22 | func basicMask2(b []byte, key uint32) uint32 { 23 | keyb := binary.LittleEndian.AppendUint32(nil, key) 24 | pos := 0 25 | for i := range b { 26 | b[i] ^= keyb[pos&3] 27 | pos++ 28 | } 29 | return bits.RotateLeft32(key, (pos&3)*-8) 30 | } 31 | 32 | func TestMask(t *testing.T) { 33 | t.Parallel() 34 | 35 | testMask(t, "basicMask", basicMask) 36 | testMask(t, "maskGo", maskGo) 37 | testMask(t, "basicMask2", basicMask2) 38 | } 39 | 40 | func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) { 41 | t.Run(name, func(t *testing.T) { 42 | t.Parallel() 43 | for i := 0; i < 9999; i++ { 44 | keyb := make([]byte, 4) 45 | _, err := rand.Read(keyb) 46 | assert.Success(t, err) 47 | key := binary.LittleEndian.Uint32(keyb) 48 | 49 | n, err := rand.Int(rand.Reader, big.NewInt(1<<16)) 50 | assert.Success(t, err) 51 | 52 | b := make([]byte, 1+n.Int64()) 53 | _, err = rand.Read(b) 54 | assert.Success(t, err) 55 | 56 | b2 := make([]byte, len(b)) 57 | copy(b2, b) 58 | b3 := make([]byte, len(b)) 59 | copy(b3, b) 60 | 61 | key2 := basicMask(b2, key) 62 | key3 := fn(b3, key) 63 | 64 | if key2 != key3 { 65 | t.Errorf("expected key %X but got %X", key2, key3) 66 | } 67 | if !bytes.Equal(b2, b3) { 68 | t.Error("bad bytes") 69 | return 70 | } 71 | } 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /netconn.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "math" 8 | "net" 9 | "sync/atomic" 10 | "time" 11 | ) 12 | 13 | // NetConn converts a *websocket.Conn into a net.Conn. 14 | // 15 | // It's for tunneling arbitrary protocols over WebSockets. 16 | // Few users of the library will need this but it's tricky to implement 17 | // correctly and so provided in the library. 18 | // See https://github.com/nhooyr/websocket/issues/100. 19 | // 20 | // Every Write to the net.Conn will correspond to a message write of 21 | // the given type on *websocket.Conn. 22 | // 23 | // The passed ctx bounds the lifetime of the net.Conn. If cancelled, 24 | // all reads and writes on the net.Conn will be cancelled. 25 | // 26 | // If a message is read that is not of the correct type, the connection 27 | // will be closed with StatusUnsupportedData and an error will be returned. 28 | // 29 | // Close will close the *websocket.Conn with StatusNormalClosure. 30 | // 31 | // When a deadline is hit and there is an active read or write goroutine, the 32 | // connection will be closed. This is different from most net.Conn implementations 33 | // where only the reading/writing goroutines are interrupted but the connection 34 | // is kept alive. 35 | // 36 | // The Addr methods will return the real addresses for connections obtained 37 | // from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr 38 | // will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for 39 | // String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the 40 | // full net.Conn to us. 41 | // 42 | // When running as WASM, the Addr methods will always return the mock address described above. 43 | // 44 | // A received StatusNormalClosure or StatusGoingAway close frame will be translated to 45 | // io.EOF when reading. 46 | // 47 | // Furthermore, the ReadLimit is set to -1 to disable it. 48 | func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { 49 | c.SetReadLimit(-1) 50 | 51 | nc := &netConn{ 52 | c: c, 53 | msgType: msgType, 54 | readMu: newMu(c), 55 | writeMu: newMu(c), 56 | } 57 | 58 | nc.writeCtx, nc.writeCancel = context.WithCancel(ctx) 59 | nc.readCtx, nc.readCancel = context.WithCancel(ctx) 60 | 61 | nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { 62 | if !nc.writeMu.tryLock() { 63 | // If the lock cannot be acquired, then there is an 64 | // active write goroutine and so we should cancel the context. 65 | nc.writeCancel() 66 | return 67 | } 68 | defer nc.writeMu.unlock() 69 | 70 | // Prevents future writes from writing until the deadline is reset. 71 | nc.writeExpired.Store(1) 72 | }) 73 | if !nc.writeTimer.Stop() { 74 | <-nc.writeTimer.C 75 | } 76 | 77 | nc.readTimer = time.AfterFunc(math.MaxInt64, func() { 78 | if !nc.readMu.tryLock() { 79 | // If the lock cannot be acquired, then there is an 80 | // active read goroutine and so we should cancel the context. 81 | nc.readCancel() 82 | return 83 | } 84 | defer nc.readMu.unlock() 85 | 86 | // Prevents future reads from reading until the deadline is reset. 87 | nc.readExpired.Store(1) 88 | }) 89 | if !nc.readTimer.Stop() { 90 | <-nc.readTimer.C 91 | } 92 | 93 | return nc 94 | } 95 | 96 | type netConn struct { 97 | c *Conn 98 | msgType MessageType 99 | 100 | writeTimer *time.Timer 101 | writeMu *mu 102 | writeExpired atomic.Int64 103 | writeCtx context.Context 104 | writeCancel context.CancelFunc 105 | 106 | readTimer *time.Timer 107 | readMu *mu 108 | readExpired atomic.Int64 109 | readCtx context.Context 110 | readCancel context.CancelFunc 111 | readEOFed bool 112 | reader io.Reader 113 | } 114 | 115 | var _ net.Conn = &netConn{} 116 | 117 | func (nc *netConn) Close() error { 118 | nc.writeTimer.Stop() 119 | nc.writeCancel() 120 | nc.readTimer.Stop() 121 | nc.readCancel() 122 | return nc.c.Close(StatusNormalClosure, "") 123 | } 124 | 125 | func (nc *netConn) Write(p []byte) (int, error) { 126 | nc.writeMu.forceLock() 127 | defer nc.writeMu.unlock() 128 | 129 | if nc.writeExpired.Load() == 1 { 130 | return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) 131 | } 132 | 133 | err := nc.c.Write(nc.writeCtx, nc.msgType, p) 134 | if err != nil { 135 | return 0, err 136 | } 137 | return len(p), nil 138 | } 139 | 140 | func (nc *netConn) Read(p []byte) (int, error) { 141 | nc.readMu.forceLock() 142 | defer nc.readMu.unlock() 143 | 144 | for { 145 | n, err := nc.read(p) 146 | if err != nil { 147 | return n, err 148 | } 149 | if n == 0 { 150 | continue 151 | } 152 | return n, nil 153 | } 154 | } 155 | 156 | func (nc *netConn) read(p []byte) (int, error) { 157 | if nc.readExpired.Load() == 1 { 158 | return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) 159 | } 160 | 161 | if nc.readEOFed { 162 | return 0, io.EOF 163 | } 164 | 165 | if nc.reader == nil { 166 | typ, r, err := nc.c.Reader(nc.readCtx) 167 | if err != nil { 168 | switch CloseStatus(err) { 169 | case StatusNormalClosure, StatusGoingAway: 170 | nc.readEOFed = true 171 | return 0, io.EOF 172 | } 173 | return 0, err 174 | } 175 | if typ != nc.msgType { 176 | err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ) 177 | nc.c.Close(StatusUnsupportedData, err.Error()) 178 | return 0, err 179 | } 180 | nc.reader = r 181 | } 182 | 183 | n, err := nc.reader.Read(p) 184 | if err == io.EOF { 185 | nc.reader = nil 186 | err = nil 187 | } 188 | return n, err 189 | } 190 | 191 | type websocketAddr struct { 192 | } 193 | 194 | func (a websocketAddr) Network() string { 195 | return "websocket" 196 | } 197 | 198 | func (a websocketAddr) String() string { 199 | return "websocket/unknown-addr" 200 | } 201 | 202 | func (nc *netConn) SetDeadline(t time.Time) error { 203 | nc.SetWriteDeadline(t) 204 | nc.SetReadDeadline(t) 205 | return nil 206 | } 207 | 208 | func (nc *netConn) SetWriteDeadline(t time.Time) error { 209 | nc.writeExpired.Store(0) 210 | if t.IsZero() { 211 | nc.writeTimer.Stop() 212 | } else { 213 | dur := time.Until(t) 214 | if dur <= 0 { 215 | dur = 1 216 | } 217 | nc.writeTimer.Reset(dur) 218 | } 219 | return nil 220 | } 221 | 222 | func (nc *netConn) SetReadDeadline(t time.Time) error { 223 | nc.readExpired.Store(0) 224 | if t.IsZero() { 225 | nc.readTimer.Stop() 226 | } else { 227 | dur := time.Until(t) 228 | if dur <= 0 { 229 | dur = 1 230 | } 231 | nc.readTimer.Reset(dur) 232 | } 233 | return nil 234 | } 235 | -------------------------------------------------------------------------------- /netconn_js.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import "net" 4 | 5 | func (nc *netConn) RemoteAddr() net.Addr { 6 | return websocketAddr{} 7 | } 8 | 9 | func (nc *netConn) LocalAddr() net.Addr { 10 | return websocketAddr{} 11 | } 12 | -------------------------------------------------------------------------------- /netconn_notjs.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import "net" 7 | 8 | func (nc *netConn) RemoteAddr() net.Addr { 9 | if unc, ok := nc.c.rwc.(net.Conn); ok { 10 | return unc.RemoteAddr() 11 | } 12 | return websocketAddr{} 13 | } 14 | 15 | func (nc *netConn) LocalAddr() net.Addr { 16 | if unc, ok := nc.c.rwc.(net.Conn); ok { 17 | return unc.LocalAddr() 18 | } 19 | return websocketAddr{} 20 | } 21 | -------------------------------------------------------------------------------- /read.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "context" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "strings" 14 | "sync/atomic" 15 | "time" 16 | 17 | "github.com/coder/websocket/internal/errd" 18 | "github.com/coder/websocket/internal/util" 19 | ) 20 | 21 | // Reader reads from the connection until there is a WebSocket 22 | // data message to be read. It will handle ping, pong and close frames as appropriate. 23 | // 24 | // It returns the type of the message and an io.Reader to read it. 25 | // The passed context will also bound the reader. 26 | // Ensure you read to EOF otherwise the connection will hang. 27 | // 28 | // Call CloseRead if you do not expect any data messages from the peer. 29 | // 30 | // Only one Reader may be open at a time. 31 | // 32 | // If you need a separate timeout on the Reader call and the Read itself, 33 | // use time.AfterFunc to cancel the context passed in. 34 | // See https://github.com/nhooyr/websocket/issues/87#issue-451703332 35 | // Most users should not need this. 36 | func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { 37 | return c.reader(ctx) 38 | } 39 | 40 | // Read is a convenience method around Reader to read a single message 41 | // from the connection. 42 | func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { 43 | typ, r, err := c.Reader(ctx) 44 | if err != nil { 45 | return 0, nil, err 46 | } 47 | 48 | b, err := io.ReadAll(r) 49 | return typ, b, err 50 | } 51 | 52 | // CloseRead starts a goroutine to read from the connection until it is closed 53 | // or a data message is received. 54 | // 55 | // Once CloseRead is called you cannot read any messages from the connection. 56 | // The returned context will be cancelled when the connection is closed. 57 | // 58 | // If a data message is received, the connection will be closed with StatusPolicyViolation. 59 | // 60 | // Call CloseRead when you do not expect to read any more messages. 61 | // Since it actively reads from the connection, it will ensure that ping, pong and close 62 | // frames are responded to. This means c.Ping and c.Close will still work as expected. 63 | // 64 | // This function is idempotent. 65 | func (c *Conn) CloseRead(ctx context.Context) context.Context { 66 | c.closeReadMu.Lock() 67 | ctx2 := c.closeReadCtx 68 | if ctx2 != nil { 69 | c.closeReadMu.Unlock() 70 | return ctx2 71 | } 72 | ctx, cancel := context.WithCancel(ctx) 73 | c.closeReadCtx = ctx 74 | c.closeReadDone = make(chan struct{}) 75 | c.closeReadMu.Unlock() 76 | 77 | go func() { 78 | defer close(c.closeReadDone) 79 | defer cancel() 80 | defer c.close() 81 | _, _, err := c.Reader(ctx) 82 | if err == nil { 83 | c.Close(StatusPolicyViolation, "unexpected data message") 84 | } 85 | }() 86 | return ctx 87 | } 88 | 89 | // SetReadLimit sets the max number of bytes to read for a single message. 90 | // It applies to the Reader and Read methods. 91 | // 92 | // By default, the connection has a message read limit of 32768 bytes. 93 | // 94 | // When the limit is hit, the connection will be closed with StatusMessageTooBig. 95 | // 96 | // Set to -1 to disable. 97 | func (c *Conn) SetReadLimit(n int64) { 98 | if n >= 0 { 99 | // We read one more byte than the limit in case 100 | // there is a fin frame that needs to be read. 101 | n++ 102 | } 103 | 104 | c.msgReader.limitReader.limit.Store(n) 105 | } 106 | 107 | const defaultReadLimit = 32768 108 | 109 | func newMsgReader(c *Conn) *msgReader { 110 | mr := &msgReader{ 111 | c: c, 112 | fin: true, 113 | } 114 | mr.readFunc = mr.read 115 | 116 | mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) 117 | return mr 118 | } 119 | 120 | func (mr *msgReader) resetFlate() { 121 | if mr.flateContextTakeover() { 122 | if mr.dict == nil { 123 | mr.dict = &slidingWindow{} 124 | } 125 | mr.dict.init(32768) 126 | } 127 | if mr.flateBufio == nil { 128 | mr.flateBufio = getBufioReader(mr.readFunc) 129 | } 130 | 131 | if mr.flateContextTakeover() { 132 | mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) 133 | } else { 134 | mr.flateReader = getFlateReader(mr.flateBufio, nil) 135 | } 136 | mr.limitReader.r = mr.flateReader 137 | mr.flateTail.Reset(deflateMessageTail) 138 | } 139 | 140 | func (mr *msgReader) putFlateReader() { 141 | if mr.flateReader != nil { 142 | putFlateReader(mr.flateReader) 143 | mr.flateReader = nil 144 | } 145 | } 146 | 147 | func (mr *msgReader) close() { 148 | mr.c.readMu.forceLock() 149 | mr.putFlateReader() 150 | if mr.dict != nil { 151 | mr.dict.close() 152 | mr.dict = nil 153 | } 154 | if mr.flateBufio != nil { 155 | putBufioReader(mr.flateBufio) 156 | } 157 | 158 | if mr.c.client { 159 | putBufioReader(mr.c.br) 160 | mr.c.br = nil 161 | } 162 | } 163 | 164 | func (mr *msgReader) flateContextTakeover() bool { 165 | if mr.c.client { 166 | return !mr.c.copts.serverNoContextTakeover 167 | } 168 | return !mr.c.copts.clientNoContextTakeover 169 | } 170 | 171 | func (c *Conn) readRSV1Illegal(h header) bool { 172 | // If compression is disabled, rsv1 is illegal. 173 | if !c.flate() { 174 | return true 175 | } 176 | // rsv1 is only allowed on data frames beginning messages. 177 | if h.opcode != opText && h.opcode != opBinary { 178 | return true 179 | } 180 | return false 181 | } 182 | 183 | func (c *Conn) readLoop(ctx context.Context) (header, error) { 184 | for { 185 | h, err := c.readFrameHeader(ctx) 186 | if err != nil { 187 | return header{}, err 188 | } 189 | 190 | if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { 191 | err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) 192 | c.writeError(StatusProtocolError, err) 193 | return header{}, err 194 | } 195 | 196 | if !c.client && !h.masked { 197 | return header{}, errors.New("received unmasked frame from client") 198 | } 199 | 200 | switch h.opcode { 201 | case opClose, opPing, opPong: 202 | err = c.handleControl(ctx, h) 203 | if err != nil { 204 | // Pass through CloseErrors when receiving a close frame. 205 | if h.opcode == opClose && CloseStatus(err) != -1 { 206 | return header{}, err 207 | } 208 | return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) 209 | } 210 | case opContinuation, opText, opBinary: 211 | return h, nil 212 | default: 213 | err := fmt.Errorf("received unknown opcode %v", h.opcode) 214 | c.writeError(StatusProtocolError, err) 215 | return header{}, err 216 | } 217 | } 218 | } 219 | 220 | // prepareRead sets the readTimeout context and returns a done function 221 | // to be called after the read is done. It also returns an error if the 222 | // connection is closed. The reference to the error is used to assign 223 | // an error depending on if the connection closed or the context timed 224 | // out during use. Typically the referenced error is a named return 225 | // variable of the function calling this method. 226 | func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { 227 | select { 228 | case <-c.closed: 229 | return nil, net.ErrClosed 230 | case c.readTimeout <- ctx: 231 | } 232 | 233 | done := func() { 234 | select { 235 | case <-c.closed: 236 | if *err != nil { 237 | *err = net.ErrClosed 238 | } 239 | case c.readTimeout <- context.Background(): 240 | } 241 | if *err != nil && ctx.Err() != nil { 242 | *err = ctx.Err() 243 | } 244 | } 245 | 246 | c.closeStateMu.Lock() 247 | closeReceivedErr := c.closeReceivedErr 248 | c.closeStateMu.Unlock() 249 | if closeReceivedErr != nil { 250 | defer done() 251 | return nil, closeReceivedErr 252 | } 253 | 254 | return done, nil 255 | } 256 | 257 | func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { 258 | readDone, err := c.prepareRead(ctx, &err) 259 | if err != nil { 260 | return header{}, err 261 | } 262 | defer readDone() 263 | 264 | h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) 265 | if err != nil { 266 | return header{}, err 267 | } 268 | 269 | return h, nil 270 | } 271 | 272 | func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { 273 | readDone, err := c.prepareRead(ctx, &err) 274 | if err != nil { 275 | return 0, err 276 | } 277 | defer readDone() 278 | 279 | n, err := io.ReadFull(c.br, p) 280 | if err != nil { 281 | return n, fmt.Errorf("failed to read frame payload: %w", err) 282 | } 283 | 284 | return n, err 285 | } 286 | 287 | func (c *Conn) handleControl(ctx context.Context, h header) (err error) { 288 | if h.payloadLength < 0 || h.payloadLength > maxControlPayload { 289 | err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) 290 | c.writeError(StatusProtocolError, err) 291 | return err 292 | } 293 | 294 | if !h.fin { 295 | err := errors.New("received fragmented control frame") 296 | c.writeError(StatusProtocolError, err) 297 | return err 298 | } 299 | 300 | ctx, cancel := context.WithTimeout(ctx, time.Second*5) 301 | defer cancel() 302 | 303 | b := c.readControlBuf[:h.payloadLength] 304 | _, err = c.readFramePayload(ctx, b) 305 | if err != nil { 306 | return err 307 | } 308 | 309 | if h.masked { 310 | mask(b, h.maskKey) 311 | } 312 | 313 | switch h.opcode { 314 | case opPing: 315 | if c.onPingReceived != nil { 316 | if !c.onPingReceived(ctx, b) { 317 | return nil 318 | } 319 | } 320 | return c.writeControl(ctx, opPong, b) 321 | case opPong: 322 | if c.onPongReceived != nil { 323 | c.onPongReceived(ctx, b) 324 | } 325 | c.activePingsMu.Lock() 326 | pong, ok := c.activePings[string(b)] 327 | c.activePingsMu.Unlock() 328 | if ok { 329 | select { 330 | case pong <- struct{}{}: 331 | default: 332 | } 333 | } 334 | return nil 335 | } 336 | 337 | // opClose 338 | 339 | ce, err := parseClosePayload(b) 340 | if err != nil { 341 | err = fmt.Errorf("received invalid close payload: %w", err) 342 | c.writeError(StatusProtocolError, err) 343 | return err 344 | } 345 | 346 | err = fmt.Errorf("received close frame: %w", ce) 347 | c.closeStateMu.Lock() 348 | c.closeReceivedErr = err 349 | closeSent := c.closeSentErr != nil 350 | c.closeStateMu.Unlock() 351 | 352 | // Only unlock readMu if this connection is being closed becaue 353 | // c.close will try to acquire the readMu lock. We unlock for 354 | // writeClose as well because it may also call c.close. 355 | if !closeSent { 356 | c.readMu.unlock() 357 | _ = c.writeClose(ce.Code, ce.Reason) 358 | } 359 | if !c.casClosing() { 360 | c.readMu.unlock() 361 | _ = c.close() 362 | } 363 | return err 364 | } 365 | 366 | func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { 367 | defer errd.Wrap(&err, "failed to get reader") 368 | 369 | err = c.readMu.lock(ctx) 370 | if err != nil { 371 | return 0, nil, err 372 | } 373 | defer c.readMu.unlock() 374 | 375 | if !c.msgReader.fin { 376 | return 0, nil, errors.New("previous message not read to completion") 377 | } 378 | 379 | h, err := c.readLoop(ctx) 380 | if err != nil { 381 | return 0, nil, err 382 | } 383 | 384 | if h.opcode == opContinuation { 385 | err := errors.New("received continuation frame without text or binary frame") 386 | c.writeError(StatusProtocolError, err) 387 | return 0, nil, err 388 | } 389 | 390 | c.msgReader.reset(ctx, h) 391 | 392 | return MessageType(h.opcode), c.msgReader, nil 393 | } 394 | 395 | type msgReader struct { 396 | c *Conn 397 | 398 | ctx context.Context 399 | flate bool 400 | flateReader io.Reader 401 | flateBufio *bufio.Reader 402 | flateTail strings.Reader 403 | limitReader *limitReader 404 | dict *slidingWindow 405 | 406 | fin bool 407 | payloadLength int64 408 | maskKey uint32 409 | 410 | // util.ReaderFunc(mr.Read) to avoid continuous allocations. 411 | readFunc util.ReaderFunc 412 | } 413 | 414 | func (mr *msgReader) reset(ctx context.Context, h header) { 415 | mr.ctx = ctx 416 | mr.flate = h.rsv1 417 | mr.limitReader.reset(mr.readFunc) 418 | 419 | if mr.flate { 420 | mr.resetFlate() 421 | } 422 | 423 | mr.setFrame(h) 424 | } 425 | 426 | func (mr *msgReader) setFrame(h header) { 427 | mr.fin = h.fin 428 | mr.payloadLength = h.payloadLength 429 | mr.maskKey = h.maskKey 430 | } 431 | 432 | func (mr *msgReader) Read(p []byte) (n int, err error) { 433 | err = mr.c.readMu.lock(mr.ctx) 434 | if err != nil { 435 | return 0, fmt.Errorf("failed to read: %w", err) 436 | } 437 | defer mr.c.readMu.unlock() 438 | 439 | n, err = mr.limitReader.Read(p) 440 | if mr.flate && mr.flateContextTakeover() { 441 | p = p[:n] 442 | mr.dict.write(p) 443 | } 444 | if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { 445 | mr.putFlateReader() 446 | return n, io.EOF 447 | } 448 | if err != nil { 449 | return n, fmt.Errorf("failed to read: %w", err) 450 | } 451 | return n, nil 452 | } 453 | 454 | func (mr *msgReader) read(p []byte) (int, error) { 455 | for { 456 | if mr.payloadLength == 0 { 457 | if mr.fin { 458 | if mr.flate { 459 | return mr.flateTail.Read(p) 460 | } 461 | return 0, io.EOF 462 | } 463 | 464 | h, err := mr.c.readLoop(mr.ctx) 465 | if err != nil { 466 | return 0, err 467 | } 468 | if h.opcode != opContinuation { 469 | err := errors.New("received new data message without finishing the previous message") 470 | mr.c.writeError(StatusProtocolError, err) 471 | return 0, err 472 | } 473 | mr.setFrame(h) 474 | 475 | continue 476 | } 477 | 478 | if int64(len(p)) > mr.payloadLength { 479 | p = p[:mr.payloadLength] 480 | } 481 | 482 | n, err := mr.c.readFramePayload(mr.ctx, p) 483 | if err != nil { 484 | return n, err 485 | } 486 | 487 | mr.payloadLength -= int64(n) 488 | 489 | if !mr.c.client { 490 | mr.maskKey = mask(p, mr.maskKey) 491 | } 492 | 493 | return n, nil 494 | } 495 | } 496 | 497 | type limitReader struct { 498 | c *Conn 499 | r io.Reader 500 | limit atomic.Int64 501 | n int64 502 | } 503 | 504 | func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { 505 | lr := &limitReader{ 506 | c: c, 507 | } 508 | lr.limit.Store(limit) 509 | lr.reset(r) 510 | return lr 511 | } 512 | 513 | func (lr *limitReader) reset(r io.Reader) { 514 | lr.n = lr.limit.Load() 515 | lr.r = r 516 | } 517 | 518 | func (lr *limitReader) Read(p []byte) (int, error) { 519 | if lr.n < 0 { 520 | return lr.r.Read(p) 521 | } 522 | 523 | if lr.n == 0 { 524 | err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) 525 | lr.c.writeError(StatusMessageTooBig, err) 526 | return 0, err 527 | } 528 | 529 | if int64(len(p)) > lr.n { 530 | p = p[:lr.n] 531 | } 532 | n, err := lr.r.Read(p) 533 | lr.n -= int64(n) 534 | if lr.n < 0 { 535 | lr.n = 0 536 | } 537 | return n, err 538 | } 539 | -------------------------------------------------------------------------------- /stringer.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT. 2 | 3 | package websocket 4 | 5 | import "strconv" 6 | 7 | func _() { 8 | // An "invalid array index" compiler error signifies that the constant values have changed. 9 | // Re-run the stringer command to generate them again. 10 | var x [1]struct{} 11 | _ = x[opContinuation-0] 12 | _ = x[opText-1] 13 | _ = x[opBinary-2] 14 | _ = x[opClose-8] 15 | _ = x[opPing-9] 16 | _ = x[opPong-10] 17 | } 18 | 19 | const ( 20 | _opcode_name_0 = "opContinuationopTextopBinary" 21 | _opcode_name_1 = "opCloseopPingopPong" 22 | ) 23 | 24 | var ( 25 | _opcode_index_0 = [...]uint8{0, 14, 20, 28} 26 | _opcode_index_1 = [...]uint8{0, 7, 13, 19} 27 | ) 28 | 29 | func (i opcode) String() string { 30 | switch { 31 | case 0 <= i && i <= 2: 32 | return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] 33 | case 8 <= i && i <= 10: 34 | i -= 8 35 | return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] 36 | default: 37 | return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" 38 | } 39 | } 40 | func _() { 41 | // An "invalid array index" compiler error signifies that the constant values have changed. 42 | // Re-run the stringer command to generate them again. 43 | var x [1]struct{} 44 | _ = x[MessageText-1] 45 | _ = x[MessageBinary-2] 46 | } 47 | 48 | const _MessageType_name = "MessageTextMessageBinary" 49 | 50 | var _MessageType_index = [...]uint8{0, 11, 24} 51 | 52 | func (i MessageType) String() string { 53 | i -= 1 54 | if i < 0 || i >= MessageType(len(_MessageType_index)-1) { 55 | return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" 56 | } 57 | return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] 58 | } 59 | func _() { 60 | // An "invalid array index" compiler error signifies that the constant values have changed. 61 | // Re-run the stringer command to generate them again. 62 | var x [1]struct{} 63 | _ = x[StatusNormalClosure-1000] 64 | _ = x[StatusGoingAway-1001] 65 | _ = x[StatusProtocolError-1002] 66 | _ = x[StatusUnsupportedData-1003] 67 | _ = x[statusReserved-1004] 68 | _ = x[StatusNoStatusRcvd-1005] 69 | _ = x[StatusAbnormalClosure-1006] 70 | _ = x[StatusInvalidFramePayloadData-1007] 71 | _ = x[StatusPolicyViolation-1008] 72 | _ = x[StatusMessageTooBig-1009] 73 | _ = x[StatusMandatoryExtension-1010] 74 | _ = x[StatusInternalError-1011] 75 | _ = x[StatusServiceRestart-1012] 76 | _ = x[StatusTryAgainLater-1013] 77 | _ = x[StatusBadGateway-1014] 78 | _ = x[StatusTLSHandshake-1015] 79 | } 80 | 81 | const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" 82 | 83 | var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} 84 | 85 | func (i StatusCode) String() string { 86 | i -= 1000 87 | if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { 88 | return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" 89 | } 90 | return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] 91 | } 92 | -------------------------------------------------------------------------------- /write.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "compress/flate" 9 | "context" 10 | "crypto/rand" 11 | "encoding/binary" 12 | "errors" 13 | "fmt" 14 | "io" 15 | "net" 16 | "time" 17 | 18 | "github.com/coder/websocket/internal/errd" 19 | "github.com/coder/websocket/internal/util" 20 | ) 21 | 22 | // Writer returns a writer bounded by the context that will write 23 | // a WebSocket message of type dataType to the connection. 24 | // 25 | // You must close the writer once you have written the entire message. 26 | // 27 | // Only one writer can be open at a time, multiple calls will block until the previous writer 28 | // is closed. 29 | func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { 30 | w, err := c.writer(ctx, typ) 31 | if err != nil { 32 | return nil, fmt.Errorf("failed to get writer: %w", err) 33 | } 34 | return w, nil 35 | } 36 | 37 | // Write writes a message to the connection. 38 | // 39 | // See the Writer method if you want to stream a message. 40 | // 41 | // If compression is disabled or the compression threshold is not met, then it 42 | // will write the message in a single frame. 43 | func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { 44 | _, err := c.write(ctx, typ, p) 45 | if err != nil { 46 | return fmt.Errorf("failed to write msg: %w", err) 47 | } 48 | return nil 49 | } 50 | 51 | type msgWriter struct { 52 | c *Conn 53 | 54 | mu *mu 55 | writeMu *mu 56 | closed bool 57 | 58 | ctx context.Context 59 | opcode opcode 60 | flate bool 61 | 62 | trimWriter *trimLastFourBytesWriter 63 | flateWriter *flate.Writer 64 | } 65 | 66 | func newMsgWriter(c *Conn) *msgWriter { 67 | mw := &msgWriter{ 68 | c: c, 69 | mu: newMu(c), 70 | writeMu: newMu(c), 71 | } 72 | return mw 73 | } 74 | 75 | func (mw *msgWriter) ensureFlate() { 76 | if mw.trimWriter == nil { 77 | mw.trimWriter = &trimLastFourBytesWriter{ 78 | w: util.WriterFunc(mw.write), 79 | } 80 | } 81 | 82 | if mw.flateWriter == nil { 83 | mw.flateWriter = getFlateWriter(mw.trimWriter) 84 | } 85 | mw.flate = true 86 | } 87 | 88 | func (mw *msgWriter) flateContextTakeover() bool { 89 | if mw.c.client { 90 | return !mw.c.copts.clientNoContextTakeover 91 | } 92 | return !mw.c.copts.serverNoContextTakeover 93 | } 94 | 95 | func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { 96 | err := c.msgWriter.reset(ctx, typ) 97 | if err != nil { 98 | return nil, err 99 | } 100 | return c.msgWriter, nil 101 | } 102 | 103 | func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { 104 | mw, err := c.writer(ctx, typ) 105 | if err != nil { 106 | return 0, err 107 | } 108 | 109 | if !c.flate() { 110 | defer c.msgWriter.mu.unlock() 111 | return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) 112 | } 113 | 114 | n, err := mw.Write(p) 115 | if err != nil { 116 | return n, err 117 | } 118 | 119 | err = mw.Close() 120 | return n, err 121 | } 122 | 123 | func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { 124 | err := mw.mu.lock(ctx) 125 | if err != nil { 126 | return err 127 | } 128 | 129 | mw.ctx = ctx 130 | mw.opcode = opcode(typ) 131 | mw.flate = false 132 | mw.closed = false 133 | 134 | mw.trimWriter.reset() 135 | 136 | return nil 137 | } 138 | 139 | func (mw *msgWriter) putFlateWriter() { 140 | if mw.flateWriter != nil { 141 | putFlateWriter(mw.flateWriter) 142 | mw.flateWriter = nil 143 | } 144 | } 145 | 146 | // Write writes the given bytes to the WebSocket connection. 147 | func (mw *msgWriter) Write(p []byte) (_ int, err error) { 148 | err = mw.writeMu.lock(mw.ctx) 149 | if err != nil { 150 | return 0, fmt.Errorf("failed to write: %w", err) 151 | } 152 | defer mw.writeMu.unlock() 153 | 154 | if mw.closed { 155 | return 0, errors.New("cannot use closed writer") 156 | } 157 | 158 | defer func() { 159 | if err != nil { 160 | err = fmt.Errorf("failed to write: %w", err) 161 | } 162 | }() 163 | 164 | if mw.c.flate() { 165 | // Only enables flate if the length crosses the 166 | // threshold on the first frame 167 | if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { 168 | mw.ensureFlate() 169 | } 170 | } 171 | 172 | if mw.flate { 173 | return mw.flateWriter.Write(p) 174 | } 175 | 176 | return mw.write(p) 177 | } 178 | 179 | func (mw *msgWriter) write(p []byte) (int, error) { 180 | n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) 181 | if err != nil { 182 | return n, fmt.Errorf("failed to write data frame: %w", err) 183 | } 184 | mw.opcode = opContinuation 185 | return n, nil 186 | } 187 | 188 | // Close flushes the frame to the connection. 189 | func (mw *msgWriter) Close() (err error) { 190 | defer errd.Wrap(&err, "failed to close writer") 191 | 192 | err = mw.writeMu.lock(mw.ctx) 193 | if err != nil { 194 | return err 195 | } 196 | defer mw.writeMu.unlock() 197 | 198 | if mw.closed { 199 | return errors.New("writer already closed") 200 | } 201 | mw.closed = true 202 | 203 | if mw.flate { 204 | err = mw.flateWriter.Flush() 205 | if err != nil { 206 | return fmt.Errorf("failed to flush flate: %w", err) 207 | } 208 | } 209 | 210 | _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) 211 | if err != nil { 212 | return fmt.Errorf("failed to write fin frame: %w", err) 213 | } 214 | 215 | if mw.flate && !mw.flateContextTakeover() { 216 | mw.putFlateWriter() 217 | } 218 | mw.mu.unlock() 219 | return nil 220 | } 221 | 222 | func (mw *msgWriter) close() { 223 | if mw.c.client { 224 | mw.c.writeFrameMu.forceLock() 225 | putBufioWriter(mw.c.bw) 226 | } 227 | 228 | mw.writeMu.forceLock() 229 | mw.putFlateWriter() 230 | } 231 | 232 | func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { 233 | ctx, cancel := context.WithTimeout(ctx, time.Second*5) 234 | defer cancel() 235 | 236 | _, err := c.writeFrame(ctx, true, false, opcode, p) 237 | if err != nil { 238 | return fmt.Errorf("failed to write control frame %v: %w", opcode, err) 239 | } 240 | return nil 241 | } 242 | 243 | // writeFrame handles all writes to the connection. 244 | func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { 245 | err = c.writeFrameMu.lock(ctx) 246 | if err != nil { 247 | return 0, err 248 | } 249 | defer c.writeFrameMu.unlock() 250 | 251 | defer func() { 252 | if c.isClosed() && opcode == opClose { 253 | err = nil 254 | } 255 | if err != nil { 256 | if ctx.Err() != nil { 257 | err = ctx.Err() 258 | } else if c.isClosed() { 259 | err = net.ErrClosed 260 | } 261 | err = fmt.Errorf("failed to write frame: %w", err) 262 | } 263 | }() 264 | 265 | c.closeStateMu.Lock() 266 | closeSentErr := c.closeSentErr 267 | c.closeStateMu.Unlock() 268 | if closeSentErr != nil { 269 | return 0, net.ErrClosed 270 | } 271 | 272 | select { 273 | case <-c.closed: 274 | return 0, net.ErrClosed 275 | case c.writeTimeout <- ctx: 276 | } 277 | defer func() { 278 | select { 279 | case <-c.closed: 280 | case c.writeTimeout <- context.Background(): 281 | } 282 | }() 283 | 284 | c.writeHeader.fin = fin 285 | c.writeHeader.opcode = opcode 286 | c.writeHeader.payloadLength = int64(len(p)) 287 | 288 | if c.client { 289 | c.writeHeader.masked = true 290 | _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) 291 | if err != nil { 292 | return 0, fmt.Errorf("failed to generate masking key: %w", err) 293 | } 294 | c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) 295 | } 296 | 297 | c.writeHeader.rsv1 = false 298 | if flate && (opcode == opText || opcode == opBinary) { 299 | c.writeHeader.rsv1 = true 300 | } 301 | 302 | err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) 303 | if err != nil { 304 | return 0, err 305 | } 306 | 307 | n, err := c.writeFramePayload(p) 308 | if err != nil { 309 | return n, err 310 | } 311 | 312 | if c.writeHeader.fin { 313 | err = c.bw.Flush() 314 | if err != nil { 315 | return n, fmt.Errorf("failed to flush: %w", err) 316 | } 317 | } 318 | 319 | if opcode == opClose { 320 | c.closeStateMu.Lock() 321 | c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed) 322 | closeReceived := c.closeReceivedErr != nil 323 | c.closeStateMu.Unlock() 324 | 325 | if closeReceived && !c.casClosing() { 326 | c.writeFrameMu.unlock() 327 | _ = c.close() 328 | } 329 | } 330 | 331 | return n, nil 332 | } 333 | 334 | func (c *Conn) writeFramePayload(p []byte) (n int, err error) { 335 | defer errd.Wrap(&err, "failed to write frame payload") 336 | 337 | if !c.writeHeader.masked { 338 | return c.bw.Write(p) 339 | } 340 | 341 | maskKey := c.writeHeader.maskKey 342 | for len(p) > 0 { 343 | // If the buffer is full, we need to flush. 344 | if c.bw.Available() == 0 { 345 | err = c.bw.Flush() 346 | if err != nil { 347 | return n, err 348 | } 349 | } 350 | 351 | // Start of next write in the buffer. 352 | i := c.bw.Buffered() 353 | 354 | j := len(p) 355 | if j > c.bw.Available() { 356 | j = c.bw.Available() 357 | } 358 | 359 | _, err := c.bw.Write(p[:j]) 360 | if err != nil { 361 | return n, err 362 | } 363 | 364 | maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey) 365 | 366 | p = p[j:] 367 | n += j 368 | } 369 | 370 | return n, nil 371 | } 372 | 373 | // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer 374 | // and returns it. 375 | func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { 376 | var writeBuf []byte 377 | bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) { 378 | writeBuf = p2[:cap(p2)] 379 | return len(p2), nil 380 | })) 381 | 382 | bw.WriteByte(0) 383 | bw.Flush() 384 | 385 | bw.Reset(w) 386 | 387 | return writeBuf 388 | } 389 | 390 | func (c *Conn) writeError(code StatusCode, err error) { 391 | c.writeClose(code, err.Error()) 392 | } 393 | -------------------------------------------------------------------------------- /ws_js_test.go: -------------------------------------------------------------------------------- 1 | package websocket_test 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "github.com/coder/websocket" 11 | "github.com/coder/websocket/internal/test/assert" 12 | "github.com/coder/websocket/internal/test/wstest" 13 | ) 14 | 15 | func TestWasm(t *testing.T) { 16 | t.Parallel() 17 | 18 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 19 | defer cancel() 20 | 21 | c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ 22 | Subprotocols: []string{"echo"}, 23 | }) 24 | assert.Success(t, err) 25 | defer c.Close(websocket.StatusInternalError, "") 26 | 27 | assert.Equal(t, "subprotocol", "echo", c.Subprotocol()) 28 | assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode) 29 | 30 | c.SetReadLimit(65536) 31 | for i := 0; i < 10; i++ { 32 | err = wstest.Echo(ctx, c, 65536) 33 | assert.Success(t, err) 34 | } 35 | 36 | err = c.Close(websocket.StatusNormalClosure, "") 37 | assert.Success(t, err) 38 | } 39 | 40 | func TestWasmDialTimeout(t *testing.T) { 41 | t.Parallel() 42 | 43 | ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) 44 | defer cancel() 45 | 46 | beforeDial := time.Now() 47 | _, _, err := websocket.Dial(ctx, "ws://example.com:9893", &websocket.DialOptions{ 48 | Subprotocols: []string{"echo"}, 49 | }) 50 | assert.Error(t, err) 51 | if time.Since(beforeDial) >= time.Second { 52 | t.Fatal("wasm context dial timeout is not working", time.Since(beforeDial)) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /wsjson/wsjson.go: -------------------------------------------------------------------------------- 1 | // Package wsjson provides helpers for reading and writing JSON messages. 2 | package wsjson // import "github.com/coder/websocket/wsjson" 3 | 4 | import ( 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | 9 | "github.com/coder/websocket" 10 | "github.com/coder/websocket/internal/bpool" 11 | "github.com/coder/websocket/internal/errd" 12 | "github.com/coder/websocket/internal/util" 13 | ) 14 | 15 | // Read reads a JSON message from c into v. 16 | // It will reuse buffers in between calls to avoid allocations. 17 | func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { 18 | return read(ctx, c, v) 19 | } 20 | 21 | func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { 22 | defer errd.Wrap(&err, "failed to read JSON message") 23 | 24 | _, r, err := c.Reader(ctx) 25 | if err != nil { 26 | return err 27 | } 28 | 29 | b := bpool.Get() 30 | defer bpool.Put(b) 31 | 32 | _, err = b.ReadFrom(r) 33 | if err != nil { 34 | return err 35 | } 36 | 37 | err = json.Unmarshal(b.Bytes(), v) 38 | if err != nil { 39 | c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") 40 | return fmt.Errorf("failed to unmarshal JSON: %w", err) 41 | } 42 | 43 | return nil 44 | } 45 | 46 | // Write writes the JSON message v to c. 47 | // It will reuse buffers in between calls to avoid allocations. 48 | func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { 49 | return write(ctx, c, v) 50 | } 51 | 52 | func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { 53 | defer errd.Wrap(&err, "failed to write JSON message") 54 | 55 | // json.Marshal cannot reuse buffers between calls as it has to return 56 | // a copy of the byte slice but Encoder does as it directly writes to w. 57 | err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) { 58 | err := c.Write(ctx, websocket.MessageText, p) 59 | if err != nil { 60 | return 0, err 61 | } 62 | return len(p), nil 63 | })).Encode(v) 64 | if err != nil { 65 | return fmt.Errorf("failed to marshal JSON: %w", err) 66 | } 67 | return nil 68 | } 69 | -------------------------------------------------------------------------------- /wsjson/wsjson_test.go: -------------------------------------------------------------------------------- 1 | package wsjson_test 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "strconv" 7 | "testing" 8 | 9 | "github.com/coder/websocket/internal/test/xrand" 10 | ) 11 | 12 | func BenchmarkJSON(b *testing.B) { 13 | sizes := []int{ 14 | 8, 15 | 16, 16 | 32, 17 | 128, 18 | 256, 19 | 512, 20 | 1024, 21 | 2048, 22 | 4096, 23 | 8192, 24 | 16384, 25 | } 26 | 27 | b.Run("json.Encoder", func(b *testing.B) { 28 | for _, size := range sizes { 29 | b.Run(strconv.Itoa(size), func(b *testing.B) { 30 | msg := xrand.String(size) 31 | b.SetBytes(int64(size)) 32 | b.ReportAllocs() 33 | b.ResetTimer() 34 | for i := 0; i < b.N; i++ { 35 | json.NewEncoder(io.Discard).Encode(msg) 36 | } 37 | }) 38 | } 39 | }) 40 | b.Run("json.Marshal", func(b *testing.B) { 41 | for _, size := range sizes { 42 | b.Run(strconv.Itoa(size), func(b *testing.B) { 43 | msg := xrand.String(size) 44 | b.SetBytes(int64(size)) 45 | b.ReportAllocs() 46 | b.ResetTimer() 47 | for i := 0; i < b.N; i++ { 48 | json.Marshal(msg) 49 | } 50 | }) 51 | } 52 | }) 53 | } 54 | --------------------------------------------------------------------------------