├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ ├── daily.yml │ └── static.yml ├── LICENSE.txt ├── Makefile ├── README.md ├── accept.go ├── accept_test.go ├── autobahn_test.go ├── ci ├── bench.sh ├── fmt.sh ├── lint.sh ├── out │ └── .gitignore └── test.sh ├── close.go ├── close_test.go ├── compress.go ├── compress_test.go ├── conn.go ├── conn_test.go ├── dial.go ├── dial_test.go ├── doc.go ├── example_test.go ├── export_test.go ├── frame.go ├── frame_test.go ├── go.mod ├── go.sum ├── hijack.go ├── hijack_go120_test.go ├── internal ├── bpool │ └── bpool.go ├── errd │ └── wrap.go ├── examples │ ├── README.md │ ├── chat │ │ ├── README.md │ │ ├── chat.go │ │ ├── chat_test.go │ │ ├── index.css │ │ ├── index.html │ │ ├── index.js │ │ └── main.go │ ├── echo │ │ ├── README.md │ │ ├── main.go │ │ ├── server.go │ │ └── server_test.go │ ├── go.mod │ └── go.sum ├── test │ ├── assert │ │ └── assert.go │ ├── doc.go │ ├── wstest │ │ ├── echo.go │ │ └── pipe.go │ └── xrand │ │ └── xrand.go ├── thirdparty │ ├── doc.go │ ├── frame_test.go │ ├── gin_test.go │ ├── go.mod │ └── go.sum ├── util │ └── util.go ├── wsjs │ └── wsjs_js.go └── xsync │ ├── go.go │ └── go_test.go ├── main_test.go ├── mask.go ├── mask_amd64.s ├── mask_arm64.s ├── mask_asm.go ├── mask_asm_test.go ├── mask_go.go ├── mask_test.go ├── netconn.go ├── netconn_js.go ├── netconn_notjs.go ├── read.go ├── stringer.go ├── write.go ├── ws_js.go ├── ws_js_test.go └── wsjson ├── wsjson.go └── wsjson_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Track in case we ever add dependencies. 4 | - package-ecosystem: 'gomod' 5 | directory: '/' 6 | schedule: 7 | interval: 'weekly' 8 | commit-message: 9 | prefix: 'chore' 10 | 11 | # Keep example and test/benchmark deps up-to-date. 12 | - package-ecosystem: 'gomod' 13 | directories: 14 | - '/internal/examples' 15 | - '/internal/thirdparty' 16 | schedule: 17 | interval: 'monthly' 18 | commit-message: 19 | prefix: 'chore' 20 | labels: [] 21 | groups: 22 | internal-deps: 23 | patterns: 24 | - '*' 25 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | branches: 8 | - master 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | fmt: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-go@v5 19 | with: 20 | go-version-file: ./go.mod 21 | - run: make fmt 22 | 23 | lint: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | - run: go version 28 | - uses: actions/setup-go@v5 29 | with: 30 | go-version-file: ./go.mod 31 | - run: make lint 32 | 33 | test: 34 | runs-on: ubuntu-latest 35 | steps: 36 | - name: Disable AppArmor 37 | if: runner.os == 'Linux' 38 | run: | 39 | # Disable AppArmor for Ubuntu 23.10+. 40 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 41 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 42 | - uses: actions/checkout@v4 43 | - uses: actions/setup-go@v5 44 | with: 45 | go-version-file: ./go.mod 46 | - run: make test 47 | - uses: actions/upload-artifact@v4 48 | with: 49 | name: coverage.html 50 | path: ./ci/out/coverage.html 51 | 52 | bench: 53 | runs-on: ubuntu-latest 54 | steps: 55 | - uses: actions/checkout@v4 56 | - uses: actions/setup-go@v5 57 | with: 58 | go-version-file: ./go.mod 59 | - run: make bench 60 | -------------------------------------------------------------------------------- /.github/workflows/daily.yml: -------------------------------------------------------------------------------- 1 | name: daily 2 | on: 3 | workflow_dispatch: 4 | schedule: 5 | - cron: '42 0 * * *' # daily at 00:42 6 | concurrency: 7 | group: ${{ github.workflow }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | bench: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-go@v5 16 | with: 17 | go-version-file: ./go.mod 18 | - run: AUTOBAHN=1 make bench 19 | test: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Disable AppArmor 23 | if: runner.os == 'Linux' 24 | run: | 25 | # Disable AppArmor for Ubuntu 23.10+. 26 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 27 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 28 | - uses: actions/checkout@v4 29 | - uses: actions/setup-go@v5 30 | with: 31 | go-version-file: ./go.mod 32 | - run: AUTOBAHN=1 make test 33 | - uses: actions/upload-artifact@v4 34 | with: 35 | name: coverage.html 36 | path: ./ci/out/coverage.html 37 | bench-dev: 38 | runs-on: ubuntu-latest 39 | steps: 40 | - uses: actions/checkout@v4 41 | with: 42 | ref: dev 43 | - uses: actions/setup-go@v5 44 | with: 45 | go-version-file: ./go.mod 46 | - run: AUTOBAHN=1 make bench 47 | test-dev: 48 | runs-on: ubuntu-latest 49 | steps: 50 | - name: Disable AppArmor 51 | if: runner.os == 'Linux' 52 | run: | 53 | # Disable AppArmor for Ubuntu 23.10+. 54 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 55 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 56 | - uses: actions/checkout@v4 57 | with: 58 | ref: dev 59 | - uses: actions/setup-go@v5 60 | with: 61 | go-version-file: ./go.mod 62 | - run: AUTOBAHN=1 make test 63 | - uses: actions/upload-artifact@v4 64 | with: 65 | name: coverage-dev.html 66 | path: ./ci/out/coverage.html 67 | -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | name: static 2 | 3 | on: 4 | push: 5 | branches: ['master'] 6 | workflow_dispatch: 7 | 8 | # Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages. 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | concurrency: 15 | group: pages 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | deploy: 20 | environment: 21 | name: github-pages 22 | url: ${{ steps.deployment.outputs.page_url }} 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Disable AppArmor 26 | if: runner.os == 'Linux' 27 | run: | 28 | # Disable AppArmor for Ubuntu 23.10+. 29 | # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md 30 | echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns 31 | - name: Checkout 32 | uses: actions/checkout@v4 33 | - name: Setup Pages 34 | uses: actions/configure-pages@v5 35 | - name: Setup Go 36 | uses: actions/setup-go@v5 37 | with: 38 | go-version-file: ./go.mod 39 | - name: Generate coverage and badge 40 | run: | 41 | make test 42 | mkdir -p ./ci/out/static 43 | cp ./ci/out/coverage.html ./ci/out/static/coverage.html 44 | percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%') 45 | wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success" 46 | - name: Upload artifact 47 | uses: actions/upload-pages-artifact@v3 48 | with: 49 | path: ./ci/out/static/ 50 | - name: Deploy to GitHub Pages 51 | id: deployment 52 | uses: actions/deploy-pages@v4 53 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Coder 2 | 3 | Permission to use, copy, modify, and distribute this software for any 4 | purpose with or without fee is hereby granted, provided that the above 5 | copyright notice and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all 2 | all: fmt lint test 3 | 4 | .PHONY: fmt 5 | fmt: 6 | ./ci/fmt.sh 7 | 8 | .PHONY: lint 9 | lint: 10 | ./ci/lint.sh 11 | 12 | .PHONY: test 13 | test: 14 | ./ci/test.sh 15 | 16 | .PHONY: bench 17 | bench: 18 | ./ci/bench.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # websocket 2 | 3 | [](https://pkg.go.dev/github.com/coder/websocket) 4 | [](https://coder.github.io/websocket/coverage.html) 5 | 6 | websocket is a minimal and idiomatic WebSocket library for Go. 7 | 8 | ## Install 9 | 10 | ```sh 11 | go get github.com/coder/websocket 12 | ``` 13 | 14 | > [!NOTE] 15 | > Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket). 16 | > We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from 17 | > 2019 to 2024. 18 | 19 | ## Highlights 20 | 21 | - Minimal and idiomatic API 22 | - First class [context.Context](https://blog.golang.org/context) support 23 | - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) 24 | - [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports) 25 | - JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage 26 | - Zero alloc reads and writes 27 | - Concurrent writes 28 | - [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close) 29 | - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper 30 | - [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API 31 | - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression 32 | - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections 33 | - Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm) 34 | 35 | ## Roadmap 36 | 37 | See GitHub issues for minor issues but the major future enhancements are: 38 | 39 | - [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217) 40 | - [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340) 41 | - [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267) 42 | - [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246) 43 | - [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209) 44 | - [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16) 45 | - WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster 46 | - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) 47 | - [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402) 48 | 49 | ## Examples 50 | 51 | For a production quality example that demonstrates the complete API, see the 52 | [echo example](./internal/examples/echo). 53 | 54 | For a full stack example, see the [chat example](./internal/examples/chat). 55 | 56 | ### Server 57 | 58 | ```go 59 | http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { 60 | c, err := websocket.Accept(w, r, nil) 61 | if err != nil { 62 | // ... 63 | } 64 | defer c.CloseNow() 65 | 66 | // Set the context as needed. Use of r.Context() is not recommended 67 | // to avoid surprising behavior (see http.Hijacker). 68 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 69 | defer cancel() 70 | 71 | var v interface{} 72 | err = wsjson.Read(ctx, c, &v) 73 | if err != nil { 74 | // ... 75 | } 76 | 77 | log.Printf("received: %v", v) 78 | 79 | c.Close(websocket.StatusNormalClosure, "") 80 | }) 81 | ``` 82 | 83 | ### Client 84 | 85 | ```go 86 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 87 | defer cancel() 88 | 89 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 90 | if err != nil { 91 | // ... 92 | } 93 | defer c.CloseNow() 94 | 95 | err = wsjson.Write(ctx, c, "hi") 96 | if err != nil { 97 | // ... 98 | } 99 | 100 | c.Close(websocket.StatusNormalClosure, "") 101 | ``` 102 | 103 | ## Comparison 104 | 105 | ### gorilla/websocket 106 | 107 | Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): 108 | 109 | - Mature and widely used 110 | - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) 111 | - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) 112 | - No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection. 113 | - Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411) 114 | 115 | Advantages of github.com/coder/websocket: 116 | 117 | - Minimal and idiomatic API 118 | - Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. 119 | - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper 120 | - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) 121 | - Full [context.Context](https://blog.golang.org/context) support 122 | - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) 123 | - Will enable easy HTTP/2 support in the future 124 | - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. 125 | - Concurrent writes 126 | - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) 127 | - Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API 128 | - Gorilla requires registering a pong callback before sending a Ping 129 | - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) 130 | - Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage 131 | - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go 132 | - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). 133 | Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326) 134 | - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support 135 | - Gorilla only supports no context takeover mode 136 | - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) 137 | 138 | #### golang.org/x/net/websocket 139 | 140 | [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. 141 | See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). 142 | 143 | The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning 144 | to github.com/coder/websocket. 145 | 146 | #### gobwas/ws 147 | 148 | [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used 149 | in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). 150 | 151 | However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws 152 | 153 | When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. 154 | 155 | #### lesismal/nbio 156 | 157 | [lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is 158 | event driven for performance reasons. 159 | 160 | However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio 161 | 162 | When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. 163 | -------------------------------------------------------------------------------- /accept.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "crypto/sha1" 10 | "encoding/base64" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "log" 15 | "net/http" 16 | "net/textproto" 17 | "net/url" 18 | "path" 19 | "strings" 20 | 21 | "github.com/coder/websocket/internal/errd" 22 | ) 23 | 24 | // AcceptOptions represents Accept's options. 25 | type AcceptOptions struct { 26 | // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. 27 | // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to 28 | // reject it, close the connection when c.Subprotocol() == "". 29 | Subprotocols []string 30 | 31 | // InsecureSkipVerify is used to disable Accept's origin verification behaviour. 32 | // 33 | // You probably want to use OriginPatterns instead. 34 | InsecureSkipVerify bool 35 | 36 | // OriginPatterns lists the host patterns for authorized origins. 37 | // The request host is always authorized. 38 | // Use this to enable cross origin WebSockets. 39 | // 40 | // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. 41 | // In such a case, example.com is the origin and chat.example.com is the request host. 42 | // One would set this field to []string{"example.com"} to authorize example.com to connect. 43 | // 44 | // Each pattern is matched case insensitively against the request origin host 45 | // with path.Match. 46 | // See https://golang.org/pkg/path/#Match 47 | // 48 | // Please ensure you understand the ramifications of enabling this. 49 | // If used incorrectly your WebSocket server will be open to CSRF attacks. 50 | // 51 | // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead 52 | // to bring attention to the danger of such a setting. 53 | OriginPatterns []string 54 | 55 | // CompressionMode controls the compression mode. 56 | // Defaults to CompressionDisabled. 57 | // 58 | // See docs on CompressionMode for details. 59 | CompressionMode CompressionMode 60 | 61 | // CompressionThreshold controls the minimum size of a message before compression is applied. 62 | // 63 | // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes 64 | // for CompressionContextTakeover. 65 | CompressionThreshold int 66 | 67 | // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. 68 | // 69 | // The payload contains the application data of the ping frame. 70 | // If the callback returns false, the subsequent pong frame will not be sent. 71 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 72 | OnPingReceived func(ctx context.Context, payload []byte) bool 73 | 74 | // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. 75 | // 76 | // The payload contains the application data of the pong frame. 77 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 78 | // 79 | // Unlike OnPingReceived, this callback does not return a value because a pong frame 80 | // is a response to a ping and does not trigger any further frame transmission. 81 | OnPongReceived func(ctx context.Context, payload []byte) 82 | } 83 | 84 | func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { 85 | var o AcceptOptions 86 | if opts != nil { 87 | o = *opts 88 | } 89 | return &o 90 | } 91 | 92 | // Accept accepts a WebSocket handshake from a client and upgrades the 93 | // the connection to a WebSocket. 94 | // 95 | // Accept will not allow cross origin requests by default. 96 | // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. 97 | // 98 | // Accept will write a response to w on all errors. 99 | // 100 | // Note that using the http.Request Context after Accept returns may lead to 101 | // unexpected behavior (see http.Hijacker). 102 | func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { 103 | return accept(w, r, opts) 104 | } 105 | 106 | func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { 107 | defer errd.Wrap(&err, "failed to accept WebSocket connection") 108 | 109 | errCode, err := verifyClientRequest(w, r) 110 | if err != nil { 111 | http.Error(w, err.Error(), errCode) 112 | return nil, err 113 | } 114 | 115 | opts = opts.cloneWithDefaults() 116 | if !opts.InsecureSkipVerify { 117 | err = authenticateOrigin(r, opts.OriginPatterns) 118 | if err != nil { 119 | if errors.Is(err, path.ErrBadPattern) { 120 | log.Printf("websocket: %v", err) 121 | err = errors.New(http.StatusText(http.StatusForbidden)) 122 | } 123 | http.Error(w, err.Error(), http.StatusForbidden) 124 | return nil, err 125 | } 126 | } 127 | 128 | hj, ok := hijacker(w) 129 | if !ok { 130 | err = errors.New("http.ResponseWriter does not implement http.Hijacker") 131 | http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) 132 | return nil, err 133 | } 134 | 135 | w.Header().Set("Upgrade", "websocket") 136 | w.Header().Set("Connection", "Upgrade") 137 | 138 | key := r.Header.Get("Sec-WebSocket-Key") 139 | w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) 140 | 141 | subproto := selectSubprotocol(r, opts.Subprotocols) 142 | if subproto != "" { 143 | w.Header().Set("Sec-WebSocket-Protocol", subproto) 144 | } 145 | 146 | copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode) 147 | if ok { 148 | w.Header().Set("Sec-WebSocket-Extensions", copts.String()) 149 | } 150 | 151 | w.WriteHeader(http.StatusSwitchingProtocols) 152 | // See https://github.com/nhooyr/websocket/issues/166 153 | if ginWriter, ok := w.(interface { 154 | WriteHeaderNow() 155 | }); ok { 156 | ginWriter.WriteHeaderNow() 157 | } 158 | 159 | netConn, brw, err := hj.Hijack() 160 | if err != nil { 161 | err = fmt.Errorf("failed to hijack connection: %w", err) 162 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 163 | return nil, err 164 | } 165 | 166 | // https://github.com/golang/go/issues/32314 167 | b, _ := brw.Reader.Peek(brw.Reader.Buffered()) 168 | brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) 169 | 170 | return newConn(connConfig{ 171 | subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), 172 | rwc: netConn, 173 | client: false, 174 | copts: copts, 175 | flateThreshold: opts.CompressionThreshold, 176 | onPingReceived: opts.OnPingReceived, 177 | onPongReceived: opts.OnPongReceived, 178 | 179 | br: brw.Reader, 180 | bw: brw.Writer, 181 | }), nil 182 | } 183 | 184 | func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { 185 | if !r.ProtoAtLeast(1, 1) { 186 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) 187 | } 188 | 189 | if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { 190 | w.Header().Set("Connection", "Upgrade") 191 | w.Header().Set("Upgrade", "websocket") 192 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) 193 | } 194 | 195 | if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { 196 | w.Header().Set("Connection", "Upgrade") 197 | w.Header().Set("Upgrade", "websocket") 198 | return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) 199 | } 200 | 201 | if r.Method != "GET" { 202 | return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) 203 | } 204 | 205 | if r.Header.Get("Sec-WebSocket-Version") != "13" { 206 | w.Header().Set("Sec-WebSocket-Version", "13") 207 | return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) 208 | } 209 | 210 | websocketSecKeys := r.Header.Values("Sec-WebSocket-Key") 211 | if len(websocketSecKeys) == 0 { 212 | return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") 213 | } 214 | 215 | if len(websocketSecKeys) > 1 { 216 | return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers") 217 | } 218 | 219 | // The RFC states to remove any leading or trailing whitespace. 220 | websocketSecKey := strings.TrimSpace(websocketSecKeys[0]) 221 | if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 { 222 | return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey) 223 | } 224 | 225 | return 0, nil 226 | } 227 | 228 | func authenticateOrigin(r *http.Request, originHosts []string) error { 229 | origin := r.Header.Get("Origin") 230 | if origin == "" { 231 | return nil 232 | } 233 | 234 | u, err := url.Parse(origin) 235 | if err != nil { 236 | return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) 237 | } 238 | 239 | if strings.EqualFold(r.Host, u.Host) { 240 | return nil 241 | } 242 | 243 | for _, hostPattern := range originHosts { 244 | matched, err := match(hostPattern, u.Host) 245 | if err != nil { 246 | return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err) 247 | } 248 | if matched { 249 | return nil 250 | } 251 | } 252 | if u.Host == "" { 253 | return fmt.Errorf("request Origin %q is not a valid URL with a host", origin) 254 | } 255 | return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host) 256 | } 257 | 258 | func match(pattern, s string) (bool, error) { 259 | return path.Match(strings.ToLower(pattern), strings.ToLower(s)) 260 | } 261 | 262 | func selectSubprotocol(r *http.Request, subprotocols []string) string { 263 | cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") 264 | for _, sp := range subprotocols { 265 | for _, cp := range cps { 266 | if strings.EqualFold(sp, cp) { 267 | return cp 268 | } 269 | } 270 | } 271 | return "" 272 | } 273 | 274 | func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) { 275 | if mode == CompressionDisabled { 276 | return nil, false 277 | } 278 | for _, ext := range extensions { 279 | switch ext.name { 280 | // We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs... 281 | // See https://github.com/nhooyr/websocket/issues/218 282 | case "permessage-deflate": 283 | copts, ok := acceptDeflate(ext, mode) 284 | if ok { 285 | return copts, true 286 | } 287 | } 288 | } 289 | return nil, false 290 | } 291 | 292 | func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) { 293 | copts := mode.opts() 294 | for _, p := range ext.params { 295 | switch p { 296 | case "client_no_context_takeover": 297 | copts.clientNoContextTakeover = true 298 | continue 299 | case "server_no_context_takeover": 300 | copts.serverNoContextTakeover = true 301 | continue 302 | case "client_max_window_bits", 303 | "server_max_window_bits=15": 304 | continue 305 | } 306 | 307 | if strings.HasPrefix(p, "client_max_window_bits=") { 308 | // We can't adjust the deflate window, but decoding with a larger window is acceptable. 309 | continue 310 | } 311 | return nil, false 312 | } 313 | return copts, true 314 | } 315 | 316 | func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { 317 | for _, t := range headerTokens(h, key) { 318 | if strings.EqualFold(t, token) { 319 | return true 320 | } 321 | } 322 | return false 323 | } 324 | 325 | type websocketExtension struct { 326 | name string 327 | params []string 328 | } 329 | 330 | func websocketExtensions(h http.Header) []websocketExtension { 331 | var exts []websocketExtension 332 | extStrs := headerTokens(h, "Sec-WebSocket-Extensions") 333 | for _, extStr := range extStrs { 334 | if extStr == "" { 335 | continue 336 | } 337 | 338 | vals := strings.Split(extStr, ";") 339 | for i := range vals { 340 | vals[i] = strings.TrimSpace(vals[i]) 341 | } 342 | 343 | e := websocketExtension{ 344 | name: vals[0], 345 | params: vals[1:], 346 | } 347 | 348 | exts = append(exts, e) 349 | } 350 | return exts 351 | } 352 | 353 | func headerTokens(h http.Header, key string) []string { 354 | key = textproto.CanonicalMIMEHeaderKey(key) 355 | var tokens []string 356 | for _, v := range h[key] { 357 | v = strings.TrimSpace(v) 358 | for _, t := range strings.Split(v, ",") { 359 | t = strings.TrimSpace(t) 360 | tokens = append(tokens, t) 361 | } 362 | } 363 | return tokens 364 | } 365 | 366 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 367 | 368 | func secWebSocketAccept(secWebSocketKey string) string { 369 | h := sha1.New() 370 | h.Write([]byte(secWebSocketKey)) 371 | h.Write(keyGUID) 372 | 373 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) 374 | } 375 | -------------------------------------------------------------------------------- /autobahn_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket_test 5 | 6 | import ( 7 | "context" 8 | "encoding/json" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "os" 14 | "os/exec" 15 | "strconv" 16 | "strings" 17 | "testing" 18 | "time" 19 | 20 | "github.com/coder/websocket" 21 | "github.com/coder/websocket/internal/errd" 22 | "github.com/coder/websocket/internal/test/assert" 23 | "github.com/coder/websocket/internal/test/wstest" 24 | "github.com/coder/websocket/internal/util" 25 | ) 26 | 27 | var excludedAutobahnCases = []string{ 28 | // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just 29 | // more performance overhead. 30 | "6.*", "7.5.1", 31 | 32 | // We skip the tests related to requestMaxWindowBits as that is unimplemented due 33 | // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 34 | "13.3.*", "13.4.*", "13.5.*", "13.6.*", 35 | } 36 | 37 | var autobahnCases = []string{"*"} 38 | 39 | // Used to run individual test cases. autobahnCases runs only those cases matched 40 | // and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases 41 | // is niled. 42 | var onlyAutobahnCases = []string{} 43 | 44 | func TestAutobahn(t *testing.T) { 45 | t.Parallel() 46 | 47 | if os.Getenv("AUTOBAHN") == "" { 48 | t.SkipNow() 49 | } 50 | 51 | if os.Getenv("AUTOBAHN") == "fast" { 52 | // These are the slow tests. 53 | excludedAutobahnCases = append(excludedAutobahnCases, 54 | "9.*", "12.*", "13.*", 55 | ) 56 | } 57 | 58 | if len(onlyAutobahnCases) > 0 { 59 | excludedAutobahnCases = []string{} 60 | autobahnCases = onlyAutobahnCases 61 | } 62 | 63 | ctx, cancel := context.WithTimeout(context.Background(), time.Hour) 64 | defer cancel() 65 | 66 | wstestURL, closeFn, err := wstestServer(t, ctx) 67 | assert.Success(t, err) 68 | defer func() { 69 | assert.Success(t, closeFn()) 70 | }() 71 | 72 | err = waitWS(ctx, wstestURL) 73 | assert.Success(t, err) 74 | 75 | cases, err := wstestCaseCount(ctx, wstestURL) 76 | assert.Success(t, err) 77 | 78 | t.Run("cases", func(t *testing.T) { 79 | for i := 1; i <= cases; i++ { 80 | i := i 81 | t.Run("", func(t *testing.T) { 82 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) 83 | defer cancel() 84 | 85 | c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{ 86 | CompressionMode: websocket.CompressionContextTakeover, 87 | }) 88 | assert.Success(t, err) 89 | err = wstest.EchoLoop(ctx, c) 90 | t.Logf("echoLoop: %v", err) 91 | }) 92 | } 93 | }) 94 | 95 | c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil) 96 | assert.Success(t, err) 97 | c.Close(websocket.StatusNormalClosure, "") 98 | 99 | checkWSTestIndex(t, "./ci/out/autobahn-report/index.json") 100 | } 101 | 102 | func waitWS(ctx context.Context, url string) error { 103 | ctx, cancel := context.WithTimeout(ctx, time.Second*5) 104 | defer cancel() 105 | 106 | for ctx.Err() == nil { 107 | c, _, err := websocket.Dial(ctx, url, nil) 108 | if err != nil { 109 | continue 110 | } 111 | c.Close(websocket.StatusNormalClosure, "") 112 | return nil 113 | } 114 | 115 | return ctx.Err() 116 | } 117 | 118 | func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) { 119 | defer errd.Wrap(&err, "failed to start autobahn wstest server") 120 | 121 | serverAddr, err := unusedListenAddr() 122 | if err != nil { 123 | return "", nil, err 124 | } 125 | _, serverPort, err := net.SplitHostPort(serverAddr) 126 | if err != nil { 127 | return "", nil, err 128 | } 129 | 130 | url = "ws://" + serverAddr 131 | const outDir = "ci/out/autobahn-report" 132 | 133 | specFile, err := tempJSONFile(map[string]interface{}{ 134 | "url": url, 135 | "outdir": outDir, 136 | "cases": autobahnCases, 137 | "exclude-cases": excludedAutobahnCases, 138 | }) 139 | if err != nil { 140 | return "", nil, fmt.Errorf("failed to write spec: %w", err) 141 | } 142 | 143 | ctx, cancel := context.WithTimeout(ctx, time.Hour) 144 | defer func() { 145 | if err != nil { 146 | cancel() 147 | } 148 | }() 149 | 150 | dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite") 151 | dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) { 152 | tb.Log(string(p)) 153 | return len(p), nil 154 | }) 155 | dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) { 156 | tb.Log(string(p)) 157 | return len(p), nil 158 | }) 159 | tb.Log(dockerPull) 160 | err = dockerPull.Run() 161 | if err != nil { 162 | return "", nil, fmt.Errorf("failed to pull docker image: %w", err) 163 | } 164 | 165 | wd, err := os.Getwd() 166 | if err != nil { 167 | return "", nil, err 168 | } 169 | 170 | var args []string 171 | args = append(args, "run", "-i", "--rm", 172 | "-v", fmt.Sprintf("%s:%[1]s", specFile), 173 | "-v", fmt.Sprintf("%s/ci:/ci", wd), 174 | fmt.Sprintf("-p=%s:%s", serverAddr, serverPort), 175 | "crossbario/autobahn-testsuite", 176 | ) 177 | args = append(args, "wstest", "--mode", "fuzzingserver", "--spec", specFile, 178 | // Disables some server that runs as part of fuzzingserver mode. 179 | // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 180 | "--webport=0", 181 | ) 182 | wstest := exec.CommandContext(ctx, "docker", args...) 183 | wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) { 184 | tb.Log(string(p)) 185 | return len(p), nil 186 | }) 187 | wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) { 188 | tb.Log(string(p)) 189 | return len(p), nil 190 | }) 191 | tb.Log(wstest) 192 | err = wstest.Start() 193 | if err != nil { 194 | return "", nil, fmt.Errorf("failed to start wstest: %w", err) 195 | } 196 | 197 | return url, func() error { 198 | err = wstest.Process.Kill() 199 | if err != nil { 200 | return fmt.Errorf("failed to kill wstest: %w", err) 201 | } 202 | err = wstest.Wait() 203 | var ee *exec.ExitError 204 | if errors.As(err, &ee) && ee.ExitCode() == -1 { 205 | return nil 206 | } 207 | return err 208 | }, nil 209 | } 210 | 211 | func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { 212 | defer errd.Wrap(&err, "failed to get case count") 213 | 214 | c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) 215 | if err != nil { 216 | return 0, err 217 | } 218 | defer c.Close(websocket.StatusInternalError, "") 219 | 220 | _, r, err := c.Reader(ctx) 221 | if err != nil { 222 | return 0, err 223 | } 224 | b, err := io.ReadAll(r) 225 | if err != nil { 226 | return 0, err 227 | } 228 | cases, err = strconv.Atoi(string(b)) 229 | if err != nil { 230 | return 0, err 231 | } 232 | 233 | c.Close(websocket.StatusNormalClosure, "") 234 | 235 | return cases, nil 236 | } 237 | 238 | func checkWSTestIndex(t *testing.T, path string) { 239 | wstestOut, err := os.ReadFile(path) 240 | assert.Success(t, err) 241 | 242 | var indexJSON map[string]map[string]struct { 243 | Behavior string `json:"behavior"` 244 | BehaviorClose string `json:"behaviorClose"` 245 | } 246 | err = json.Unmarshal(wstestOut, &indexJSON) 247 | assert.Success(t, err) 248 | 249 | for _, tests := range indexJSON { 250 | for test, result := range tests { 251 | t.Run(test, func(t *testing.T) { 252 | switch result.BehaviorClose { 253 | case "OK", "INFORMATIONAL": 254 | default: 255 | t.Errorf("bad close behaviour") 256 | } 257 | 258 | switch result.Behavior { 259 | case "OK", "NON-STRICT", "INFORMATIONAL": 260 | default: 261 | t.Errorf("failed") 262 | } 263 | }) 264 | } 265 | } 266 | 267 | if t.Failed() { 268 | htmlPath := strings.Replace(path, ".json", ".html", 1) 269 | t.Errorf("detected autobahn violation, see %q", htmlPath) 270 | } 271 | } 272 | 273 | func unusedListenAddr() (_ string, err error) { 274 | defer errd.Wrap(&err, "failed to get unused listen address") 275 | l, err := net.Listen("tcp", "localhost:0") 276 | if err != nil { 277 | return "", err 278 | } 279 | l.Close() 280 | return l.Addr().String(), nil 281 | } 282 | 283 | func tempJSONFile(v interface{}) (string, error) { 284 | f, err := os.CreateTemp("", "temp.json") 285 | if err != nil { 286 | return "", fmt.Errorf("temp file: %w", err) 287 | } 288 | defer f.Close() 289 | 290 | e := json.NewEncoder(f) 291 | e.SetIndent("", "\t") 292 | err = e.Encode(v) 293 | if err != nil { 294 | return "", fmt.Errorf("json encode: %w", err) 295 | } 296 | 297 | err = f.Close() 298 | if err != nil { 299 | return "", fmt.Errorf("close temp file: %w", err) 300 | } 301 | 302 | return f.Name(), nil 303 | } 304 | -------------------------------------------------------------------------------- /ci/bench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | go test --run=^$ --bench=. --benchmem "$@" ./... 6 | # For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test 7 | ( 8 | cd ./internal/thirdparty 9 | go test --run=^$ --bench=. --benchmem "$@" . 10 | 11 | GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" . 12 | if [ "$#" -eq 0 ]; then 13 | if [ "${CI-}" ]; then 14 | sudo apt-get update 15 | sudo apt-get install -y qemu-user-static 16 | ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 17 | fi 18 | qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem 19 | fi 20 | ) 21 | -------------------------------------------------------------------------------- /ci/fmt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | X_TOOLS_VERSION=v0.31.0 6 | 7 | go mod tidy 8 | (cd ./internal/thirdparty && go mod tidy) 9 | (cd ./internal/examples && go mod tidy) 10 | gofmt -w -s . 11 | go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" . 12 | 13 | git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \ 14 | --check \ 15 | --log-level=warn \ 16 | --print-width=90 \ 17 | --no-semi \ 18 | --single-quote \ 19 | --arrow-parens=avoid 20 | 21 | go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go 22 | 23 | if [ "${CI-}" ]; then 24 | git diff --exit-code 25 | fi 26 | -------------------------------------------------------------------------------- /ci/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | STATICCHECK_VERSION=v0.6.1 6 | GOVULNCHECK_VERSION=v1.1.4 7 | 8 | go vet ./... 9 | GOOS=js GOARCH=wasm go vet ./... 10 | 11 | go install honnef.co/go/tools/cmd/staticcheck@${STATICCHECK_VERSION} 12 | staticcheck ./... 13 | GOOS=js GOARCH=wasm staticcheck ./... 14 | 15 | govulncheck() { 16 | tmpf=$(mktemp) 17 | if ! command govulncheck "$@" >"$tmpf" 2>&1; then 18 | cat "$tmpf" 19 | fi 20 | } 21 | go install golang.org/x/vuln/cmd/govulncheck@${GOVULNCHECK_VERSION} 22 | govulncheck ./... 23 | GOOS=js GOARCH=wasm govulncheck ./... 24 | 25 | ( 26 | cd ./internal/examples 27 | go vet ./... 28 | staticcheck ./... 29 | govulncheck ./... 30 | ) 31 | ( 32 | cd ./internal/thirdparty 33 | go vet ./... 34 | staticcheck ./... 35 | govulncheck ./... 36 | ) 37 | -------------------------------------------------------------------------------- /ci/out/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /ci/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eu 3 | cd -- "$(dirname "$0")/.." 4 | 5 | ( 6 | cd ./internal/examples 7 | go test "$@" ./... 8 | ) 9 | ( 10 | cd ./internal/thirdparty 11 | go test "$@" ./... 12 | ) 13 | 14 | ( 15 | GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" . 16 | if [ "$#" -eq 0 ]; then 17 | if [ "${CI-}" ]; then 18 | sudo apt-get update 19 | sudo apt-get install -y qemu-user-static 20 | ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 21 | fi 22 | qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask 23 | fi 24 | ) 25 | 26 | 27 | go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce 28 | go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... 29 | sed -i.bak '/stringer\.go/d' ci/out/coverage.prof 30 | sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof 31 | sed -i.bak '/examples/d' ci/out/coverage.prof 32 | 33 | # Last line is the total coverage. 34 | go tool cover -func ci/out/coverage.prof | tail -n1 35 | 36 | go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html 37 | -------------------------------------------------------------------------------- /close.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "context" 8 | "encoding/binary" 9 | "errors" 10 | "fmt" 11 | "net" 12 | "time" 13 | 14 | "github.com/coder/websocket/internal/errd" 15 | ) 16 | 17 | // StatusCode represents a WebSocket status code. 18 | // https://tools.ietf.org/html/rfc6455#section-7.4 19 | type StatusCode int 20 | 21 | // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number 22 | // 23 | // These are only the status codes defined by the protocol. 24 | // 25 | // You can define custom codes in the 3000-4999 range. 26 | // The 3000-3999 range is reserved for use by libraries, frameworks and applications. 27 | // The 4000-4999 range is reserved for private use. 28 | const ( 29 | StatusNormalClosure StatusCode = 1000 30 | StatusGoingAway StatusCode = 1001 31 | StatusProtocolError StatusCode = 1002 32 | StatusUnsupportedData StatusCode = 1003 33 | 34 | // 1004 is reserved and so unexported. 35 | statusReserved StatusCode = 1004 36 | 37 | // StatusNoStatusRcvd cannot be sent in a close message. 38 | // It is reserved for when a close message is received without 39 | // a status code. 40 | StatusNoStatusRcvd StatusCode = 1005 41 | 42 | // StatusAbnormalClosure is exported for use only with Wasm. 43 | // In non Wasm Go, the returned error will indicate whether the 44 | // connection was closed abnormally. 45 | StatusAbnormalClosure StatusCode = 1006 46 | 47 | StatusInvalidFramePayloadData StatusCode = 1007 48 | StatusPolicyViolation StatusCode = 1008 49 | StatusMessageTooBig StatusCode = 1009 50 | StatusMandatoryExtension StatusCode = 1010 51 | StatusInternalError StatusCode = 1011 52 | StatusServiceRestart StatusCode = 1012 53 | StatusTryAgainLater StatusCode = 1013 54 | StatusBadGateway StatusCode = 1014 55 | 56 | // StatusTLSHandshake is only exported for use with Wasm. 57 | // In non Wasm Go, the returned error will indicate whether there was 58 | // a TLS handshake failure. 59 | StatusTLSHandshake StatusCode = 1015 60 | ) 61 | 62 | // CloseError is returned when the connection is closed with a status and reason. 63 | // 64 | // Use Go 1.13's errors.As to check for this error. 65 | // Also see the CloseStatus helper. 66 | type CloseError struct { 67 | Code StatusCode 68 | Reason string 69 | } 70 | 71 | func (ce CloseError) Error() string { 72 | return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) 73 | } 74 | 75 | // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab 76 | // the status code from a CloseError. 77 | // 78 | // -1 will be returned if the passed error is nil or not a CloseError. 79 | func CloseStatus(err error) StatusCode { 80 | var ce CloseError 81 | if errors.As(err, &ce) { 82 | return ce.Code 83 | } 84 | return -1 85 | } 86 | 87 | // Close performs the WebSocket close handshake with the given status code and reason. 88 | // 89 | // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for 90 | // the peer to send a close frame. 91 | // All data messages received from the peer during the close handshake will be discarded. 92 | // 93 | // The connection can only be closed once. Additional calls to Close 94 | // are no-ops. 95 | // 96 | // The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason. 97 | // 98 | // Close will unblock all goroutines interacting with the connection once 99 | // complete. 100 | func (c *Conn) Close(code StatusCode, reason string) (err error) { 101 | defer errd.Wrap(&err, "failed to close WebSocket") 102 | 103 | if c.casClosing() { 104 | err = c.waitGoroutines() 105 | if err != nil { 106 | return err 107 | } 108 | return net.ErrClosed 109 | } 110 | defer func() { 111 | if errors.Is(err, net.ErrClosed) { 112 | err = nil 113 | } 114 | }() 115 | 116 | err = c.closeHandshake(code, reason) 117 | 118 | err2 := c.close() 119 | if err == nil && err2 != nil { 120 | err = err2 121 | } 122 | 123 | err2 = c.waitGoroutines() 124 | if err == nil && err2 != nil { 125 | err = err2 126 | } 127 | 128 | return err 129 | } 130 | 131 | // CloseNow closes the WebSocket connection without attempting a close handshake. 132 | // Use when you do not want the overhead of the close handshake. 133 | func (c *Conn) CloseNow() (err error) { 134 | defer errd.Wrap(&err, "failed to immediately close WebSocket") 135 | 136 | if c.casClosing() { 137 | err = c.waitGoroutines() 138 | if err != nil { 139 | return err 140 | } 141 | return net.ErrClosed 142 | } 143 | defer func() { 144 | if errors.Is(err, net.ErrClosed) { 145 | err = nil 146 | } 147 | }() 148 | 149 | err = c.close() 150 | 151 | err2 := c.waitGoroutines() 152 | if err == nil && err2 != nil { 153 | err = err2 154 | } 155 | return err 156 | } 157 | 158 | func (c *Conn) closeHandshake(code StatusCode, reason string) error { 159 | err := c.writeClose(code, reason) 160 | if err != nil { 161 | return err 162 | } 163 | 164 | err = c.waitCloseHandshake() 165 | if CloseStatus(err) != code { 166 | return err 167 | } 168 | return nil 169 | } 170 | 171 | func (c *Conn) writeClose(code StatusCode, reason string) error { 172 | ce := CloseError{ 173 | Code: code, 174 | Reason: reason, 175 | } 176 | 177 | var p []byte 178 | var err error 179 | if ce.Code != StatusNoStatusRcvd { 180 | p, err = ce.bytes() 181 | if err != nil { 182 | return err 183 | } 184 | } 185 | 186 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 187 | defer cancel() 188 | 189 | err = c.writeControl(ctx, opClose, p) 190 | // If the connection closed as we're writing we ignore the error as we might 191 | // have written the close frame, the peer responded and then someone else read it 192 | // and closed the connection. 193 | if err != nil && !errors.Is(err, net.ErrClosed) { 194 | return err 195 | } 196 | return nil 197 | } 198 | 199 | func (c *Conn) waitCloseHandshake() error { 200 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 201 | defer cancel() 202 | 203 | err := c.readMu.lock(ctx) 204 | if err != nil { 205 | return err 206 | } 207 | defer c.readMu.unlock() 208 | 209 | for i := int64(0); i < c.msgReader.payloadLength; i++ { 210 | _, err := c.br.ReadByte() 211 | if err != nil { 212 | return err 213 | } 214 | } 215 | 216 | for { 217 | h, err := c.readLoop(ctx) 218 | if err != nil { 219 | return err 220 | } 221 | 222 | for i := int64(0); i < h.payloadLength; i++ { 223 | _, err := c.br.ReadByte() 224 | if err != nil { 225 | return err 226 | } 227 | } 228 | } 229 | } 230 | 231 | func (c *Conn) waitGoroutines() error { 232 | t := time.NewTimer(time.Second * 15) 233 | defer t.Stop() 234 | 235 | select { 236 | case <-c.timeoutLoopDone: 237 | case <-t.C: 238 | return errors.New("failed to wait for timeoutLoop goroutine to exit") 239 | } 240 | 241 | c.closeReadMu.Lock() 242 | closeRead := c.closeReadCtx != nil 243 | c.closeReadMu.Unlock() 244 | if closeRead { 245 | select { 246 | case <-c.closeReadDone: 247 | case <-t.C: 248 | return errors.New("failed to wait for close read goroutine to exit") 249 | } 250 | } 251 | 252 | select { 253 | case <-c.closed: 254 | case <-t.C: 255 | return errors.New("failed to wait for connection to be closed") 256 | } 257 | 258 | return nil 259 | } 260 | 261 | func parseClosePayload(p []byte) (CloseError, error) { 262 | if len(p) == 0 { 263 | return CloseError{ 264 | Code: StatusNoStatusRcvd, 265 | }, nil 266 | } 267 | 268 | if len(p) < 2 { 269 | return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) 270 | } 271 | 272 | ce := CloseError{ 273 | Code: StatusCode(binary.BigEndian.Uint16(p)), 274 | Reason: string(p[2:]), 275 | } 276 | 277 | if !validWireCloseCode(ce.Code) { 278 | return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) 279 | } 280 | 281 | return ce, nil 282 | } 283 | 284 | // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number 285 | // and https://tools.ietf.org/html/rfc6455#section-7.4.1 286 | func validWireCloseCode(code StatusCode) bool { 287 | switch code { 288 | case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: 289 | return false 290 | } 291 | 292 | if code >= StatusNormalClosure && code <= StatusBadGateway { 293 | return true 294 | } 295 | if code >= 3000 && code <= 4999 { 296 | return true 297 | } 298 | 299 | return false 300 | } 301 | 302 | func (ce CloseError) bytes() ([]byte, error) { 303 | p, err := ce.bytesErr() 304 | if err != nil { 305 | err = fmt.Errorf("failed to marshal close frame: %w", err) 306 | ce = CloseError{ 307 | Code: StatusInternalError, 308 | } 309 | p, _ = ce.bytesErr() 310 | } 311 | return p, err 312 | } 313 | 314 | const maxCloseReason = maxControlPayload - 2 315 | 316 | func (ce CloseError) bytesErr() ([]byte, error) { 317 | if len(ce.Reason) > maxCloseReason { 318 | return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) 319 | } 320 | 321 | if !validWireCloseCode(ce.Code) { 322 | return nil, fmt.Errorf("status code %v cannot be set", ce.Code) 323 | } 324 | 325 | buf := make([]byte, 2+len(ce.Reason)) 326 | binary.BigEndian.PutUint16(buf, uint16(ce.Code)) 327 | copy(buf[2:], ce.Reason) 328 | return buf, nil 329 | } 330 | 331 | func (c *Conn) casClosing() bool { 332 | return c.closing.Swap(true) 333 | } 334 | 335 | func (c *Conn) isClosed() bool { 336 | select { 337 | case <-c.closed: 338 | return true 339 | default: 340 | return false 341 | } 342 | } 343 | -------------------------------------------------------------------------------- /close_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "io" 8 | "math" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/coder/websocket/internal/test/assert" 13 | ) 14 | 15 | func TestCloseError(t *testing.T) { 16 | t.Parallel() 17 | 18 | testCases := []struct { 19 | name string 20 | ce CloseError 21 | success bool 22 | }{ 23 | { 24 | name: "normal", 25 | ce: CloseError{ 26 | Code: StatusNormalClosure, 27 | Reason: strings.Repeat("x", maxCloseReason), 28 | }, 29 | success: true, 30 | }, 31 | { 32 | name: "bigReason", 33 | ce: CloseError{ 34 | Code: StatusNormalClosure, 35 | Reason: strings.Repeat("x", maxCloseReason+1), 36 | }, 37 | success: false, 38 | }, 39 | { 40 | name: "bigCode", 41 | ce: CloseError{ 42 | Code: math.MaxUint16, 43 | Reason: strings.Repeat("x", maxCloseReason), 44 | }, 45 | success: false, 46 | }, 47 | } 48 | 49 | for _, tc := range testCases { 50 | tc := tc 51 | t.Run(tc.name, func(t *testing.T) { 52 | t.Parallel() 53 | 54 | _, err := tc.ce.bytesErr() 55 | if tc.success { 56 | assert.Success(t, err) 57 | } else { 58 | assert.Error(t, err) 59 | } 60 | }) 61 | } 62 | 63 | t.Run("Error", func(t *testing.T) { 64 | exp := `status = StatusInternalError and reason = "meow"` 65 | act := CloseError{ 66 | Code: StatusInternalError, 67 | Reason: "meow", 68 | }.Error() 69 | assert.Equal(t, "CloseError.Error()", exp, act) 70 | }) 71 | } 72 | 73 | func Test_parseClosePayload(t *testing.T) { 74 | t.Parallel() 75 | 76 | testCases := []struct { 77 | name string 78 | p []byte 79 | success bool 80 | ce CloseError 81 | }{ 82 | { 83 | name: "normal", 84 | p: append([]byte{0x3, 0xE8}, []byte("hello")...), 85 | success: true, 86 | ce: CloseError{ 87 | Code: StatusNormalClosure, 88 | Reason: "hello", 89 | }, 90 | }, 91 | { 92 | name: "nothing", 93 | success: true, 94 | ce: CloseError{ 95 | Code: StatusNoStatusRcvd, 96 | }, 97 | }, 98 | { 99 | name: "oneByte", 100 | p: []byte{0}, 101 | success: false, 102 | }, 103 | { 104 | name: "badStatusCode", 105 | p: []byte{0x17, 0x70}, 106 | success: false, 107 | }, 108 | } 109 | 110 | for _, tc := range testCases { 111 | tc := tc 112 | t.Run(tc.name, func(t *testing.T) { 113 | t.Parallel() 114 | 115 | ce, err := parseClosePayload(tc.p) 116 | if tc.success { 117 | assert.Success(t, err) 118 | assert.Equal(t, "close payload", tc.ce, ce) 119 | } else { 120 | assert.Error(t, err) 121 | } 122 | }) 123 | } 124 | } 125 | 126 | func Test_validWireCloseCode(t *testing.T) { 127 | t.Parallel() 128 | 129 | testCases := []struct { 130 | name string 131 | code StatusCode 132 | valid bool 133 | }{ 134 | { 135 | name: "normal", 136 | code: StatusNormalClosure, 137 | valid: true, 138 | }, 139 | { 140 | name: "noStatus", 141 | code: StatusNoStatusRcvd, 142 | valid: false, 143 | }, 144 | { 145 | name: "3000", 146 | code: 3000, 147 | valid: true, 148 | }, 149 | { 150 | name: "4999", 151 | code: 4999, 152 | valid: true, 153 | }, 154 | { 155 | name: "unknown", 156 | code: 5000, 157 | valid: false, 158 | }, 159 | } 160 | 161 | for _, tc := range testCases { 162 | tc := tc 163 | t.Run(tc.name, func(t *testing.T) { 164 | t.Parallel() 165 | 166 | act := validWireCloseCode(tc.code) 167 | assert.Equal(t, "wire close code", tc.valid, act) 168 | }) 169 | } 170 | } 171 | 172 | func TestCloseStatus(t *testing.T) { 173 | t.Parallel() 174 | 175 | testCases := []struct { 176 | name string 177 | in error 178 | exp StatusCode 179 | }{ 180 | { 181 | name: "nil", 182 | in: nil, 183 | exp: -1, 184 | }, 185 | { 186 | name: "io.EOF", 187 | in: io.EOF, 188 | exp: -1, 189 | }, 190 | { 191 | name: "StatusInternalError", 192 | in: CloseError{ 193 | Code: StatusInternalError, 194 | }, 195 | exp: StatusInternalError, 196 | }, 197 | } 198 | 199 | for _, tc := range testCases { 200 | tc := tc 201 | t.Run(tc.name, func(t *testing.T) { 202 | t.Parallel() 203 | 204 | act := CloseStatus(tc.in) 205 | assert.Equal(t, "close status", tc.exp, act) 206 | }) 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /compress.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "compress/flate" 8 | "io" 9 | "sync" 10 | ) 11 | 12 | // CompressionMode represents the modes available to the permessage-deflate extension. 13 | // See https://tools.ietf.org/html/rfc7692 14 | // 15 | // Works in all modern browsers except Safari which does not implement the permessage-deflate extension. 16 | // 17 | // Compression is only used if the peer supports the mode selected. 18 | type CompressionMode int 19 | 20 | const ( 21 | // CompressionDisabled disables the negotiation of the permessage-deflate extension. 22 | // 23 | // This is the default. Do not enable compression without benchmarking for your particular use case first. 24 | CompressionDisabled CompressionMode = iota 25 | 26 | // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from 27 | // previous messages. i.e compression context across messages is preserved. 28 | // 29 | // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient. 30 | // 31 | // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's 32 | // that are used when reading and then returned. 33 | // 34 | // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently. 35 | // 36 | // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover. 37 | CompressionContextTakeover 38 | 39 | // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with 40 | // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from 41 | // a sync.Pool. 42 | // 43 | // This means less efficient compression as the sliding window from previous messages will not be used but the 44 | // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window. 45 | // Especially if the connections are long lived and seldom written to. 46 | // 47 | // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently. 48 | // 49 | // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled. 50 | CompressionNoContextTakeover 51 | ) 52 | 53 | func (m CompressionMode) opts() *compressionOptions { 54 | return &compressionOptions{ 55 | clientNoContextTakeover: m == CompressionNoContextTakeover, 56 | serverNoContextTakeover: m == CompressionNoContextTakeover, 57 | } 58 | } 59 | 60 | type compressionOptions struct { 61 | clientNoContextTakeover bool 62 | serverNoContextTakeover bool 63 | } 64 | 65 | func (copts *compressionOptions) String() string { 66 | s := "permessage-deflate" 67 | if copts.clientNoContextTakeover { 68 | s += "; client_no_context_takeover" 69 | } 70 | if copts.serverNoContextTakeover { 71 | s += "; server_no_context_takeover" 72 | } 73 | return s 74 | } 75 | 76 | // These bytes are required to get flate.Reader to return. 77 | // They are removed when sending to avoid the overhead as 78 | // WebSocket framing tell's when the message has ended but then 79 | // we need to add them back otherwise flate.Reader keeps 80 | // trying to read more bytes. 81 | const deflateMessageTail = "\x00\x00\xff\xff" 82 | 83 | type trimLastFourBytesWriter struct { 84 | w io.Writer 85 | tail []byte 86 | } 87 | 88 | func (tw *trimLastFourBytesWriter) reset() { 89 | if tw != nil && tw.tail != nil { 90 | tw.tail = tw.tail[:0] 91 | } 92 | } 93 | 94 | func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { 95 | if tw.tail == nil { 96 | tw.tail = make([]byte, 0, 4) 97 | } 98 | 99 | extra := len(tw.tail) + len(p) - 4 100 | 101 | if extra <= 0 { 102 | tw.tail = append(tw.tail, p...) 103 | return len(p), nil 104 | } 105 | 106 | // Now we need to write as many extra bytes as we can from the previous tail. 107 | if extra > len(tw.tail) { 108 | extra = len(tw.tail) 109 | } 110 | if extra > 0 { 111 | _, err := tw.w.Write(tw.tail[:extra]) 112 | if err != nil { 113 | return 0, err 114 | } 115 | 116 | // Shift remaining bytes in tail over. 117 | n := copy(tw.tail, tw.tail[extra:]) 118 | tw.tail = tw.tail[:n] 119 | } 120 | 121 | // If p is less than or equal to 4 bytes, 122 | // all of it is is part of the tail. 123 | if len(p) <= 4 { 124 | tw.tail = append(tw.tail, p...) 125 | return len(p), nil 126 | } 127 | 128 | // Otherwise, only the last 4 bytes are. 129 | tw.tail = append(tw.tail, p[len(p)-4:]...) 130 | 131 | p = p[:len(p)-4] 132 | n, err := tw.w.Write(p) 133 | return n + 4, err 134 | } 135 | 136 | var flateReaderPool sync.Pool 137 | 138 | func getFlateReader(r io.Reader, dict []byte) io.Reader { 139 | fr, ok := flateReaderPool.Get().(io.Reader) 140 | if !ok { 141 | return flate.NewReaderDict(r, dict) 142 | } 143 | fr.(flate.Resetter).Reset(r, dict) 144 | return fr 145 | } 146 | 147 | func putFlateReader(fr io.Reader) { 148 | flateReaderPool.Put(fr) 149 | } 150 | 151 | var flateWriterPool sync.Pool 152 | 153 | func getFlateWriter(w io.Writer) *flate.Writer { 154 | fw, ok := flateWriterPool.Get().(*flate.Writer) 155 | if !ok { 156 | fw, _ = flate.NewWriter(w, flate.BestSpeed) 157 | return fw 158 | } 159 | fw.Reset(w) 160 | return fw 161 | } 162 | 163 | func putFlateWriter(w *flate.Writer) { 164 | flateWriterPool.Put(w) 165 | } 166 | 167 | type slidingWindow struct { 168 | buf []byte 169 | } 170 | 171 | var swPoolMu sync.RWMutex 172 | var swPool = map[int]*sync.Pool{} 173 | 174 | func slidingWindowPool(n int) *sync.Pool { 175 | swPoolMu.RLock() 176 | p, ok := swPool[n] 177 | swPoolMu.RUnlock() 178 | if ok { 179 | return p 180 | } 181 | 182 | p = &sync.Pool{} 183 | 184 | swPoolMu.Lock() 185 | swPool[n] = p 186 | swPoolMu.Unlock() 187 | 188 | return p 189 | } 190 | 191 | func (sw *slidingWindow) init(n int) { 192 | if sw.buf != nil { 193 | return 194 | } 195 | 196 | if n == 0 { 197 | n = 32768 198 | } 199 | 200 | p := slidingWindowPool(n) 201 | sw2, ok := p.Get().(*slidingWindow) 202 | if ok { 203 | *sw = *sw2 204 | } else { 205 | sw.buf = make([]byte, 0, n) 206 | } 207 | } 208 | 209 | func (sw *slidingWindow) close() { 210 | sw.buf = sw.buf[:0] 211 | swPoolMu.Lock() 212 | swPool[cap(sw.buf)].Put(sw) 213 | swPoolMu.Unlock() 214 | } 215 | 216 | func (sw *slidingWindow) write(p []byte) { 217 | if len(p) >= cap(sw.buf) { 218 | sw.buf = sw.buf[:cap(sw.buf)] 219 | p = p[len(p)-cap(sw.buf):] 220 | copy(sw.buf, p) 221 | return 222 | } 223 | 224 | left := cap(sw.buf) - len(sw.buf) 225 | if left < len(p) { 226 | // We need to shift spaceNeeded bytes from the end to make room for p at the end. 227 | spaceNeeded := len(p) - left 228 | copy(sw.buf, sw.buf[spaceNeeded:]) 229 | sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] 230 | } 231 | 232 | sw.buf = append(sw.buf, p...) 233 | } 234 | -------------------------------------------------------------------------------- /compress_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bytes" 8 | "compress/flate" 9 | "io" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/coder/websocket/internal/test/assert" 14 | "github.com/coder/websocket/internal/test/xrand" 15 | ) 16 | 17 | func Test_slidingWindow(t *testing.T) { 18 | t.Parallel() 19 | 20 | const testCount = 99 21 | const maxWindow = 99999 22 | for i := 0; i < testCount; i++ { 23 | t.Run("", func(t *testing.T) { 24 | t.Parallel() 25 | 26 | input := xrand.String(maxWindow) 27 | windowLength := xrand.Int(maxWindow) 28 | var sw slidingWindow 29 | sw.init(windowLength) 30 | sw.write([]byte(input)) 31 | 32 | assert.Equal(t, "window length", windowLength, cap(sw.buf)) 33 | if !strings.HasSuffix(input, string(sw.buf)) { 34 | t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) 35 | } 36 | }) 37 | } 38 | } 39 | 40 | func BenchmarkFlateWriter(b *testing.B) { 41 | b.ReportAllocs() 42 | for i := 0; i < b.N; i++ { 43 | w, _ := flate.NewWriter(io.Discard, flate.BestSpeed) 44 | // We have to write a byte to get the writer to allocate to its full extent. 45 | w.Write([]byte{'a'}) 46 | w.Flush() 47 | } 48 | } 49 | 50 | func BenchmarkFlateReader(b *testing.B) { 51 | b.ReportAllocs() 52 | 53 | var buf bytes.Buffer 54 | w, _ := flate.NewWriter(&buf, flate.BestSpeed) 55 | w.Write([]byte{'a'}) 56 | w.Flush() 57 | 58 | for i := 0; i < b.N; i++ { 59 | r := flate.NewReader(bytes.NewReader(buf.Bytes())) 60 | io.ReadAll(r) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "context" 9 | "fmt" 10 | "io" 11 | "net" 12 | "runtime" 13 | "strconv" 14 | "sync" 15 | "sync/atomic" 16 | ) 17 | 18 | // MessageType represents the type of a WebSocket message. 19 | // See https://tools.ietf.org/html/rfc6455#section-5.6 20 | type MessageType int 21 | 22 | // MessageType constants. 23 | const ( 24 | // MessageText is for UTF-8 encoded text messages like JSON. 25 | MessageText MessageType = iota + 1 26 | // MessageBinary is for binary messages like protobufs. 27 | MessageBinary 28 | ) 29 | 30 | // Conn represents a WebSocket connection. 31 | // All methods may be called concurrently except for Reader and Read. 32 | // 33 | // You must always read from the connection. Otherwise control 34 | // frames will not be handled. See Reader and CloseRead. 35 | // 36 | // Be sure to call Close on the connection when you 37 | // are finished with it to release associated resources. 38 | // 39 | // On any error from any method, the connection is closed 40 | // with an appropriate reason. 41 | // 42 | // This applies to context expirations as well unfortunately. 43 | // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 44 | type Conn struct { 45 | noCopy noCopy 46 | 47 | subprotocol string 48 | rwc io.ReadWriteCloser 49 | client bool 50 | copts *compressionOptions 51 | flateThreshold int 52 | br *bufio.Reader 53 | bw *bufio.Writer 54 | 55 | readTimeout chan context.Context 56 | writeTimeout chan context.Context 57 | timeoutLoopDone chan struct{} 58 | 59 | // Read state. 60 | readMu *mu 61 | readHeaderBuf [8]byte 62 | readControlBuf [maxControlPayload]byte 63 | msgReader *msgReader 64 | 65 | // Write state. 66 | msgWriter *msgWriter 67 | writeFrameMu *mu 68 | writeBuf []byte 69 | writeHeaderBuf [8]byte 70 | writeHeader header 71 | 72 | // Close handshake state. 73 | closeStateMu sync.RWMutex 74 | closeReceivedErr error 75 | closeSentErr error 76 | 77 | // CloseRead state. 78 | closeReadMu sync.Mutex 79 | closeReadCtx context.Context 80 | closeReadDone chan struct{} 81 | 82 | closing atomic.Bool 83 | closeMu sync.Mutex // Protects following. 84 | closed chan struct{} 85 | 86 | pingCounter atomic.Int64 87 | activePingsMu sync.Mutex 88 | activePings map[string]chan<- struct{} 89 | onPingReceived func(context.Context, []byte) bool 90 | onPongReceived func(context.Context, []byte) 91 | } 92 | 93 | type connConfig struct { 94 | subprotocol string 95 | rwc io.ReadWriteCloser 96 | client bool 97 | copts *compressionOptions 98 | flateThreshold int 99 | onPingReceived func(context.Context, []byte) bool 100 | onPongReceived func(context.Context, []byte) 101 | 102 | br *bufio.Reader 103 | bw *bufio.Writer 104 | } 105 | 106 | func newConn(cfg connConfig) *Conn { 107 | c := &Conn{ 108 | subprotocol: cfg.subprotocol, 109 | rwc: cfg.rwc, 110 | client: cfg.client, 111 | copts: cfg.copts, 112 | flateThreshold: cfg.flateThreshold, 113 | 114 | br: cfg.br, 115 | bw: cfg.bw, 116 | 117 | readTimeout: make(chan context.Context), 118 | writeTimeout: make(chan context.Context), 119 | timeoutLoopDone: make(chan struct{}), 120 | 121 | closed: make(chan struct{}), 122 | activePings: make(map[string]chan<- struct{}), 123 | onPingReceived: cfg.onPingReceived, 124 | onPongReceived: cfg.onPongReceived, 125 | } 126 | 127 | c.readMu = newMu(c) 128 | c.writeFrameMu = newMu(c) 129 | 130 | c.msgReader = newMsgReader(c) 131 | 132 | c.msgWriter = newMsgWriter(c) 133 | if c.client { 134 | c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) 135 | } 136 | 137 | if c.flate() && c.flateThreshold == 0 { 138 | c.flateThreshold = 128 139 | if !c.msgWriter.flateContextTakeover() { 140 | c.flateThreshold = 512 141 | } 142 | } 143 | 144 | runtime.SetFinalizer(c, func(c *Conn) { 145 | c.close() 146 | }) 147 | 148 | go c.timeoutLoop() 149 | 150 | return c 151 | } 152 | 153 | // Subprotocol returns the negotiated subprotocol. 154 | // An empty string means the default protocol. 155 | func (c *Conn) Subprotocol() string { 156 | return c.subprotocol 157 | } 158 | 159 | func (c *Conn) close() error { 160 | c.closeMu.Lock() 161 | defer c.closeMu.Unlock() 162 | 163 | if c.isClosed() { 164 | return net.ErrClosed 165 | } 166 | runtime.SetFinalizer(c, nil) 167 | close(c.closed) 168 | 169 | // Have to close after c.closed is closed to ensure any goroutine that wakes up 170 | // from the connection being closed also sees that c.closed is closed and returns 171 | // closeErr. 172 | err := c.rwc.Close() 173 | // With the close of rwc, these become safe to close. 174 | c.msgWriter.close() 175 | c.msgReader.close() 176 | return err 177 | } 178 | 179 | func (c *Conn) timeoutLoop() { 180 | defer close(c.timeoutLoopDone) 181 | 182 | readCtx := context.Background() 183 | writeCtx := context.Background() 184 | 185 | for { 186 | select { 187 | case <-c.closed: 188 | return 189 | 190 | case writeCtx = <-c.writeTimeout: 191 | case readCtx = <-c.readTimeout: 192 | 193 | case <-readCtx.Done(): 194 | c.close() 195 | return 196 | case <-writeCtx.Done(): 197 | c.close() 198 | return 199 | } 200 | } 201 | } 202 | 203 | func (c *Conn) flate() bool { 204 | return c.copts != nil 205 | } 206 | 207 | // Ping sends a ping to the peer and waits for a pong. 208 | // Use this to measure latency or ensure the peer is responsive. 209 | // Ping must be called concurrently with Reader as it does 210 | // not read from the connection but instead waits for a Reader call 211 | // to read the pong. 212 | // 213 | // TCP Keepalives should suffice for most use cases. 214 | func (c *Conn) Ping(ctx context.Context) error { 215 | p := c.pingCounter.Add(1) 216 | 217 | err := c.ping(ctx, strconv.FormatInt(p, 10)) 218 | if err != nil { 219 | return fmt.Errorf("failed to ping: %w", err) 220 | } 221 | return nil 222 | } 223 | 224 | func (c *Conn) ping(ctx context.Context, p string) error { 225 | pong := make(chan struct{}, 1) 226 | 227 | c.activePingsMu.Lock() 228 | c.activePings[p] = pong 229 | c.activePingsMu.Unlock() 230 | 231 | defer func() { 232 | c.activePingsMu.Lock() 233 | delete(c.activePings, p) 234 | c.activePingsMu.Unlock() 235 | }() 236 | 237 | err := c.writeControl(ctx, opPing, []byte(p)) 238 | if err != nil { 239 | return err 240 | } 241 | 242 | select { 243 | case <-c.closed: 244 | return net.ErrClosed 245 | case <-ctx.Done(): 246 | return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) 247 | case <-pong: 248 | return nil 249 | } 250 | } 251 | 252 | type mu struct { 253 | c *Conn 254 | ch chan struct{} 255 | } 256 | 257 | func newMu(c *Conn) *mu { 258 | return &mu{ 259 | c: c, 260 | ch: make(chan struct{}, 1), 261 | } 262 | } 263 | 264 | func (m *mu) forceLock() { 265 | m.ch <- struct{}{} 266 | } 267 | 268 | func (m *mu) tryLock() bool { 269 | select { 270 | case m.ch <- struct{}{}: 271 | return true 272 | default: 273 | return false 274 | } 275 | } 276 | 277 | func (m *mu) lock(ctx context.Context) error { 278 | select { 279 | case <-m.c.closed: 280 | return net.ErrClosed 281 | case <-ctx.Done(): 282 | return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) 283 | case m.ch <- struct{}{}: 284 | // To make sure the connection is certainly alive. 285 | // As it's possible the send on m.ch was selected 286 | // over the receive on closed. 287 | select { 288 | case <-m.c.closed: 289 | // Make sure to release. 290 | m.unlock() 291 | return net.ErrClosed 292 | default: 293 | } 294 | return nil 295 | } 296 | } 297 | 298 | func (m *mu) unlock() { 299 | select { 300 | case <-m.ch: 301 | default: 302 | } 303 | } 304 | 305 | type noCopy struct{} 306 | 307 | func (*noCopy) Lock() {} 308 | -------------------------------------------------------------------------------- /dial.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "context" 10 | "crypto/rand" 11 | "encoding/base64" 12 | "fmt" 13 | "io" 14 | "net/http" 15 | "net/url" 16 | "strings" 17 | "sync" 18 | "time" 19 | 20 | "github.com/coder/websocket/internal/errd" 21 | ) 22 | 23 | // DialOptions represents Dial's options. 24 | type DialOptions struct { 25 | // HTTPClient is used for the connection. 26 | // Its Transport must return writable bodies for WebSocket handshakes. 27 | // http.Transport does beginning with Go 1.12. 28 | HTTPClient *http.Client 29 | 30 | // HTTPHeader specifies the HTTP headers included in the handshake request. 31 | HTTPHeader http.Header 32 | 33 | // Host optionally overrides the Host HTTP header to send. If empty, the value 34 | // of URL.Host will be used. 35 | Host string 36 | 37 | // Subprotocols lists the WebSocket subprotocols to negotiate with the server. 38 | Subprotocols []string 39 | 40 | // CompressionMode controls the compression mode. 41 | // Defaults to CompressionDisabled. 42 | // 43 | // See docs on CompressionMode for details. 44 | CompressionMode CompressionMode 45 | 46 | // CompressionThreshold controls the minimum size of a message before compression is applied. 47 | // 48 | // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes 49 | // for CompressionContextTakeover. 50 | CompressionThreshold int 51 | 52 | // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. 53 | // 54 | // The payload contains the application data of the ping frame. 55 | // If the callback returns false, the subsequent pong frame will not be sent. 56 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 57 | OnPingReceived func(ctx context.Context, payload []byte) bool 58 | 59 | // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. 60 | // 61 | // The payload contains the application data of the pong frame. 62 | // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. 63 | // 64 | // Unlike OnPingReceived, this callback does not return a value because a pong frame 65 | // is a response to a ping and does not trigger any further frame transmission. 66 | OnPongReceived func(ctx context.Context, payload []byte) 67 | } 68 | 69 | func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { 70 | var cancel context.CancelFunc 71 | 72 | var o DialOptions 73 | if opts != nil { 74 | o = *opts 75 | } 76 | if o.HTTPClient == nil { 77 | o.HTTPClient = http.DefaultClient 78 | } 79 | if o.HTTPClient.Timeout > 0 { 80 | ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout) 81 | 82 | newClient := *o.HTTPClient 83 | newClient.Timeout = 0 84 | o.HTTPClient = &newClient 85 | } 86 | if o.HTTPHeader == nil { 87 | o.HTTPHeader = http.Header{} 88 | } 89 | newClient := *o.HTTPClient 90 | oldCheckRedirect := o.HTTPClient.CheckRedirect 91 | newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { 92 | switch req.URL.Scheme { 93 | case "ws": 94 | req.URL.Scheme = "http" 95 | case "wss": 96 | req.URL.Scheme = "https" 97 | } 98 | if oldCheckRedirect != nil { 99 | return oldCheckRedirect(req, via) 100 | } 101 | return nil 102 | } 103 | o.HTTPClient = &newClient 104 | 105 | return ctx, cancel, &o 106 | } 107 | 108 | // Dial performs a WebSocket handshake on url. 109 | // 110 | // The response is the WebSocket handshake response from the server. 111 | // You never need to close resp.Body yourself. 112 | // 113 | // If an error occurs, the returned response may be non nil. 114 | // However, you can only read the first 1024 bytes of the body. 115 | // 116 | // This function requires at least Go 1.12 as it uses a new feature 117 | // in net/http to perform WebSocket handshakes. 118 | // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 119 | // 120 | // URLs with http/https schemes will work and are interpreted as ws/wss. 121 | func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { 122 | return dial(ctx, u, opts, nil) 123 | } 124 | 125 | func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { 126 | defer errd.Wrap(&err, "failed to WebSocket dial") 127 | 128 | var cancel context.CancelFunc 129 | ctx, cancel, opts = opts.cloneWithDefaults(ctx) 130 | if cancel != nil { 131 | defer cancel() 132 | } 133 | 134 | secWebSocketKey, err := secWebSocketKey(rand) 135 | if err != nil { 136 | return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) 137 | } 138 | 139 | var copts *compressionOptions 140 | if opts.CompressionMode != CompressionDisabled { 141 | copts = opts.CompressionMode.opts() 142 | } 143 | 144 | resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) 145 | if err != nil { 146 | return nil, resp, err 147 | } 148 | respBody := resp.Body 149 | resp.Body = nil 150 | defer func() { 151 | if err != nil { 152 | // We read a bit of the body for easier debugging. 153 | r := io.LimitReader(respBody, 1024) 154 | 155 | timer := time.AfterFunc(time.Second*3, func() { 156 | respBody.Close() 157 | }) 158 | defer timer.Stop() 159 | 160 | b, _ := io.ReadAll(r) 161 | respBody.Close() 162 | resp.Body = io.NopCloser(bytes.NewReader(b)) 163 | } 164 | }() 165 | 166 | copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) 167 | if err != nil { 168 | return nil, resp, err 169 | } 170 | 171 | rwc, ok := respBody.(io.ReadWriteCloser) 172 | if !ok { 173 | return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) 174 | } 175 | 176 | return newConn(connConfig{ 177 | subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), 178 | rwc: rwc, 179 | client: true, 180 | copts: copts, 181 | flateThreshold: opts.CompressionThreshold, 182 | onPingReceived: opts.OnPingReceived, 183 | onPongReceived: opts.OnPongReceived, 184 | br: getBufioReader(rwc), 185 | bw: getBufioWriter(rwc), 186 | }), resp, nil 187 | } 188 | 189 | func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { 190 | u, err := url.Parse(urls) 191 | if err != nil { 192 | return nil, fmt.Errorf("failed to parse url: %w", err) 193 | } 194 | 195 | switch u.Scheme { 196 | case "ws": 197 | u.Scheme = "http" 198 | case "wss": 199 | u.Scheme = "https" 200 | case "http", "https": 201 | default: 202 | return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) 203 | } 204 | 205 | req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 206 | if err != nil { 207 | return nil, fmt.Errorf("failed to create new http request: %w", err) 208 | } 209 | if len(opts.Host) > 0 { 210 | req.Host = opts.Host 211 | } 212 | req.Header = opts.HTTPHeader.Clone() 213 | req.Header.Set("Connection", "Upgrade") 214 | req.Header.Set("Upgrade", "websocket") 215 | req.Header.Set("Sec-WebSocket-Version", "13") 216 | req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) 217 | if len(opts.Subprotocols) > 0 { 218 | req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) 219 | } 220 | if copts != nil { 221 | req.Header.Set("Sec-WebSocket-Extensions", copts.String()) 222 | } 223 | 224 | resp, err := opts.HTTPClient.Do(req) 225 | if err != nil { 226 | return nil, fmt.Errorf("failed to send handshake request: %w", err) 227 | } 228 | return resp, nil 229 | } 230 | 231 | func secWebSocketKey(rr io.Reader) (string, error) { 232 | if rr == nil { 233 | rr = rand.Reader 234 | } 235 | b := make([]byte, 16) 236 | _, err := io.ReadFull(rr, b) 237 | if err != nil { 238 | return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) 239 | } 240 | return base64.StdEncoding.EncodeToString(b), nil 241 | } 242 | 243 | func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { 244 | if resp.StatusCode != http.StatusSwitchingProtocols { 245 | return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) 246 | } 247 | 248 | if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { 249 | return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) 250 | } 251 | 252 | if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { 253 | return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) 254 | } 255 | 256 | if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { 257 | return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", 258 | resp.Header.Get("Sec-WebSocket-Accept"), 259 | secWebSocketKey, 260 | ) 261 | } 262 | 263 | err := verifySubprotocol(opts.Subprotocols, resp) 264 | if err != nil { 265 | return nil, err 266 | } 267 | 268 | return verifyServerExtensions(copts, resp.Header) 269 | } 270 | 271 | func verifySubprotocol(subprotos []string, resp *http.Response) error { 272 | proto := resp.Header.Get("Sec-WebSocket-Protocol") 273 | if proto == "" { 274 | return nil 275 | } 276 | 277 | for _, sp2 := range subprotos { 278 | if strings.EqualFold(sp2, proto) { 279 | return nil 280 | } 281 | } 282 | 283 | return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) 284 | } 285 | 286 | func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { 287 | exts := websocketExtensions(h) 288 | if len(exts) == 0 { 289 | return nil, nil 290 | } 291 | 292 | ext := exts[0] 293 | if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { 294 | return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) 295 | } 296 | 297 | _copts := *copts 298 | copts = &_copts 299 | 300 | for _, p := range ext.params { 301 | switch p { 302 | case "client_no_context_takeover": 303 | copts.clientNoContextTakeover = true 304 | continue 305 | case "server_no_context_takeover": 306 | copts.serverNoContextTakeover = true 307 | continue 308 | } 309 | if strings.HasPrefix(p, "server_max_window_bits=") { 310 | // We can't adjust the deflate window, but decoding with a larger window is acceptable. 311 | continue 312 | } 313 | 314 | return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) 315 | } 316 | 317 | return copts, nil 318 | } 319 | 320 | var bufioReaderPool sync.Pool 321 | 322 | func getBufioReader(r io.Reader) *bufio.Reader { 323 | br, ok := bufioReaderPool.Get().(*bufio.Reader) 324 | if !ok { 325 | return bufio.NewReader(r) 326 | } 327 | br.Reset(r) 328 | return br 329 | } 330 | 331 | func putBufioReader(br *bufio.Reader) { 332 | bufioReaderPool.Put(br) 333 | } 334 | 335 | var bufioWriterPool sync.Pool 336 | 337 | func getBufioWriter(w io.Writer) *bufio.Writer { 338 | bw, ok := bufioWriterPool.Get().(*bufio.Writer) 339 | if !ok { 340 | return bufio.NewWriter(w) 341 | } 342 | bw.Reset(w) 343 | return bw 344 | } 345 | 346 | func putBufioWriter(bw *bufio.Writer) { 347 | bufioWriterPool.Put(bw) 348 | } 349 | -------------------------------------------------------------------------------- /dial_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket_test 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "crypto/rand" 10 | "io" 11 | "net/http" 12 | "net/http/httptest" 13 | "net/url" 14 | "strings" 15 | "testing" 16 | "time" 17 | 18 | "github.com/coder/websocket" 19 | "github.com/coder/websocket/internal/test/assert" 20 | "github.com/coder/websocket/internal/util" 21 | "github.com/coder/websocket/internal/xsync" 22 | ) 23 | 24 | func TestBadDials(t *testing.T) { 25 | t.Parallel() 26 | 27 | t.Run("badReq", func(t *testing.T) { 28 | t.Parallel() 29 | 30 | testCases := []struct { 31 | name string 32 | url string 33 | opts *websocket.DialOptions 34 | rand util.ReaderFunc 35 | nilCtx bool 36 | }{ 37 | { 38 | name: "badURL", 39 | url: "://noscheme", 40 | }, 41 | { 42 | name: "badURLScheme", 43 | url: "ftp://nhooyr.io", 44 | }, 45 | { 46 | name: "badTLS", 47 | url: "wss://totallyfake.nhooyr.io", 48 | }, 49 | { 50 | name: "badReader", 51 | rand: func(p []byte) (int, error) { 52 | return 0, io.EOF 53 | }, 54 | }, 55 | { 56 | name: "nilContext", 57 | url: "http://localhost", 58 | nilCtx: true, 59 | }, 60 | } 61 | 62 | for _, tc := range testCases { 63 | tc := tc 64 | t.Run(tc.name, func(t *testing.T) { 65 | t.Parallel() 66 | 67 | var ctx context.Context 68 | var cancel func() 69 | if !tc.nilCtx { 70 | ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) 71 | defer cancel() 72 | } 73 | 74 | if tc.rand == nil { 75 | tc.rand = rand.Reader.Read 76 | } 77 | 78 | _, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand) 79 | assert.Error(t, err) 80 | }) 81 | } 82 | }) 83 | 84 | t.Run("badResponse", func(t *testing.T) { 85 | t.Parallel() 86 | 87 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 88 | defer cancel() 89 | 90 | _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 91 | HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { 92 | return &http.Response{ 93 | Body: io.NopCloser(strings.NewReader("hi")), 94 | }, nil 95 | }), 96 | }) 97 | assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") 98 | }) 99 | 100 | t.Run("badBody", func(t *testing.T) { 101 | t.Parallel() 102 | 103 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 104 | defer cancel() 105 | 106 | rt := func(r *http.Request) (*http.Response, error) { 107 | h := http.Header{} 108 | h.Set("Connection", "Upgrade") 109 | h.Set("Upgrade", "websocket") 110 | h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) 111 | 112 | return &http.Response{ 113 | StatusCode: http.StatusSwitchingProtocols, 114 | Header: h, 115 | Body: io.NopCloser(strings.NewReader("hi")), 116 | }, nil 117 | } 118 | 119 | _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 120 | HTTPClient: mockHTTPClient(rt), 121 | }) 122 | assert.Contains(t, err, "response body is not a io.ReadWriteCloser") 123 | }) 124 | } 125 | 126 | func Test_verifyHostOverride(t *testing.T) { 127 | testCases := []struct { 128 | name string 129 | host string 130 | exp string 131 | }{ 132 | { 133 | name: "noOverride", 134 | host: "", 135 | exp: "example.com", 136 | }, 137 | { 138 | name: "hostOverride", 139 | host: "example.net", 140 | exp: "example.net", 141 | }, 142 | } 143 | 144 | for _, tc := range testCases { 145 | tc := tc 146 | t.Run(tc.name, func(t *testing.T) { 147 | t.Parallel() 148 | 149 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 150 | defer cancel() 151 | 152 | rt := func(r *http.Request) (*http.Response, error) { 153 | assert.Equal(t, "Host", tc.exp, r.Host) 154 | 155 | h := http.Header{} 156 | h.Set("Connection", "Upgrade") 157 | h.Set("Upgrade", "websocket") 158 | h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) 159 | 160 | return &http.Response{ 161 | StatusCode: http.StatusSwitchingProtocols, 162 | Header: h, 163 | Body: mockBody{bytes.NewBufferString("hi")}, 164 | }, nil 165 | } 166 | 167 | c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 168 | HTTPClient: mockHTTPClient(rt), 169 | Host: tc.host, 170 | }) 171 | assert.Success(t, err) 172 | c.CloseNow() 173 | }) 174 | } 175 | 176 | } 177 | 178 | type mockBody struct { 179 | *bytes.Buffer 180 | } 181 | 182 | func (mb mockBody) Close() error { 183 | return nil 184 | } 185 | 186 | func Test_verifyServerHandshake(t *testing.T) { 187 | t.Parallel() 188 | 189 | testCases := []struct { 190 | name string 191 | response func(w http.ResponseWriter) 192 | success bool 193 | }{ 194 | { 195 | name: "badStatus", 196 | response: func(w http.ResponseWriter) { 197 | w.WriteHeader(http.StatusOK) 198 | }, 199 | success: false, 200 | }, 201 | { 202 | name: "badConnection", 203 | response: func(w http.ResponseWriter) { 204 | w.Header().Set("Connection", "???") 205 | w.WriteHeader(http.StatusSwitchingProtocols) 206 | }, 207 | success: false, 208 | }, 209 | { 210 | name: "badUpgrade", 211 | response: func(w http.ResponseWriter) { 212 | w.Header().Set("Connection", "Upgrade") 213 | w.Header().Set("Upgrade", "???") 214 | w.WriteHeader(http.StatusSwitchingProtocols) 215 | }, 216 | success: false, 217 | }, 218 | { 219 | name: "badSecWebSocketAccept", 220 | response: func(w http.ResponseWriter) { 221 | w.Header().Set("Connection", "Upgrade") 222 | w.Header().Set("Upgrade", "websocket") 223 | w.Header().Set("Sec-WebSocket-Accept", "xd") 224 | w.WriteHeader(http.StatusSwitchingProtocols) 225 | }, 226 | success: false, 227 | }, 228 | { 229 | name: "badSecWebSocketProtocol", 230 | response: func(w http.ResponseWriter) { 231 | w.Header().Set("Connection", "Upgrade") 232 | w.Header().Set("Upgrade", "websocket") 233 | w.Header().Set("Sec-WebSocket-Protocol", "xd") 234 | w.WriteHeader(http.StatusSwitchingProtocols) 235 | }, 236 | success: false, 237 | }, 238 | { 239 | name: "unsupportedExtension", 240 | response: func(w http.ResponseWriter) { 241 | w.Header().Set("Connection", "Upgrade") 242 | w.Header().Set("Upgrade", "websocket") 243 | w.Header().Set("Sec-WebSocket-Extensions", "meow") 244 | w.WriteHeader(http.StatusSwitchingProtocols) 245 | }, 246 | success: false, 247 | }, 248 | { 249 | name: "unsupportedDeflateParam", 250 | response: func(w http.ResponseWriter) { 251 | w.Header().Set("Connection", "Upgrade") 252 | w.Header().Set("Upgrade", "websocket") 253 | w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") 254 | w.WriteHeader(http.StatusSwitchingProtocols) 255 | }, 256 | success: false, 257 | }, 258 | { 259 | name: "success", 260 | response: func(w http.ResponseWriter) { 261 | w.Header().Set("Connection", "Upgrade") 262 | w.Header().Set("Upgrade", "websocket") 263 | w.WriteHeader(http.StatusSwitchingProtocols) 264 | }, 265 | success: true, 266 | }, 267 | } 268 | 269 | for _, tc := range testCases { 270 | tc := tc 271 | t.Run(tc.name, func(t *testing.T) { 272 | t.Parallel() 273 | 274 | w := httptest.NewRecorder() 275 | tc.response(w) 276 | resp := w.Result() 277 | 278 | r := httptest.NewRequest("GET", "/", nil) 279 | key, err := websocket.SecWebSocketKey(rand.Reader) 280 | assert.Success(t, err) 281 | r.Header.Set("Sec-WebSocket-Key", key) 282 | 283 | if resp.Header.Get("Sec-WebSocket-Accept") == "" { 284 | resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key)) 285 | } 286 | 287 | opts := &websocket.DialOptions{ 288 | Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), 289 | } 290 | _, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp) 291 | if tc.success { 292 | assert.Success(t, err) 293 | } else { 294 | assert.Error(t, err) 295 | } 296 | }) 297 | } 298 | } 299 | 300 | func mockHTTPClient(fn roundTripperFunc) *http.Client { 301 | return &http.Client{ 302 | Transport: fn, 303 | } 304 | } 305 | 306 | type roundTripperFunc func(*http.Request) (*http.Response, error) 307 | 308 | func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { 309 | return f(r) 310 | } 311 | 312 | func TestDialRedirect(t *testing.T) { 313 | t.Parallel() 314 | 315 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 316 | defer cancel() 317 | 318 | _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{ 319 | HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) { 320 | resp := &http.Response{ 321 | Header: http.Header{}, 322 | } 323 | if r.URL.Scheme != "https" { 324 | resp.Header.Set("Location", "wss://example.com") 325 | resp.StatusCode = http.StatusFound 326 | return resp, nil 327 | } 328 | resp.Header.Set("Connection", "Upgrade") 329 | resp.Header.Set("Upgrade", "meow") 330 | resp.StatusCode = http.StatusSwitchingProtocols 331 | return resp, nil 332 | }), 333 | }) 334 | assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket") 335 | } 336 | 337 | type forwardProxy struct { 338 | hc *http.Client 339 | } 340 | 341 | func newForwardProxy() *forwardProxy { 342 | return &forwardProxy{ 343 | hc: &http.Client{}, 344 | } 345 | } 346 | 347 | func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 348 | ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) 349 | defer cancel() 350 | 351 | r = r.WithContext(ctx) 352 | r.RequestURI = "" 353 | resp, err := fc.hc.Do(r) 354 | if err != nil { 355 | http.Error(w, err.Error(), http.StatusBadRequest) 356 | return 357 | } 358 | defer resp.Body.Close() 359 | 360 | for k, v := range resp.Header { 361 | w.Header()[k] = v 362 | } 363 | w.Header().Set("PROXIED", "true") 364 | w.WriteHeader(resp.StatusCode) 365 | if resprw, ok := resp.Body.(io.ReadWriter); ok { 366 | c, brw, err := w.(http.Hijacker).Hijack() 367 | if err != nil { 368 | http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 369 | return 370 | } 371 | brw.Flush() 372 | 373 | errc1 := xsync.Go(func() error { 374 | _, err := io.Copy(c, resprw) 375 | return err 376 | }) 377 | errc2 := xsync.Go(func() error { 378 | _, err := io.Copy(resprw, c) 379 | return err 380 | }) 381 | select { 382 | case <-errc1: 383 | case <-errc2: 384 | case <-r.Context().Done(): 385 | } 386 | } else { 387 | io.Copy(w, resp.Body) 388 | } 389 | } 390 | 391 | func TestDialViaProxy(t *testing.T) { 392 | t.Parallel() 393 | 394 | ps := httptest.NewServer(newForwardProxy()) 395 | defer ps.Close() 396 | 397 | s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 398 | err := echoServer(w, r, nil) 399 | assert.Success(t, err) 400 | })) 401 | defer s.Close() 402 | 403 | psu, err := url.Parse(ps.URL) 404 | assert.Success(t, err) 405 | proxyTransport := http.DefaultTransport.(*http.Transport).Clone() 406 | proxyTransport.Proxy = http.ProxyURL(psu) 407 | 408 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 409 | defer cancel() 410 | c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{ 411 | HTTPClient: &http.Client{ 412 | Transport: proxyTransport, 413 | }, 414 | }) 415 | assert.Success(t, err) 416 | assert.Equal(t, "", "true", resp.Header.Get("PROXIED")) 417 | 418 | assertEcho(t, ctx, c) 419 | assertClose(t, c) 420 | } 421 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | // Package websocket implements the RFC 6455 WebSocket protocol. 5 | // 6 | // https://tools.ietf.org/html/rfc6455 7 | // 8 | // Use Dial to dial a WebSocket server. 9 | // 10 | // Use Accept to accept a WebSocket client. 11 | // 12 | // Conn represents the resulting WebSocket connection. 13 | // 14 | // The examples are the best way to understand how to correctly use the library. 15 | // 16 | // The wsjson subpackage contain helpers for JSON and protobuf messages. 17 | // 18 | // More documentation at https://github.com/coder/websocket. 19 | // 20 | // # Wasm 21 | // 22 | // The client side supports compiling to Wasm. 23 | // It wraps the WebSocket browser API. 24 | // 25 | // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket 26 | // 27 | // Some important caveats to be aware of: 28 | // 29 | // - Accept always errors out 30 | // - Conn.Ping is no-op 31 | // - Conn.CloseNow is Close(StatusGoingAway, "") 32 | // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op 33 | // - *http.Response from Dial is &http.Response{} with a 101 status code on success 34 | package websocket // import "github.com/coder/websocket" 35 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package websocket_test 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/coder/websocket" 10 | "github.com/coder/websocket/wsjson" 11 | ) 12 | 13 | func ExampleAccept() { 14 | // This handler accepts a WebSocket connection, reads a single JSON 15 | // message from the client and then closes the connection. 16 | 17 | fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | c, err := websocket.Accept(w, r, nil) 19 | if err != nil { 20 | log.Println(err) 21 | return 22 | } 23 | defer c.CloseNow() 24 | 25 | ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) 26 | defer cancel() 27 | 28 | var v interface{} 29 | err = wsjson.Read(ctx, c, &v) 30 | if err != nil { 31 | log.Println(err) 32 | return 33 | } 34 | 35 | c.Close(websocket.StatusNormalClosure, "") 36 | }) 37 | 38 | err := http.ListenAndServe("localhost:8080", fn) 39 | log.Fatal(err) 40 | } 41 | 42 | func ExampleDial() { 43 | // Dials a server, writes a single JSON message and then 44 | // closes the connection. 45 | 46 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 47 | defer cancel() 48 | 49 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | defer c.CloseNow() 54 | 55 | err = wsjson.Write(ctx, c, "hi") 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | 60 | c.Close(websocket.StatusNormalClosure, "") 61 | } 62 | 63 | func ExampleCloseStatus() { 64 | // Dials a server and then expects to be disconnected with status code 65 | // websocket.StatusNormalClosure. 66 | 67 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 68 | defer cancel() 69 | 70 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 71 | if err != nil { 72 | log.Fatal(err) 73 | } 74 | defer c.CloseNow() 75 | 76 | _, _, err = c.Reader(ctx) 77 | if websocket.CloseStatus(err) != websocket.StatusNormalClosure { 78 | log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) 79 | } 80 | } 81 | 82 | func Example_writeOnly() { 83 | // This handler demonstrates how to correctly handle a write only WebSocket connection. 84 | // i.e you only expect to write messages and do not expect to read any messages. 85 | fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 86 | c, err := websocket.Accept(w, r, nil) 87 | if err != nil { 88 | log.Println(err) 89 | return 90 | } 91 | defer c.CloseNow() 92 | 93 | ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) 94 | defer cancel() 95 | 96 | ctx = c.CloseRead(ctx) 97 | 98 | t := time.NewTicker(time.Second * 30) 99 | defer t.Stop() 100 | 101 | for { 102 | select { 103 | case <-ctx.Done(): 104 | c.Close(websocket.StatusNormalClosure, "") 105 | return 106 | case <-t.C: 107 | err = wsjson.Write(ctx, c, "hi") 108 | if err != nil { 109 | log.Println(err) 110 | return 111 | } 112 | } 113 | } 114 | }) 115 | 116 | err := http.ListenAndServe("localhost:8080", fn) 117 | log.Fatal(err) 118 | } 119 | 120 | func Example_crossOrigin() { 121 | // This handler demonstrates how to safely accept cross origin WebSockets 122 | // from the origin example.com. 123 | fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 124 | c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ 125 | OriginPatterns: []string{"example.com"}, 126 | }) 127 | if err != nil { 128 | log.Println(err) 129 | return 130 | } 131 | c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted") 132 | }) 133 | 134 | err := http.ListenAndServe("localhost:8080", fn) 135 | log.Fatal(err) 136 | } 137 | 138 | func ExampleConn_Ping() { 139 | // Dials a server and pings it 5 times. 140 | 141 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 142 | defer cancel() 143 | 144 | c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) 145 | if err != nil { 146 | log.Fatal(err) 147 | } 148 | defer c.CloseNow() 149 | 150 | // Required to read the Pongs from the server. 151 | ctx = c.CloseRead(ctx) 152 | 153 | for i := 0; i < 5; i++ { 154 | err = c.Ping(ctx) 155 | if err != nil { 156 | log.Fatal(err) 157 | } 158 | } 159 | 160 | c.Close(websocket.StatusNormalClosure, "") 161 | } 162 | 163 | // This example demonstrates full stack chat with an automated test. 164 | func Example_fullStackChat() { 165 | // https://github.com/nhooyr/websocket/tree/master/internal/examples/chat 166 | } 167 | 168 | // This example demonstrates a echo server. 169 | func Example_echo() { 170 | // https://github.com/nhooyr/websocket/tree/master/internal/examples/echo 171 | } 172 | -------------------------------------------------------------------------------- /export_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "net" 8 | 9 | "github.com/coder/websocket/internal/util" 10 | ) 11 | 12 | func (c *Conn) RecordBytesWritten() *int { 13 | var bytesWritten int 14 | c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) { 15 | bytesWritten += len(p) 16 | return c.rwc.Write(p) 17 | })) 18 | return &bytesWritten 19 | } 20 | 21 | func (c *Conn) RecordBytesRead() *int { 22 | var bytesRead int 23 | c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) { 24 | n, err := c.rwc.Read(p) 25 | bytesRead += n 26 | return n, err 27 | })) 28 | return &bytesRead 29 | } 30 | 31 | var ErrClosed = net.ErrClosed 32 | 33 | var ExportedDial = dial 34 | var SecWebSocketAccept = secWebSocketAccept 35 | var SecWebSocketKey = secWebSocketKey 36 | var VerifyServerResponse = verifyServerResponse 37 | 38 | var CompressionModeOpts = CompressionMode.opts 39 | -------------------------------------------------------------------------------- /frame.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | 3 | package websocket 4 | 5 | import ( 6 | "bufio" 7 | "encoding/binary" 8 | "fmt" 9 | "io" 10 | "math" 11 | 12 | "github.com/coder/websocket/internal/errd" 13 | ) 14 | 15 | // opcode represents a WebSocket opcode. 16 | type opcode int 17 | 18 | // https://tools.ietf.org/html/rfc6455#section-11.8. 19 | const ( 20 | opContinuation opcode = iota 21 | opText 22 | opBinary 23 | // 3 - 7 are reserved for further non-control frames. 24 | _ 25 | _ 26 | _ 27 | _ 28 | _ 29 | opClose 30 | opPing 31 | opPong 32 | // 11-16 are reserved for further control frames. 33 | ) 34 | 35 | // header represents a WebSocket frame header. 36 | // See https://tools.ietf.org/html/rfc6455#section-5.2. 37 | type header struct { 38 | fin bool 39 | rsv1 bool 40 | rsv2 bool 41 | rsv3 bool 42 | opcode opcode 43 | 44 | payloadLength int64 45 | 46 | masked bool 47 | maskKey uint32 48 | } 49 | 50 | // readFrameHeader reads a header from the reader. 51 | // See https://tools.ietf.org/html/rfc6455#section-5.2. 52 | func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { 53 | defer errd.Wrap(&err, "failed to read frame header") 54 | 55 | b, err := r.ReadByte() 56 | if err != nil { 57 | return header{}, err 58 | } 59 | 60 | h.fin = b&(1<<7) != 0 61 | h.rsv1 = b&(1<<6) != 0 62 | h.rsv2 = b&(1<<5) != 0 63 | h.rsv3 = b&(1<<4) != 0 64 | 65 | h.opcode = opcode(b & 0xf) 66 | 67 | b, err = r.ReadByte() 68 | if err != nil { 69 | return header{}, err 70 | } 71 | 72 | h.masked = b&(1<<7) != 0 73 | 74 | payloadLength := b &^ (1 << 7) 75 | switch { 76 | case payloadLength < 126: 77 | h.payloadLength = int64(payloadLength) 78 | case payloadLength == 126: 79 | _, err = io.ReadFull(r, readBuf[:2]) 80 | h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) 81 | case payloadLength == 127: 82 | _, err = io.ReadFull(r, readBuf) 83 | h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) 84 | } 85 | if err != nil { 86 | return header{}, err 87 | } 88 | 89 | if h.payloadLength < 0 { 90 | return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) 91 | } 92 | 93 | if h.masked { 94 | _, err = io.ReadFull(r, readBuf[:4]) 95 | if err != nil { 96 | return header{}, err 97 | } 98 | h.maskKey = binary.LittleEndian.Uint32(readBuf) 99 | } 100 | 101 | return h, nil 102 | } 103 | 104 | // maxControlPayload is the maximum length of a control frame payload. 105 | // See https://tools.ietf.org/html/rfc6455#section-5.5. 106 | const maxControlPayload = 125 107 | 108 | // writeFrameHeader writes the bytes of the header to w. 109 | // See https://tools.ietf.org/html/rfc6455#section-5.2 110 | func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { 111 | defer errd.Wrap(&err, "failed to write frame header") 112 | 113 | var b byte 114 | if h.fin { 115 | b |= 1 << 7 116 | } 117 | if h.rsv1 { 118 | b |= 1 << 6 119 | } 120 | if h.rsv2 { 121 | b |= 1 << 5 122 | } 123 | if h.rsv3 { 124 | b |= 1 << 4 125 | } 126 | 127 | b |= byte(h.opcode) 128 | 129 | err = w.WriteByte(b) 130 | if err != nil { 131 | return err 132 | } 133 | 134 | lengthByte := byte(0) 135 | if h.masked { 136 | lengthByte |= 1 << 7 137 | } 138 | 139 | switch { 140 | case h.payloadLength > math.MaxUint16: 141 | lengthByte |= 127 142 | case h.payloadLength > 125: 143 | lengthByte |= 126 144 | case h.payloadLength >= 0: 145 | lengthByte |= byte(h.payloadLength) 146 | } 147 | err = w.WriteByte(lengthByte) 148 | if err != nil { 149 | return err 150 | } 151 | 152 | switch { 153 | case h.payloadLength > math.MaxUint16: 154 | binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) 155 | _, err = w.Write(buf) 156 | case h.payloadLength > 125: 157 | binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) 158 | _, err = w.Write(buf[:2]) 159 | } 160 | if err != nil { 161 | return err 162 | } 163 | 164 | if h.masked { 165 | binary.LittleEndian.PutUint32(buf, h.maskKey) 166 | _, err = w.Write(buf[:4]) 167 | if err != nil { 168 | return err 169 | } 170 | } 171 | 172 | return nil 173 | } 174 | -------------------------------------------------------------------------------- /frame_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | // +build !js 3 | 4 | package websocket 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "encoding/binary" 10 | "math/bits" 11 | "math/rand" 12 | "strconv" 13 | "testing" 14 | "time" 15 | 16 | "github.com/coder/websocket/internal/test/assert" 17 | ) 18 | 19 | func TestHeader(t *testing.T) { 20 | t.Parallel() 21 | 22 | t.Run("lengths", func(t *testing.T) { 23 | t.Parallel() 24 | 25 | lengths := []int{ 26 | 124, 27 | 125, 28 | 126, 29 | 127, 30 | 31 | 65534, 32 | 65535, 33 | 65536, 34 | 65537, 35 | } 36 | 37 | for _, n := range lengths { 38 | n := n 39 | t.Run(strconv.Itoa(n), func(t *testing.T) { 40 | t.Parallel() 41 | 42 | testHeader(t, header{ 43 | payloadLength: int64(n), 44 | }) 45 | }) 46 | } 47 | }) 48 | 49 | t.Run("fuzz", func(t *testing.T) { 50 | t.Parallel() 51 | 52 | r := rand.New(rand.NewSource(time.Now().UnixNano())) 53 | randBool := func() bool { 54 | return r.Intn(2) == 0 55 | } 56 | 57 | for i := 0; i < 10000; i++ { 58 | h := header{ 59 | fin: randBool(), 60 | rsv1: randBool(), 61 | rsv2: randBool(), 62 | rsv3: randBool(), 63 | opcode: opcode(r.Intn(16)), 64 | 65 | masked: randBool(), 66 | payloadLength: r.Int63(), 67 | } 68 | if h.masked { 69 | h.maskKey = r.Uint32() 70 | } 71 | 72 | testHeader(t, h) 73 | } 74 | }) 75 | } 76 | 77 | func testHeader(t *testing.T, h header) { 78 | b := &bytes.Buffer{} 79 | w := bufio.NewWriter(b) 80 | r := bufio.NewReader(b) 81 | 82 | err := writeFrameHeader(h, w, make([]byte, 8)) 83 | assert.Success(t, err) 84 | 85 | err = w.Flush() 86 | assert.Success(t, err) 87 | 88 | h2, err := readFrameHeader(r, make([]byte, 8)) 89 | assert.Success(t, err) 90 | 91 | assert.Equal(t, "read header", h, h2) 92 | } 93 | 94 | func Test_mask(t *testing.T) { 95 | t.Parallel() 96 | 97 | key := []byte{0xa, 0xb, 0xc, 0xff} 98 | key32 := binary.LittleEndian.Uint32(key) 99 | p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} 100 | gotKey32 := mask(p, key32) 101 | 102 | expP := []byte{0, 0, 0, 0x0d, 0x6} 103 | assert.Equal(t, "p", expP, p) 104 | 105 | expKey32 := bits.RotateLeft32(key32, -8) 106 | assert.Equal(t, "key32", expKey32, gotKey32) 107 | } 108 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coder/websocket 2 | 3 | go 1.23 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder/websocket/efb626be44240d7979b57427265d9b6402166b96/go.sum -------------------------------------------------------------------------------- /hijack.go: -------------------------------------------------------------------------------- 1 | //go:build !js 2 | 3 | package websocket 4 | 5 | import ( 6 | "net/http" 7 | ) 8 | 9 | type rwUnwrapper interface { 10 | Unwrap() http.ResponseWriter 11 | } 12 | 13 | // hijacker returns the Hijacker interface of the http.ResponseWriter. 14 | // It follows the Unwrap method of the http.ResponseWriter if available, 15 | // matching the behavior of http.ResponseController. If the Hijacker 16 | // interface is not found, it returns false. 17 | // 18 | // Since the http.ResponseController is not available in Go 1.19, and 19 | // does not support checking the presence of the Hijacker interface, 20 | // this function is used to provide a consistent way to check for the 21 | // Hijacker interface across Go versions. 22 | func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) { 23 | for { 24 | switch t := rw.(type) { 25 | case http.Hijacker: 26 | return t, true 27 | case rwUnwrapper: 28 | rw = t.Unwrap() 29 | default: 30 | return nil, false 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /hijack_go120_test.go: -------------------------------------------------------------------------------- 1 | //go:build !js && go1.20 2 | 3 | package websocket 4 | 5 | import ( 6 | "bufio" 7 | "errors" 8 | "net" 9 | "net/http" 10 | "net/http/httptest" 11 | "testing" 12 | 13 | "github.com/coder/websocket/internal/test/assert" 14 | ) 15 | 16 | func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) { 17 | t.Parallel() 18 | 19 | rr := httptest.NewRecorder() 20 | w := mockUnwrapper{ 21 | ResponseWriter: rr, 22 | unwrap: func() http.ResponseWriter { 23 | return mockHijacker{ 24 | ResponseWriter: rr, 25 | hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { 26 | return nil, nil, errors.New("haha") 27 | }, 28 | } 29 | }, 30 | } 31 | 32 | _, _, err := http.NewResponseController(w).Hijack() 33 | assert.Contains(t, err, "haha") 34 | hj, ok := hijacker(w) 35 | assert.Equal(t, "hijacker found", ok, true) 36 | _, _, err = hj.Hijack() 37 | assert.Contains(t, err, "haha") 38 | } 39 | -------------------------------------------------------------------------------- /internal/bpool/bpool.go: -------------------------------------------------------------------------------- 1 | package bpool 2 | 3 | import ( 4 | "bytes" 5 | "sync" 6 | ) 7 | 8 | var bpool = sync.Pool{ 9 | New: func() any { 10 | return &bytes.Buffer{} 11 | }, 12 | } 13 | 14 | // Get returns a buffer from the pool or creates a new one if 15 | // the pool is empty. 16 | func Get() *bytes.Buffer { 17 | b := bpool.Get() 18 | return b.(*bytes.Buffer) 19 | } 20 | 21 | // Put returns a buffer into the pool. 22 | func Put(b *bytes.Buffer) { 23 | b.Reset() 24 | bpool.Put(b) 25 | } 26 | -------------------------------------------------------------------------------- /internal/errd/wrap.go: -------------------------------------------------------------------------------- 1 | package errd 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Wrap wraps err with fmt.Errorf if err is non nil. 8 | // Intended for use with defer and a named error return. 9 | // Inspired by https://github.com/golang/go/issues/32676. 10 | func Wrap(err *error, f string, v ...interface{}) { 11 | if *err != nil { 12 | *err = fmt.Errorf(f+": %w", append(v, *err)...) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /internal/examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains more involved examples unsuitable 4 | for display with godoc. 5 | -------------------------------------------------------------------------------- /internal/examples/chat/README.md: -------------------------------------------------------------------------------- 1 | # Chat Example 2 | 3 | This directory contains a full stack example of a simple chat webapp using github.com/coder/websocket. 4 | 5 | ```bash 6 | $ cd examples/chat 7 | $ go run . localhost:0 8 | listening on ws://127.0.0.1:51055 9 | ``` 10 | 11 | Visit the printed URL to submit and view broadcasted messages in a browser. 12 | 13 |  14 | 15 | ## Structure 16 | 17 | The frontend is contained in `index.html`, `index.js` and `index.css`. It sets up the 18 | DOM with a scrollable div at the top that is populated with new messages as they are broadcast. 19 | At the bottom it adds a form to submit messages. 20 | 21 | The messages are received via the WebSocket `/subscribe` endpoint and published via 22 | the HTTP POST `/publish` endpoint. The reason for not publishing messages over the WebSocket 23 | is so that you can easily publish a message with curl. 24 | 25 | The server portion is `main.go` and `chat.go` and implements serving the static frontend 26 | assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoint. 27 | 28 | The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by 29 | `index.html` and then `index.js`. 30 | 31 | There are two automated tests for the server included in `chat_test.go`. The first is a simple one 32 | client echo test. It publishes a single message and ensures it's received. 33 | 34 | The second is a complex concurrency test where 10 clients send 128 unique messages 35 | of max 128 bytes concurrently. The test ensures all messages are seen by every client. 36 | -------------------------------------------------------------------------------- /internal/examples/chat/chat.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "log" 8 | "net" 9 | "net/http" 10 | "sync" 11 | "time" 12 | 13 | "golang.org/x/time/rate" 14 | 15 | "github.com/coder/websocket" 16 | ) 17 | 18 | // chatServer enables broadcasting to a set of subscribers. 19 | type chatServer struct { 20 | // subscriberMessageBuffer controls the max number 21 | // of messages that can be queued for a subscriber 22 | // before it is kicked. 23 | // 24 | // Defaults to 16. 25 | subscriberMessageBuffer int 26 | 27 | // publishLimiter controls the rate limit applied to the publish endpoint. 28 | // 29 | // Defaults to one publish every 100ms with a burst of 8. 30 | publishLimiter *rate.Limiter 31 | 32 | // logf controls where logs are sent. 33 | // Defaults to log.Printf. 34 | logf func(f string, v ...interface{}) 35 | 36 | // serveMux routes the various endpoints to the appropriate handler. 37 | serveMux http.ServeMux 38 | 39 | subscribersMu sync.Mutex 40 | subscribers map[*subscriber]struct{} 41 | } 42 | 43 | // newChatServer constructs a chatServer with the defaults. 44 | func newChatServer() *chatServer { 45 | cs := &chatServer{ 46 | subscriberMessageBuffer: 16, 47 | logf: log.Printf, 48 | subscribers: make(map[*subscriber]struct{}), 49 | publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), 50 | } 51 | cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) 52 | cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) 53 | cs.serveMux.HandleFunc("/publish", cs.publishHandler) 54 | 55 | return cs 56 | } 57 | 58 | // subscriber represents a subscriber. 59 | // Messages are sent on the msgs channel and if the client 60 | // cannot keep up with the messages, closeSlow is called. 61 | type subscriber struct { 62 | msgs chan []byte 63 | closeSlow func() 64 | } 65 | 66 | func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 67 | cs.serveMux.ServeHTTP(w, r) 68 | } 69 | 70 | // subscribeHandler accepts the WebSocket connection and then subscribes 71 | // it to all future messages. 72 | func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { 73 | err := cs.subscribe(w, r) 74 | if errors.Is(err, context.Canceled) { 75 | return 76 | } 77 | if websocket.CloseStatus(err) == websocket.StatusNormalClosure || 78 | websocket.CloseStatus(err) == websocket.StatusGoingAway { 79 | return 80 | } 81 | if err != nil { 82 | cs.logf("%v", err) 83 | return 84 | } 85 | } 86 | 87 | // publishHandler reads the request body with a limit of 8192 bytes and then publishes 88 | // the received message. 89 | func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { 90 | if r.Method != "POST" { 91 | http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) 92 | return 93 | } 94 | body := http.MaxBytesReader(w, r.Body, 8192) 95 | msg, err := io.ReadAll(body) 96 | if err != nil { 97 | http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) 98 | return 99 | } 100 | 101 | cs.publish(msg) 102 | 103 | w.WriteHeader(http.StatusAccepted) 104 | } 105 | 106 | // subscribe subscribes the given WebSocket to all broadcast messages. 107 | // It creates a subscriber with a buffered msgs chan to give some room to slower 108 | // connections and then registers the subscriber. It then listens for all messages 109 | // and writes them to the WebSocket. If the context is cancelled or 110 | // an error occurs, it returns and deletes the subscription. 111 | // 112 | // It uses CloseRead to keep reading from the connection to process control 113 | // messages and cancel the context if the connection drops. 114 | func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error { 115 | var mu sync.Mutex 116 | var c *websocket.Conn 117 | var closed bool 118 | s := &subscriber{ 119 | msgs: make(chan []byte, cs.subscriberMessageBuffer), 120 | closeSlow: func() { 121 | mu.Lock() 122 | defer mu.Unlock() 123 | closed = true 124 | if c != nil { 125 | c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") 126 | } 127 | }, 128 | } 129 | cs.addSubscriber(s) 130 | defer cs.deleteSubscriber(s) 131 | 132 | c2, err := websocket.Accept(w, r, nil) 133 | if err != nil { 134 | return err 135 | } 136 | mu.Lock() 137 | if closed { 138 | mu.Unlock() 139 | return net.ErrClosed 140 | } 141 | c = c2 142 | mu.Unlock() 143 | defer c.CloseNow() 144 | 145 | ctx := c.CloseRead(context.Background()) 146 | 147 | for { 148 | select { 149 | case msg := <-s.msgs: 150 | err := writeTimeout(ctx, time.Second*5, c, msg) 151 | if err != nil { 152 | return err 153 | } 154 | case <-ctx.Done(): 155 | return ctx.Err() 156 | } 157 | } 158 | } 159 | 160 | // publish publishes the msg to all subscribers. 161 | // It never blocks and so messages to slow subscribers 162 | // are dropped. 163 | func (cs *chatServer) publish(msg []byte) { 164 | cs.subscribersMu.Lock() 165 | defer cs.subscribersMu.Unlock() 166 | 167 | cs.publishLimiter.Wait(context.Background()) 168 | 169 | for s := range cs.subscribers { 170 | select { 171 | case s.msgs <- msg: 172 | default: 173 | go s.closeSlow() 174 | } 175 | } 176 | } 177 | 178 | // addSubscriber registers a subscriber. 179 | func (cs *chatServer) addSubscriber(s *subscriber) { 180 | cs.subscribersMu.Lock() 181 | cs.subscribers[s] = struct{}{} 182 | cs.subscribersMu.Unlock() 183 | } 184 | 185 | // deleteSubscriber deletes the given subscriber. 186 | func (cs *chatServer) deleteSubscriber(s *subscriber) { 187 | cs.subscribersMu.Lock() 188 | delete(cs.subscribers, s) 189 | cs.subscribersMu.Unlock() 190 | } 191 | 192 | func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { 193 | ctx, cancel := context.WithTimeout(ctx, timeout) 194 | defer cancel() 195 | 196 | return c.Write(ctx, websocket.MessageText, msg) 197 | } 198 | -------------------------------------------------------------------------------- /internal/examples/chat/chat_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "fmt" 7 | "math/big" 8 | "net/http" 9 | "net/http/httptest" 10 | "strings" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "golang.org/x/time/rate" 16 | 17 | "github.com/coder/websocket" 18 | ) 19 | 20 | func Test_chatServer(t *testing.T) { 21 | t.Parallel() 22 | 23 | // This is a simple echo test with a single client. 24 | // The client sends a message and ensures it receives 25 | // it on its WebSocket. 26 | t.Run("simple", func(t *testing.T) { 27 | t.Parallel() 28 | 29 | url, closeFn := setupTest(t) 30 | defer closeFn() 31 | 32 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 33 | defer cancel() 34 | 35 | cl, err := newClient(ctx, url) 36 | assertSuccess(t, err) 37 | defer cl.Close() 38 | 39 | expMsg := randString(512) 40 | err = cl.publish(ctx, expMsg) 41 | assertSuccess(t, err) 42 | 43 | msg, err := cl.nextMessage() 44 | assertSuccess(t, err) 45 | 46 | if expMsg != msg { 47 | t.Fatalf("expected %v but got %v", expMsg, msg) 48 | } 49 | }) 50 | 51 | // This test is a complex concurrency test. 52 | // 10 clients are started that send 128 different 53 | // messages of max 128 bytes concurrently. 54 | // 55 | // The test verifies that every message is seen by every client 56 | // and no errors occur anywhere. 57 | t.Run("concurrency", func(t *testing.T) { 58 | t.Parallel() 59 | 60 | const nmessages = 128 61 | const maxMessageSize = 128 62 | const nclients = 16 63 | 64 | url, closeFn := setupTest(t) 65 | defer closeFn() 66 | 67 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 68 | defer cancel() 69 | 70 | var clients []*client 71 | var clientMsgs []map[string]struct{} 72 | for i := 0; i < nclients; i++ { 73 | cl, err := newClient(ctx, url) 74 | assertSuccess(t, err) 75 | defer cl.Close() 76 | 77 | clients = append(clients, cl) 78 | clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) 79 | } 80 | 81 | allMessages := make(map[string]struct{}) 82 | for _, msgs := range clientMsgs { 83 | for m := range msgs { 84 | allMessages[m] = struct{}{} 85 | } 86 | } 87 | 88 | var wg sync.WaitGroup 89 | for i, cl := range clients { 90 | i := i 91 | cl := cl 92 | 93 | wg.Add(1) 94 | go func() { 95 | defer wg.Done() 96 | err := cl.publishMsgs(ctx, clientMsgs[i]) 97 | if err != nil { 98 | t.Errorf("client %d failed to publish all messages: %v", i, err) 99 | } 100 | }() 101 | 102 | wg.Add(1) 103 | go func() { 104 | defer wg.Done() 105 | err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) 106 | if err != nil { 107 | t.Errorf("client %d failed to receive all messages: %v", i, err) 108 | } 109 | }() 110 | } 111 | 112 | wg.Wait() 113 | }) 114 | } 115 | 116 | // setupTest sets up chatServer that can be used 117 | // via the returned url. 118 | // 119 | // Defer closeFn to ensure everything is cleaned up at 120 | // the end of the test. 121 | // 122 | // chatServer logs will be logged via t.Logf. 123 | func setupTest(t *testing.T) (url string, closeFn func()) { 124 | cs := newChatServer() 125 | cs.logf = t.Logf 126 | 127 | // To ensure tests run quickly under even -race. 128 | cs.subscriberMessageBuffer = 4096 129 | cs.publishLimiter.SetLimit(rate.Inf) 130 | 131 | s := httptest.NewServer(cs) 132 | return s.URL, func() { 133 | s.Close() 134 | } 135 | } 136 | 137 | // testAllMessagesReceived ensures that after n reads, all msgs in msgs 138 | // have been read. 139 | func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { 140 | msgs = cloneMessages(msgs) 141 | 142 | for i := 0; i < n; i++ { 143 | msg, err := cl.nextMessage() 144 | if err != nil { 145 | return err 146 | } 147 | delete(msgs, msg) 148 | } 149 | 150 | if len(msgs) != 0 { 151 | return fmt.Errorf("did not receive all expected messages: %q", msgs) 152 | } 153 | return nil 154 | } 155 | 156 | func cloneMessages(msgs map[string]struct{}) map[string]struct{} { 157 | msgs2 := make(map[string]struct{}, len(msgs)) 158 | for m := range msgs { 159 | msgs2[m] = struct{}{} 160 | } 161 | return msgs2 162 | } 163 | 164 | func randMessages(n, maxMessageLength int) map[string]struct{} { 165 | msgs := make(map[string]struct{}) 166 | for i := 0; i < n; i++ { 167 | m := randString(randInt(maxMessageLength)) 168 | if _, ok := msgs[m]; ok { 169 | i-- 170 | continue 171 | } 172 | msgs[m] = struct{}{} 173 | } 174 | return msgs 175 | } 176 | 177 | func assertSuccess(t *testing.T, err error) { 178 | t.Helper() 179 | if err != nil { 180 | t.Fatal(err) 181 | } 182 | } 183 | 184 | type client struct { 185 | url string 186 | c *websocket.Conn 187 | } 188 | 189 | func newClient(ctx context.Context, url string) (*client, error) { 190 | c, _, err := websocket.Dial(ctx, url+"/subscribe", nil) 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | cl := &client{ 196 | url: url, 197 | c: c, 198 | } 199 | 200 | return cl, nil 201 | } 202 | 203 | func (cl *client) publish(ctx context.Context, msg string) (err error) { 204 | defer func() { 205 | if err != nil { 206 | cl.c.Close(websocket.StatusInternalError, "publish failed") 207 | } 208 | }() 209 | 210 | req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) 211 | resp, err := http.DefaultClient.Do(req) 212 | if err != nil { 213 | return err 214 | } 215 | defer resp.Body.Close() 216 | if resp.StatusCode != http.StatusAccepted { 217 | return fmt.Errorf("publish request failed: %v", resp.StatusCode) 218 | } 219 | return nil 220 | } 221 | 222 | func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { 223 | for m := range msgs { 224 | err := cl.publish(ctx, m) 225 | if err != nil { 226 | return err 227 | } 228 | } 229 | return nil 230 | } 231 | 232 | func (cl *client) nextMessage() (string, error) { 233 | typ, b, err := cl.c.Read(context.Background()) 234 | if err != nil { 235 | return "", err 236 | } 237 | 238 | if typ != websocket.MessageText { 239 | cl.c.Close(websocket.StatusUnsupportedData, "expected text message") 240 | return "", fmt.Errorf("expected text message but got %v", typ) 241 | } 242 | return string(b), nil 243 | } 244 | 245 | func (cl *client) Close() error { 246 | return cl.c.Close(websocket.StatusNormalClosure, "") 247 | } 248 | 249 | // randString generates a random string with length n. 250 | func randString(n int) string { 251 | b := make([]byte, n) 252 | _, err := rand.Reader.Read(b) 253 | if err != nil { 254 | panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) 255 | } 256 | 257 | s := strings.ToValidUTF8(string(b), "_") 258 | s = strings.ReplaceAll(s, "\x00", "_") 259 | if len(s) > n { 260 | return s[:n] 261 | } 262 | if len(s) < n { 263 | // Pad with = 264 | extra := n - len(s) 265 | return s + strings.Repeat("=", extra) 266 | } 267 | return s 268 | } 269 | 270 | // randInt returns a randomly generated integer between [0, max). 271 | func randInt(max int) int { 272 | x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) 273 | if err != nil { 274 | panic(fmt.Sprintf("failed to get random int: %v", err)) 275 | } 276 | return int(x.Int64()) 277 | } 278 | -------------------------------------------------------------------------------- /internal/examples/chat/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | width: 100vw; 3 | min-width: 320px; 4 | } 5 | 6 | #root { 7 | padding: 40px 20px; 8 | max-width: 600px; 9 | margin: auto; 10 | height: 100vh; 11 | 12 | display: flex; 13 | flex-direction: column; 14 | align-items: center; 15 | justify-content: center; 16 | } 17 | 18 | #root > * + * { 19 | margin: 20px 0 0 0; 20 | } 21 | 22 | /* 100vh on safari does not include the bottom bar. */ 23 | @supports (-webkit-overflow-scrolling: touch) { 24 | #root { 25 | height: 85vh; 26 | } 27 | } 28 | 29 | #message-log { 30 | width: 100%; 31 | flex-grow: 1; 32 | overflow: auto; 33 | } 34 | 35 | #message-log p:first-child { 36 | margin: 0; 37 | } 38 | 39 | #message-log > * + * { 40 | margin: 10px 0 0 0; 41 | } 42 | 43 | #publish-form-container { 44 | width: 100%; 45 | } 46 | 47 | #publish-form { 48 | width: 100%; 49 | display: flex; 50 | height: 40px; 51 | } 52 | 53 | #publish-form > * + * { 54 | margin: 0 0 0 10px; 55 | } 56 | 57 | #publish-form input[type='text'] { 58 | flex-grow: 1; 59 | 60 | -moz-appearance: none; 61 | -webkit-appearance: none; 62 | word-break: normal; 63 | border-radius: 5px; 64 | border: 1px solid #ccc; 65 | } 66 | 67 | #publish-form input[type='submit'] { 68 | color: white; 69 | background-color: black; 70 | border-radius: 5px; 71 | padding: 5px 10px; 72 | border: none; 73 | } 74 | 75 | #publish-form input[type='submit']:hover { 76 | background-color: red; 77 | } 78 | 79 | #publish-form input[type='submit']:active { 80 | background-color: red; 81 | } 82 | -------------------------------------------------------------------------------- /internal/examples/chat/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |