├── .github ├── dependabot.yml └── workflows │ ├── autobahn.yml │ └── main.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── autobahn ├── .gitignore ├── config │ └── fuzzingclient.json ├── docker │ ├── autobahn │ │ └── Dockerfile │ └── server │ │ └── Dockerfile ├── main.go ├── main_go17.go ├── main_go18.go └── script │ └── test.sh ├── check.go ├── cipher.go ├── cipher_test.go ├── dialer.go ├── dialer_test.go ├── dialer_tls_go17.go ├── dialer_tls_go18.go ├── doc.go ├── errors.go ├── example └── autobahn │ ├── autobahn.go │ └── autobahn_test.go ├── frame.go ├── frame_test.go ├── go.mod ├── go.sum ├── hijack_go119.go ├── hijack_go120.go ├── http.go ├── http_test.go ├── nonce.go ├── nonce_test.go ├── read.go ├── read_test.go ├── rw_test.go ├── server.go ├── server_test.go ├── tests └── deflate_test.go ├── util.go ├── util_purego.go ├── util_test.go ├── util_unsafe.go ├── write.go ├── write_test.go ├── wsflate ├── cbuf.go ├── cbuf_test.go ├── extension.go ├── helper.go ├── helper_test.go ├── parameters.go ├── parameters_test.go ├── reader.go ├── reader_test.go ├── writer.go └── writer_test.go └── wsutil ├── cipher.go ├── cipher_test.go ├── dialer.go ├── dialer_test.go ├── extenstion.go ├── handler.go ├── handler_test.go ├── helper.go ├── helper_test.go ├── reader.go ├── reader_test.go ├── upgrader.go ├── upgrader_test.go ├── utf8.go ├── utf8_test.go ├── writer.go ├── writer_test.go └── wsutil.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | commit-message: 5 | prefix: "deps:" 6 | directory: "/" 7 | schedule: 8 | interval: "weekly" 9 | day: "sunday" 10 | time: "09:00" 11 | - package-ecosystem: "github-actions" 12 | commit-message: 13 | prefix: "ci:" 14 | directory: "/" 15 | schedule: 16 | interval: "weekly" 17 | day: "sunday" 18 | time: "09:00" 19 | -------------------------------------------------------------------------------- /.github/workflows/autobahn.yml: -------------------------------------------------------------------------------- 1 | name: Autobahn 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | schedule: 9 | - cron: '0 10 * * 1' # run "At 10:00 on Monday" 10 | 11 | concurrency: 12 | group: autobahn-${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | strategy: 18 | matrix: 19 | os: [ ubuntu-latest ] 20 | go: [ 'stable', 'oldstable' ] 21 | 22 | runs-on: ${{ matrix.os }} 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | 27 | - name: Setup Go 28 | uses: actions/setup-go@v5 29 | with: 30 | go-version: ${{ matrix.go }} 31 | check-latest: true 32 | 33 | - name: Autobahn 34 | env: 35 | CRYPTOGRAPHY_ALLOW_OPENSSL_102: yes 36 | run: | 37 | make test autobahn 38 | 39 | - name: Autobahn Report Artifact 40 | uses: actions/upload-artifact@v4 41 | with: 42 | name: autobahn report ${{ matrix.go }} ${{ matrix.os }} 43 | path: autobahn/report 44 | retention-days: 7 45 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | schedule: 9 | - cron: '0 10 * * 1' # run "At 10:00 on Monday" 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | test: 17 | strategy: 18 | matrix: 19 | os: [ ubuntu-latest, macos-latest, windows-latest ] 20 | go: [ 'stable', 'oldstable' ] 21 | 22 | runs-on: ${{ matrix.os }} 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | 27 | - name: Setup Go 28 | uses: actions/setup-go@v5 29 | with: 30 | go-version: ${{ matrix.go }} 31 | check-latest: true 32 | 33 | - name: Go Env 34 | run: | 35 | go env 36 | 37 | - name: Go Mod 38 | run: | 39 | go mod download 40 | 41 | - name: Go Mod Verify 42 | run: | 43 | go mod verify 44 | 45 | - name: Test 46 | run: | 47 | go test -v -race -shuffle=on -cover ./... 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | reports/ 3 | cpu.out 4 | mem.out 5 | ws.test 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017-2021 Sergey Kamardin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | BENCH ?=. 2 | BENCH_BASE?=master 3 | 4 | clean: 5 | rm -f bin/reporter 6 | rm -fr autobahn/report/* 7 | 8 | bin/reporter: 9 | go build -o bin/reporter ./autobahn 10 | 11 | bin/gocovmerge: 12 | go build -o bin/gocovmerge github.com/wadey/gocovmerge 13 | 14 | .PHONY: autobahn 15 | autobahn: clean bin/reporter 16 | ./autobahn/script/test.sh --build --follow-logs 17 | bin/reporter $(PWD)/autobahn/report/index.json 18 | 19 | .PHONY: autobahn/report 20 | autobahn/report: bin/reporter 21 | ./bin/reporter -http localhost:5555 ./autobahn/report/index.json 22 | 23 | test: 24 | go test -coverprofile=ws.coverage . 25 | go test -coverprofile=wsutil.coverage ./wsutil 26 | go test -coverprofile=wsfalte.coverage ./wsflate 27 | # No statements to cover in ./tests (there are only tests). 28 | go test ./tests 29 | 30 | cover: bin/gocovmerge test autobahn 31 | bin/gocovmerge ws.coverage wsutil.coverage wsflate.coverage autobahn/report/server.coverage > total.coverage 32 | 33 | benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD) 34 | benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX) 35 | benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX) 36 | benchcmp: 37 | if [ ! -z "$(shell git status -s)" ]; then\ 38 | echo "could not compare with $(BENCH_BASE) – found unstaged changes";\ 39 | exit 1;\ 40 | fi;\ 41 | if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\ 42 | echo "comparing the same branches";\ 43 | exit 1;\ 44 | fi;\ 45 | echo "benchmarking $(BENCH_BRANCH)...";\ 46 | go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\ 47 | echo "benchmarking $(BENCH_BASE)...";\ 48 | git checkout -q $(BENCH_BASE);\ 49 | go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\ 50 | git checkout -q $(BENCH_BRANCH);\ 51 | echo "\nresults:";\ 52 | echo "========\n";\ 53 | benchcmp $(BENCH_OLD) $(BENCH_NEW);\ 54 | 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ws 2 | 3 | [![GoDoc][godoc-image]][godoc-url] 4 | [![CI][ci-badge]][ci-url] 5 | 6 | > [RFC6455][rfc-url] WebSocket implementation in Go. 7 | 8 | # Features 9 | 10 | - Zero-copy upgrade 11 | - No intermediate allocations during I/O 12 | - Low-level API which allows to build your own logic of packet handling and 13 | buffers reuse 14 | - High-level wrappers and helpers around API in `wsutil` package, which allow 15 | to start fast without digging the protocol internals 16 | 17 | # Documentation 18 | 19 | [GoDoc][godoc-url]. 20 | 21 | # Why 22 | 23 | Existing WebSocket implementations do not allow users to reuse I/O buffers 24 | between connections in clear way. This library aims to export efficient 25 | low-level interface for working with the protocol without forcing only one way 26 | it could be used. 27 | 28 | By the way, if you want get the higher-level tools, you can use `wsutil` 29 | package. 30 | 31 | # Status 32 | 33 | Library is tagged as `v1*` so its API must not be broken during some 34 | improvements or refactoring. 35 | 36 | This implementation of RFC6455 passes [Autobahn Test 37 | Suite](https://github.com/crossbario/autobahn-testsuite) and currently has 38 | about 78% coverage. 39 | 40 | # Examples 41 | 42 | Example applications using `ws` are developed in separate repository 43 | [ws-examples](https://github.com/gobwas/ws-examples). 44 | 45 | # Usage 46 | 47 | The higher-level example of WebSocket echo server: 48 | 49 | ```go 50 | package main 51 | 52 | import ( 53 | "net/http" 54 | 55 | "github.com/gobwas/ws" 56 | "github.com/gobwas/ws/wsutil" 57 | ) 58 | 59 | func main() { 60 | http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 | conn, _, _, err := ws.UpgradeHTTP(r, w) 62 | if err != nil { 63 | // handle error 64 | } 65 | go func() { 66 | defer conn.Close() 67 | 68 | for { 69 | msg, op, err := wsutil.ReadClientData(conn) 70 | if err != nil { 71 | // handle error 72 | } 73 | err = wsutil.WriteServerMessage(conn, op, msg) 74 | if err != nil { 75 | // handle error 76 | } 77 | } 78 | }() 79 | })) 80 | } 81 | ``` 82 | 83 | Lower-level, but still high-level example: 84 | 85 | 86 | ```go 87 | import ( 88 | "net/http" 89 | "io" 90 | 91 | "github.com/gobwas/ws" 92 | "github.com/gobwas/ws/wsutil" 93 | ) 94 | 95 | func main() { 96 | http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 97 | conn, _, _, err := ws.UpgradeHTTP(r, w) 98 | if err != nil { 99 | // handle error 100 | } 101 | go func() { 102 | defer conn.Close() 103 | 104 | var ( 105 | state = ws.StateServerSide 106 | reader = wsutil.NewReader(conn, state) 107 | writer = wsutil.NewWriter(conn, state, ws.OpText) 108 | ) 109 | for { 110 | header, err := reader.NextFrame() 111 | if err != nil { 112 | // handle error 113 | } 114 | 115 | // Reset writer to write frame with right operation code. 116 | writer.Reset(conn, state, header.OpCode) 117 | 118 | if _, err = io.Copy(writer, reader); err != nil { 119 | // handle error 120 | } 121 | if err = writer.Flush(); err != nil { 122 | // handle error 123 | } 124 | } 125 | }() 126 | })) 127 | } 128 | ``` 129 | 130 | We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.: 131 | 132 | ```go 133 | ... 134 | var ( 135 | r = wsutil.NewReader(conn, ws.StateServerSide) 136 | w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText) 137 | decoder = json.NewDecoder(r) 138 | encoder = json.NewEncoder(w) 139 | ) 140 | for { 141 | hdr, err = r.NextFrame() 142 | if err != nil { 143 | return err 144 | } 145 | if hdr.OpCode == ws.OpClose { 146 | return io.EOF 147 | } 148 | var req Request 149 | if err := decoder.Decode(&req); err != nil { 150 | return err 151 | } 152 | var resp Response 153 | if err := encoder.Encode(&resp); err != nil { 154 | return err 155 | } 156 | if err = w.Flush(); err != nil { 157 | return err 158 | } 159 | } 160 | ... 161 | ``` 162 | 163 | The lower-level example without `wsutil`: 164 | 165 | ```go 166 | package main 167 | 168 | import ( 169 | "net" 170 | "io" 171 | 172 | "github.com/gobwas/ws" 173 | ) 174 | 175 | func main() { 176 | ln, err := net.Listen("tcp", "localhost:8080") 177 | if err != nil { 178 | log.Fatal(err) 179 | } 180 | 181 | for { 182 | conn, err := ln.Accept() 183 | if err != nil { 184 | // handle error 185 | } 186 | _, err = ws.Upgrade(conn) 187 | if err != nil { 188 | // handle error 189 | } 190 | 191 | go func() { 192 | defer conn.Close() 193 | 194 | for { 195 | header, err := ws.ReadHeader(conn) 196 | if err != nil { 197 | // handle error 198 | } 199 | 200 | payload := make([]byte, header.Length) 201 | _, err = io.ReadFull(conn, payload) 202 | if err != nil { 203 | // handle error 204 | } 205 | if header.Masked { 206 | ws.Cipher(payload, header.Mask, 0) 207 | } 208 | 209 | // Reset the Masked flag, server frames must not be masked as 210 | // RFC6455 says. 211 | header.Masked = false 212 | 213 | if err := ws.WriteHeader(conn, header); err != nil { 214 | // handle error 215 | } 216 | if _, err := conn.Write(payload); err != nil { 217 | // handle error 218 | } 219 | 220 | if header.OpCode == ws.OpClose { 221 | return 222 | } 223 | } 224 | }() 225 | } 226 | } 227 | ``` 228 | 229 | # Zero-copy upgrade 230 | 231 | Zero-copy upgrade helps to avoid unnecessary allocations and copying while 232 | handling HTTP Upgrade request. 233 | 234 | Processing of all non-websocket headers is made in place with use of registered 235 | user callbacks whose arguments are only valid until callback returns. 236 | 237 | The simple example looks like this: 238 | 239 | ```go 240 | package main 241 | 242 | import ( 243 | "net" 244 | "log" 245 | 246 | "github.com/gobwas/ws" 247 | ) 248 | 249 | func main() { 250 | ln, err := net.Listen("tcp", "localhost:8080") 251 | if err != nil { 252 | log.Fatal(err) 253 | } 254 | u := ws.Upgrader{ 255 | OnHeader: func(key, value []byte) (err error) { 256 | log.Printf("non-websocket header: %q=%q", key, value) 257 | return 258 | }, 259 | } 260 | for { 261 | conn, err := ln.Accept() 262 | if err != nil { 263 | // handle error 264 | } 265 | 266 | _, err = u.Upgrade(conn) 267 | if err != nil { 268 | // handle error 269 | } 270 | } 271 | } 272 | ``` 273 | 274 | Usage of `ws.Upgrader` here brings ability to control incoming connections on 275 | tcp level and simply not to accept them by some logic. 276 | 277 | Zero-copy upgrade is for high-load services which have to control many 278 | resources such as connections buffers. 279 | 280 | The real life example could be like this: 281 | 282 | ```go 283 | package main 284 | 285 | import ( 286 | "fmt" 287 | "io" 288 | "log" 289 | "net" 290 | "net/http" 291 | "runtime" 292 | 293 | "github.com/gobwas/httphead" 294 | "github.com/gobwas/ws" 295 | ) 296 | 297 | func main() { 298 | ln, err := net.Listen("tcp", "localhost:8080") 299 | if err != nil { 300 | // handle error 301 | } 302 | 303 | // Prepare handshake header writer from http.Header mapping. 304 | header := ws.HandshakeHeaderHTTP(http.Header{ 305 | "X-Go-Version": []string{runtime.Version()}, 306 | }) 307 | 308 | u := ws.Upgrader{ 309 | OnHost: func(host []byte) error { 310 | if string(host) == "github.com" { 311 | return nil 312 | } 313 | return ws.RejectConnectionError( 314 | ws.RejectionStatus(403), 315 | ws.RejectionHeader(ws.HandshakeHeaderString( 316 | "X-Want-Host: github.com\r\n", 317 | )), 318 | ) 319 | }, 320 | OnHeader: func(key, value []byte) error { 321 | if string(key) != "Cookie" { 322 | return nil 323 | } 324 | ok := httphead.ScanCookie(value, func(key, value []byte) bool { 325 | // Check session here or do some other stuff with cookies. 326 | // Maybe copy some values for future use. 327 | return true 328 | }) 329 | if ok { 330 | return nil 331 | } 332 | return ws.RejectConnectionError( 333 | ws.RejectionReason("bad cookie"), 334 | ws.RejectionStatus(400), 335 | ) 336 | }, 337 | OnBeforeUpgrade: func() (ws.HandshakeHeader, error) { 338 | return header, nil 339 | }, 340 | } 341 | for { 342 | conn, err := ln.Accept() 343 | if err != nil { 344 | log.Fatal(err) 345 | } 346 | _, err = u.Upgrade(conn) 347 | if err != nil { 348 | log.Printf("upgrade error: %s", err) 349 | } 350 | } 351 | } 352 | ``` 353 | 354 | # Compression 355 | 356 | There is a `ws/wsflate` package to support [Permessage-Deflate Compression 357 | Extension][rfc-pmce]. 358 | 359 | It provides minimalistic I/O wrappers to be used in conjunction with any 360 | deflate implementation (for example, the standard library's 361 | [compress/flate][compress/flate]). 362 | 363 | It is also compatible with `wsutil`'s reader and writer by providing 364 | `wsflate.MessageState` type, which implements `wsutil.SendExtension` and 365 | `wsutil.RecvExtension` interfaces. 366 | 367 | ```go 368 | package main 369 | 370 | import ( 371 | "bytes" 372 | "log" 373 | "net" 374 | 375 | "github.com/gobwas/ws" 376 | "github.com/gobwas/ws/wsflate" 377 | ) 378 | 379 | func main() { 380 | ln, err := net.Listen("tcp", "localhost:8080") 381 | if err != nil { 382 | // handle error 383 | } 384 | e := wsflate.Extension{ 385 | // We are using default parameters here since we use 386 | // wsflate.{Compress,Decompress}Frame helpers below in the code. 387 | // This assumes that we use standard compress/flate package as flate 388 | // implementation. 389 | Parameters: wsflate.DefaultParameters, 390 | } 391 | u := ws.Upgrader{ 392 | Negotiate: e.Negotiate, 393 | } 394 | for { 395 | conn, err := ln.Accept() 396 | if err != nil { 397 | log.Fatal(err) 398 | } 399 | 400 | // Reset extension after previous upgrades. 401 | e.Reset() 402 | 403 | _, err = u.Upgrade(conn) 404 | if err != nil { 405 | log.Printf("upgrade error: %s", err) 406 | continue 407 | } 408 | if _, ok := e.Accepted(); !ok { 409 | log.Printf("didn't negotiate compression for %s", conn.RemoteAddr()) 410 | conn.Close() 411 | continue 412 | } 413 | 414 | go func() { 415 | defer conn.Close() 416 | for { 417 | frame, err := ws.ReadFrame(conn) 418 | if err != nil { 419 | // Handle error. 420 | return 421 | } 422 | 423 | frame = ws.UnmaskFrameInPlace(frame) 424 | 425 | if wsflate.IsCompressed(frame.Header) { 426 | // Note that even after successful negotiation of 427 | // compression extension, both sides are able to send 428 | // non-compressed messages. 429 | frame, err = wsflate.DecompressFrame(frame) 430 | if err != nil { 431 | // Handle error. 432 | return 433 | } 434 | } 435 | 436 | // Do something with frame... 437 | 438 | ack := ws.NewTextFrame([]byte("this is an acknowledgement")) 439 | 440 | // Compress response unconditionally. 441 | ack, err = wsflate.CompressFrame(ack) 442 | if err != nil { 443 | // Handle error. 444 | return 445 | } 446 | if err = ws.WriteFrame(conn, ack); err != nil { 447 | // Handle error. 448 | return 449 | } 450 | } 451 | }() 452 | } 453 | } 454 | ``` 455 | 456 | You can use compression with `wsutil` package this way: 457 | 458 | ```go 459 | // Upgrade somehow and negotiate compression to get the conn... 460 | 461 | // Initialize flate reader. We are using nil as a source io.Reader because 462 | // we will Reset() it in the message i/o loop below. 463 | fr := wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor { 464 | return flate.NewReader(r) 465 | }) 466 | // Initialize flate writer. We are using nil as a destination io.Writer 467 | // because we will Reset() it in the message i/o loop below. 468 | fw := wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor { 469 | f, _ := flate.NewWriter(w, 9) 470 | return f 471 | }) 472 | 473 | // Declare compression message state variable. 474 | // 475 | // It has two goals: 476 | // - Allow users to check whether received message is compressed or not. 477 | // - Help wsutil.Reader and wsutil.Writer to set/unset appropriate 478 | // WebSocket header bits while writing next frame to the wire (it 479 | // implements wsutil.RecvExtension and wsutil.SendExtension). 480 | var msg wsflate.MessageState 481 | 482 | // Initialize WebSocket reader as previously. 483 | // Please note the use of Reader.Extensions field as well as 484 | // of ws.StateExtended flag. 485 | rd := &wsutil.Reader{ 486 | Source: conn, 487 | State: ws.StateServerSide | ws.StateExtended, 488 | Extensions: []wsutil.RecvExtension{ 489 | &msg, 490 | }, 491 | } 492 | 493 | // Initialize WebSocket writer with ws.StateExtended flag as well. 494 | wr := wsutil.NewWriter(conn, ws.StateServerSide|ws.StateExtended, 0) 495 | // Use the message state as wsutil.SendExtension. 496 | wr.SetExtensions(&msg) 497 | 498 | for { 499 | h, err := rd.NextFrame() 500 | if err != nil { 501 | // handle error. 502 | } 503 | if h.OpCode.IsControl() { 504 | // handle control frame. 505 | } 506 | if !msg.IsCompressed() { 507 | // handle uncompressed frame (skipped for the sake of example 508 | // simplicity). 509 | } 510 | 511 | // Reset the writer to echo same op code. 512 | wr.Reset(h.OpCode) 513 | 514 | // Reset both flate reader and writer to start the new round of i/o. 515 | fr.Reset(rd) 516 | fw.Reset(wr) 517 | 518 | // Copy whole message from reader to writer decompressing it and 519 | // compressing again. 520 | if _, err := io.Copy(fw, fr); err != nil { 521 | // handle error. 522 | } 523 | // Flush any remaining buffers from flate writer to WebSocket writer. 524 | if err := fw.Close(); err != nil { 525 | // handle error. 526 | } 527 | // Flush the whole WebSocket message to the wire. 528 | if err := wr.Flush(); err != nil { 529 | // handle error. 530 | } 531 | } 532 | ``` 533 | 534 | 535 | [rfc-url]: https://tools.ietf.org/html/rfc6455 536 | [rfc-pmce]: https://tools.ietf.org/html/rfc7692#section-7 537 | [godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg 538 | [godoc-url]: https://godoc.org/github.com/gobwas/ws 539 | [compress/flate]: https://golang.org/pkg/compress/flate/ 540 | [ci-badge]: https://github.com/gobwas/ws/workflows/CI/badge.svg 541 | [ci-url]: https://github.com/gobwas/ws/actions?query=workflow%3ACI 542 | -------------------------------------------------------------------------------- /autobahn/.gitignore: -------------------------------------------------------------------------------- 1 | report/ 2 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "/report", 3 | "servers": [ 4 | { 5 | "agent": "ws", 6 | "url": "ws://ws-server:9001/ws" 7 | }, 8 | { 9 | "agent": "wsutil", 10 | "url": "ws://ws-server:9001/wsutil" 11 | }, 12 | { 13 | "agent": "helpers/low", 14 | "url": "ws://ws-server:9001/helpers/low" 15 | }, 16 | { 17 | "agent": "helpers/high", 18 | "url": "ws://ws-server:9001/helpers/high" 19 | }, 20 | { 21 | "agent": "wsflate", 22 | "url": "ws://ws-server:9001/wsflate" 23 | } 24 | ], 25 | "cases": ["*"], 26 | "exclude-cases": [], 27 | "exclude-agent-cases": { 28 | "ws": [ 29 | "12.*", "13.*" 30 | ], 31 | "wsutil": [ 32 | "12.*", "13.*" 33 | ], 34 | "helpers/low": [ 35 | "12.*", "13.*" 36 | ], 37 | "helpers/high": [ 38 | "12.*", "13.*" 39 | ], 40 | "wsflate": [ 41 | "1.*","2.*","3.*", "4.*", 42 | "5.*","6.*","7.*", "8.*", 43 | "9.*","10.*","11.*" 44 | ] 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /autobahn/docker/autobahn/Dockerfile: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/crossbario/autobahn-testsuite/blob/09cfbf74b0c8e335c6fc7df88e5c88349ca66879/docker/Dockerfile 2 | 3 | FROM pypy:2-slim 4 | 5 | # make "pypy" available as "python" 6 | RUN ln -s /usr/local/bin/pypy /usr/local/bin/python 7 | 8 | # We need this to fix pip & cryptography 9 | RUN apt-get update && apt-get install -y build-essential libssl-dev 10 | 11 | # install Autobahn|Testsuite 12 | RUN pip install -U pip typing && \ 13 | pip install autobahntestsuite=='0.8.2' 14 | 15 | VOLUME /config 16 | VOLUME /report 17 | 18 | WORKDIR / 19 | EXPOSE 9001 9001 20 | 21 | CMD ["wstest", "--mode", "fuzzingclient", "--spec", "/config/fuzzingclient.json"] 22 | -------------------------------------------------------------------------------- /autobahn/docker/server/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.20.2-alpine3.17 2 | 3 | WORKDIR /go/src/github.com/gobwas/ws 4 | 5 | COPY go.mod . 6 | COPY go.sum . 7 | RUN go mod download 8 | 9 | COPY . . 10 | ENV CGO_ENABLED=0 11 | RUN go test -c -tags autobahn -coverpkg "github.com/gobwas/ws/..." github.com/gobwas/ws/example/autobahn 12 | 13 | CMD ["./autobahn.test", "-test.coverprofile", "/report/server.coverage"] 14 | -------------------------------------------------------------------------------- /autobahn/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "html/template" 8 | "log" 9 | "net/http" 10 | "os" 11 | "path" 12 | "sort" 13 | "strconv" 14 | "strings" 15 | "text/tabwriter" 16 | ) 17 | 18 | var ( 19 | verbose = flag.Bool("verbose", false, "be verbose") 20 | web = flag.String("http", "", "open web browser instead") 21 | ) 22 | 23 | const ( 24 | statusOK = "OK" 25 | statusInformational = "INFORMATIONAL" 26 | statusUnimplemented = "UNIMPLEMENTED" 27 | statusNonStrict = "NON-STRICT" 28 | statusUnclean = "UNCLEAN" 29 | statusFailed = "FAILED" 30 | ) 31 | 32 | func failing(behavior string) bool { 33 | switch behavior { 34 | case statusUnclean, statusFailed, statusNonStrict: 35 | return true 36 | default: 37 | return false 38 | } 39 | } 40 | 41 | type statusCounter struct { 42 | Total int 43 | OK int 44 | Informational int 45 | Unimplemented int 46 | NonStrict int 47 | Unclean int 48 | Failed int 49 | } 50 | 51 | func (c *statusCounter) Inc(s string) { 52 | c.Total++ 53 | switch s { 54 | case statusOK: 55 | c.OK++ 56 | case statusInformational: 57 | c.Informational++ 58 | case statusNonStrict: 59 | c.NonStrict++ 60 | case statusUnimplemented: 61 | c.Unimplemented++ 62 | case statusUnclean: 63 | c.Unclean++ 64 | case statusFailed: 65 | c.Failed++ 66 | default: 67 | panic(fmt.Sprintf("unexpected status %q", s)) 68 | } 69 | } 70 | 71 | func main() { 72 | log.SetFlags(0) 73 | flag.Parse() 74 | 75 | if flag.NArg() < 1 { 76 | log.Fatalf("Usage: %s [options] ", os.Args[0]) 77 | } 78 | 79 | base := path.Dir(flag.Arg(0)) 80 | 81 | if addr := *web; addr != "" { 82 | http.HandleFunc("/", handlerIndex()) 83 | http.Handle("/report/", http.StripPrefix("/report/", 84 | http.FileServer(http.Dir(base)), 85 | )) 86 | log.Fatal(http.ListenAndServe(addr, nil)) 87 | return 88 | } 89 | 90 | var report report 91 | if err := decodeFile(os.Args[1], &report); err != nil { 92 | log.Fatal(err) 93 | } 94 | 95 | servers := make([]string, 0, len(report)) 96 | for s := range report { 97 | servers = append(servers, s) 98 | } 99 | sort.Strings(servers) 100 | 101 | var failed bool 102 | tw := tabwriter.NewWriter(os.Stderr, 0, 4, 1, ' ', 0) 103 | for _, server := range servers { 104 | var ( 105 | srvFailed bool 106 | hdrWritten bool 107 | counter statusCounter 108 | ) 109 | 110 | var cases []string 111 | for id := range report[server] { 112 | cases = append(cases, id) 113 | } 114 | sortBySegment(cases) 115 | for _, id := range cases { 116 | c := report[server][id] 117 | 118 | var r entryReport 119 | err := decodeFile(path.Join(base, c.ReportFile), &r) 120 | if err != nil { 121 | log.Fatal(err) 122 | } 123 | counter.Inc(c.Behavior) 124 | bad := failing(c.Behavior) 125 | if bad { 126 | srvFailed = true 127 | failed = true 128 | } 129 | if *verbose || bad { 130 | if !hdrWritten { 131 | hdrWritten = true 132 | n, _ := fmt.Fprintf(os.Stderr, "AGENT %q\n", server) 133 | fmt.Fprintf(tw, "%s\n", strings.Repeat("=", n-1)) 134 | } 135 | fmt.Fprintf(tw, "%s\t%s\t%s\n", server, id, c.Behavior) 136 | } 137 | if bad { 138 | fmt.Fprintf(tw, "\tdesc:\t%s\n", r.Description) 139 | fmt.Fprintf(tw, "\texp: \t%s\n", r.Expectation) 140 | fmt.Fprintf(tw, "\tact: \t%s\n", r.Result) 141 | } 142 | } 143 | if hdrWritten { 144 | fmt.Fprint(tw, "\n") 145 | } 146 | var status string 147 | if srvFailed { 148 | status = statusFailed 149 | } else { 150 | status = statusOK 151 | } 152 | n, _ := fmt.Fprintf(tw, "AGENT %q SUMMARY (%s)\n", server, status) 153 | fmt.Fprintf(tw, "%s\n", strings.Repeat("=", n-1)) 154 | 155 | fmt.Fprintf(tw, "TOTAL:\t%d\n", counter.Total) 156 | fmt.Fprintf(tw, "%s:\t%d\n", statusOK, counter.OK) 157 | fmt.Fprintf(tw, "%s:\t%d\n", statusInformational, counter.Informational) 158 | fmt.Fprintf(tw, "%s:\t%d\n", statusUnimplemented, counter.Unimplemented) 159 | fmt.Fprintf(tw, "%s:\t%d\n", statusNonStrict, counter.NonStrict) 160 | fmt.Fprintf(tw, "%s:\t%d\n", statusUnclean, counter.Unclean) 161 | fmt.Fprintf(tw, "%s:\t%d\n", statusFailed, counter.Failed) 162 | fmt.Fprint(tw, "\n") 163 | tw.Flush() 164 | } 165 | var rc int 166 | if failed { 167 | rc = 1 168 | fmt.Fprintf(tw, "\n\nTEST %s\n\n", statusFailed) 169 | } else { 170 | fmt.Fprintf(tw, "\n\nTEST %s\n\n", statusOK) 171 | } 172 | 173 | tw.Flush() 174 | os.Exit(rc) 175 | } 176 | 177 | type report map[string]server 178 | 179 | type server map[string]entry 180 | 181 | type entry struct { 182 | Behavior string `json:"behavior"` 183 | BehaviorClose string `json:"behaviorClose"` 184 | Duration int `json:"duration"` 185 | RemoveCloseCode int `json:"removeCloseCode"` 186 | ReportFile string `json:"reportFile"` 187 | } 188 | 189 | type entryReport struct { 190 | Description string `json:"description"` 191 | Expectation string `json:"expectation"` 192 | Result string `json:"result"` 193 | Duration int `json:"duration"` 194 | } 195 | 196 | func decodeFile(path string, x interface{}) error { 197 | f, err := os.Open(path) 198 | if err != nil { 199 | return err 200 | } 201 | defer f.Close() 202 | 203 | d := json.NewDecoder(f) 204 | return d.Decode(x) 205 | } 206 | 207 | func compareBySegment(a, b string) int { 208 | as := strings.Split(a, ".") 209 | bs := strings.Split(b, ".") 210 | for i := 0; i < min(len(as), len(bs)); i++ { 211 | ax := mustInt(as[i]) 212 | bx := mustInt(bs[i]) 213 | if ax == bx { 214 | continue 215 | } 216 | return ax - bx 217 | } 218 | return len(b) - len(a) 219 | } 220 | 221 | func mustInt(s string) int { 222 | const bits = 32 << (^uint(0) >> 63) 223 | x, err := strconv.ParseInt(s, 10, bits) 224 | if err != nil { 225 | panic(err) 226 | } 227 | return int(x) 228 | } 229 | 230 | func min(a, b int) int { 231 | if a < b { 232 | return a 233 | } 234 | return b 235 | } 236 | 237 | func handlerIndex() func(w http.ResponseWriter, r *http.Request) { 238 | return func(w http.ResponseWriter, r *http.Request) { 239 | if *verbose { 240 | log.Printf("request to %s", r.URL) 241 | } 242 | if r.URL.Path != "/" { 243 | w.WriteHeader(http.StatusNotFound) 244 | return 245 | } 246 | if err := index.Execute(w, nil); err != nil { 247 | w.WriteHeader(http.StatusInternalServerError) 248 | log.Fatal(err) 249 | return 250 | } 251 | } 252 | } 253 | 254 | var index = template.Must(template.New("").Parse(` 255 | 256 | 257 |

Welcome to WebSocket test server!

258 |

Ready to Autobahn!

259 | Reports 260 | 261 | 262 | `)) 263 | -------------------------------------------------------------------------------- /autobahn/main_go17.go: -------------------------------------------------------------------------------- 1 | // +build !go1.8 2 | 3 | package main 4 | 5 | import "sort" 6 | 7 | func sortBySegment(s []string) { 8 | sort.Sort(segmentSorter(s)) 9 | } 10 | 11 | type segmentSorter []string 12 | 13 | func (s segmentSorter) Less(i, j int) bool { 14 | return compareBySegment(s[i], s[j]) < 0 15 | } 16 | 17 | func (s segmentSorter) Len() int { 18 | return len(s) 19 | } 20 | 21 | func (s segmentSorter) Swap(i, j int) { 22 | s[i], s[j] = s[j], s[i] 23 | } 24 | -------------------------------------------------------------------------------- /autobahn/main_go18.go: -------------------------------------------------------------------------------- 1 | //go:build go1.8 2 | // +build go1.8 3 | 4 | package main 5 | 6 | import "sort" 7 | 8 | func sortBySegment(s []string) { 9 | sort.Slice(s, func(i, j int) bool { 10 | return compareBySegment(s[i], s[j]) < 0 11 | }) 12 | } 13 | -------------------------------------------------------------------------------- /autobahn/script/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FOLLOW_LOGS=0 4 | 5 | while [[ $# -gt 0 ]]; do 6 | key="$1" 7 | case $key in 8 | --network) 9 | NETWORK="$2" 10 | shift 11 | ;; 12 | 13 | --build) 14 | case "$2" in 15 | autobahn) 16 | docker build . --file autobahn/docker/autobahn/Dockerfile --tag ws-autobahn 17 | shift 18 | ;; 19 | server) 20 | docker build . --file autobahn/docker/server/Dockerfile --tag ws-server 21 | shift 22 | ;; 23 | *) 24 | docker build . --file autobahn/docker/autobahn/Dockerfile --tag ws-autobahn 25 | docker build . --file autobahn/docker/server/Dockerfile --tag ws-server 26 | ;; 27 | esac 28 | ;; 29 | 30 | --run) 31 | docker run \ 32 | --interactive \ 33 | --tty \ 34 | ${@:2} 35 | exit $? 36 | ;; 37 | 38 | --follow-logs) 39 | FOLLOW_LOGS=1 40 | shift 41 | ;; 42 | esac 43 | shift 44 | done 45 | 46 | with_prefix() { 47 | local p="$1" 48 | shift 49 | 50 | local out=$(mktemp -u ws.fifo.out.XXXX) 51 | local err=$(mktemp -u ws.fifo.err.XXXX) 52 | mkfifo $out $err 53 | if [ $? -ne 0 ]; then 54 | exit 1 55 | fi 56 | 57 | # Start two background sed processes. 58 | sed "s/^/$p/" <$out & 59 | sed "s/^/$p/" <$err >&2 & 60 | 61 | # Run the program 62 | "$@" >$out 2>$err 63 | rm $out $err 64 | } 65 | 66 | random=$(xxd -l 4 -p /dev/random) 67 | server="${random}_ws-server" 68 | autobahn="${random}_ws-autobahn" 69 | 70 | network="ws-network-$random" 71 | docker network create --driver bridge "$network" 72 | if [ $? -ne 0 ]; then 73 | exit 1 74 | fi 75 | 76 | docker run \ 77 | --interactive \ 78 | --tty \ 79 | --detach \ 80 | --network="$network" \ 81 | --network-alias="ws-server" \ 82 | -v $(pwd)/autobahn/report:/report \ 83 | --name="$server" \ 84 | "ws-server" 85 | 86 | docker run \ 87 | --interactive \ 88 | --tty \ 89 | --detach \ 90 | --network="$network" \ 91 | -v $(pwd)/autobahn/config:/config \ 92 | -v $(pwd)/autobahn/report:/report \ 93 | --name="$autobahn" \ 94 | "ws-autobahn" 95 | 96 | 97 | if [[ $FOLLOW_LOGS -eq 1 ]]; then 98 | (with_prefix "$(tput setaf 3)[ws-autobahn]: $(tput sgr0)" docker logs --follow "$autobahn")& 99 | (with_prefix "$(tput setaf 5)[ws-server]: $(tput sgr0)" docker logs --follow "$server")& 100 | fi 101 | 102 | trap ctrl_c INT 103 | ctrl_c () { 104 | echo "SIGINT received; cleaning up" 105 | docker kill --signal INT "$autobahn" >/dev/null 106 | docker kill --signal INT "$server" >/dev/null 107 | cleanup 108 | exit 130 109 | } 110 | 111 | cleanup() { 112 | docker rm "$server" >/dev/null 113 | docker rm "$autobahn" >/dev/null 114 | docker network rm "$network" 115 | } 116 | 117 | docker wait "$autobahn" >/dev/null 118 | docker stop "$server" >/dev/null 119 | 120 | cleanup 121 | -------------------------------------------------------------------------------- /check.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import "unicode/utf8" 4 | 5 | // State represents state of websocket endpoint. 6 | // It used by some functions to be more strict when checking compatibility with RFC6455. 7 | type State uint8 8 | 9 | const ( 10 | // StateServerSide means that endpoint (caller) is a server. 11 | StateServerSide State = 0x1 << iota 12 | // StateClientSide means that endpoint (caller) is a client. 13 | StateClientSide 14 | // StateExtended means that extension was negotiated during handshake. 15 | StateExtended 16 | // StateFragmented means that endpoint (caller) has received fragmented 17 | // frame and waits for continuation parts. 18 | StateFragmented 19 | ) 20 | 21 | // Is checks whether the s has v enabled. 22 | func (s State) Is(v State) bool { 23 | return uint8(s)&uint8(v) != 0 24 | } 25 | 26 | // Set enables v state on s. 27 | func (s State) Set(v State) State { 28 | return s | v 29 | } 30 | 31 | // Clear disables v state on s. 32 | func (s State) Clear(v State) State { 33 | return s & (^v) 34 | } 35 | 36 | // ServerSide reports whether states represents server side. 37 | func (s State) ServerSide() bool { return s.Is(StateServerSide) } 38 | 39 | // ClientSide reports whether state represents client side. 40 | func (s State) ClientSide() bool { return s.Is(StateClientSide) } 41 | 42 | // Extended reports whether state is extended. 43 | func (s State) Extended() bool { return s.Is(StateExtended) } 44 | 45 | // Fragmented reports whether state is fragmented. 46 | func (s State) Fragmented() bool { return s.Is(StateFragmented) } 47 | 48 | // ProtocolError describes error during checking/parsing websocket frames or 49 | // headers. 50 | type ProtocolError string 51 | 52 | // Error implements error interface. 53 | func (p ProtocolError) Error() string { return string(p) } 54 | 55 | // Errors used by the protocol checkers. 56 | var ( 57 | ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code") 58 | ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded") 59 | ErrProtocolControlNotFinal = ProtocolError("control frame is not final") 60 | ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated") 61 | ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked") 62 | ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked") 63 | ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame") 64 | ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame") 65 | ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use") 66 | ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level") 67 | ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet") 68 | ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec") 69 | ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason") 70 | ) 71 | 72 | // CheckHeader checks h to contain valid header data for given state s. 73 | // 74 | // Note that zero state (0) means that state is clean, 75 | // neither server or client side, nor fragmented, nor extended. 76 | func CheckHeader(h Header, s State) error { 77 | if h.OpCode.IsReserved() { 78 | return ErrProtocolOpCodeReserved 79 | } 80 | if h.OpCode.IsControl() { 81 | if h.Length > MaxControlFramePayloadSize { 82 | return ErrProtocolControlPayloadOverflow 83 | } 84 | if !h.Fin { 85 | return ErrProtocolControlNotFinal 86 | } 87 | } 88 | 89 | switch { 90 | // [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for 91 | // non-zero values. If a nonzero value is received and none of the 92 | // negotiated extensions defines the meaning of such a nonzero value, the 93 | // receiving endpoint MUST _Fail the WebSocket Connection_. 94 | case h.Rsv != 0 && !s.Extended(): 95 | return ErrProtocolNonZeroRsv 96 | 97 | // [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked. 98 | // In this case, a server MAY send a Close frame with a status code of 1002 (protocol error) 99 | // as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client. 100 | // A client MUST close a connection if it detects a masked frame. In this case, it MAY use the 101 | // status code 1002 (protocol error) as defined in Section 7.4.1. 102 | case s.ServerSide() && !h.Masked: 103 | return ErrProtocolMaskRequired 104 | case s.ClientSide() && h.Masked: 105 | return ErrProtocolMaskUnexpected 106 | 107 | // [RFC6455]: See detailed explanation in 5.4 section. 108 | case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation: 109 | return ErrProtocolContinuationExpected 110 | case !s.Fragmented() && h.OpCode == OpContinuation: 111 | return ErrProtocolContinuationUnexpected 112 | 113 | default: 114 | return nil 115 | } 116 | } 117 | 118 | // CheckCloseFrameData checks received close information 119 | // to be valid RFC6455 compatible close info. 120 | // 121 | // Note that code.Empty() or code.IsAppLevel() will raise error. 122 | // 123 | // If endpoint sends close frame without status code (with frame.Length = 0), 124 | // application should not check its payload. 125 | func CheckCloseFrameData(code StatusCode, reason string) error { 126 | switch { 127 | case code.IsNotUsed(): 128 | return ErrProtocolStatusCodeNotInUse 129 | 130 | case code.IsProtocolReserved(): 131 | return ErrProtocolStatusCodeApplicationLevel 132 | 133 | case code == StatusNoMeaningYet: 134 | return ErrProtocolStatusCodeNoMeaning 135 | 136 | case code.IsProtocolSpec() && !code.IsProtocolDefined(): 137 | return ErrProtocolStatusCodeUnknown 138 | 139 | case !utf8.ValidString(reason): 140 | return ErrProtocolInvalidUTF8 141 | 142 | default: 143 | return nil 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /cipher.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "encoding/binary" 5 | ) 6 | 7 | // Cipher applies XOR cipher to the payload using mask. 8 | // Offset is used to cipher chunked data (e.g. in io.Reader implementations). 9 | // 10 | // To convert masked data into unmasked data, or vice versa, the following 11 | // algorithm is applied. The same algorithm applies regardless of the 12 | // direction of the translation, e.g., the same steps are applied to 13 | // mask the data as to unmask the data. 14 | func Cipher(payload []byte, mask [4]byte, offset int) { 15 | n := len(payload) 16 | if n < 8 { 17 | for i := 0; i < n; i++ { 18 | payload[i] ^= mask[(offset+i)%4] 19 | } 20 | return 21 | } 22 | 23 | // Calculate position in mask due to previously processed bytes number. 24 | mpos := offset % 4 25 | // Count number of bytes will processed one by one from the beginning of payload. 26 | ln := remain[mpos] 27 | // Count number of bytes will processed one by one from the end of payload. 28 | // This is done to process payload by 16 bytes in each iteration of main loop. 29 | rn := (n - ln) % 16 30 | 31 | for i := 0; i < ln; i++ { 32 | payload[i] ^= mask[(mpos+i)%4] 33 | } 34 | for i := n - rn; i < n; i++ { 35 | payload[i] ^= mask[(mpos+i)%4] 36 | } 37 | 38 | // NOTE: we use here binary.LittleEndian regardless of what is real 39 | // endianness on machine is. To do so, we have to use binary.LittleEndian in 40 | // the masking loop below as well. 41 | var ( 42 | m = binary.LittleEndian.Uint32(mask[:]) 43 | m2 = uint64(m)<<32 | uint64(m) 44 | ) 45 | // Skip already processed right part. 46 | // Get number of uint64 parts remaining to process. 47 | n = (n - ln - rn) >> 4 48 | j := ln 49 | for i := 0; i < n; i++ { 50 | chunk := payload[j : j+16] 51 | p := binary.LittleEndian.Uint64(chunk) ^ m2 52 | p2 := binary.LittleEndian.Uint64(chunk[8:]) ^ m2 53 | binary.LittleEndian.PutUint64(chunk, p) 54 | binary.LittleEndian.PutUint64(chunk[8:], p2) 55 | j += 16 56 | } 57 | } 58 | 59 | // remain maps position in masking key [0,4) to number 60 | // of bytes that need to be processed manually inside Cipher(). 61 | var remain = [4]int{0, 3, 2, 1} 62 | -------------------------------------------------------------------------------- /cipher_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestCipher(t *testing.T) { 11 | type test struct { 12 | name string 13 | in []byte 14 | mask [4]byte 15 | offset int 16 | } 17 | cases := []test{ 18 | { 19 | name: "simple", 20 | in: []byte("Hello, XOR!"), 21 | mask: [4]byte{1, 2, 3, 4}, 22 | }, 23 | { 24 | name: "simple", 25 | in: []byte("Hello, XOR!"), 26 | mask: [4]byte{255, 255, 255, 255}, 27 | }, 28 | } 29 | for offset := 0; offset < 4; offset++ { 30 | for tail := 0; tail < 8; tail++ { 31 | for b64 := 0; b64 < 3; b64++ { 32 | var ( 33 | ln = remain[offset] 34 | rn = tail 35 | n = b64*8 + ln + rn 36 | ) 37 | 38 | p := make([]byte, n) 39 | rand.Read(p) 40 | 41 | var m [4]byte 42 | rand.Read(m[:]) 43 | 44 | cases = append(cases, test{ 45 | in: p, 46 | mask: m, 47 | offset: offset, 48 | }) 49 | } 50 | } 51 | } 52 | for _, test := range cases { 53 | t.Run(test.name, func(t *testing.T) { 54 | // naive implementation of xor-cipher 55 | exp := cipherNaive(test.in, test.mask, test.offset) 56 | 57 | res := make([]byte, len(test.in)) 58 | copy(res, test.in) 59 | Cipher(res, test.mask, test.offset) 60 | 61 | if !reflect.DeepEqual(res, exp) { 62 | t.Errorf("Cipher(%v, %v):\nact:\t%v\nexp:\t%v\n", test.in, test.mask, res, exp) 63 | } 64 | }) 65 | } 66 | } 67 | 68 | func TestCipherChops(t *testing.T) { 69 | for n := 2; n <= 1024; n <<= 1 { 70 | t.Run(fmt.Sprintf("%d", n), func(t *testing.T) { 71 | p := make([]byte, n) 72 | b := make([]byte, n) 73 | var m [4]byte 74 | 75 | _, err := rand.Read(p) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | _, err = rand.Read(m[:]) 80 | if err != nil { 81 | t.Fatal(err) 82 | } 83 | 84 | exp := cipherNaive(p, m, 0) 85 | 86 | for i := 1; i <= n; i <<= 1 { 87 | copy(b, p) 88 | s := n / i 89 | 90 | for j := s; j <= n; j += s { 91 | l, r := j-s, j 92 | Cipher(b[l:r], m, l) 93 | if !reflect.DeepEqual(b[l:r], exp[l:r]) { 94 | t.Fatalf("unexpected Cipher([%d:%d]) = %x; want %x", l, r, b[l:r], exp[l:r]) 95 | } 96 | } 97 | } 98 | 99 | l := 0 100 | copy(b, p) 101 | for l < n { 102 | r := rand.Intn(n-l) + l + 1 103 | Cipher(b[l:r], m, l) 104 | if !reflect.DeepEqual(b[l:r], exp[l:r]) { 105 | t.Fatalf("unexpected Cipher([%d:%d]):\nact:\t%v\nexp:\t%v\nact:\t%#x\nexp:\t%#x\n\n", l, r, b[l:r], exp[l:r], b[l:r], exp[l:r]) 106 | } 107 | l = r 108 | } 109 | }) 110 | } 111 | } 112 | 113 | func cipherNaive(p []byte, m [4]byte, pos int) []byte { 114 | r := make([]byte, len(p)) 115 | copy(r, p) 116 | cipherNaiveNoCp(r, m, pos) 117 | return r 118 | } 119 | 120 | func cipherNaiveNoCp(p []byte, m [4]byte, pos int) []byte { 121 | for i := 0; i < len(p); i++ { 122 | p[i] ^= m[(pos+i)%4] 123 | } 124 | return p 125 | } 126 | 127 | func BenchmarkCipher(b *testing.B) { 128 | for _, bench := range []struct { 129 | size int 130 | offset int 131 | }{ 132 | { 133 | size: 7, 134 | offset: 1, 135 | }, 136 | { 137 | size: 125, 138 | }, 139 | { 140 | size: 1024, 141 | }, 142 | { 143 | size: 4096, 144 | }, 145 | { 146 | size: 4100, 147 | offset: 4, 148 | }, 149 | { 150 | size: 4099, 151 | offset: 3, 152 | }, 153 | { 154 | size: (1 << 15) + 7, 155 | offset: 49, 156 | }, 157 | } { 158 | bts := make([]byte, bench.size) 159 | _, err := rand.Read(bts) 160 | if err != nil { 161 | b.Fatal(err) 162 | } 163 | 164 | var mask [4]byte 165 | _, err = rand.Read(mask[:]) 166 | if err != nil { 167 | b.Fatal(err) 168 | } 169 | 170 | b.Run(fmt.Sprintf("naive_bytes=%d;offset=%d", bench.size, bench.offset), func(b *testing.B) { 171 | var sink int64 172 | b.SetBytes(int64(bench.size)) 173 | b.ResetTimer() 174 | for i := 0; i < b.N; i++ { 175 | r := cipherNaiveNoCp(bts, mask, bench.offset) 176 | sink += int64(len(r)) 177 | } 178 | sinkValue(sink) 179 | }) 180 | b.Run(fmt.Sprintf("bytes=%d;offset=%d", bench.size, bench.offset), func(b *testing.B) { 181 | var sink int64 182 | b.SetBytes(int64(bench.size)) 183 | b.ResetTimer() 184 | for i := 0; i < b.N; i++ { 185 | Cipher(bts, mask, bench.offset) 186 | sink += int64(len(bts)) 187 | } 188 | sinkValue(sink) 189 | }) 190 | } 191 | } 192 | 193 | // sinkValue makes variable used and prevents dead code elimination. 194 | func sinkValue(v int64) { 195 | if r := rand.Float32(); r > 2 { 196 | panic(fmt.Sprintf("impossible %g: %v", r, v)) 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /dialer_tls_go17.go: -------------------------------------------------------------------------------- 1 | // +build !go1.8 2 | 3 | package ws 4 | 5 | import "crypto/tls" 6 | 7 | func tlsCloneConfig(c *tls.Config) *tls.Config { 8 | // NOTE: we copying SessionTicketsDisabled and SessionTicketKey here 9 | // without calling inner c.initOnceServer somehow because we only could get 10 | // here from the ws.Dialer code, which is obviously a client and makes 11 | // tls.Client() when it gets new net.Conn. 12 | return &tls.Config{ 13 | Rand: c.Rand, 14 | Time: c.Time, 15 | Certificates: c.Certificates, 16 | NameToCertificate: c.NameToCertificate, 17 | GetCertificate: c.GetCertificate, 18 | RootCAs: c.RootCAs, 19 | NextProtos: c.NextProtos, 20 | ServerName: c.ServerName, 21 | ClientAuth: c.ClientAuth, 22 | ClientCAs: c.ClientCAs, 23 | InsecureSkipVerify: c.InsecureSkipVerify, 24 | CipherSuites: c.CipherSuites, 25 | PreferServerCipherSuites: c.PreferServerCipherSuites, 26 | SessionTicketsDisabled: c.SessionTicketsDisabled, 27 | SessionTicketKey: c.SessionTicketKey, 28 | ClientSessionCache: c.ClientSessionCache, 29 | MinVersion: c.MinVersion, 30 | MaxVersion: c.MaxVersion, 31 | CurvePreferences: c.CurvePreferences, 32 | DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, 33 | Renegotiation: c.Renegotiation, 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /dialer_tls_go18.go: -------------------------------------------------------------------------------- 1 | //go:build go1.8 2 | // +build go1.8 3 | 4 | package ws 5 | 6 | import "crypto/tls" 7 | 8 | func tlsCloneConfig(c *tls.Config) *tls.Config { 9 | return c.Clone() 10 | } 11 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package ws implements a client and server for the WebSocket protocol as 3 | specified in RFC 6455. 4 | 5 | The main purpose of this package is to provide simple low-level API for 6 | efficient work with protocol. 7 | 8 | Overview. 9 | 10 | Upgrade to WebSocket (or WebSocket handshake) can be done in two ways. 11 | 12 | The first way is to use `net/http` server: 13 | 14 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 15 | conn, _, _, err := ws.UpgradeHTTP(r, w) 16 | }) 17 | 18 | The second and much more efficient way is so-called "zero-copy upgrade". It 19 | avoids redundant allocations and copying of not used headers or other request 20 | data. User decides by himself which data should be copied. 21 | 22 | ln, err := net.Listen("tcp", ":8080") 23 | if err != nil { 24 | // handle error 25 | } 26 | 27 | conn, err := ln.Accept() 28 | if err != nil { 29 | // handle error 30 | } 31 | 32 | handshake, err := ws.Upgrade(conn) 33 | if err != nil { 34 | // handle error 35 | } 36 | 37 | For customization details see `ws.Upgrader` documentation. 38 | 39 | After WebSocket handshake you can work with connection in multiple ways. 40 | That is, `ws` does not force the only one way of how to work with WebSocket: 41 | 42 | header, err := ws.ReadHeader(conn) 43 | if err != nil { 44 | // handle err 45 | } 46 | 47 | buf := make([]byte, header.Length) 48 | _, err := io.ReadFull(conn, buf) 49 | if err != nil { 50 | // handle err 51 | } 52 | 53 | resp := ws.NewBinaryFrame([]byte("hello, world!")) 54 | if err := ws.WriteFrame(conn, frame); err != nil { 55 | // handle err 56 | } 57 | 58 | As you can see, it stream friendly: 59 | 60 | const N = 42 61 | 62 | ws.WriteHeader(ws.Header{ 63 | Fin: true, 64 | Length: N, 65 | OpCode: ws.OpBinary, 66 | }) 67 | 68 | io.CopyN(conn, rand.Reader, N) 69 | 70 | Or: 71 | 72 | header, err := ws.ReadHeader(conn) 73 | if err != nil { 74 | // handle err 75 | } 76 | 77 | io.CopyN(ioutil.Discard, conn, header.Length) 78 | 79 | For more info see the documentation. 80 | */ 81 | package ws 82 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | // RejectOption represents an option used to control the way connection is 4 | // rejected. 5 | type RejectOption func(*ConnectionRejectedError) 6 | 7 | // RejectionReason returns an option that makes connection to be rejected with 8 | // given reason. 9 | func RejectionReason(reason string) RejectOption { 10 | return func(err *ConnectionRejectedError) { 11 | err.reason = reason 12 | } 13 | } 14 | 15 | // RejectionStatus returns an option that makes connection to be rejected with 16 | // given HTTP status code. 17 | func RejectionStatus(code int) RejectOption { 18 | return func(err *ConnectionRejectedError) { 19 | err.code = code 20 | } 21 | } 22 | 23 | // RejectionHeader returns an option that makes connection to be rejected with 24 | // given HTTP headers. 25 | func RejectionHeader(h HandshakeHeader) RejectOption { 26 | return func(err *ConnectionRejectedError) { 27 | err.header = h 28 | } 29 | } 30 | 31 | // RejectConnectionError constructs an error that could be used to control the 32 | // way handshake is rejected by Upgrader. 33 | func RejectConnectionError(options ...RejectOption) error { 34 | err := new(ConnectionRejectedError) 35 | for _, opt := range options { 36 | opt(err) 37 | } 38 | return err 39 | } 40 | 41 | // ConnectionRejectedError represents a rejection of connection during 42 | // WebSocket handshake error. 43 | // 44 | // It can be returned by Upgrader's On* hooks to indicate that WebSocket 45 | // handshake should be rejected. 46 | type ConnectionRejectedError struct { 47 | reason string 48 | code int 49 | header HandshakeHeader 50 | } 51 | 52 | // Error implements error interface. 53 | func (r *ConnectionRejectedError) Error() string { 54 | return r.reason 55 | } 56 | 57 | func (r *ConnectionRejectedError) StatusCode() int { 58 | return r.code 59 | } 60 | -------------------------------------------------------------------------------- /example/autobahn/autobahn.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "compress/flate" 5 | "context" 6 | "flag" 7 | "fmt" 8 | "io" 9 | "io/ioutil" 10 | "log" 11 | "net" 12 | "net/http" 13 | "os" 14 | "os/signal" 15 | "syscall" 16 | "time" 17 | 18 | "github.com/gobwas/httphead" 19 | "github.com/gobwas/ws" 20 | "github.com/gobwas/ws/wsflate" 21 | "github.com/gobwas/ws/wsutil" 22 | ) 23 | 24 | var addr = flag.String("listen", ":9001", "addr to listen") 25 | 26 | func main() { 27 | log.SetFlags(0) 28 | flag.Parse() 29 | 30 | http.HandleFunc("/ws", wsHandler) 31 | http.HandleFunc("/wsutil", wsutilHandler) 32 | http.HandleFunc("/wsflate", wsflateHandler) 33 | http.HandleFunc("/helpers/low", helpersLowLevelHandler) 34 | http.HandleFunc("/helpers/high", helpersHighLevelHandler) 35 | 36 | ln, err := net.Listen("tcp", *addr) 37 | if err != nil { 38 | log.Fatalf("listen %q error: %v", *addr, err) 39 | } 40 | log.Printf("listening %s (%q)", ln.Addr(), *addr) 41 | 42 | var ( 43 | s = new(http.Server) 44 | serve = make(chan error, 1) 45 | sig = make(chan os.Signal, 1) 46 | ) 47 | signal.Notify(sig, syscall.SIGTERM) 48 | go func() { serve <- s.Serve(ln) }() 49 | 50 | select { 51 | case err := <-serve: 52 | log.Fatal(err) 53 | case sig := <-sig: 54 | const timeout = 5 * time.Second 55 | 56 | log.Printf("signal %q received; shutting down with %s timeout", sig, timeout) 57 | 58 | ctx, _ := context.WithTimeout(context.Background(), timeout) 59 | if err := s.Shutdown(ctx); err != nil { 60 | log.Fatal(err) 61 | } 62 | } 63 | } 64 | 65 | var ( 66 | closeInvalidPayload = ws.MustCompileFrame( 67 | ws.NewCloseFrame(ws.NewCloseFrameBody( 68 | ws.StatusInvalidFramePayloadData, "", 69 | )), 70 | ) 71 | closeProtocolError = ws.MustCompileFrame( 72 | ws.NewCloseFrame(ws.NewCloseFrameBody( 73 | ws.StatusProtocolError, "", 74 | )), 75 | ) 76 | ) 77 | 78 | func helpersHighLevelHandler(w http.ResponseWriter, r *http.Request) { 79 | conn, _, _, err := ws.UpgradeHTTP(r, w) 80 | if err != nil { 81 | log.Printf("upgrade error: %s", err) 82 | return 83 | } 84 | defer conn.Close() 85 | 86 | for { 87 | bts, op, err := wsutil.ReadClientData(conn) 88 | if err != nil { 89 | log.Printf("read message error: %v", err) 90 | return 91 | } 92 | err = wsutil.WriteServerMessage(conn, op, bts) 93 | if err != nil { 94 | log.Printf("write message error: %v", err) 95 | return 96 | } 97 | } 98 | } 99 | 100 | func helpersLowLevelHandler(w http.ResponseWriter, r *http.Request) { 101 | conn, _, _, err := ws.UpgradeHTTP(r, w) 102 | if err != nil { 103 | log.Printf("upgrade error: %s", err) 104 | return 105 | } 106 | defer conn.Close() 107 | 108 | msg := make([]wsutil.Message, 0, 4) 109 | 110 | for { 111 | msg, err = wsutil.ReadClientMessage(conn, msg[:0]) 112 | if err != nil { 113 | log.Printf("read message error: %v", err) 114 | return 115 | } 116 | for _, m := range msg { 117 | if m.OpCode.IsControl() { 118 | err := wsutil.HandleClientControlMessage(conn, m) 119 | if err != nil { 120 | log.Printf("handle control error: %v", err) 121 | return 122 | } 123 | continue 124 | } 125 | err := wsutil.WriteServerMessage(conn, m.OpCode, m.Payload) 126 | if err != nil { 127 | log.Printf("write message error: %v", err) 128 | return 129 | } 130 | } 131 | } 132 | } 133 | 134 | func wsutilHandler(res http.ResponseWriter, req *http.Request) { 135 | conn, _, _, err := ws.UpgradeHTTP(req, res) 136 | if err != nil { 137 | log.Printf("upgrade error: %s", err) 138 | return 139 | } 140 | defer conn.Close() 141 | 142 | state := ws.StateServerSide 143 | 144 | ch := wsutil.ControlFrameHandler(conn, state) 145 | r := &wsutil.Reader{ 146 | Source: conn, 147 | State: state, 148 | CheckUTF8: true, 149 | OnIntermediate: ch, 150 | } 151 | w := wsutil.NewWriter(conn, state, 0) 152 | 153 | for { 154 | h, err := r.NextFrame() 155 | if err != nil { 156 | log.Printf("next frame error: %v", err) 157 | return 158 | } 159 | if h.OpCode.IsControl() { 160 | if err = ch(h, r); err != nil { 161 | log.Printf("handle control error: %v", err) 162 | return 163 | } 164 | continue 165 | } 166 | 167 | w.Reset(conn, state, h.OpCode) 168 | 169 | if _, err = io.Copy(w, r); err == nil { 170 | err = w.Flush() 171 | } 172 | if err != nil { 173 | log.Printf("echo error: %s", err) 174 | return 175 | } 176 | } 177 | } 178 | 179 | func wsflateHandler(w http.ResponseWriter, r *http.Request) { 180 | e := wsflate.Extension{ 181 | Parameters: wsflate.Parameters{ 182 | ServerNoContextTakeover: true, 183 | ClientNoContextTakeover: true, 184 | }, 185 | } 186 | u := ws.HTTPUpgrader{ 187 | Negotiate: e.Negotiate, 188 | } 189 | conn, _, _, err := u.Upgrade(r, w) 190 | if err != nil { 191 | log.Printf("upgrade error: %s", err) 192 | return 193 | } 194 | defer conn.Close() 195 | 196 | if _, ok := e.Accepted(); !ok { 197 | log.Printf("no accepted extension") 198 | return 199 | } 200 | 201 | // Using nil as a destination io.Writer since we will Reset() it in the 202 | // loop below. 203 | fw := wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor { 204 | // As flat.NewWriter() docs says: 205 | // If level is in the range [-2, 9] then the error returned will 206 | // be nil. 207 | f, _ := flate.NewWriter(w, 9) 208 | return f 209 | }) 210 | // Using nil as a source io.Reader since we will Reset() it in the loop 211 | // below. 212 | fr := wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor { 213 | return flate.NewReader(r) 214 | }) 215 | 216 | // MessageState implements wsutil.Extension and is used to check whether 217 | // received WebSocket message is compressed. That is, it's generally 218 | // possible to receive uncompressed messaged even if compression extension 219 | // was negotiated. 220 | var msg wsflate.MessageState 221 | 222 | // Note that control frames are all written without compression. 223 | controlHandler := wsutil.ControlFrameHandler(conn, ws.StateServerSide) 224 | rd := wsutil.Reader{ 225 | Source: conn, 226 | State: ws.StateServerSide | ws.StateExtended, 227 | CheckUTF8: false, 228 | OnIntermediate: controlHandler, 229 | Extensions: []wsutil.RecvExtension{&msg}, 230 | } 231 | 232 | wr := wsutil.NewWriter(conn, ws.StateServerSide|ws.StateExtended, 0) 233 | wr.SetExtensions(&msg) 234 | 235 | for { 236 | h, err := rd.NextFrame() 237 | if err != nil { 238 | log.Printf("next frame error: %v", err) 239 | return 240 | } 241 | if h.OpCode.IsControl() { 242 | if err := controlHandler(h, &rd); err != nil { 243 | log.Printf("handle control frame error: %v", err) 244 | return 245 | } 246 | continue 247 | } 248 | 249 | wr.ResetOp(h.OpCode) 250 | 251 | var ( 252 | src io.Reader = &rd 253 | dst io.Writer = wr 254 | ) 255 | if msg.IsCompressed() { 256 | fr.Reset(src) 257 | fw.Reset(dst) 258 | src = fr 259 | dst = fw 260 | } 261 | // Copy incoming bytes right into writer, probably through decompressor 262 | // and compressor. 263 | if _, err = io.Copy(dst, src); err != nil { 264 | log.Fatal(err) 265 | } 266 | if msg.IsCompressed() { 267 | // Flush the flate writer. 268 | if err = fw.Close(); err != nil { 269 | log.Fatal(err) 270 | } 271 | } 272 | // Flush WebSocket fragment writer. We could send multiple fragments 273 | // for large messages. 274 | if err = wr.Flush(); err != nil { 275 | log.Fatal(err) 276 | } 277 | } 278 | } 279 | 280 | func wsHandler(w http.ResponseWriter, r *http.Request) { 281 | u := ws.HTTPUpgrader{ 282 | Extension: func(opt httphead.Option) bool { 283 | log.Printf("extension: %s", opt) 284 | return false 285 | }, 286 | } 287 | conn, _, _, err := u.Upgrade(r, w) 288 | if err != nil { 289 | log.Printf("upgrade error: %s", err) 290 | return 291 | } 292 | defer conn.Close() 293 | 294 | state := ws.StateServerSide 295 | 296 | textPending := false 297 | utf8Reader := wsutil.NewUTF8Reader(nil) 298 | cipherReader := wsutil.NewCipherReader(nil, [4]byte{0, 0, 0, 0}) 299 | 300 | for { 301 | header, err := ws.ReadHeader(conn) 302 | if err != nil { 303 | log.Printf("read header error: %s", err) 304 | break 305 | } 306 | if err = ws.CheckHeader(header, state); err != nil { 307 | log.Printf("header check error: %s", err) 308 | conn.Write(closeProtocolError) 309 | return 310 | } 311 | 312 | cipherReader.Reset( 313 | io.LimitReader(conn, header.Length), 314 | header.Mask, 315 | ) 316 | 317 | var utf8Fin bool 318 | var r io.Reader = cipherReader 319 | 320 | switch header.OpCode { 321 | case ws.OpPing: 322 | header.OpCode = ws.OpPong 323 | header.Masked = false 324 | ws.WriteHeader(conn, header) 325 | io.CopyN(conn, cipherReader, header.Length) 326 | continue 327 | 328 | case ws.OpPong: 329 | io.CopyN(ioutil.Discard, conn, header.Length) 330 | continue 331 | 332 | case ws.OpClose: 333 | utf8Fin = true 334 | 335 | case ws.OpContinuation: 336 | if textPending { 337 | utf8Reader.Source = cipherReader 338 | r = utf8Reader 339 | } 340 | if header.Fin { 341 | state = state.Clear(ws.StateFragmented) 342 | textPending = false 343 | utf8Fin = true 344 | } 345 | 346 | case ws.OpText: 347 | utf8Reader.Reset(cipherReader) 348 | r = utf8Reader 349 | 350 | if !header.Fin { 351 | state = state.Set(ws.StateFragmented) 352 | textPending = true 353 | } else { 354 | utf8Fin = true 355 | } 356 | 357 | case ws.OpBinary: 358 | if !header.Fin { 359 | state = state.Set(ws.StateFragmented) 360 | } 361 | } 362 | 363 | payload := make([]byte, header.Length) 364 | _, err = io.ReadFull(r, payload) 365 | if err == nil && utf8Fin && !utf8Reader.Valid() { 366 | err = wsutil.ErrInvalidUTF8 367 | } 368 | if err != nil { 369 | log.Printf("read payload error: %s", err) 370 | if err == wsutil.ErrInvalidUTF8 { 371 | conn.Write(closeInvalidPayload) 372 | } else { 373 | conn.Write(ws.CompiledClose) 374 | } 375 | return 376 | } 377 | 378 | if header.OpCode == ws.OpClose { 379 | code, reason := ws.ParseCloseFrameData(payload) 380 | log.Printf("close frame received: %v %v", code, reason) 381 | 382 | if !code.Empty() { 383 | switch { 384 | case code.IsProtocolSpec() && !code.IsProtocolDefined(): 385 | err = fmt.Errorf("close code from spec range is not defined") 386 | default: 387 | err = ws.CheckCloseFrameData(code, reason) 388 | } 389 | if err != nil { 390 | log.Printf("invalid close data: %s", err) 391 | conn.Write(closeProtocolError) 392 | } else { 393 | ws.WriteFrame(conn, ws.NewCloseFrame(ws.NewCloseFrameBody( 394 | code, "", 395 | ))) 396 | } 397 | return 398 | } 399 | 400 | conn.Write(ws.CompiledClose) 401 | return 402 | } 403 | 404 | header.Masked = false 405 | ws.WriteHeader(conn, header) 406 | conn.Write(payload) 407 | } 408 | } 409 | -------------------------------------------------------------------------------- /example/autobahn/autobahn_test.go: -------------------------------------------------------------------------------- 1 | // +build autobahn 2 | 3 | package main 4 | 5 | import "testing" 6 | 7 | func TestCallMain(t *testing.T) { 8 | main() 9 | } 10 | -------------------------------------------------------------------------------- /frame_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestOpCodeIsControl(t *testing.T) { 9 | for _, test := range []struct { 10 | code OpCode 11 | exp bool 12 | }{ 13 | {OpClose, true}, 14 | {OpPing, true}, 15 | {OpPong, true}, 16 | {OpBinary, false}, 17 | {OpText, false}, 18 | {OpContinuation, false}, 19 | } { 20 | t.Run(fmt.Sprintf("0x%02x", test.code), func(t *testing.T) { 21 | if act := test.code.IsControl(); act != test.exp { 22 | t.Errorf("IsControl = %v; want %v", act, test.exp) 23 | } 24 | }) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gobwas/ws 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/gobwas/httphead v0.1.0 7 | github.com/gobwas/pool v0.2.1 8 | golang.org/x/sys v0.6.0 // indirect 9 | ) 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= 2 | github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= 3 | github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= 4 | github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= 5 | golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= 6 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 7 | -------------------------------------------------------------------------------- /hijack_go119.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.20 2 | // +build !go1.20 3 | 4 | package ws 5 | 6 | import ( 7 | "bufio" 8 | "net" 9 | "net/http" 10 | ) 11 | 12 | func hijack(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { 13 | hj, ok := w.(http.Hijacker) 14 | if ok { 15 | return hj.Hijack() 16 | } 17 | return nil, nil, ErrNotHijacker 18 | } 19 | -------------------------------------------------------------------------------- /hijack_go120.go: -------------------------------------------------------------------------------- 1 | //go:build go1.20 2 | // +build go1.20 3 | 4 | package ws 5 | 6 | import ( 7 | "bufio" 8 | "errors" 9 | "net" 10 | "net/http" 11 | ) 12 | 13 | func hijack(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { 14 | conn, rw, err := http.NewResponseController(w).Hijack() 15 | if errors.Is(err, http.ErrNotSupported) { 16 | return nil, nil, ErrNotHijacker 17 | } 18 | return conn, rw, err 19 | } 20 | -------------------------------------------------------------------------------- /http_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "bufio" 5 | "io/ioutil" 6 | "net/textproto" 7 | "net/url" 8 | "testing" 9 | 10 | "github.com/gobwas/httphead" 11 | ) 12 | 13 | type httpVersionCase struct { 14 | in []byte 15 | major int 16 | minor int 17 | ok bool 18 | } 19 | 20 | var httpVersionCases = []httpVersionCase{ 21 | {[]byte("HTTP/1.1"), 1, 1, true}, 22 | {[]byte("HTTP/1.0"), 1, 0, true}, 23 | {[]byte("HTTP/1.2"), 1, 2, true}, 24 | {[]byte("HTTP/42.1092"), 42, 1092, true}, 25 | } 26 | 27 | func TestParseHttpVersion(t *testing.T) { 28 | for _, c := range httpVersionCases { 29 | t.Run(string(c.in), func(t *testing.T) { 30 | major, minor, ok := httpParseVersion(c.in) 31 | if major != c.major || minor != c.minor || ok != c.ok { 32 | t.Errorf( 33 | "parseHttpVersion([]byte(%q)) = %v, %v, %v; want %v, %v, %v", 34 | string(c.in), major, minor, ok, c.major, c.minor, c.ok, 35 | ) 36 | } 37 | }) 38 | } 39 | } 40 | 41 | func TestHeaderNames(t *testing.T) { 42 | testCases := []struct { 43 | have, want string 44 | }{ 45 | { 46 | have: headerHost, 47 | want: headerHostCanonical, 48 | }, 49 | { 50 | have: headerUpgrade, 51 | want: headerUpgradeCanonical, 52 | }, 53 | { 54 | have: headerConnection, 55 | want: headerConnectionCanonical, 56 | }, 57 | { 58 | have: headerSecVersion, 59 | want: headerSecVersionCanonical, 60 | }, 61 | { 62 | have: headerSecProtocol, 63 | want: headerSecProtocolCanonical, 64 | }, 65 | { 66 | have: headerSecExtensions, 67 | want: headerSecExtensionsCanonical, 68 | }, 69 | { 70 | have: headerSecKey, 71 | want: headerSecKeyCanonical, 72 | }, 73 | { 74 | have: headerSecAccept, 75 | want: headerSecAcceptCanonical, 76 | }, 77 | } 78 | 79 | for _, tc := range testCases { 80 | if have := textproto.CanonicalMIMEHeaderKey(tc.have); have != tc.want { 81 | t.Errorf("have %q want %q,", have, tc.want) 82 | } 83 | } 84 | } 85 | 86 | func BenchmarkParseHttpVersion(b *testing.B) { 87 | for _, c := range httpVersionCases { 88 | b.Run(string(c.in), func(b *testing.B) { 89 | for i := 0; i < b.N; i++ { 90 | _, _, _ = httpParseVersion(c.in) 91 | } 92 | }) 93 | } 94 | } 95 | 96 | func BenchmarkHttpWriteUpgradeRequest(b *testing.B) { 97 | for _, test := range []struct { 98 | url *url.URL 99 | protocols []string 100 | extensions []httphead.Option 101 | headers HandshakeHeaderFunc 102 | host string 103 | }{ 104 | { 105 | url: makeURL("ws://example.org"), 106 | }, 107 | { 108 | url: makeURL("ws://example.org"), 109 | host: "test-host", 110 | }, 111 | } { 112 | bw := bufio.NewWriter(ioutil.Discard) 113 | nonce := make([]byte, nonceSize) 114 | initNonce(nonce) 115 | 116 | var headers HandshakeHeader 117 | if test.headers != nil { 118 | headers = test.headers 119 | } 120 | 121 | b.ResetTimer() 122 | b.Run("", func(b *testing.B) { 123 | for i := 0; i < b.N; i++ { 124 | httpWriteUpgradeRequest(bw, 125 | test.url, 126 | nonce, 127 | test.protocols, 128 | test.extensions, 129 | headers, 130 | test.host, 131 | ) 132 | } 133 | }) 134 | } 135 | } 136 | 137 | func makeURL(s string) *url.URL { 138 | ret, err := url.Parse(s) 139 | if err != nil { 140 | panic(err) 141 | } 142 | return ret 143 | } 144 | -------------------------------------------------------------------------------- /nonce.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto/sha1" 7 | "encoding/base64" 8 | "fmt" 9 | "math/rand" 10 | ) 11 | 12 | const ( 13 | // RFC6455: The value of this header field MUST be a nonce consisting of a 14 | // randomly selected 16-byte value that has been base64-encoded (see 15 | // Section 4 of [RFC4648]). The nonce MUST be selected randomly for each 16 | // connection. 17 | nonceKeySize = 16 18 | nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) 19 | 20 | // RFC6455: The value of this header field is constructed by concatenating 21 | // /key/, defined above in step 4 in Section 4.2.2, with the string 22 | // "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this 23 | // concatenated value to obtain a 20-byte value and base64- encoding (see 24 | // Section 4 of [RFC4648]) this 20-byte hash. 25 | acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size) 26 | ) 27 | 28 | // initNonce fills given slice with random base64-encoded nonce bytes. 29 | func initNonce(dst []byte) { 30 | // NOTE: bts does not escape. 31 | bts := make([]byte, nonceKeySize) 32 | if _, err := rand.Read(bts); err != nil { 33 | panic(fmt.Sprintf("rand read error: %s", err)) 34 | } 35 | base64.StdEncoding.Encode(dst, bts) 36 | } 37 | 38 | // checkAcceptFromNonce reports whether given accept bytes are valid for given 39 | // nonce bytes. 40 | func checkAcceptFromNonce(accept, nonce []byte) bool { 41 | if len(accept) != acceptSize { 42 | return false 43 | } 44 | // NOTE: expect does not escape. 45 | expect := make([]byte, acceptSize) 46 | initAcceptFromNonce(expect, nonce) 47 | return bytes.Equal(expect, accept) 48 | } 49 | 50 | // initAcceptFromNonce fills given slice with accept bytes generated from given 51 | // nonce bytes. Given buffer should be exactly acceptSize bytes. 52 | func initAcceptFromNonce(accept, nonce []byte) { 53 | const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 54 | 55 | if len(accept) != acceptSize { 56 | panic("accept buffer is invalid") 57 | } 58 | if len(nonce) != nonceSize { 59 | panic("nonce is invalid") 60 | } 61 | 62 | p := make([]byte, nonceSize+len(magic)) 63 | copy(p[:nonceSize], nonce) 64 | copy(p[nonceSize:], magic) 65 | 66 | sum := sha1.Sum(p) 67 | base64.StdEncoding.Encode(accept, sum[:]) 68 | } 69 | 70 | func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) { 71 | accept := make([]byte, acceptSize) 72 | initAcceptFromNonce(accept, nonce) 73 | // NOTE: write accept bytes as a string to prevent heap allocation – 74 | // WriteString() copy given string into its inner buffer, unlike Write() 75 | // which may write p directly to the underlying io.Writer – which in turn 76 | // will lead to p escape. 77 | return bw.WriteString(btsToString(accept)) 78 | } 79 | -------------------------------------------------------------------------------- /nonce_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import "testing" 4 | 5 | func BenchmarkInitAcceptFromNonce(b *testing.B) { 6 | dst := make([]byte, acceptSize) 7 | nonce := mustMakeNonce() 8 | for i := 0; i < b.N; i++ { 9 | initAcceptFromNonce(dst, nonce) 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /read.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | // Errors used by frame reader. 10 | var ( 11 | ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0") 12 | ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits") 13 | ) 14 | 15 | // ReadHeader reads a frame header from r. 16 | func ReadHeader(r io.Reader) (h Header, err error) { 17 | // Make slice of bytes with capacity 12 that could hold any header. 18 | // 19 | // The maximum header size is 14, but due to the 2 hop reads, 20 | // after first hop that reads first 2 constant bytes, we could reuse 2 bytes. 21 | // So 14 - 2 = 12. 22 | bts := make([]byte, 2, MaxHeaderSize-2) 23 | 24 | // Prepare to hold first 2 bytes to choose size of next read. 25 | _, err = io.ReadFull(r, bts) 26 | if err != nil { 27 | return h, err 28 | } 29 | 30 | h.Fin = bts[0]&bit0 != 0 31 | h.Rsv = (bts[0] & 0x70) >> 4 32 | h.OpCode = OpCode(bts[0] & 0x0f) 33 | 34 | var extra int 35 | 36 | if bts[1]&bit0 != 0 { 37 | h.Masked = true 38 | extra += 4 39 | } 40 | 41 | length := bts[1] & 0x7f 42 | switch { 43 | case length < 126: 44 | h.Length = int64(length) 45 | 46 | case length == 126: 47 | extra += 2 48 | 49 | case length == 127: 50 | extra += 8 51 | 52 | default: 53 | err = ErrHeaderLengthUnexpected 54 | return h, err 55 | } 56 | 57 | if extra == 0 { 58 | return h, err 59 | } 60 | 61 | // Increase len of bts to extra bytes need to read. 62 | // Overwrite first 2 bytes that was read before. 63 | bts = bts[:extra] 64 | _, err = io.ReadFull(r, bts) 65 | if err != nil { 66 | return h, err 67 | } 68 | 69 | switch { 70 | case length == 126: 71 | h.Length = int64(binary.BigEndian.Uint16(bts[:2])) 72 | bts = bts[2:] 73 | 74 | case length == 127: 75 | if bts[0]&0x80 != 0 { 76 | err = ErrHeaderLengthMSB 77 | return h, err 78 | } 79 | h.Length = int64(binary.BigEndian.Uint64(bts[:8])) 80 | bts = bts[8:] 81 | } 82 | 83 | if h.Masked { 84 | copy(h.Mask[:], bts) 85 | } 86 | 87 | return h, nil 88 | } 89 | 90 | // ReadFrame reads a frame from r. 91 | // It is not designed for high optimized use case cause it makes allocation 92 | // for frame.Header.Length size inside to read frame payload into. 93 | // 94 | // Note that ReadFrame does not unmask payload. 95 | func ReadFrame(r io.Reader) (f Frame, err error) { 96 | f.Header, err = ReadHeader(r) 97 | if err != nil { 98 | return f, err 99 | } 100 | 101 | if f.Header.Length > 0 { 102 | // int(f.Header.Length) is safe here cause we have 103 | // checked it for overflow above in ReadHeader. 104 | f.Payload = make([]byte, int(f.Header.Length)) 105 | _, err = io.ReadFull(r, f.Payload) 106 | } 107 | 108 | return f, err 109 | } 110 | 111 | // MustReadFrame is like ReadFrame but panics if frame can not be read. 112 | func MustReadFrame(r io.Reader) Frame { 113 | f, err := ReadFrame(r) 114 | if err != nil { 115 | panic(err) 116 | } 117 | return f 118 | } 119 | 120 | // ParseCloseFrameData parses close frame status code and closure reason if any provided. 121 | // If there is no status code in the payload 122 | // the empty status code is returned (code.Empty()) with empty string as a reason. 123 | func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) { 124 | if len(payload) < 2 { 125 | // We returning empty StatusCode here, preventing the situation 126 | // when endpoint really sent code 1005 and we should return ProtocolError on that. 127 | // 128 | // In other words, we ignoring this rule [RFC6455:7.1.5]: 129 | // If this Close control frame contains no status code, _The WebSocket 130 | // Connection Close Code_ is considered to be 1005. 131 | return code, reason 132 | } 133 | code = StatusCode(binary.BigEndian.Uint16(payload)) 134 | reason = string(payload[2:]) 135 | return code, reason 136 | } 137 | 138 | // ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing 139 | // that it does not copies payload bytes into reason, but prepares unsafe cast. 140 | func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) { 141 | if len(payload) < 2 { 142 | return code, reason 143 | } 144 | code = StatusCode(binary.BigEndian.Uint16(payload)) 145 | reason = btsToString(payload[2:]) 146 | return code, reason 147 | } 148 | -------------------------------------------------------------------------------- /read_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | func TestReadHeader(t *testing.T) { 12 | for i, test := range append([]RWTestCase{ 13 | { 14 | Data: bits("0000 0000 0 1111111 10000000 00000000 00000000 00000000 00000000 00000000 00000000 00000000"), 15 | // _______________________________________________________________________ 16 | // | 17 | // Length value 18 | Err: true, 19 | }, 20 | }, RWTestCases...) { 21 | t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { 22 | r := bytes.NewReader(test.Data) 23 | h, err := ReadHeader(r) 24 | if test.Err && err == nil { 25 | t.Errorf("expected error, got nil") 26 | } 27 | if !test.Err && err != nil { 28 | t.Errorf("unexpected error: %s", err) 29 | } 30 | if test.Err { 31 | return 32 | } 33 | if !reflect.DeepEqual(h, test.Header) { 34 | t.Errorf("ReadHeader()\nread:\n\t%#v\nwant:\n\t%#v", h, test.Header) 35 | } 36 | }) 37 | } 38 | } 39 | 40 | func BenchmarkReadHeader(b *testing.B) { 41 | for i, bench := range RWBenchCases { 42 | b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) { 43 | bts := MustCompileFrame(Frame{Header: bench.header}) 44 | rds := make([]io.Reader, b.N) 45 | for i := 0; i < b.N; i++ { 46 | rds[i] = bytes.NewReader(bts) 47 | } 48 | 49 | b.ResetTimer() 50 | 51 | for i := 0; i < b.N; i++ { 52 | _, err := ReadHeader(rds[i]) 53 | if err != nil { 54 | b.Fatal(err) 55 | } 56 | } 57 | }) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /rw_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | type RWTestCase struct { 9 | Data []byte 10 | Header Header 11 | Err bool 12 | } 13 | 14 | type RWBenchCase struct { 15 | label string 16 | header Header 17 | } 18 | 19 | var RWBenchCases = []RWBenchCase{ 20 | { 21 | "no-mask", 22 | Header{ 23 | OpCode: OpText, 24 | Fin: true, 25 | }, 26 | }, 27 | { 28 | "mask", 29 | Header{ 30 | OpCode: OpText, 31 | Fin: true, 32 | Masked: true, 33 | Mask: NewMask(), 34 | }, 35 | }, 36 | { 37 | "mask-u16", 38 | Header{ 39 | OpCode: OpText, 40 | Fin: true, 41 | Length: len16, 42 | Masked: true, 43 | Mask: NewMask(), 44 | }, 45 | }, 46 | { 47 | "mask-u64", 48 | Header{ 49 | OpCode: OpText, 50 | Fin: true, 51 | Length: len64, 52 | Masked: true, 53 | Mask: NewMask(), 54 | }, 55 | }, 56 | } 57 | 58 | var RWTestCases = []RWTestCase{ 59 | { 60 | Data: bits("1 001 0001 0 1100100"), 61 | // _ ___ ____ _ _______ 62 | // | | | | | 63 | // Fin | | Mask Length 64 | // Rsv | 65 | // TextFrame 66 | Header: Header{ 67 | Fin: true, 68 | Rsv: Rsv(false, false, true), 69 | OpCode: OpText, 70 | Length: 100, 71 | }, 72 | }, 73 | { 74 | Data: bits("1 001 0001 1 1100100 00000001 10001000 00000000 11111111"), 75 | // _ ___ ____ _ _______ ___________________________________ 76 | // | | | | | | 77 | // Fin | | Mask Length Mask value 78 | // Rsv | 79 | // TextFrame 80 | Header: Header{ 81 | Fin: true, 82 | Rsv: Rsv(false, false, true), 83 | OpCode: OpText, 84 | Length: 100, 85 | Masked: true, 86 | Mask: [4]byte{0x01, 0x88, 0x00, 0xff}, 87 | }, 88 | }, 89 | { 90 | Data: bits("0 110 0010 0 1111110 00001111 11111111"), 91 | // _ ___ ____ _ _______ _________________ 92 | // | | | | | | 93 | // Fin | | Mask Length Length value 94 | // Rsv | 95 | // BinaryFrame 96 | Header: Header{ 97 | Fin: false, 98 | Rsv: Rsv(true, true, false), 99 | OpCode: OpBinary, 100 | Length: 0x0fff, 101 | }, 102 | }, 103 | { 104 | Data: bits("1 000 1010 0 1111111 01111111 00000000 00000000 00000000 00000000 00000000 00000000 00000000"), 105 | // _ ___ ____ _ _______ _______________________________________________________________________ 106 | // | | | | | | 107 | // Fin | | Mask Length Length value 108 | // Rsv | 109 | // PongFrame 110 | Header: Header{ 111 | Fin: true, 112 | Rsv: Rsv(false, false, false), 113 | OpCode: OpPong, 114 | Length: 0x7f00000000000000, 115 | }, 116 | }, 117 | } 118 | 119 | func bits(s string) []byte { 120 | s = strings.ReplaceAll(s, " ", "") 121 | bts := make([]byte, len(s)/8) 122 | 123 | for i, j := 0, 0; i < len(s); i, j = i+8, j+1 { 124 | fmt.Sscanf(s[i:], "%08b", &bts[j]) 125 | } 126 | 127 | return bts 128 | } 129 | -------------------------------------------------------------------------------- /tests/deflate_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | "net" 8 | "testing" 9 | "time" 10 | 11 | "github.com/gobwas/httphead" 12 | "github.com/gobwas/ws" 13 | "github.com/gobwas/ws/wsflate" 14 | "github.com/gobwas/ws/wsutil" 15 | ) 16 | 17 | func TestFlateClientServer(t *testing.T) { 18 | e := wsflate.Extension{ 19 | Parameters: wsflate.DefaultParameters, 20 | } 21 | client, server := net.Pipe() 22 | 23 | serverDone := make(chan error) 24 | go func() { 25 | defer func() { 26 | client.Close() 27 | close(serverDone) 28 | }() 29 | u := ws.Upgrader{ 30 | Negotiate: e.Negotiate, 31 | } 32 | _, err := u.Upgrade(client) 33 | if err != nil { 34 | serverDone <- err 35 | return 36 | } 37 | var buf bytes.Buffer 38 | for { 39 | frame, err := ws.ReadFrame(client) 40 | if err != nil { 41 | serverDone <- err 42 | return 43 | } 44 | frame = ws.UnmaskFrameInPlace(frame) 45 | frame, err = wsflate.DecompressFrameBuffer(&buf, frame) 46 | if err != nil { 47 | serverDone <- err 48 | return 49 | } 50 | echo := ws.NewTextFrame(reverse(frame.Payload)) 51 | if err := ws.WriteFrame(client, echo); err != nil { 52 | serverDone <- err 53 | return 54 | } 55 | buf.Reset() 56 | } 57 | }() 58 | 59 | d := ws.Dialer{ 60 | Extensions: []httphead.Option{ 61 | e.Parameters.Option(), 62 | }, 63 | NetDial: func(_ context.Context, network, addr string) (net.Conn, error) { 64 | return server, nil 65 | }, 66 | } 67 | dd := wsutil.DebugDialer{ 68 | Dialer: d, 69 | OnRequest: func(p []byte) { 70 | t.Logf("Request:\n%s", p) 71 | }, 72 | OnResponse: func(p []byte) { 73 | t.Logf("Response:\n%s", p) 74 | }, 75 | } 76 | conn, _, _, err := dd.Dial(context.Background(), "ws://stubbed") 77 | if err != nil { 78 | t.Fatalf("unexpected Dial() error: %v", err) 79 | } 80 | 81 | payload := []byte("hello, deflate!") 82 | 83 | frame := ws.NewTextFrame(payload) 84 | frame, err = wsflate.CompressFrame(frame) 85 | if err != nil { 86 | t.Fatalf("can't compress frame: %v", err) 87 | } 88 | frame = ws.MaskFrameInPlace(frame) 89 | if err := ws.WriteFrame(server, frame); err != nil { 90 | t.Fatalf("unexpected WriteFrame() error: %v", err) 91 | } 92 | 93 | echo, err := ws.ReadFrame(server) 94 | if err != nil { 95 | t.Fatalf("unexpected ReadFrame() error: %v", err) 96 | } 97 | if !bytes.Equal(reverse(echo.Payload), payload) { 98 | t.Fatalf("unexpected echoed bytes") 99 | } 100 | 101 | conn.Close() 102 | 103 | const timeout = time.Second 104 | select { 105 | case <-time.After(timeout): 106 | t.Fatalf("server goroutine timeout: %s", timeout) 107 | 108 | case err := <-serverDone: 109 | if err != io.EOF { 110 | t.Fatalf("unexpected server goroutine error: %v", err) 111 | } 112 | } 113 | } 114 | 115 | func reverse(buf []byte) []byte { 116 | for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 { 117 | buf[i], buf[j] = buf[j], buf[i] 118 | } 119 | return buf 120 | } 121 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | 8 | "github.com/gobwas/httphead" 9 | ) 10 | 11 | // SelectFromSlice creates accept function that could be used as Protocol/Extension 12 | // select during upgrade. 13 | func SelectFromSlice(accept []string) func(string) bool { 14 | if len(accept) > 16 { 15 | mp := make(map[string]struct{}, len(accept)) 16 | for _, p := range accept { 17 | mp[p] = struct{}{} 18 | } 19 | return func(p string) bool { 20 | _, ok := mp[p] 21 | return ok 22 | } 23 | } 24 | return func(p string) bool { 25 | for _, ok := range accept { 26 | if p == ok { 27 | return true 28 | } 29 | } 30 | return false 31 | } 32 | } 33 | 34 | // SelectEqual creates accept function that could be used as Protocol/Extension 35 | // select during upgrade. 36 | func SelectEqual(v string) func(string) bool { 37 | return func(p string) bool { 38 | return v == p 39 | } 40 | } 41 | 42 | // asciiToInt converts bytes to int. 43 | func asciiToInt(bts []byte) (ret int, err error) { 44 | // ASCII numbers all start with the high-order bits 0011. 45 | // If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those 46 | // bits and interpret them directly as an integer. 47 | var n int 48 | if n = len(bts); n < 1 { 49 | return 0, fmt.Errorf("converting empty bytes to int") 50 | } 51 | for i := 0; i < n; i++ { 52 | if bts[i]&0xf0 != 0x30 { 53 | return 0, fmt.Errorf("%s is not a numeric character", string(bts[i])) 54 | } 55 | ret += int(bts[i]&0xf) * pow(10, n-i-1) 56 | } 57 | return ret, nil 58 | } 59 | 60 | // pow for integers implementation. 61 | // See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3. 62 | func pow(a, b int) int { 63 | p := 1 64 | for b > 0 { 65 | if b&1 != 0 { 66 | p *= a 67 | } 68 | b >>= 1 69 | a *= a 70 | } 71 | return p 72 | } 73 | 74 | func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) { 75 | a := bytes.IndexByte(bts, sep) 76 | b := bytes.IndexByte(bts[a+1:], sep) 77 | if a == -1 || b == -1 { 78 | return bts, nil, nil 79 | } 80 | b += a + 1 81 | return bts[:a], bts[a+1 : b], bts[b+1:] 82 | } 83 | 84 | func btrim(bts []byte) []byte { 85 | var i, j int 86 | for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); { 87 | i++ 88 | } 89 | for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); { 90 | j-- 91 | } 92 | return bts[i:j] 93 | } 94 | 95 | func strHasToken(header, token string) (has bool) { 96 | return btsHasToken(strToBytes(header), strToBytes(token)) 97 | } 98 | 99 | func btsHasToken(header, token []byte) (has bool) { 100 | httphead.ScanTokens(header, func(v []byte) bool { 101 | has = bytes.EqualFold(v, token) 102 | return !has 103 | }) 104 | return has 105 | } 106 | 107 | const ( 108 | toLower = 'a' - 'A' // for use with OR. 109 | toUpper = ^byte(toLower) // for use with AND. 110 | toLower8 = uint64(toLower) | 111 | uint64(toLower)<<8 | 112 | uint64(toLower)<<16 | 113 | uint64(toLower)<<24 | 114 | uint64(toLower)<<32 | 115 | uint64(toLower)<<40 | 116 | uint64(toLower)<<48 | 117 | uint64(toLower)<<56 118 | ) 119 | 120 | // Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except 121 | // that it operates with slice of bytes and modifies it inplace without copying. 122 | func canonicalizeHeaderKey(k []byte) { 123 | upper := true 124 | for i, c := range k { 125 | if upper && 'a' <= c && c <= 'z' { 126 | k[i] &= toUpper 127 | } else if !upper && 'A' <= c && c <= 'Z' { 128 | k[i] |= toLower 129 | } 130 | upper = c == '-' 131 | } 132 | } 133 | 134 | // readLine reads line from br. It reads until '\n' and returns bytes without 135 | // '\n' or '\r\n' at the end. 136 | // It returns err if and only if line does not end in '\n'. Note that read 137 | // bytes returned in any case of error. 138 | // 139 | // It is much like the textproto/Reader.ReadLine() except the thing that it 140 | // returns raw bytes, instead of string. That is, it avoids copying bytes read 141 | // from br. 142 | // 143 | // textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be 144 | // safe with future I/O operations on br. 145 | // 146 | // We could control I/O operations on br and do not need to make additional 147 | // copy for safety. 148 | // 149 | // NOTE: it may return copied flag to notify that returned buffer is safe to 150 | // use. 151 | func readLine(br *bufio.Reader) ([]byte, error) { 152 | var line []byte 153 | for { 154 | bts, err := br.ReadSlice('\n') 155 | if err == bufio.ErrBufferFull { 156 | // Copy bytes because next read will discard them. 157 | line = append(line, bts...) 158 | continue 159 | } 160 | 161 | // Avoid copy of single read. 162 | if line == nil { 163 | line = bts 164 | } else { 165 | line = append(line, bts...) 166 | } 167 | 168 | if err != nil { 169 | return line, err 170 | } 171 | 172 | // Size of line is at least 1. 173 | // In other case bufio.ReadSlice() returns error. 174 | n := len(line) 175 | 176 | // Cut '\n' or '\r\n'. 177 | if n > 1 && line[n-2] == '\r' { 178 | line = line[:n-2] 179 | } else { 180 | line = line[:n-1] 181 | } 182 | 183 | return line, nil 184 | } 185 | } 186 | 187 | func min(a, b int) int { 188 | if a < b { 189 | return a 190 | } 191 | return b 192 | } 193 | 194 | func nonZero(a, b int) int { 195 | if a != 0 { 196 | return a 197 | } 198 | return b 199 | } 200 | -------------------------------------------------------------------------------- /util_purego.go: -------------------------------------------------------------------------------- 1 | //go:build purego 2 | // +build purego 3 | 4 | package ws 5 | 6 | func strToBytes(str string) (bts []byte) { 7 | return []byte(str) 8 | } 9 | 10 | func btsToString(bts []byte) (str string) { 11 | return string(bts) 12 | } 13 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "fmt" 8 | "io" 9 | "net" 10 | "net/http" 11 | "net/textproto" 12 | "reflect" 13 | "strings" 14 | "sync" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | var readLineCases = []struct { 20 | label string 21 | in string 22 | line []byte 23 | err error 24 | bufSize int 25 | }{ 26 | { 27 | label: "simple", 28 | in: "hello, world!", 29 | line: []byte("hello, world!"), 30 | err: io.EOF, 31 | bufSize: 1024, 32 | }, 33 | { 34 | label: "simple", 35 | in: "hello, world!\r\n", 36 | line: []byte("hello, world!"), 37 | bufSize: 1024, 38 | }, 39 | { 40 | label: "simple", 41 | in: "hello, world!\n", 42 | line: []byte("hello, world!"), 43 | bufSize: 1024, 44 | }, 45 | { 46 | // The case where "\r\n" straddles the buffer. 47 | label: "straddle", 48 | in: "hello, world!!!\r\n...", 49 | line: []byte("hello, world!!!"), 50 | bufSize: 16, 51 | }, 52 | { 53 | label: "chunked", 54 | in: "hello, world! this is a long long line!", 55 | line: []byte("hello, world! this is a long long line!"), 56 | err: io.EOF, 57 | bufSize: 16, 58 | }, 59 | { 60 | label: "chunked", 61 | in: "hello, world! this is a long long line!\r\n", 62 | line: []byte("hello, world! this is a long long line!"), 63 | bufSize: 16, 64 | }, 65 | } 66 | 67 | func TestReadLine(t *testing.T) { 68 | for _, test := range readLineCases { 69 | t.Run(test.label, func(t *testing.T) { 70 | br := bufio.NewReaderSize(strings.NewReader(test.in), test.bufSize) 71 | bts, err := readLine(br) 72 | if err != test.err { 73 | t.Errorf("unexpected error: %v; want %v", err, test.err) 74 | } 75 | if act, exp := bts, test.line; !bytes.Equal(act, exp) { 76 | t.Errorf("readLine() result is %#q; want %#q", act, exp) 77 | } 78 | }) 79 | } 80 | } 81 | 82 | func BenchmarkReadLine(b *testing.B) { 83 | for _, test := range readLineCases { 84 | sr := strings.NewReader(test.in) 85 | br := bufio.NewReaderSize(sr, test.bufSize) 86 | b.Run(test.label, func(b *testing.B) { 87 | for i := 0; i < b.N; i++ { 88 | _, _ = readLine(br) 89 | sr.Reset(test.in) 90 | br.Reset(sr) 91 | } 92 | }) 93 | } 94 | } 95 | 96 | func TestUpgradeSlowClient(t *testing.T) { 97 | for _, test := range []struct { 98 | lim *limitWriter 99 | }{ 100 | { 101 | lim: &limitWriter{ 102 | Bandwidth: 100, 103 | Period: time.Second, 104 | Burst: 10, 105 | }, 106 | }, 107 | { 108 | lim: &limitWriter{ 109 | Bandwidth: 100, 110 | Period: time.Second, 111 | Burst: 100, 112 | }, 113 | }, 114 | } { 115 | t.Run("", func(t *testing.T) { 116 | client, server, err := socketPair() 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | test.lim.Dest = server 121 | 122 | header := http.Header{ 123 | "X-Websocket-Test-1": []string{"Yes"}, 124 | "X-Websocket-Test-2": []string{"Yes"}, 125 | "X-Websocket-Test-3": []string{"Yes"}, 126 | "X-Websocket-Test-4": []string{"Yes"}, 127 | } 128 | d := Dialer{ 129 | NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { 130 | return connWithWriter{server, test.lim}, nil 131 | }, 132 | Header: HandshakeHeaderHTTP(header), 133 | } 134 | var ( 135 | expHost = "example.org" 136 | expURI = "/path/to/ws" 137 | ) 138 | receivedHeader := http.Header{} 139 | u := Upgrader{ 140 | OnRequest: func(uri []byte) error { 141 | if u := string(uri); u != expURI { 142 | t.Errorf( 143 | "unexpected URI in OnRequest() callback: %q; want %q", 144 | u, expURI, 145 | ) 146 | } 147 | return nil 148 | }, 149 | OnHost: func(host []byte) error { 150 | if h := string(host); h != expHost { 151 | t.Errorf( 152 | "unexpected host in OnRequest() callback: %q; want %q", 153 | h, expHost, 154 | ) 155 | } 156 | return nil 157 | }, 158 | OnHeader: func(key, value []byte) error { 159 | receivedHeader.Add(string(key), string(value)) 160 | return nil 161 | }, 162 | } 163 | upgrade := make(chan error, 1) 164 | go func() { 165 | _, err := u.Upgrade(client) 166 | upgrade <- err 167 | }() 168 | 169 | _, _, _, err = d.Dial(context.Background(), "ws://"+expHost+expURI) 170 | if err != nil { 171 | t.Errorf("Dial() error: %v", err) 172 | } 173 | 174 | if err := <-upgrade; err != nil { 175 | t.Errorf("Upgrade() error: %v", err) 176 | } 177 | for key, values := range header { 178 | act, has := receivedHeader[key] 179 | if !has { 180 | t.Errorf("OnHeader() was not called with %q header key", key) 181 | } 182 | if !reflect.DeepEqual(act, values) { 183 | t.Errorf("OnHeader(%q) different values: %v; want %v", key, act, values) 184 | } 185 | } 186 | }) 187 | } 188 | } 189 | 190 | type connWithWriter struct { 191 | net.Conn 192 | w io.Writer 193 | } 194 | 195 | func (w connWithWriter) Write(p []byte) (int, error) { 196 | return w.w.Write(p) 197 | } 198 | 199 | type limitWriter struct { 200 | Dest io.Writer 201 | Bandwidth int 202 | Burst int 203 | Period time.Duration 204 | 205 | mu sync.Mutex 206 | cond sync.Cond 207 | once sync.Once 208 | done chan struct{} 209 | tickets int 210 | } 211 | 212 | func (w *limitWriter) init() { 213 | w.once.Do(func() { 214 | w.cond.L = &w.mu 215 | w.done = make(chan struct{}) 216 | 217 | tick := w.Period / time.Duration(w.Bandwidth) 218 | go func() { 219 | t := time.NewTicker(tick) 220 | for { 221 | select { 222 | case <-t.C: 223 | w.mu.Lock() 224 | w.tickets = w.Burst 225 | w.mu.Unlock() 226 | w.cond.Signal() 227 | case <-w.done: 228 | t.Stop() 229 | return 230 | } 231 | } 232 | }() 233 | }) 234 | } 235 | 236 | func (w *limitWriter) allow(n int) (allowed int) { 237 | w.init() 238 | w.mu.Lock() 239 | defer w.mu.Unlock() 240 | for w.tickets == 0 { 241 | w.cond.Wait() 242 | } 243 | if w.tickets < 0 { 244 | return -1 245 | } 246 | allowed = min(w.tickets, n) 247 | w.tickets -= allowed 248 | return allowed 249 | } 250 | 251 | func (w *limitWriter) Close() error { 252 | w.init() 253 | w.mu.Lock() 254 | defer w.mu.Unlock() 255 | if w.tickets < 0 { 256 | return nil 257 | } 258 | w.tickets = -1 259 | close(w.done) 260 | w.cond.Broadcast() 261 | return nil 262 | } 263 | 264 | func (w *limitWriter) Write(p []byte) (n int, err error) { 265 | w.init() 266 | for n < len(p) { 267 | m := w.allow(len(p)) 268 | if m < 0 { 269 | return 0, io.ErrClosedPipe 270 | } 271 | if _, err := w.Dest.Write(p[n : n+m]); err != nil { 272 | return n, err 273 | } 274 | n += m 275 | } 276 | return n, nil 277 | } 278 | 279 | func socketPair() (client, server net.Conn, err error) { 280 | ln, err := net.Listen("tcp", "localhost:") 281 | if err != nil { 282 | return nil, nil, err 283 | } 284 | type connAndError struct { 285 | conn net.Conn 286 | err error 287 | } 288 | dial := make(chan connAndError, 1) 289 | go func() { 290 | conn, err := net.Dial("tcp", ln.Addr().String()) 291 | dial <- connAndError{conn, err} 292 | }() 293 | server, err = ln.Accept() 294 | if err != nil { 295 | return nil, nil, err 296 | } 297 | ce := <-dial 298 | if err := ce.err; err != nil { 299 | return nil, nil, err 300 | } 301 | return ce.conn, server, nil 302 | } 303 | 304 | func TestHasToken(t *testing.T) { 305 | for i, test := range []struct { 306 | header string 307 | token string 308 | exp bool 309 | }{ 310 | {"Keep-Alive, Close, Upgrade", "upgrade", true}, 311 | {"Keep-Alive, Close, upgrade, hello", "upgrade", true}, 312 | {"Keep-Alive, Close, hello", "upgrade", false}, 313 | } { 314 | t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { 315 | if has := strHasToken(test.header, test.token); has != test.exp { 316 | t.Errorf("hasToken(%q, %q) = %v; want %v", test.header, test.token, has, test.exp) 317 | } 318 | }) 319 | } 320 | } 321 | 322 | func BenchmarkHasToken(b *testing.B) { 323 | for i, bench := range []struct { 324 | header string 325 | token string 326 | }{ 327 | {"Keep-Alive, Close, Upgrade", "upgrade"}, 328 | {"Keep-Alive, Close, upgrade, hello", "upgrade"}, 329 | {"Keep-Alive, Close, hello", "upgrade"}, 330 | } { 331 | b.Run(fmt.Sprintf("#%d", i), func(b *testing.B) { 332 | for i := 0; i < b.N; i++ { 333 | _ = strHasToken(bench.header, bench.token) 334 | } 335 | }) 336 | } 337 | } 338 | 339 | func TestAsciiToInt(t *testing.T) { 340 | for _, test := range []struct { 341 | bts []byte 342 | exp int 343 | err bool 344 | }{ 345 | {[]byte{'0'}, 0, false}, 346 | {[]byte{'1'}, 1, false}, 347 | {[]byte("42"), 42, false}, 348 | {[]byte("420"), 420, false}, 349 | {[]byte("010050042"), 10050042, false}, 350 | } { 351 | t.Run(string(test.bts), func(t *testing.T) { 352 | act, err := asciiToInt(test.bts) 353 | if (test.err && err == nil) || (!test.err && err != nil) { 354 | t.Errorf("unexpected error: %v", err) 355 | } 356 | if act != test.exp { 357 | t.Errorf("asciiToInt(%v) = %v; want %v", test.bts, act, test.exp) 358 | } 359 | }) 360 | } 361 | } 362 | 363 | func TestBtrim(t *testing.T) { 364 | for _, test := range []struct { 365 | bts []byte 366 | exp []byte 367 | }{ 368 | {[]byte("abc"), []byte("abc")}, 369 | {[]byte(" abc"), []byte("abc")}, 370 | {[]byte("abc "), []byte("abc")}, 371 | {[]byte(" abc "), []byte("abc")}, 372 | } { 373 | t.Run(string(test.bts), func(t *testing.T) { 374 | if act := btrim(test.bts); !bytes.Equal(act, test.exp) { 375 | t.Errorf("btrim(%v) = %v; want %v", test.bts, act, test.exp) 376 | } 377 | }) 378 | } 379 | } 380 | 381 | func TestBSplit3(t *testing.T) { 382 | for _, test := range []struct { 383 | bts []byte 384 | sep byte 385 | exp1 []byte 386 | exp2 []byte 387 | exp3 []byte 388 | }{ 389 | {[]byte(""), ' ', []byte{}, nil, nil}, 390 | {[]byte("GET / HTTP/1.1"), ' ', []byte("GET"), []byte("/"), []byte("HTTP/1.1")}, 391 | } { 392 | t.Run(string(test.bts), func(t *testing.T) { 393 | b1, b2, b3 := bsplit3(test.bts, test.sep) 394 | if !bytes.Equal(b1, test.exp1) || !bytes.Equal(b2, test.exp2) || !bytes.Equal(b3, test.exp3) { 395 | t.Errorf( 396 | "bsplit3(%q) = %q, %q, %q; want %q, %q, %q", 397 | string(test.bts), string(b1), string(b2), string(b3), 398 | string(test.exp1), string(test.exp2), string(test.exp3), 399 | ) 400 | } 401 | }) 402 | } 403 | } 404 | 405 | var canonicalHeaderCases = [][]byte{ 406 | []byte("foo-"), 407 | []byte("-foo"), 408 | []byte("-"), 409 | []byte("foo----bar"), 410 | []byte("foo-bar"), 411 | []byte("FoO-BaR"), 412 | []byte("Foo-Bar"), 413 | []byte("sec-websocket-extensions"), 414 | } 415 | 416 | func TestCanonicalizeHeaderKey(t *testing.T) { 417 | for _, bts := range canonicalHeaderCases { 418 | t.Run(string(bts), func(t *testing.T) { 419 | act := append([]byte(nil), bts...) 420 | canonicalizeHeaderKey(act) 421 | 422 | exp := strToBytes(textproto.CanonicalMIMEHeaderKey(string(bts))) 423 | 424 | if !bytes.Equal(act, exp) { 425 | t.Errorf( 426 | "canonicalizeHeaderKey(%v) = %v; want %v", 427 | string(bts), string(act), string(exp), 428 | ) 429 | } 430 | }) 431 | } 432 | } 433 | 434 | func BenchmarkCanonicalizeHeaderKey(b *testing.B) { 435 | for _, bts := range canonicalHeaderCases { 436 | b.Run(string(bts), func(b *testing.B) { 437 | for i := 0; i < b.N; i++ { 438 | canonicalizeHeaderKey(bts) 439 | } 440 | }) 441 | } 442 | } 443 | -------------------------------------------------------------------------------- /util_unsafe.go: -------------------------------------------------------------------------------- 1 | //go:build !purego 2 | // +build !purego 3 | 4 | package ws 5 | 6 | import ( 7 | "reflect" 8 | "unsafe" 9 | ) 10 | 11 | func strToBytes(str string) (bts []byte) { 12 | s := (*reflect.StringHeader)(unsafe.Pointer(&str)) 13 | b := (*reflect.SliceHeader)(unsafe.Pointer(&bts)) 14 | b.Data = s.Data 15 | b.Len = s.Len 16 | b.Cap = s.Len 17 | return bts 18 | } 19 | 20 | func btsToString(bts []byte) (str string) { 21 | return *(*string)(unsafe.Pointer(&bts)) 22 | } 23 | -------------------------------------------------------------------------------- /write.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | ) 7 | 8 | // Header size length bounds in bytes. 9 | const ( 10 | MaxHeaderSize = 14 11 | MinHeaderSize = 2 12 | ) 13 | 14 | const ( 15 | bit0 = 0x80 16 | bit1 = 0x40 17 | bit2 = 0x20 18 | bit3 = 0x10 19 | bit4 = 0x08 20 | bit5 = 0x04 21 | bit6 = 0x02 22 | bit7 = 0x01 23 | 24 | len7 = int64(125) 25 | len16 = int64(^(uint16(0))) 26 | len64 = int64(^(uint64(0)) >> 1) 27 | ) 28 | 29 | // HeaderSize returns number of bytes that are needed to encode given header. 30 | // It returns -1 if header is malformed. 31 | func HeaderSize(h Header) (n int) { 32 | switch { 33 | case h.Length < 126: 34 | n = 2 35 | case h.Length <= len16: 36 | n = 4 37 | case h.Length <= len64: 38 | n = 10 39 | default: 40 | return -1 41 | } 42 | if h.Masked { 43 | n += len(h.Mask) 44 | } 45 | return n 46 | } 47 | 48 | // WriteHeader writes header binary representation into w. 49 | func WriteHeader(w io.Writer, h Header) error { 50 | // Make slice of bytes with capacity 14 that could hold any header. 51 | bts := make([]byte, MaxHeaderSize) 52 | 53 | if h.Fin { 54 | bts[0] |= bit0 55 | } 56 | bts[0] |= h.Rsv << 4 57 | bts[0] |= byte(h.OpCode) 58 | 59 | var n int 60 | switch { 61 | case h.Length <= len7: 62 | bts[1] = byte(h.Length) 63 | n = 2 64 | 65 | case h.Length <= len16: 66 | bts[1] = 126 67 | binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length)) 68 | n = 4 69 | 70 | case h.Length <= len64: 71 | bts[1] = 127 72 | binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length)) 73 | n = 10 74 | 75 | default: 76 | return ErrHeaderLengthUnexpected 77 | } 78 | 79 | if h.Masked { 80 | bts[1] |= bit0 81 | n += copy(bts[n:], h.Mask[:]) 82 | } 83 | 84 | _, err := w.Write(bts[:n]) 85 | 86 | return err 87 | } 88 | 89 | // WriteFrame writes frame binary representation into w. 90 | func WriteFrame(w io.Writer, f Frame) error { 91 | err := WriteHeader(w, f.Header) 92 | if err != nil { 93 | return err 94 | } 95 | _, err = w.Write(f.Payload) 96 | return err 97 | } 98 | 99 | // MustWriteFrame is like WriteFrame but panics if frame can not be read. 100 | func MustWriteFrame(w io.Writer, f Frame) { 101 | if err := WriteFrame(w, f); err != nil { 102 | panic(err) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /write_test.go: -------------------------------------------------------------------------------- 1 | package ws 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "testing" 8 | ) 9 | 10 | func TestWriteHeader(t *testing.T) { 11 | for i, test := range RWTestCases { 12 | t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { 13 | buf := &bytes.Buffer{} 14 | err := WriteHeader(buf, test.Header) 15 | if test.Err && err == nil { 16 | t.Errorf("expected error, got nil") 17 | } 18 | if !test.Err && err != nil { 19 | t.Errorf("unexpected error: %s", err) 20 | } 21 | if test.Err { 22 | return 23 | } 24 | if bts := buf.Bytes(); !bytes.Equal(bts, test.Data) { 25 | t.Errorf("WriteHeader()\nwrote:\n\t%08b\nwant:\n\t%08b", bts, test.Data) 26 | } 27 | }) 28 | } 29 | } 30 | 31 | func BenchmarkWriteHeader(b *testing.B) { 32 | for _, bench := range RWBenchCases { 33 | b.Run(bench.label, func(b *testing.B) { 34 | for i := 0; i < b.N; i++ { 35 | if err := WriteHeader(ioutil.Discard, bench.header); err != nil { 36 | b.Fatal(err) 37 | } 38 | } 39 | }) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /wsflate/cbuf.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // cbuf is a tiny proxy-buffer that writes all but 4 last bytes to the 8 | // destination. 9 | type cbuf struct { 10 | buf [4]byte 11 | n int 12 | dst io.Writer 13 | err error 14 | } 15 | 16 | // Write implements io.Writer interface. 17 | func (c *cbuf) Write(p []byte) (int, error) { 18 | if c.err != nil { 19 | return 0, c.err 20 | } 21 | head, tail := c.split(p) 22 | n := c.n + len(tail) 23 | if n > len(c.buf) { 24 | x := n - len(c.buf) 25 | c.flush(c.buf[:x]) 26 | copy(c.buf[:], c.buf[x:]) 27 | c.n -= x 28 | } 29 | if len(head) > 0 { 30 | c.flush(head) 31 | } 32 | copy(c.buf[c.n:], tail) 33 | c.n = min(c.n+len(tail), len(c.buf)) 34 | return len(p), c.err 35 | } 36 | 37 | func (c *cbuf) flush(p []byte) { 38 | if c.err == nil { 39 | _, c.err = c.dst.Write(p) 40 | } 41 | } 42 | 43 | func (c *cbuf) split(p []byte) (head, tail []byte) { 44 | if n := len(p); n > len(c.buf) { 45 | x := n - len(c.buf) 46 | head = p[:x] 47 | tail = p[x:] 48 | return head, tail 49 | } 50 | return nil, p 51 | } 52 | 53 | func (c *cbuf) reset(dst io.Writer) { 54 | c.n = 0 55 | c.err = nil 56 | c.buf = [4]byte{0, 0, 0, 0} 57 | c.dst = dst 58 | } 59 | 60 | type suffixedReader struct { 61 | r io.Reader 62 | pos int // position in the suffix. 63 | suffix [9]byte 64 | 65 | rx struct{ io.Reader } 66 | } 67 | 68 | func (r *suffixedReader) iface() io.Reader { 69 | if _, ok := r.r.(io.ByteReader); ok { 70 | // If source io.Reader implements io.ByteReader, return full set of 71 | // methods from suffixedReader struct (Read() and ReadByte()). 72 | // This actually is an optimization needed for those Decompressor 73 | // implementations (such as default flate.Reader) which do check if 74 | // given source is already "buffered" by checking if source implements 75 | // io.ByteReader. So without this checks we will always result in 76 | // double-buffering for default decompressors. 77 | return r 78 | } 79 | // Source io.Reader doesn't support io.ByteReader, so we should cut off the 80 | // ReadByte() method from suffixedReader struct. We use r.srx field to 81 | // avoid allocations. 82 | r.rx.Reader = r 83 | return &r.rx 84 | } 85 | 86 | func (r *suffixedReader) Read(p []byte) (n int, err error) { 87 | if r.r != nil { 88 | n, err = r.r.Read(p) 89 | if err == io.EOF { 90 | err = nil 91 | r.r = nil 92 | } 93 | return n, err 94 | } 95 | if r.pos >= len(r.suffix) { 96 | return 0, io.EOF 97 | } 98 | n = copy(p, r.suffix[r.pos:]) 99 | r.pos += n 100 | return n, nil 101 | } 102 | 103 | func (r *suffixedReader) ReadByte() (b byte, err error) { 104 | if r.r != nil { 105 | br, ok := r.r.(io.ByteReader) 106 | if !ok { 107 | panic("wsflate: internal error: incorrect use of suffixedReader") 108 | } 109 | b, err = br.ReadByte() 110 | if err == io.EOF { 111 | err = nil 112 | r.r = nil 113 | } 114 | return b, err 115 | } 116 | if r.pos >= len(r.suffix) { 117 | return 0, io.EOF 118 | } 119 | b = r.suffix[r.pos] 120 | r.pos++ 121 | return b, nil 122 | } 123 | 124 | func (r *suffixedReader) reset(src io.Reader) { 125 | r.r = src 126 | r.pos = 0 127 | } 128 | 129 | func min(a, b int) int { 130 | if a < b { 131 | return a 132 | } 133 | return b 134 | } 135 | -------------------------------------------------------------------------------- /wsflate/cbuf_test.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "testing" 9 | ) 10 | 11 | func TestSuffixReader(t *testing.T) { 12 | for chunk := 1; chunk < 100; chunk++ { 13 | var ( 14 | data = []byte("hello, flate!") 15 | name = fmt.Sprintf("chunk-%d", chunk) 16 | ) 17 | t.Run(name, func(t *testing.T) { 18 | r := suffixedReader{ 19 | r: bytes.NewReader(data), 20 | suffix: [9]byte{ 21 | 1, 2, 3, 22 | 4, 5, 6, 23 | 7, 8, 9, 24 | }, 25 | } 26 | var ( 27 | act = make([]byte, 0, len(data)+len(r.suffix)) 28 | p = make([]byte, chunk) 29 | ) 30 | for len(act) < cap(act) { 31 | n, err := r.Read(p) 32 | act = append(act, p[:n]...) 33 | if err == io.EOF { 34 | break 35 | } 36 | if err != nil { 37 | t.Fatalf("unexpected Read() error: %v", err) 38 | } 39 | } 40 | exp := append(data, r.suffix[:]...) 41 | if !bytes.Equal(act, exp) { 42 | t.Fatalf("unexpected bytes read: %#q; want %#q", act, exp) 43 | } 44 | }) 45 | } 46 | } 47 | 48 | func TestCBuf(t *testing.T) { 49 | for _, test := range []struct { 50 | name string 51 | stream [][]byte 52 | expBody []byte 53 | expTail []byte 54 | }{ 55 | { 56 | stream: [][]byte{ 57 | {1}, {2}, {3}, {4}, 58 | }, 59 | expTail: []byte{1, 2, 3, 4}, 60 | }, 61 | { 62 | stream: [][]byte{ 63 | {1, 2}, {3, 4}, 64 | }, 65 | expTail: []byte{1, 2, 3, 4}, 66 | }, 67 | { 68 | stream: [][]byte{ 69 | {1, 2, 3}, {4, 5, 6}, 70 | }, 71 | expBody: []byte{1, 2}, 72 | expTail: []byte{3, 4, 5, 6}, 73 | }, 74 | { 75 | stream: [][]byte{ 76 | {1, 2, 3, 4}, {5, 6, 7, 8}, 77 | }, 78 | expBody: []byte{1, 2, 3, 4}, 79 | expTail: []byte{5, 6, 7, 8}, 80 | }, 81 | { 82 | stream: [][]byte{ 83 | {1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}, 84 | }, 85 | expBody: []byte{1, 2, 3, 4, 5, 6}, 86 | expTail: []byte{7, 8, 9, 10}, 87 | }, 88 | { 89 | stream: [][]byte{ 90 | {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 91 | }, 92 | expBody: []byte{1, 2, 3, 4, 5, 6}, 93 | expTail: []byte{7, 8, 9, 10}, 94 | }, 95 | { 96 | name: "xxx", 97 | stream: [][]byte{ 98 | {1, 2, 3, 4, 5}, {6}, 99 | }, 100 | expBody: []byte{1, 2}, 101 | expTail: []byte{3, 4, 5, 6}, 102 | }, 103 | } { 104 | t.Run(test.name, func(t *testing.T) { 105 | var buf bytes.Buffer 106 | w := &cbuf{ 107 | dst: &buf, 108 | } 109 | for _, bts := range test.stream { 110 | n, err := w.Write(bts) 111 | if err != nil { 112 | t.Fatalf("unexpected error: %v", err) 113 | } 114 | if act, exp := n, len(bts); act != exp { 115 | t.Fatalf( 116 | "unexpected number of bytes written: %d; want %d", 117 | act, exp, 118 | ) 119 | } 120 | } 121 | if act, exp := w.buf[:], test.expTail; !bytes.Equal(act, exp) { 122 | t.Errorf( 123 | "unexpected tail: %v; want %v", 124 | act, exp, 125 | ) 126 | } 127 | if act, exp := buf.Bytes(), test.expBody; !bytes.Equal(act, exp) { 128 | t.Errorf( 129 | "unexpected body: %v; want %v", 130 | act, exp, 131 | ) 132 | } 133 | }) 134 | } 135 | } 136 | 137 | func BenchmarkCBuf(b *testing.B) { 138 | for _, test := range []struct { 139 | name string 140 | chunk []byte 141 | }{ 142 | { 143 | chunk: []byte{1, 2, 3, 4, 5}, 144 | }, 145 | { 146 | chunk: []byte{1, 2, 3, 4}, 147 | }, 148 | } { 149 | b.Run(test.name, func(b *testing.B) { 150 | w := &cbuf{ 151 | dst: ioutil.Discard, 152 | } 153 | for i := 0; i < b.N; i++ { 154 | w.Write(test.chunk) 155 | } 156 | }) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /wsflate/extension.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "bytes" 5 | 6 | "github.com/gobwas/httphead" 7 | "github.com/gobwas/ws" 8 | ) 9 | 10 | // Extension contains logic of compression extension parameters negotiation 11 | // made during HTTP WebSocket handshake. 12 | // It might be reused between different upgrades (but not concurrently) with 13 | // Reset() being called after each. 14 | type Extension struct { 15 | // Parameters is specification of extension parameters server is going to 16 | // accept. 17 | Parameters Parameters 18 | 19 | accepted bool 20 | params Parameters 21 | } 22 | 23 | // Negotiate parses given HTTP header option and returns (if any) header option 24 | // which describes accepted parameters. 25 | // 26 | // It may return zero option (i.e. one which Size() returns 0) alongside with 27 | // nil error. 28 | func (n *Extension) Negotiate(opt httphead.Option) (accept httphead.Option, err error) { 29 | if !bytes.Equal(opt.Name, ExtensionNameBytes) { 30 | return accept, nil 31 | } 32 | if n.accepted { 33 | // Negotiate might be called multiple times during upgrade. 34 | // We stick to first one accepted extension since they must be passed 35 | // in ordered by preference. 36 | return accept, nil 37 | } 38 | 39 | want := n.Parameters 40 | 41 | // NOTE: Parse() resets params inside, so no worries. 42 | if err := n.params.Parse(opt); err != nil { 43 | return accept, err 44 | } 45 | { 46 | offer := n.params.ServerMaxWindowBits 47 | want := want.ServerMaxWindowBits 48 | if offer > want { 49 | // A server declines an extension negotiation offer 50 | // with this parameter if the server doesn't support 51 | // it. 52 | return accept, nil 53 | } 54 | } 55 | { 56 | // If a received extension negotiation offer has the 57 | // "client_max_window_bits" extension parameter, the server MAY 58 | // include the "client_max_window_bits" extension parameter in the 59 | // corresponding extension negotiation response to the offer. 60 | offer := n.params.ClientMaxWindowBits 61 | want := want.ClientMaxWindowBits 62 | if want > offer { 63 | return accept, nil 64 | } 65 | } 66 | { 67 | offer := n.params.ServerNoContextTakeover 68 | want := want.ServerNoContextTakeover 69 | if offer && !want { 70 | return accept, nil 71 | } 72 | } 73 | 74 | n.accepted = true 75 | 76 | return want.Option(), nil 77 | } 78 | 79 | // Accepted returns parameters parsed during last negotiation and a flag that 80 | // reports whether they were accepted. 81 | func (n *Extension) Accepted() (_ Parameters, accepted bool) { 82 | return n.params, n.accepted 83 | } 84 | 85 | // Reset resets extension for further reuse. 86 | func (n *Extension) Reset() { 87 | n.accepted = false 88 | n.params = Parameters{} 89 | } 90 | 91 | var ErrUnexpectedCompressionBit = ws.ProtocolError( 92 | "control frame or non-first fragment of data contains compression bit set", 93 | ) 94 | 95 | // UnsetBit clears the Per-Message Compression bit in header h and returns its 96 | // modified copy. It reports whether compression bit was set in header h. 97 | // It returns non-nil error if compression bit has unexpected value. 98 | // 99 | // This function's main purpose is to be compatible with "Framing" section of 100 | // the Compression Extensions for WebSocket RFC. If you don't need to work with 101 | // chains of extensions then IsCompressed() could be enough to check if 102 | // message is compressed. 103 | // See https://tools.ietf.org/html/rfc7692#section-6.2 104 | func UnsetBit(h ws.Header) (_ ws.Header, wasSet bool, err error) { 105 | var s MessageState 106 | h, err = s.UnsetBits(h) 107 | return h, s.IsCompressed(), err 108 | } 109 | 110 | // SetBit sets the Per-Message Compression bit in header h and returns its 111 | // modified copy. 112 | // It returns non-nil error if compression bit has unexpected value. 113 | func SetBit(h ws.Header) (_ ws.Header, err error) { 114 | var s MessageState 115 | s.SetCompressed(true) 116 | return s.SetBits(h) 117 | } 118 | 119 | // IsCompressed reports whether the Per-Message Compression bit is set in 120 | // header h. 121 | // It returns non-nil error if compression bit has unexpected value. 122 | // 123 | // If you need to be fully compatible with Compression Extensions for WebSocket 124 | // RFC and work with chains of extensions, take a look at the UnsetBit() 125 | // instead. That is, IsCompressed() is a shortcut for UnsetBit() with reduced 126 | // number of return values. 127 | func IsCompressed(h ws.Header) (bool, error) { 128 | _, isSet, err := UnsetBit(h) 129 | return isSet, err 130 | } 131 | 132 | // MessageState holds message compression state. 133 | // 134 | // It is consulted during SetBits(h) call to make a decision whether we must 135 | // set the Per-Message Compression bit for given header h argument. 136 | // It is updated during UnsetBits(h) to reflect compression state of a message 137 | // represented by header h argument. 138 | // It can also be consulted/updated directly by calling 139 | // IsCompressed()/SetCompressed(). 140 | // 141 | // In general MessageState should be used when there is no direct access to 142 | // connection to read frame from, but it is still needed to know if message 143 | // being read is compressed. For other cases SetBit() and UnsetBit() should be 144 | // used instead. 145 | // 146 | // NOTE: the compression state is updated during UnsetBits(h) only when header 147 | // h argument represents data (text or binary) frame. 148 | type MessageState struct { 149 | compressed bool 150 | } 151 | 152 | // SetCompressed marks message as "compressed" or "uncompressed". 153 | // See https://tools.ietf.org/html/rfc7692#section-6 154 | func (s *MessageState) SetCompressed(v bool) { 155 | s.compressed = v 156 | } 157 | 158 | // IsCompressed reports whether message is "compressed". 159 | // See https://tools.ietf.org/html/rfc7692#section-6 160 | func (s *MessageState) IsCompressed() bool { 161 | return s.compressed 162 | } 163 | 164 | // UnsetBits changes RSV bits of the given frame header h as if compression 165 | // extension was negotiated. It returns modified copy of h and error if header 166 | // is malformed from the RFC perspective. 167 | func (s *MessageState) UnsetBits(h ws.Header) (ws.Header, error) { 168 | r1, r2, r3 := ws.RsvBits(h.Rsv) 169 | switch { 170 | case h.OpCode.IsData() && h.OpCode != ws.OpContinuation: 171 | h.Rsv = ws.Rsv(false, r2, r3) 172 | s.SetCompressed(r1) 173 | return h, nil 174 | 175 | case r1: 176 | // An endpoint MUST NOT set the "Per-Message Compressed" 177 | // bit of control frames and non-first fragments of a data 178 | // message. An endpoint receiving such a frame MUST _Fail 179 | // the WebSocket Connection_. 180 | return h, ErrUnexpectedCompressionBit 181 | 182 | default: 183 | // NOTE: do not change the state of s.compressed since UnsetBits() 184 | // might also be called for (intermediate) control frames. 185 | return h, nil 186 | } 187 | } 188 | 189 | // SetBits changes RSV bits of the frame header h which is being send as if 190 | // compression extension was negotiated. It returns modified copy of h and 191 | // error if header is malformed from the RFC perspective. 192 | func (s *MessageState) SetBits(h ws.Header) (ws.Header, error) { 193 | r1, r2, r3 := ws.RsvBits(h.Rsv) 194 | if r1 { 195 | return h, ErrUnexpectedCompressionBit 196 | } 197 | if !h.OpCode.IsData() || h.OpCode == ws.OpContinuation { 198 | // An endpoint MUST NOT set the "Per-Message Compressed" 199 | // bit of control frames and non-first fragments of a data 200 | // message. An endpoint receiving such a frame MUST _Fail 201 | // the WebSocket Connection_. 202 | return h, nil 203 | } 204 | if s.IsCompressed() { 205 | h.Rsv = ws.Rsv(true, r2, r3) 206 | } 207 | return h, nil 208 | } 209 | -------------------------------------------------------------------------------- /wsflate/helper.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "bytes" 5 | "compress/flate" 6 | "fmt" 7 | "io" 8 | 9 | "github.com/gobwas/ws" 10 | ) 11 | 12 | // DefaultHelper is a default helper instance holding standard library's 13 | // `compress/flate` compressor and decompressor under the hood. 14 | // 15 | // Note that use of DefaultHelper methods assumes that DefaultParameters were 16 | // used for extension negotiation during WebSocket handshake. 17 | var DefaultHelper = Helper{ 18 | Compressor: func(w io.Writer) Compressor { 19 | // No error can be returned here as NewWriter() doc says. 20 | f, _ := flate.NewWriter(w, 9) 21 | return f 22 | }, 23 | Decompressor: func(r io.Reader) Decompressor { 24 | return flate.NewReader(r) 25 | }, 26 | } 27 | 28 | // DefaultParameters holds deflate extension parameters which are assumed by 29 | // DefaultHelper to be used during WebSocket handshake. 30 | var DefaultParameters = Parameters{ 31 | ServerNoContextTakeover: true, 32 | ClientNoContextTakeover: true, 33 | } 34 | 35 | // CompressFrame is a shortcut for DefaultHelper.CompressFrame(). 36 | // 37 | // Note that use of DefaultHelper methods assumes that DefaultParameters were 38 | // used for extension negotiation during WebSocket handshake. 39 | func CompressFrame(f ws.Frame) (ws.Frame, error) { 40 | return DefaultHelper.CompressFrame(f) 41 | } 42 | 43 | // CompressFrameBuffer is a shortcut for DefaultHelper.CompressFrameBuffer(). 44 | // 45 | // Note that use of DefaultHelper methods assumes that DefaultParameters were 46 | // used for extension negotiation during WebSocket handshake. 47 | func CompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) { 48 | return DefaultHelper.CompressFrameBuffer(buf, f) 49 | } 50 | 51 | // DecompressFrame is a shortcut for DefaultHelper.DecompressFrame(). 52 | // 53 | // Note that use of DefaultHelper methods assumes that DefaultParameters were 54 | // used for extension negotiation during WebSocket handshake. 55 | func DecompressFrame(f ws.Frame) (ws.Frame, error) { 56 | return DefaultHelper.DecompressFrame(f) 57 | } 58 | 59 | // DecompressFrameBuffer is a shortcut for 60 | // DefaultHelper.DecompressFrameBuffer(). 61 | // 62 | // Note that use of DefaultHelper methods assumes that DefaultParameters were 63 | // used for extension negotiation during WebSocket handshake. 64 | func DecompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) { 65 | return DefaultHelper.DecompressFrameBuffer(buf, f) 66 | } 67 | 68 | // Helper is a helper struct that holds common code for compression and 69 | // decompression bytes or WebSocket frames. 70 | // 71 | // Its purpose is to reduce boilerplate code in WebSocket applications. 72 | type Helper struct { 73 | Compressor func(w io.Writer) Compressor 74 | Decompressor func(r io.Reader) Decompressor 75 | } 76 | 77 | // Buffer is an interface representing some bytes buffering object. 78 | type Buffer interface { 79 | io.Writer 80 | Bytes() []byte 81 | } 82 | 83 | // CompressFrame returns compressed version of a frame. 84 | // Note that it does memory allocations internally. To control those 85 | // allocations consider using CompressFrameBuffer(). 86 | func (h *Helper) CompressFrame(in ws.Frame) (f ws.Frame, err error) { 87 | var buf bytes.Buffer 88 | return h.CompressFrameBuffer(&buf, in) 89 | } 90 | 91 | // DecompressFrame returns decompressed version of a frame. 92 | // Note that it does memory allocations internally. To control those 93 | // allocations consider using DecompressFrameBuffer(). 94 | func (h *Helper) DecompressFrame(in ws.Frame) (f ws.Frame, err error) { 95 | var buf bytes.Buffer 96 | return h.DecompressFrameBuffer(&buf, in) 97 | } 98 | 99 | // CompressFrameBuffer compresses a frame using given buffer. 100 | // Returned frame's payload holds bytes returned by buf.Bytes(). 101 | func (h *Helper) CompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) { 102 | if !f.Header.Fin { 103 | return f, fmt.Errorf("wsflate: fragmented messages are not allowed") 104 | } 105 | if err := h.CompressTo(buf, f.Payload); err != nil { 106 | return f, err 107 | } 108 | var err error 109 | f.Payload = buf.Bytes() 110 | f.Header.Length = int64(len(f.Payload)) 111 | f.Header, err = SetBit(f.Header) 112 | if err != nil { 113 | return f, err 114 | } 115 | return f, nil 116 | } 117 | 118 | // DecompressFrameBuffer decompresses a frame using given buffer. 119 | // Returned frame's payload holds bytes returned by buf.Bytes(). 120 | func (h *Helper) DecompressFrameBuffer(buf Buffer, f ws.Frame) (ws.Frame, error) { 121 | if !f.Header.Fin { 122 | return f, fmt.Errorf( 123 | "wsflate: fragmented messages are not supported by helper", 124 | ) 125 | } 126 | var ( 127 | compressed bool 128 | err error 129 | ) 130 | f.Header, compressed, err = UnsetBit(f.Header) 131 | if err != nil { 132 | return f, err 133 | } 134 | if !compressed { 135 | return f, nil 136 | } 137 | if err := h.DecompressTo(buf, f.Payload); err != nil { 138 | return f, err 139 | } 140 | 141 | f.Payload = buf.Bytes() 142 | f.Header.Length = int64(len(f.Payload)) 143 | 144 | return f, nil 145 | } 146 | 147 | // Compress compresses given bytes. 148 | // Note that it does memory allocations internally. To control those 149 | // allocations consider using CompressTo(). 150 | func (h *Helper) Compress(p []byte) ([]byte, error) { 151 | var buf bytes.Buffer 152 | if err := h.CompressTo(&buf, p); err != nil { 153 | return nil, err 154 | } 155 | return buf.Bytes(), nil 156 | } 157 | 158 | // Decompress decompresses given bytes. 159 | // Note that it does memory allocations internally. To control those 160 | // allocations consider using DecompressTo(). 161 | func (h *Helper) Decompress(p []byte) ([]byte, error) { 162 | var buf bytes.Buffer 163 | if err := h.DecompressTo(&buf, p); err != nil { 164 | return nil, err 165 | } 166 | return buf.Bytes(), nil 167 | } 168 | 169 | // CompressTo compresses bytes into given buffer. 170 | func (h *Helper) CompressTo(w io.Writer, p []byte) (err error) { 171 | c := NewWriter(w, h.Compressor) 172 | if _, err = c.Write(p); err != nil { 173 | return err 174 | } 175 | if err := c.Flush(); err != nil { 176 | return err 177 | } 178 | if err := c.Close(); err != nil { 179 | return err 180 | } 181 | return nil 182 | } 183 | 184 | // DecompressTo decompresses bytes into given buffer. 185 | // Returned bytes are bytes returned by buf.Bytes(). 186 | func (h *Helper) DecompressTo(w io.Writer, p []byte) (err error) { 187 | fr := NewReader(bytes.NewReader(p), h.Decompressor) 188 | if _, err = io.Copy(w, fr); err != nil { 189 | return err 190 | } 191 | if err := fr.Close(); err != nil { 192 | return err 193 | } 194 | return nil 195 | } 196 | -------------------------------------------------------------------------------- /wsflate/helper_test.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/gobwas/ws" 8 | ) 9 | 10 | func TestHelperWriteAndRead(t *testing.T) { 11 | const text = "hello, wsflate!" 12 | f := ws.NewTextFrame([]byte(text)) 13 | c, err := CompressFrame(f) 14 | if err != nil { 15 | t.Fatalf("can't compress frame: %v", err) 16 | } 17 | d, err := DecompressFrame(c) 18 | if err != nil { 19 | t.Fatalf("can't decompress frame: %v", err) 20 | } 21 | if f.Header != d.Header { 22 | t.Fatalf("original and decompressed headers are not equal") 23 | } 24 | if !bytes.Equal(f.Payload, d.Payload) { 25 | t.Fatalf("original and decompressed payload are not equal") 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /wsflate/parameters.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | 7 | "github.com/gobwas/httphead" 8 | ) 9 | 10 | const ( 11 | ExtensionName = "permessage-deflate" 12 | 13 | serverNoContextTakeover = "server_no_context_takeover" 14 | clientNoContextTakeover = "client_no_context_takeover" 15 | serverMaxWindowBits = "server_max_window_bits" 16 | clientMaxWindowBits = "client_max_window_bits" 17 | ) 18 | 19 | var ( 20 | ExtensionNameBytes = []byte(ExtensionName) 21 | 22 | serverNoContextTakeoverBytes = []byte(serverNoContextTakeover) 23 | clientNoContextTakeoverBytes = []byte(clientNoContextTakeover) 24 | serverMaxWindowBitsBytes = []byte(serverMaxWindowBits) 25 | clientMaxWindowBitsBytes = []byte(clientMaxWindowBits) 26 | ) 27 | 28 | var windowBits [8][]byte 29 | 30 | func init() { 31 | for i := range windowBits { 32 | windowBits[i] = []byte(strconv.Itoa(i + 8)) 33 | } 34 | } 35 | 36 | // Parameters contains compression extension options. 37 | type Parameters struct { 38 | ServerNoContextTakeover bool 39 | ClientNoContextTakeover bool 40 | ServerMaxWindowBits WindowBits 41 | ClientMaxWindowBits WindowBits 42 | } 43 | 44 | // WindowBits specifies window size accordingly to RFC. 45 | // Use its Bytes() method to obtain actual size of window in bytes. 46 | type WindowBits byte 47 | 48 | // Defined reports whether window bits were specified. 49 | func (b WindowBits) Defined() bool { 50 | return b > 0 51 | } 52 | 53 | // Bytes returns window size in number of bytes. 54 | func (b WindowBits) Bytes() int { 55 | return 1 << uint(b) 56 | } 57 | 58 | const ( 59 | MaxLZ77WindowSize = 32768 // 2^15 60 | ) 61 | 62 | // Parse reads parameters from given HTTP header option accordingly to RFC. 63 | // 64 | // It returns non-nil error at least in these cases: 65 | // - The negotiation offer contains an extension parameter not defined for 66 | // use in an offer/response. 67 | // - The negotiation offer/response contains an extension parameter with an 68 | // invalid value. 69 | // - The negotiation offer/response contains multiple extension parameters 70 | // with the same name. 71 | func (p *Parameters) Parse(opt httphead.Option) (err error) { 72 | const ( 73 | clientMaxWindowBitsSeen = 1 << iota 74 | serverMaxWindowBitsSeen 75 | clientNoContextTakeoverSeen 76 | serverNoContextTakeoverSeen 77 | ) 78 | 79 | // Reset to not mix parsed data from previous Parse() calls. 80 | *p = Parameters{} 81 | 82 | var seen byte 83 | opt.Parameters.ForEach(func(key, val []byte) (ok bool) { 84 | switch string(key) { 85 | case clientMaxWindowBits: 86 | if len(val) == 0 { 87 | p.ClientMaxWindowBits = 1 88 | return true 89 | } 90 | if seen&clientMaxWindowBitsSeen != 0 { 91 | err = paramError("duplicate", key, val) 92 | return false 93 | } 94 | seen |= clientMaxWindowBitsSeen 95 | if p.ClientMaxWindowBits, ok = bitsFromASCII(val); !ok { 96 | err = paramError("invalid", key, val) 97 | return false 98 | } 99 | 100 | case serverMaxWindowBits: 101 | if len(val) == 0 { 102 | err = paramError("invalid", key, val) 103 | return false 104 | } 105 | if seen&serverMaxWindowBitsSeen != 0 { 106 | err = paramError("duplicate", key, val) 107 | return false 108 | } 109 | seen |= serverMaxWindowBitsSeen 110 | if p.ServerMaxWindowBits, ok = bitsFromASCII(val); !ok { 111 | err = paramError("invalid", key, val) 112 | return false 113 | } 114 | 115 | case clientNoContextTakeover: 116 | if len(val) > 0 { 117 | err = paramError("invalid", key, val) 118 | return false 119 | } 120 | if seen&clientNoContextTakeoverSeen != 0 { 121 | err = paramError("duplicate", key, val) 122 | return false 123 | } 124 | seen |= clientNoContextTakeoverSeen 125 | p.ClientNoContextTakeover = true 126 | 127 | case serverNoContextTakeover: 128 | if len(val) > 0 { 129 | err = paramError("invalid", key, val) 130 | return false 131 | } 132 | if seen&serverNoContextTakeoverSeen != 0 { 133 | err = paramError("duplicate", key, val) 134 | return false 135 | } 136 | seen |= serverNoContextTakeoverSeen 137 | p.ServerNoContextTakeover = true 138 | 139 | default: 140 | err = paramError("unexpected", key, val) 141 | return false 142 | } 143 | return true 144 | }) 145 | return err 146 | } 147 | 148 | // Option encodes parameters into HTTP header option. 149 | func (p Parameters) Option() httphead.Option { 150 | opt := httphead.Option{ 151 | Name: ExtensionNameBytes, 152 | } 153 | setBool(&opt, serverNoContextTakeoverBytes, p.ServerNoContextTakeover) 154 | setBool(&opt, clientNoContextTakeoverBytes, p.ClientNoContextTakeover) 155 | setBits(&opt, serverMaxWindowBitsBytes, p.ServerMaxWindowBits) 156 | setBits(&opt, clientMaxWindowBitsBytes, p.ClientMaxWindowBits) 157 | return opt 158 | } 159 | 160 | func isValidBits(x int) bool { 161 | return 8 <= x && x <= 15 162 | } 163 | 164 | func bitsFromASCII(p []byte) (WindowBits, bool) { 165 | n, ok := httphead.IntFromASCII(p) 166 | if !ok || !isValidBits(n) { 167 | return 0, false 168 | } 169 | return WindowBits(n), true 170 | } 171 | 172 | func setBits(opt *httphead.Option, name []byte, bits WindowBits) { 173 | if bits == 0 { 174 | return 175 | } 176 | if bits == 1 { 177 | opt.Parameters.Set(name, nil) 178 | return 179 | } 180 | if !isValidBits(int(bits)) { 181 | panic(fmt.Sprintf("wsflate: invalid bits value: %d", bits)) 182 | } 183 | opt.Parameters.Set(name, windowBits[bits-8]) 184 | } 185 | 186 | func setBool(opt *httphead.Option, name []byte, flag bool) { 187 | if flag { 188 | opt.Parameters.Set(name, nil) 189 | } 190 | } 191 | 192 | func paramError(reason string, key, val []byte) error { 193 | return fmt.Errorf( 194 | "wsflate: %s extension parameter %q: %q", 195 | reason, key, val, 196 | ) 197 | } 198 | -------------------------------------------------------------------------------- /wsflate/parameters_test.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | -------------------------------------------------------------------------------- /wsflate/reader.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // Decompressor is an interface holding deflate decompression implementation. 8 | type Decompressor interface { 9 | io.Reader 10 | } 11 | 12 | // ReadResetter is an optional interface that Decompressor can implement. 13 | type ReadResetter interface { 14 | Reset(io.Reader) 15 | } 16 | 17 | // Reader implements decompression from an io.Reader object using Decompressor. 18 | // Essentially Reader is a thin wrapper around Decompressor interface to meet 19 | // PMCE specs. 20 | // 21 | // After all data has been written client should call Flush() method. 22 | // If any error occurs after reading from Reader, all subsequent calls to 23 | // Read() or Close() will return the error. 24 | // 25 | // Reader might be reused for different io.Reader objects after its Reset() 26 | // method has been called. 27 | type Reader struct { 28 | src io.Reader 29 | ctor func(io.Reader) Decompressor 30 | d Decompressor 31 | sr suffixedReader 32 | err error 33 | } 34 | 35 | // NewReader returns a new Reader. 36 | func NewReader(r io.Reader, ctor func(io.Reader) Decompressor) *Reader { 37 | ret := &Reader{ 38 | src: r, 39 | ctor: ctor, 40 | sr: suffixedReader{ 41 | suffix: compressionReadTail, 42 | }, 43 | } 44 | ret.Reset(r) 45 | return ret 46 | } 47 | 48 | // Reset resets Reader to decompress data from src. 49 | func (r *Reader) Reset(src io.Reader) { 50 | r.err = nil 51 | r.src = src 52 | r.sr.reset(src) 53 | 54 | if x, ok := r.d.(ReadResetter); ok { 55 | x.Reset(r.sr.iface()) 56 | } else { 57 | r.d = r.ctor(r.sr.iface()) 58 | } 59 | } 60 | 61 | // Read implements io.Reader. 62 | func (r *Reader) Read(p []byte) (n int, err error) { 63 | if r.err != nil { 64 | return 0, r.err 65 | } 66 | return r.d.Read(p) 67 | } 68 | 69 | // Close closes Reader and a Decompressor instance used under the hood (if it 70 | // implements io.Closer interface). 71 | func (r *Reader) Close() error { 72 | if r.err != nil { 73 | return r.err 74 | } 75 | if c, ok := r.d.(io.Closer); ok { 76 | r.err = c.Close() 77 | } 78 | return r.err 79 | } 80 | 81 | // Err returns an error happened during any operation. 82 | func (r *Reader) Err() error { 83 | return r.err 84 | } 85 | -------------------------------------------------------------------------------- /wsflate/reader_test.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "testing" 8 | ) 9 | 10 | func TestSuffixedReaderIface(t *testing.T) { 11 | for _, test := range []struct { 12 | src io.Reader 13 | exp bool 14 | }{ 15 | { 16 | src: bytes.NewReader(nil), 17 | exp: true, 18 | }, 19 | { 20 | src: io.TeeReader(nil, nil), 21 | exp: false, 22 | }, 23 | } { 24 | t.Run(fmt.Sprintf("%T", test.src), func(t *testing.T) { 25 | isByteReader := func(r io.Reader) bool { 26 | _, ok := r.(io.ByteReader) 27 | return ok 28 | } 29 | s := &suffixedReader{ 30 | r: test.src, 31 | } 32 | if act, exp := isByteReader(s.iface()), test.exp; act != exp { 33 | t.Fatalf("unexpected io.ByteReader: %t; want %t", act, exp) 34 | } 35 | }) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /wsflate/writer.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | var ( 9 | compressionTail = [4]byte{ 10 | 0, 0, 0xff, 0xff, 11 | } 12 | compressionReadTail = [9]byte{ 13 | 0, 0, 0xff, 0xff, 14 | 1, 15 | 0, 0, 0xff, 0xff, 16 | } 17 | ) 18 | 19 | // Compressor is an interface holding deflate compression implementation. 20 | type Compressor interface { 21 | io.Writer 22 | Flush() error 23 | } 24 | 25 | // WriteResetter is an optional interface that Compressor can implement. 26 | type WriteResetter interface { 27 | Reset(io.Writer) 28 | } 29 | 30 | // Writer implements compression for an io.Writer object using Compressor. 31 | // Essentially Writer is a thin wrapper around Compressor interface to meet 32 | // PMCE specs. 33 | // 34 | // After all data has been written client should call Flush() method. 35 | // If any error occurs after writing to or flushing a Writer, all subsequent 36 | // calls to Write(), Flush() or Close() will return the error. 37 | // 38 | // Writer might be reused for different io.Writer objects after its Reset() 39 | // method has been called. 40 | type Writer struct { 41 | // NOTE: Writer uses compressor constructor function instead of field to 42 | // reach these goals: 43 | // 1. To shrink Compressor interface and make it easier to be implemented. 44 | // 2. If used as a field (and argument to the NewWriter()), Compressor object 45 | // will probably be initialized twice - first time to pass into Writer, and 46 | // second time during Writer initialization (which does Reset() internally). 47 | // 3. To get rid of wrappers if Reset() would be a part of Compressor. 48 | // E.g. non conformant implementations would have to provide it somehow, 49 | // probably making a wrapper with the same constructor function. 50 | // 4. To make Reader and Writer API the same. That is, there is no Reset() 51 | // method for flate.Reader already, so we need to provide it as a wrapper 52 | // (see point #3), or drop the Reader.Reset() method. 53 | dest io.Writer 54 | ctor func(io.Writer) Compressor 55 | c Compressor 56 | cbuf cbuf 57 | err error 58 | } 59 | 60 | // NewWriter returns a new Writer. 61 | func NewWriter(w io.Writer, ctor func(io.Writer) Compressor) *Writer { 62 | // NOTE: NewWriter() is chosen against structure with exported fields here 63 | // due its Reset() method, which in case of structure, would change 64 | // exported field. 65 | ret := &Writer{ 66 | dest: w, 67 | ctor: ctor, 68 | } 69 | ret.Reset(w) 70 | return ret 71 | } 72 | 73 | // Reset resets Writer to compress data into dest. 74 | // Any not flushed data will be lost. 75 | func (w *Writer) Reset(dest io.Writer) { 76 | w.err = nil 77 | w.cbuf.reset(dest) 78 | if x, ok := w.c.(WriteResetter); ok { 79 | x.Reset(&w.cbuf) 80 | } else { 81 | w.c = w.ctor(&w.cbuf) 82 | } 83 | } 84 | 85 | // Write implements io.Writer. 86 | func (w *Writer) Write(p []byte) (n int, err error) { 87 | if w.err != nil { 88 | return 0, w.err 89 | } 90 | n, w.err = w.c.Write(p) 91 | return n, w.err 92 | } 93 | 94 | // Flush writes any pending data into w.Dest. 95 | func (w *Writer) Flush() error { 96 | if w.err != nil { 97 | return w.err 98 | } 99 | w.err = w.c.Flush() 100 | w.checkTail() 101 | return w.err 102 | } 103 | 104 | // Close closes Writer and a Compressor instance used under the hood (if it 105 | // implements io.Closer interface). 106 | func (w *Writer) Close() error { 107 | if w.err != nil { 108 | return w.err 109 | } 110 | if c, ok := w.c.(io.Closer); ok { 111 | w.err = c.Close() 112 | } 113 | w.checkTail() 114 | return w.err 115 | } 116 | 117 | // Err returns an error happened during any operation. 118 | func (w *Writer) Err() error { 119 | return w.err 120 | } 121 | 122 | func (w *Writer) checkTail() { 123 | if w.err == nil && w.cbuf.buf != compressionTail { 124 | w.err = fmt.Errorf( 125 | "wsflate: bad compressor: unexpected stream tail: %#x vs %#x", 126 | w.cbuf.buf, compressionTail, 127 | ) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /wsflate/writer_test.go: -------------------------------------------------------------------------------- 1 | package wsflate 2 | 3 | import ( 4 | "bytes" 5 | "compress/flate" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "net" 10 | "net/url" 11 | "testing" 12 | 13 | "github.com/gobwas/httphead" 14 | "github.com/gobwas/ws" 15 | ) 16 | 17 | func TestWriter(t *testing.T) { 18 | var buf bytes.Buffer 19 | w := NewWriter(&buf, func(w io.Writer) Compressor { 20 | fw, _ := flate.NewWriter(w, 9) 21 | return fw 22 | }) 23 | data := []byte("hello, flate!") 24 | for _, p := range bytes.SplitAfter(data, []byte{','}) { 25 | w.Write(p) 26 | w.Flush() 27 | } 28 | if err := w.Close(); err != nil { 29 | t.Fatalf("unexpected Close() error: %v", err) 30 | } 31 | if err := w.Err(); err != nil { 32 | t.Fatalf("unexpected Writer error: %v", err) 33 | } 34 | 35 | r := NewReader(&buf, func(r io.Reader) Decompressor { 36 | return flate.NewReader(r) 37 | }) 38 | act, err := ioutil.ReadAll(r) 39 | if err != nil { 40 | t.Fatalf("unexpected Reader error: %v", err) 41 | } 42 | if exp := data; !bytes.Equal(act, exp) { 43 | t.Fatalf("unexpected bytes: %#q; want %#q", act, exp) 44 | } 45 | } 46 | 47 | func TestExtensionNegotiation(t *testing.T) { 48 | client, server := net.Pipe() 49 | 50 | done := make(chan error) 51 | go func() { 52 | defer close(done) 53 | var ( 54 | req bytes.Buffer 55 | res bytes.Buffer 56 | ) 57 | conn := struct { 58 | io.Reader 59 | io.Writer 60 | }{ 61 | io.TeeReader(server, &req), 62 | io.MultiWriter(server, &res), 63 | } 64 | e := Extension{ 65 | Parameters: Parameters{ 66 | ServerNoContextTakeover: true, 67 | ClientNoContextTakeover: true, 68 | }, 69 | } 70 | u := ws.Upgrader{ 71 | Negotiate: e.Negotiate, 72 | } 73 | hs, err := u.Upgrade(&conn) 74 | if err != nil { 75 | done <- err 76 | return 77 | } 78 | 79 | p, ok := e.Accepted() 80 | t.Logf("accepted: %t %+v", ok, p) 81 | 82 | fmt.Println(req.String()) 83 | fmt.Println(res.String()) 84 | t.Logf("server: %+v", hs) 85 | }() 86 | 87 | d := ws.Dialer{ 88 | Extensions: []httphead.Option{ 89 | (Parameters{ 90 | ServerNoContextTakeover: true, 91 | ClientNoContextTakeover: true, 92 | ClientMaxWindowBits: 8, 93 | ServerMaxWindowBits: 10, 94 | }).Option(), 95 | (Parameters{ 96 | ClientMaxWindowBits: 1, 97 | }).Option(), 98 | (Parameters{}).Option(), 99 | }, 100 | } 101 | 102 | uri, err := url.Parse("ws://example.com") 103 | if err != nil { 104 | t.Fatal(err) 105 | } 106 | _, hs, err := d.Upgrade(client, uri) 107 | if err != nil { 108 | t.Fatalf("client: %v", err) 109 | } 110 | if n := len(hs.Extensions); n != 1 { 111 | t.Fatalf("unexpected number of accepted extensions: %d", n) 112 | } 113 | var p Parameters 114 | if err := p.Parse(hs.Extensions[0]); err != nil { 115 | t.Fatalf("parse extension error: %v", err) 116 | } 117 | t.Logf("client params: %+v", p) 118 | if err := <-done; err != nil { 119 | t.Fatalf("server Upgrade() error: %v", err) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /wsutil/cipher.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/gobwas/pool/pbytes" 7 | "github.com/gobwas/ws" 8 | ) 9 | 10 | // CipherReader implements io.Reader that applies xor-cipher to the bytes read 11 | // from source. 12 | // It could help to unmask WebSocket frame payload on the fly. 13 | type CipherReader struct { 14 | r io.Reader 15 | mask [4]byte 16 | pos int 17 | } 18 | 19 | // NewCipherReader creates xor-cipher reader from r with given mask. 20 | func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader { 21 | return &CipherReader{r, mask, 0} 22 | } 23 | 24 | // Reset resets CipherReader to read from r with given mask. 25 | func (c *CipherReader) Reset(r io.Reader, mask [4]byte) { 26 | c.r = r 27 | c.mask = mask 28 | c.pos = 0 29 | } 30 | 31 | // Read implements io.Reader interface. It applies mask given during 32 | // initialization to every read byte. 33 | func (c *CipherReader) Read(p []byte) (n int, err error) { 34 | n, err = c.r.Read(p) 35 | ws.Cipher(p[:n], c.mask, c.pos) 36 | c.pos += n 37 | return n, err 38 | } 39 | 40 | // CipherWriter implements io.Writer that applies xor-cipher to the bytes 41 | // written to the destination writer. It does not modify the original bytes. 42 | type CipherWriter struct { 43 | w io.Writer 44 | mask [4]byte 45 | pos int 46 | } 47 | 48 | // NewCipherWriter creates xor-cipher writer to w with given mask. 49 | func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter { 50 | return &CipherWriter{w, mask, 0} 51 | } 52 | 53 | // Reset reset CipherWriter to write to w with given mask. 54 | func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) { 55 | c.w = w 56 | c.mask = mask 57 | c.pos = 0 58 | } 59 | 60 | // Write implements io.Writer interface. It applies masking during 61 | // initialization to every sent byte. It does not modify original slice. 62 | func (c *CipherWriter) Write(p []byte) (n int, err error) { 63 | cp := pbytes.GetLen(len(p)) 64 | defer pbytes.Put(cp) 65 | 66 | copy(cp, p) 67 | ws.Cipher(cp, c.mask, c.pos) 68 | n, err = c.w.Write(cp) 69 | c.pos += n 70 | 71 | return n, err 72 | } 73 | -------------------------------------------------------------------------------- /wsutil/cipher_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "reflect" 8 | "testing" 9 | 10 | "github.com/gobwas/ws" 11 | ) 12 | 13 | func TestCipherReader(t *testing.T) { 14 | for i, test := range []struct { 15 | label string 16 | data []byte 17 | chop int 18 | }{ 19 | { 20 | label: "simple", 21 | data: []byte("hello, websockets!"), 22 | chop: 512, 23 | }, 24 | { 25 | label: "chopped", 26 | data: []byte("hello, websockets!"), 27 | chop: 3, 28 | }, 29 | } { 30 | t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { 31 | mask := ws.NewMask() 32 | masked := make([]byte, len(test.data)) 33 | copy(masked, test.data) 34 | ws.Cipher(masked, mask, 0) 35 | 36 | src := &chopReader{bytes.NewReader(masked), test.chop} 37 | rd := NewCipherReader(src, mask) 38 | 39 | bts, err := ioutil.ReadAll(rd) 40 | if err != nil { 41 | t.Fatalf("unexpected error: %s", err) 42 | } 43 | if !reflect.DeepEqual(bts, test.data) { 44 | t.Fatalf("read data is not equal:\n\tact:\t%#v\n\texp:\t%#x\n", bts, test.data) 45 | } 46 | }) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /wsutil/dialer.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "io" 8 | "io/ioutil" 9 | "net" 10 | "net/http" 11 | 12 | "github.com/gobwas/ws" 13 | ) 14 | 15 | // DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket 16 | // handshake. That is, it gives ability to receive copied HTTP request and 17 | // response bytes that made inside Dialer.Dial(). 18 | // 19 | // Note that it must not be used in production applications that requires 20 | // Dial() to be efficient. 21 | type DebugDialer struct { 22 | // Dialer contains WebSocket connection establishment options. 23 | Dialer ws.Dialer 24 | 25 | // OnRequest and OnResponse are the callbacks that will be called with the 26 | // HTTP request and response respectively. 27 | OnRequest, OnResponse func([]byte) 28 | } 29 | 30 | // Dial connects to the url host and upgrades connection to WebSocket. It makes 31 | // it by calling d.Dialer.Dial(). 32 | func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) { 33 | // Need to copy Dialer to prevent original object mutation. 34 | dialer := d.Dialer 35 | var ( 36 | reqBuf bytes.Buffer 37 | resBuf bytes.Buffer 38 | 39 | resContentLength int64 40 | ) 41 | userWrap := dialer.WrapConn 42 | dialer.WrapConn = func(c net.Conn) net.Conn { 43 | if userWrap != nil { 44 | c = userWrap(c) 45 | } 46 | 47 | // Save the pointer to the raw connection. 48 | conn = c 49 | 50 | var ( 51 | r io.Reader = conn 52 | w io.Writer = conn 53 | ) 54 | if d.OnResponse != nil { 55 | r = &prefetchResponseReader{ 56 | source: conn, 57 | buffer: &resBuf, 58 | contentLength: &resContentLength, 59 | } 60 | } 61 | if d.OnRequest != nil { 62 | w = io.MultiWriter(conn, &reqBuf) 63 | } 64 | return rwConn{conn, r, w} 65 | } 66 | 67 | _, br, hs, err = dialer.Dial(ctx, urlstr) 68 | 69 | if onRequest := d.OnRequest; onRequest != nil { 70 | onRequest(reqBuf.Bytes()) 71 | } 72 | if onResponse := d.OnResponse; onResponse != nil { 73 | // We must split response inside buffered bytes from other received 74 | // bytes from server. 75 | p := resBuf.Bytes() 76 | n := bytes.Index(p, headEnd) 77 | h := n + len(headEnd) // Head end index. 78 | n = h + int(resContentLength) // Body end index. 79 | 80 | onResponse(p[:n]) 81 | 82 | if br != nil { 83 | // If br is non-nil, then it mean two things. First is that 84 | // handshake is OK and server has sent additional bytes – probably 85 | // immediate sent frames (or weird but possible response body). 86 | // Second, the bad one, is that br buffer's source is now rwConn 87 | // instance from above WrapConn call. It is incorrect, so we must 88 | // fix it. 89 | var r io.Reader = conn 90 | if len(p) > h { 91 | // Buffer contains more than just HTTP headers bytes. 92 | r = io.MultiReader( 93 | bytes.NewReader(p[h:]), 94 | conn, 95 | ) 96 | } 97 | br.Reset(r) 98 | // Must make br.Buffered() to be non-zero. 99 | br.Peek(len(p[h:])) 100 | } 101 | } 102 | 103 | return conn, br, hs, err 104 | } 105 | 106 | type rwConn struct { 107 | net.Conn 108 | 109 | r io.Reader 110 | w io.Writer 111 | } 112 | 113 | func (rwc rwConn) Read(p []byte) (int, error) { 114 | return rwc.r.Read(p) 115 | } 116 | 117 | func (rwc rwConn) Write(p []byte) (int, error) { 118 | return rwc.w.Write(p) 119 | } 120 | 121 | var headEnd = []byte("\r\n\r\n") 122 | 123 | type prefetchResponseReader struct { 124 | source io.Reader // Original connection source. 125 | reader io.Reader // Wrapped reader used to read from by clients. 126 | buffer *bytes.Buffer 127 | 128 | contentLength *int64 129 | } 130 | 131 | func (r *prefetchResponseReader) Read(p []byte) (int, error) { 132 | if r.reader == nil { 133 | resp, err := http.ReadResponse(bufio.NewReader( 134 | io.TeeReader(r.source, r.buffer), 135 | ), nil) 136 | if err == nil { 137 | *r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body) 138 | resp.Body.Close() 139 | } 140 | bts := r.buffer.Bytes() 141 | r.reader = io.MultiReader( 142 | bytes.NewReader(bts), 143 | r.source, 144 | ) 145 | } 146 | return r.reader.Read(p) 147 | } 148 | -------------------------------------------------------------------------------- /wsutil/dialer_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "io" 8 | "io/ioutil" 9 | "net" 10 | "net/http" 11 | "testing" 12 | 13 | "github.com/gobwas/ws" 14 | ) 15 | 16 | var bg = context.Background() 17 | 18 | func TestDebugDialer(t *testing.T) { 19 | for _, test := range []struct { 20 | name string 21 | resp *http.Response 22 | body []byte 23 | err error 24 | }{ 25 | { 26 | name: "base", 27 | }, 28 | { 29 | name: "base with footer", 30 | body: []byte("hello, additional bytes!"), 31 | }, 32 | { 33 | name: "fail", 34 | resp: &http.Response{ 35 | StatusCode: http.StatusSwitchingProtocols, 36 | ProtoMajor: 1, 37 | ProtoMinor: 1, 38 | }, 39 | err: ws.ErrHandshakeBadUpgrade, 40 | }, 41 | { 42 | name: "fail", 43 | resp: &http.Response{ 44 | StatusCode: http.StatusBadRequest, 45 | ProtoMajor: 42, 46 | ProtoMinor: 1, 47 | }, 48 | err: ws.ErrHandshakeBadProtocol, 49 | }, 50 | { 51 | name: "fail", 52 | resp: &http.Response{ 53 | StatusCode: http.StatusBadRequest, 54 | ProtoMajor: 1, 55 | ProtoMinor: 1, 56 | }, 57 | err: ws.StatusError(400), 58 | }, 59 | { 60 | name: "fail footer", 61 | resp: &http.Response{ 62 | StatusCode: http.StatusBadRequest, 63 | ProtoMajor: 1, 64 | ProtoMinor: 1, 65 | }, 66 | err: ws.StatusError(400), 67 | }, 68 | 69 | { 70 | name: "big response", 71 | // This test expects that even when server sent unsuccessful 72 | // response with body that does not fit to Dialer read buffer, 73 | // OnResponse will still be called with full response bytes. 74 | resp: &http.Response{ 75 | StatusCode: http.StatusOK, 76 | ProtoMajor: 1, 77 | ProtoMinor: 1, 78 | Body: ioutil.NopCloser(bytes.NewReader( 79 | bytes.Repeat([]byte("x"), 5000), 80 | )), 81 | ContentLength: 5000, 82 | }, 83 | // Additional data sent. We expect it will not be shown in 84 | // OnResponse. 85 | body: bytes.Repeat([]byte("y"), 1000), 86 | err: ws.StatusError(200), 87 | }, 88 | } { 89 | t.Run(test.name, func(t *testing.T) { 90 | client, server := net.Pipe() 91 | 92 | var ( 93 | actReq, actRes []byte 94 | expReq, expRes []byte 95 | ) 96 | dd := DebugDialer{ 97 | Dialer: ws.Dialer{ 98 | NetDial: func(_ context.Context, _, _ string) (net.Conn, error) { 99 | return client, nil 100 | }, 101 | }, 102 | OnRequest: func(p []byte) { actReq = p }, 103 | OnResponse: func(p []byte) { actRes = p }, 104 | } 105 | go func() { 106 | var ( 107 | reqBuf bytes.Buffer 108 | resBuf bytes.Buffer 109 | ) 110 | var ( 111 | tr = io.TeeReader(server, &reqBuf) 112 | bw = bufio.NewWriterSize(server, 65536) 113 | mw = io.MultiWriter(bw, &resBuf) 114 | ) 115 | conn := struct { 116 | io.Reader 117 | io.Writer 118 | }{ 119 | tr, mw, 120 | } 121 | if test.resp == nil { 122 | _, err := ws.Upgrade(conn) 123 | if err != nil { 124 | panic(err) 125 | } 126 | } else { 127 | if _, err := http.ReadRequest(bufio.NewReader(conn)); err != nil { 128 | panic(err) 129 | } 130 | if err := test.resp.Write(conn); err != nil { 131 | panic(err) 132 | } 133 | } 134 | 135 | expReq = reqBuf.Bytes() 136 | expRes = resBuf.Bytes() 137 | 138 | if test.body != nil { 139 | bw.Write(test.body) 140 | } 141 | bw.Flush() 142 | server.Close() 143 | }() 144 | 145 | conn, br, _, err := dd.Dial(bg, "ws://stub") 146 | if err != test.err { 147 | t.Fatalf("unexpected error: %v; want %v", err, test.err) 148 | } 149 | if conn != client { 150 | t.Errorf("returned connection is non raw") 151 | } 152 | if br != nil { 153 | body, err := ioutil.ReadAll(br) 154 | if err != nil { 155 | t.Fatal(err) 156 | } 157 | if !bytes.Equal(body, test.body) { 158 | t.Errorf("unexpected buffered body: %q; want %q", body, test.body) 159 | } 160 | } 161 | if !bytes.Equal(actReq, expReq) { 162 | t.Errorf( 163 | "unexpected request bytes:\nact %d bytes:\n%s\nexp %d bytes:\n%s\n", 164 | len(actReq), actReq, len(expReq), expReq, 165 | ) 166 | } 167 | if !bytes.Equal(actRes, expRes) { 168 | t.Errorf( 169 | "unexpected response bytes:\nact %d bytes:\n%s\nexp %d bytes:\n%s\n", 170 | len(actRes), actRes, len(expRes), expRes, 171 | ) 172 | } 173 | }) 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /wsutil/extenstion.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import "github.com/gobwas/ws" 4 | 5 | // RecvExtension is an interface for clearing fragment header RSV bits. 6 | type RecvExtension interface { 7 | UnsetBits(ws.Header) (ws.Header, error) 8 | } 9 | 10 | // RecvExtensionFunc is an adapter to allow the use of ordinary functions as 11 | // RecvExtension. 12 | type RecvExtensionFunc func(ws.Header) (ws.Header, error) 13 | 14 | // BitsRecv implements RecvExtension. 15 | func (fn RecvExtensionFunc) UnsetBits(h ws.Header) (ws.Header, error) { 16 | return fn(h) 17 | } 18 | 19 | // SendExtension is an interface for setting fragment header RSV bits. 20 | type SendExtension interface { 21 | SetBits(ws.Header) (ws.Header, error) 22 | } 23 | 24 | // SendExtensionFunc is an adapter to allow the use of ordinary functions as 25 | // SendExtension. 26 | type SendExtensionFunc func(ws.Header) (ws.Header, error) 27 | 28 | // BitsSend implements SendExtension. 29 | func (fn SendExtensionFunc) SetBits(h ws.Header) (ws.Header, error) { 30 | return fn(h) 31 | } 32 | -------------------------------------------------------------------------------- /wsutil/handler.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "io/ioutil" 7 | "strconv" 8 | 9 | "github.com/gobwas/pool/pbytes" 10 | "github.com/gobwas/ws" 11 | ) 12 | 13 | // ClosedError returned when peer has closed the connection with appropriate 14 | // code and a textual reason. 15 | type ClosedError struct { 16 | Code ws.StatusCode 17 | Reason string 18 | } 19 | 20 | // Error implements error interface. 21 | func (err ClosedError) Error() string { 22 | return "ws closed: " + strconv.FormatUint(uint64(err.Code), 10) + " " + err.Reason 23 | } 24 | 25 | // ControlHandler contains logic of handling control frames. 26 | // 27 | // The intentional way to use it is to read the next frame header from the 28 | // connection, optionally check its validity via ws.CheckHeader() and if it is 29 | // not a ws.OpText of ws.OpBinary (or ws.OpContinuation) – pass it to Handle() 30 | // method. 31 | // 32 | // That is, passed header should be checked to get rid of unexpected errors. 33 | // 34 | // The Handle() method will read out all control frame payload (if any) and 35 | // write necessary bytes as a rfc compatible response. 36 | type ControlHandler struct { 37 | Src io.Reader 38 | Dst io.Writer 39 | State ws.State 40 | 41 | // DisableSrcCiphering disables unmasking payload data read from Src. 42 | // It is useful when wsutil.Reader is used or when frame payload already 43 | // pulled and ciphered out from the connection (and introduced by 44 | // bytes.Reader, for example). 45 | DisableSrcCiphering bool 46 | } 47 | 48 | // ErrNotControlFrame is returned by ControlHandler to indicate that given 49 | // header could not be handled. 50 | var ErrNotControlFrame = errors.New("not a control frame") 51 | 52 | // Handle handles control frames regarding to the c.State and writes responses 53 | // to the c.Dst when needed. 54 | // 55 | // It returns ErrNotControlFrame when given header is not of ws.OpClose, 56 | // ws.OpPing or ws.OpPong operation code. 57 | func (c ControlHandler) Handle(h ws.Header) error { 58 | switch h.OpCode { 59 | case ws.OpPing: 60 | return c.HandlePing(h) 61 | case ws.OpPong: 62 | return c.HandlePong(h) 63 | case ws.OpClose: 64 | return c.HandleClose(h) 65 | } 66 | return ErrNotControlFrame 67 | } 68 | 69 | // HandlePing handles ping frame and writes specification compatible response 70 | // to the c.Dst. 71 | func (c ControlHandler) HandlePing(h ws.Header) error { 72 | if h.Length == 0 { 73 | // The most common case when ping is empty. 74 | // Note that when sending masked frame the mask for empty payload is 75 | // just four zero bytes. 76 | return ws.WriteHeader(c.Dst, ws.Header{ 77 | Fin: true, 78 | OpCode: ws.OpPong, 79 | Masked: c.State.ClientSide(), 80 | }) 81 | } 82 | 83 | // In other way reply with Pong frame with copied payload. 84 | p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{ 85 | Length: h.Length, 86 | Masked: c.State.ClientSide(), 87 | })) 88 | defer pbytes.Put(p) 89 | 90 | // Deal with ciphering i/o: 91 | // Masking key is used to mask the "Payload data" defined in the same 92 | // section as frame-payload-data, which includes "Extension data" and 93 | // "Application data". 94 | // 95 | // See https://tools.ietf.org/html/rfc6455#section-5.3 96 | // 97 | // NOTE: We prefer ControlWriter with preallocated buffer to 98 | // ws.WriteHeader because it performs one syscall instead of two. 99 | w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p) 100 | r := c.Src 101 | if c.State.ServerSide() && !c.DisableSrcCiphering { 102 | r = NewCipherReader(r, h.Mask) 103 | } 104 | 105 | _, err := io.Copy(w, r) 106 | if err == nil { 107 | err = w.Flush() 108 | } 109 | 110 | return err 111 | } 112 | 113 | // HandlePong handles pong frame by discarding it. 114 | func (c ControlHandler) HandlePong(h ws.Header) error { 115 | if h.Length == 0 { 116 | return nil 117 | } 118 | 119 | buf := pbytes.GetLen(int(h.Length)) 120 | defer pbytes.Put(buf) 121 | 122 | // Discard pong message according to the RFC6455: 123 | // A Pong frame MAY be sent unsolicited. This serves as a 124 | // unidirectional heartbeat. A response to an unsolicited Pong frame 125 | // is not expected. 126 | _, err := io.CopyBuffer(ioutil.Discard, c.Src, buf) 127 | 128 | return err 129 | } 130 | 131 | // HandleClose handles close frame, makes protocol validity checks and writes 132 | // specification compatible response to the c.Dst. 133 | func (c ControlHandler) HandleClose(h ws.Header) error { 134 | if h.Length == 0 { 135 | err := ws.WriteHeader(c.Dst, ws.Header{ 136 | Fin: true, 137 | OpCode: ws.OpClose, 138 | Masked: c.State.ClientSide(), 139 | }) 140 | if err != nil { 141 | return err 142 | } 143 | 144 | // Due to RFC, we should interpret the code as no status code 145 | // received: 146 | // If this Close control frame contains no status code, _The WebSocket 147 | // Connection Close Code_ is considered to be 1005. 148 | // 149 | // See https://tools.ietf.org/html/rfc6455#section-7.1.5 150 | return ClosedError{ 151 | Code: ws.StatusNoStatusRcvd, 152 | } 153 | } 154 | 155 | // Prepare bytes both for reading reason and sending response. 156 | p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{ 157 | Length: h.Length, 158 | Masked: c.State.ClientSide(), 159 | })) 160 | defer pbytes.Put(p) 161 | 162 | // Get the subslice to read the frame payload out. 163 | subp := p[:h.Length] 164 | 165 | r := c.Src 166 | if c.State.ServerSide() && !c.DisableSrcCiphering { 167 | r = NewCipherReader(r, h.Mask) 168 | } 169 | if _, err := io.ReadFull(r, subp); err != nil { 170 | return err 171 | } 172 | 173 | code, reason := ws.ParseCloseFrameData(subp) 174 | if err := ws.CheckCloseFrameData(code, reason); err != nil { 175 | // Here we could not use the prepared bytes because there is no 176 | // guarantee that it may fit our protocol error closure code and a 177 | // reason. 178 | c.closeWithProtocolError(err) 179 | return err 180 | } 181 | 182 | // Deal with ciphering i/o: 183 | // Masking key is used to mask the "Payload data" defined in the same 184 | // section as frame-payload-data, which includes "Extension data" and 185 | // "Application data". 186 | // 187 | // See https://tools.ietf.org/html/rfc6455#section-5.3 188 | // 189 | // NOTE: We prefer ControlWriter with preallocated buffer to 190 | // ws.WriteHeader because it performs one syscall instead of two. 191 | w := NewControlWriterBuffer(c.Dst, c.State, ws.OpClose, p) 192 | 193 | // RFC6455#5.5.1: 194 | // If an endpoint receives a Close frame and did not previously 195 | // send a Close frame, the endpoint MUST send a Close frame in 196 | // response. (When sending a Close frame in response, the endpoint 197 | // typically echoes the status code it received.) 198 | _, err := w.Write(p[:2]) 199 | if err != nil { 200 | return err 201 | } 202 | if err := w.Flush(); err != nil { 203 | return err 204 | } 205 | return ClosedError{ 206 | Code: code, 207 | Reason: reason, 208 | } 209 | } 210 | 211 | func (c ControlHandler) closeWithProtocolError(reason error) error { 212 | f := ws.NewCloseFrame(ws.NewCloseFrameBody( 213 | ws.StatusProtocolError, reason.Error(), 214 | )) 215 | if c.State.ClientSide() { 216 | ws.MaskFrameInPlace(f) 217 | } 218 | return ws.WriteFrame(c.Dst, f) 219 | } 220 | -------------------------------------------------------------------------------- /wsutil/handler_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "runtime" 6 | "testing" 7 | 8 | "github.com/gobwas/ws" 9 | ) 10 | 11 | func TestControlHandler(t *testing.T) { 12 | for _, test := range []struct { 13 | name string 14 | state ws.State 15 | in ws.Frame 16 | out ws.Frame 17 | noOut bool 18 | err error 19 | }{ 20 | { 21 | name: "ping", 22 | in: ws.NewPingFrame(nil), 23 | out: ws.NewPongFrame(nil), 24 | }, 25 | { 26 | name: "ping", 27 | in: ws.NewPingFrame([]byte("catch the ball")), 28 | out: ws.NewPongFrame([]byte("catch the ball")), 29 | }, 30 | { 31 | name: "ping", 32 | state: ws.StateServerSide, 33 | in: ws.MaskFrame(ws.NewPingFrame([]byte("catch the ball"))), 34 | out: ws.NewPongFrame([]byte("catch the ball")), 35 | }, 36 | { 37 | name: "ping", 38 | in: ws.NewPingFrame(bytes.Repeat([]byte{0xfe}, 125)), 39 | out: ws.NewPongFrame(bytes.Repeat([]byte{0xfe}, 125)), 40 | }, 41 | { 42 | name: "pong", 43 | in: ws.NewPongFrame(nil), 44 | noOut: true, 45 | }, 46 | { 47 | name: "pong", 48 | in: ws.NewPongFrame([]byte("caught")), 49 | noOut: true, 50 | }, 51 | { 52 | name: "close", 53 | in: ws.NewCloseFrame(nil), 54 | out: ws.NewCloseFrame(nil), 55 | err: ClosedError{ 56 | Code: ws.StatusNoStatusRcvd, 57 | }, 58 | }, 59 | { 60 | name: "close", 61 | in: ws.NewCloseFrame(ws.NewCloseFrameBody( 62 | ws.StatusGoingAway, "goodbye!", 63 | )), 64 | out: ws.NewCloseFrame(ws.NewCloseFrameBody( 65 | ws.StatusGoingAway, "", 66 | )), 67 | err: ClosedError{ 68 | Code: ws.StatusGoingAway, 69 | Reason: "goodbye!", 70 | }, 71 | }, 72 | { 73 | name: "close", 74 | in: ws.NewCloseFrame(ws.NewCloseFrameBody( 75 | ws.StatusGoingAway, "bye", 76 | )), 77 | out: ws.NewCloseFrame(ws.NewCloseFrameBody( 78 | ws.StatusGoingAway, "", 79 | )), 80 | err: ClosedError{ 81 | Code: ws.StatusGoingAway, 82 | Reason: "bye", 83 | }, 84 | }, 85 | { 86 | name: "close", 87 | state: ws.StateServerSide, 88 | in: ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody( 89 | ws.StatusGoingAway, "goodbye!", 90 | ))), 91 | out: ws.NewCloseFrame(ws.NewCloseFrameBody( 92 | ws.StatusGoingAway, "", 93 | )), 94 | err: ClosedError{ 95 | Code: ws.StatusGoingAway, 96 | Reason: "goodbye!", 97 | }, 98 | }, 99 | { 100 | name: "close", 101 | in: ws.NewCloseFrame(ws.NewCloseFrameBody( 102 | ws.StatusNormalClosure, string([]byte{0, 200}), 103 | )), 104 | out: ws.NewCloseFrame(ws.NewCloseFrameBody( 105 | ws.StatusProtocolError, ws.ErrProtocolInvalidUTF8.Error(), 106 | )), 107 | err: ws.ErrProtocolInvalidUTF8, 108 | }, 109 | } { 110 | t.Run(test.name, func(t *testing.T) { 111 | defer func() { 112 | if err := recover(); err != nil { 113 | stack := make([]byte, 4096) 114 | n := runtime.Stack(stack, true) 115 | t.Fatalf( 116 | "panic recovered: %v\n%s", 117 | err, stack[:n], 118 | ) 119 | } 120 | }() 121 | var ( 122 | out = bytes.NewBuffer(nil) 123 | in = bytes.NewReader(test.in.Payload) 124 | ) 125 | c := ControlHandler{ 126 | Src: in, 127 | Dst: out, 128 | State: test.state, 129 | } 130 | 131 | err := c.Handle(test.in.Header) 132 | if err != test.err { 133 | t.Errorf("unexpected error: %v; want %v", err, test.err) 134 | } 135 | 136 | if in.Len() != 0 { 137 | t.Errorf("handler did not drained the input") 138 | } 139 | 140 | act := out.Bytes() 141 | switch { 142 | case len(act) == 0 && test.noOut: 143 | return 144 | case len(act) == 0 && !test.noOut: 145 | t.Errorf("unexpected silence") 146 | case len(act) > 0 && test.noOut: 147 | t.Errorf("unexpected sent frame") 148 | default: 149 | exp := ws.MustCompileFrame(test.out) 150 | if !bytes.Equal(act, exp) { 151 | fa := ws.MustReadFrame(bytes.NewReader(act)) 152 | fe := ws.MustReadFrame(bytes.NewReader(exp)) 153 | t.Errorf( 154 | "unexpected sent frame:\n\tact: %+v\n\texp: %+v\nbytes:\n\tact: %v\n\texp: %v", 155 | fa, fe, act, exp, 156 | ) 157 | } 158 | } 159 | }) 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /wsutil/helper.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "io/ioutil" 7 | 8 | "github.com/gobwas/ws" 9 | ) 10 | 11 | // Message represents a message from peer, that could be presented in one or 12 | // more frames. That is, it contains payload of all message fragments and 13 | // operation code of initial frame for this message. 14 | type Message struct { 15 | OpCode ws.OpCode 16 | Payload []byte 17 | } 18 | 19 | // ReadMessage is a helper function that reads next message from r. It appends 20 | // received message(s) to the third argument and returns the result of it and 21 | // an error if some failure happened. That is, it probably could receive more 22 | // than one message when peer sending fragmented message in multiple frames and 23 | // want to send some control frame between fragments. Then returned slice will 24 | // contain those control frames at first, and then result of gluing fragments. 25 | // 26 | // TODO(gobwas): add DefaultReader with buffer size options. 27 | func ReadMessage(r io.Reader, s ws.State, m []Message) ([]Message, error) { 28 | rd := Reader{ 29 | Source: r, 30 | State: s, 31 | CheckUTF8: true, 32 | OnIntermediate: func(hdr ws.Header, src io.Reader) error { 33 | bts, err := ioutil.ReadAll(src) 34 | if err != nil { 35 | return err 36 | } 37 | m = append(m, Message{hdr.OpCode, bts}) 38 | return nil 39 | }, 40 | } 41 | h, err := rd.NextFrame() 42 | if err != nil { 43 | return m, err 44 | } 45 | var p []byte 46 | if h.Fin { 47 | // No more frames will be read. Use fixed sized buffer to read payload. 48 | p = make([]byte, h.Length) 49 | // It is not possible to receive io.EOF here because Reader does not 50 | // return EOF if frame payload was successfully fetched. 51 | // Thus we consistent here with io.Reader behavior. 52 | _, err = io.ReadFull(&rd, p) 53 | } else { 54 | // Frame is fragmented, thus use ioutil.ReadAll behavior. 55 | var buf bytes.Buffer 56 | _, err = buf.ReadFrom(&rd) 57 | p = buf.Bytes() 58 | } 59 | if err != nil { 60 | return m, err 61 | } 62 | return append(m, Message{h.OpCode, p}), nil 63 | } 64 | 65 | // ReadClientMessage reads next message from r, considering that caller 66 | // represents server side. 67 | // It is a shortcut for ReadMessage(r, ws.StateServerSide, m). 68 | func ReadClientMessage(r io.Reader, m []Message) ([]Message, error) { 69 | return ReadMessage(r, ws.StateServerSide, m) 70 | } 71 | 72 | // ReadServerMessage reads next message from r, considering that caller 73 | // represents client side. 74 | // It is a shortcut for ReadMessage(r, ws.StateClientSide, m). 75 | func ReadServerMessage(r io.Reader, m []Message) ([]Message, error) { 76 | return ReadMessage(r, ws.StateClientSide, m) 77 | } 78 | 79 | // ReadData is a helper function that reads next data (non-control) message 80 | // from rw. 81 | // It takes care on handling all control frames. It will write response on 82 | // control frames to the write part of rw. It blocks until some data frame 83 | // will be received. 84 | // 85 | // Note this may handle and write control frames into the writer part of a 86 | // given io.ReadWriter. 87 | func ReadData(rw io.ReadWriter, s ws.State) ([]byte, ws.OpCode, error) { 88 | return readData(rw, s, ws.OpText|ws.OpBinary) 89 | } 90 | 91 | // ReadClientData reads next data message from rw, considering that caller 92 | // represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide). 93 | // 94 | // Note this may handle and write control frames into the writer part of a 95 | // given io.ReadWriter. 96 | func ReadClientData(rw io.ReadWriter) ([]byte, ws.OpCode, error) { 97 | return ReadData(rw, ws.StateServerSide) 98 | } 99 | 100 | // ReadClientText reads next text message from rw, considering that caller 101 | // represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide). 102 | // It discards received binary messages. 103 | // 104 | // Note this may handle and write control frames into the writer part of a 105 | // given io.ReadWriter. 106 | func ReadClientText(rw io.ReadWriter) ([]byte, error) { 107 | p, _, err := readData(rw, ws.StateServerSide, ws.OpText) 108 | return p, err 109 | } 110 | 111 | // ReadClientBinary reads next binary message from rw, considering that caller 112 | // represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide). 113 | // It discards received text messages. 114 | // 115 | // Note this may handle and write control frames into the writer part of a given 116 | // io.ReadWriter. 117 | func ReadClientBinary(rw io.ReadWriter) ([]byte, error) { 118 | p, _, err := readData(rw, ws.StateServerSide, ws.OpBinary) 119 | return p, err 120 | } 121 | 122 | // ReadServerData reads next data message from rw, considering that caller 123 | // represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide). 124 | // 125 | // Note this may handle and write control frames into the writer part of a 126 | // given io.ReadWriter. 127 | func ReadServerData(rw io.ReadWriter) ([]byte, ws.OpCode, error) { 128 | return ReadData(rw, ws.StateClientSide) 129 | } 130 | 131 | // ReadServerText reads next text message from rw, considering that caller 132 | // represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide). 133 | // It discards received binary messages. 134 | // 135 | // Note this may handle and write control frames into the writer part of a given 136 | // io.ReadWriter. 137 | func ReadServerText(rw io.ReadWriter) ([]byte, error) { 138 | p, _, err := readData(rw, ws.StateClientSide, ws.OpText) 139 | return p, err 140 | } 141 | 142 | // ReadServerBinary reads next binary message from rw, considering that caller 143 | // represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide). 144 | // It discards received text messages. 145 | // 146 | // Note this may handle and write control frames into the writer part of a 147 | // given io.ReadWriter. 148 | func ReadServerBinary(rw io.ReadWriter) ([]byte, error) { 149 | p, _, err := readData(rw, ws.StateClientSide, ws.OpBinary) 150 | return p, err 151 | } 152 | 153 | // WriteMessage is a helper function that writes message to the w. It 154 | // constructs single frame with given operation code and payload. 155 | // It uses given state to prepare side-dependent things, like cipher 156 | // payload bytes from client to server. It will not mutate p bytes if 157 | // cipher must be made. 158 | // 159 | // If you want to write message in fragmented frames, use Writer instead. 160 | func WriteMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error { 161 | return writeFrame(w, s, op, true, p) 162 | } 163 | 164 | // WriteServerMessage writes message to w, considering that caller 165 | // represents server side. 166 | func WriteServerMessage(w io.Writer, op ws.OpCode, p []byte) error { 167 | return WriteMessage(w, ws.StateServerSide, op, p) 168 | } 169 | 170 | // WriteServerText is the same as WriteServerMessage with 171 | // ws.OpText. 172 | func WriteServerText(w io.Writer, p []byte) error { 173 | return WriteServerMessage(w, ws.OpText, p) 174 | } 175 | 176 | // WriteServerBinary is the same as WriteServerMessage with 177 | // ws.OpBinary. 178 | func WriteServerBinary(w io.Writer, p []byte) error { 179 | return WriteServerMessage(w, ws.OpBinary, p) 180 | } 181 | 182 | // WriteClientMessage writes message to w, considering that caller 183 | // represents client side. 184 | func WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error { 185 | return WriteMessage(w, ws.StateClientSide, op, p) 186 | } 187 | 188 | // WriteClientText is the same as WriteClientMessage with 189 | // ws.OpText. 190 | func WriteClientText(w io.Writer, p []byte) error { 191 | return WriteClientMessage(w, ws.OpText, p) 192 | } 193 | 194 | // WriteClientBinary is the same as WriteClientMessage with 195 | // ws.OpBinary. 196 | func WriteClientBinary(w io.Writer, p []byte) error { 197 | return WriteClientMessage(w, ws.OpBinary, p) 198 | } 199 | 200 | // HandleClientControlMessage handles control frame from conn and writes 201 | // response when needed. 202 | // 203 | // It considers that caller represents server side. 204 | func HandleClientControlMessage(conn io.Writer, msg Message) error { 205 | return HandleControlMessage(conn, ws.StateServerSide, msg) 206 | } 207 | 208 | // HandleServerControlMessage handles control frame from conn and writes 209 | // response when needed. 210 | // 211 | // It considers that caller represents client side. 212 | func HandleServerControlMessage(conn io.Writer, msg Message) error { 213 | return HandleControlMessage(conn, ws.StateClientSide, msg) 214 | } 215 | 216 | // HandleControlMessage handles message which was read by ReadMessage() 217 | // functions. 218 | // 219 | // That is, it is expected, that payload is already unmasked and frame header 220 | // were checked by ws.CheckHeader() call. 221 | func HandleControlMessage(conn io.Writer, state ws.State, msg Message) error { 222 | return (ControlHandler{ 223 | DisableSrcCiphering: true, 224 | Src: bytes.NewReader(msg.Payload), 225 | Dst: conn, 226 | State: state, 227 | }).Handle(ws.Header{ 228 | Length: int64(len(msg.Payload)), 229 | OpCode: msg.OpCode, 230 | Fin: true, 231 | Masked: state.ServerSide(), 232 | }) 233 | } 234 | 235 | // ControlFrameHandler returns FrameHandlerFunc for handling control frames. 236 | // For more info see ControlHandler docs. 237 | func ControlFrameHandler(w io.Writer, state ws.State) FrameHandlerFunc { 238 | return func(h ws.Header, r io.Reader) error { 239 | return (ControlHandler{ 240 | DisableSrcCiphering: true, 241 | Src: r, 242 | Dst: w, 243 | State: state, 244 | }).Handle(h) 245 | } 246 | } 247 | 248 | func readData(rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, ws.OpCode, error) { 249 | controlHandler := ControlFrameHandler(rw, s) 250 | rd := Reader{ 251 | Source: rw, 252 | State: s, 253 | CheckUTF8: true, 254 | SkipHeaderCheck: false, 255 | OnIntermediate: controlHandler, 256 | } 257 | for { 258 | hdr, err := rd.NextFrame() 259 | if err != nil { 260 | return nil, 0, err 261 | } 262 | if hdr.OpCode.IsControl() { 263 | if err := controlHandler(hdr, &rd); err != nil { 264 | return nil, 0, err 265 | } 266 | continue 267 | } 268 | if hdr.OpCode&want == 0 { 269 | if err := rd.Discard(); err != nil { 270 | return nil, 0, err 271 | } 272 | continue 273 | } 274 | 275 | bts, err := ioutil.ReadAll(&rd) 276 | 277 | return bts, hdr.OpCode, err 278 | } 279 | } 280 | -------------------------------------------------------------------------------- /wsutil/helper_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "testing" 7 | 8 | "github.com/gobwas/ws" 9 | ) 10 | 11 | func TestReadMessageEOF(t *testing.T) { 12 | for _, test := range []struct { 13 | source func() io.Reader 14 | messages []Message 15 | err error 16 | }{ 17 | { 18 | source: func() io.Reader { return eofReader }, 19 | err: io.EOF, 20 | }, 21 | { 22 | source: func() io.Reader { 23 | // This case tests that ReadMessage still fails after 24 | // successfully reading header bytes frame via ws.ReadHeader() 25 | // and non-successfully read of the body. 26 | var buf bytes.Buffer 27 | f := ws.NewTextFrame([]byte("this part will be lost")) 28 | if err := ws.WriteHeader(&buf, f.Header); err != nil { 29 | t.Fatal(err) 30 | } 31 | return &buf 32 | }, 33 | err: io.ErrUnexpectedEOF, 34 | }, 35 | { 36 | source: func() io.Reader { 37 | // This case tests that ReadMessage not fail when reading 38 | // fragmented messages. 39 | var buf bytes.Buffer 40 | fs := []ws.Frame{ 41 | ws.NewFrame(ws.OpText, false, []byte("fragment1")), 42 | ws.NewFrame(ws.OpContinuation, false, []byte(",")), 43 | ws.NewFrame(ws.OpContinuation, true, []byte("fragment2")), 44 | } 45 | for _, f := range fs { 46 | if err := ws.WriteFrame(&buf, f); err != nil { 47 | t.Fatal(err) 48 | } 49 | } 50 | return &buf 51 | }, 52 | messages: []Message{ 53 | {ws.OpText, []byte("fragment1,fragment2")}, 54 | }, 55 | }, 56 | } { 57 | t.Run("", func(t *testing.T) { 58 | ms, err := ReadMessage(test.source(), 0, nil) 59 | if err != test.err { 60 | t.Errorf("unexpected error: %v; want %v", err, test.err) 61 | } 62 | if n := len(ms); n != len(test.messages) { 63 | t.Fatalf("unexpected number of read messages: %d; want %d", n, 0) 64 | } 65 | for i, exp := range test.messages { 66 | act := ms[i] 67 | if act.OpCode != exp.OpCode { 68 | t.Errorf( 69 | "unexpected #%d message op code: %v; want %v", 70 | i, act.OpCode, exp.OpCode, 71 | ) 72 | } 73 | if !bytes.Equal(act.Payload, exp.Payload) { 74 | t.Errorf( 75 | "unexpected #%d message payload: %q; want %q", 76 | i, string(act.Payload), string(exp.Payload), 77 | ) 78 | } 79 | } 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /wsutil/reader.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "io" 7 | "io/ioutil" 8 | 9 | "github.com/gobwas/ws" 10 | ) 11 | 12 | // ErrNoFrameAdvance means that Reader's Read() method was called without 13 | // preceding NextFrame() call. 14 | var ErrNoFrameAdvance = errors.New("no frame advance") 15 | 16 | // ErrFrameTooLarge indicates that a message of length higher than 17 | // MaxFrameSize was being read. 18 | var ErrFrameTooLarge = errors.New("frame too large") 19 | 20 | // FrameHandlerFunc handles parsed frame header and its body represented by 21 | // io.Reader. 22 | // 23 | // Note that reader represents already unmasked body. 24 | type FrameHandlerFunc func(ws.Header, io.Reader) error 25 | 26 | // Reader is a wrapper around source io.Reader which represents WebSocket 27 | // connection. It contains options for reading messages from source. 28 | // 29 | // Reader implements io.Reader, which Read() method reads payload of incoming 30 | // WebSocket frames. It also takes care on fragmented frames and possibly 31 | // intermediate control frames between them. 32 | // 33 | // Note that Reader's methods are not goroutine safe. 34 | type Reader struct { 35 | Source io.Reader 36 | State ws.State 37 | 38 | // SkipHeaderCheck disables checking header bits to be RFC6455 compliant. 39 | SkipHeaderCheck bool 40 | 41 | // CheckUTF8 enables UTF-8 checks for text frames payload. If incoming 42 | // bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned. 43 | CheckUTF8 bool 44 | 45 | // Extensions is a list of negotiated extensions for reader Source. 46 | // It is used to meet the specs and clear appropriate bits in fragment 47 | // header RSV segment. 48 | Extensions []RecvExtension 49 | 50 | // MaxFrameSize controls the maximum frame size in bytes 51 | // that can be read. A message exceeding that size will return 52 | // a ErrFrameTooLarge to the application. 53 | // 54 | // Not setting this field means there is no limit. 55 | MaxFrameSize int64 56 | 57 | OnContinuation FrameHandlerFunc 58 | OnIntermediate FrameHandlerFunc 59 | 60 | opCode ws.OpCode // Used to store message op code on fragmentation. 61 | frame io.Reader // Used to as frame reader. 62 | raw io.LimitedReader // Used to discard frames without cipher. 63 | utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true. 64 | tmp [ws.MaxHeaderSize - 2]byte // Used for reading headers. 65 | cr *CipherReader // Used by NextFrame() to unmask frame payload. 66 | } 67 | 68 | // NewReader creates new frame reader that reads from r keeping given state to 69 | // make some protocol validity checks when it needed. 70 | func NewReader(r io.Reader, s ws.State) *Reader { 71 | return &Reader{ 72 | Source: r, 73 | State: s, 74 | } 75 | } 76 | 77 | // NewClientSideReader is a helper function that calls NewReader with r and 78 | // ws.StateClientSide. 79 | func NewClientSideReader(r io.Reader) *Reader { 80 | return NewReader(r, ws.StateClientSide) 81 | } 82 | 83 | // NewServerSideReader is a helper function that calls NewReader with r and 84 | // ws.StateServerSide. 85 | func NewServerSideReader(r io.Reader) *Reader { 86 | return NewReader(r, ws.StateServerSide) 87 | } 88 | 89 | // Read implements io.Reader. It reads the next message payload into p. 90 | // It takes care on fragmented messages. 91 | // 92 | // The error is io.EOF only if all of message bytes were read. 93 | // If an io.EOF happens during reading some but not all the message bytes 94 | // Read() returns io.ErrUnexpectedEOF. 95 | // 96 | // The error is ErrNoFrameAdvance if no NextFrame() call was made before 97 | // reading next message bytes. 98 | func (r *Reader) Read(p []byte) (n int, err error) { 99 | if r.frame == nil { 100 | if !r.fragmented() { 101 | // Every new Read() must be preceded by NextFrame() call. 102 | return 0, ErrNoFrameAdvance 103 | } 104 | // Read next continuation or intermediate control frame. 105 | _, err := r.NextFrame() 106 | if err != nil { 107 | return 0, err 108 | } 109 | if r.frame == nil { 110 | // We handled intermediate control and now got nothing to read. 111 | return 0, nil 112 | } 113 | } 114 | 115 | n, err = r.frame.Read(p) 116 | if err != nil && err != io.EOF { 117 | return n, err 118 | } 119 | if err == nil && r.raw.N != 0 { 120 | return n, nil 121 | } 122 | 123 | // EOF condition (either err is io.EOF or r.raw.N is zero). 124 | switch { 125 | case r.raw.N != 0: 126 | err = io.ErrUnexpectedEOF 127 | 128 | case r.fragmented(): 129 | err = nil 130 | r.resetFragment() 131 | 132 | case r.CheckUTF8 && !r.utf8.Valid(): 133 | // NOTE: check utf8 only when full message received, since partial 134 | // reads may be invalid. 135 | n = r.utf8.Accepted() 136 | err = ErrInvalidUTF8 137 | 138 | default: 139 | r.reset() 140 | err = io.EOF 141 | } 142 | 143 | return n, err 144 | } 145 | 146 | // Discard discards current message unread bytes. 147 | // It discards all frames of fragmented message. 148 | func (r *Reader) Discard() (err error) { 149 | for { 150 | _, err = io.Copy(ioutil.Discard, &r.raw) 151 | if err != nil { 152 | break 153 | } 154 | if !r.fragmented() { 155 | break 156 | } 157 | if _, err = r.NextFrame(); err != nil { 158 | break 159 | } 160 | } 161 | r.reset() 162 | return err 163 | } 164 | 165 | // NextFrame prepares r to read next message. It returns received frame header 166 | // and non-nil error on failure. 167 | // 168 | // Note that next NextFrame() call must be done after receiving or discarding 169 | // all current message bytes. 170 | func (r *Reader) NextFrame() (hdr ws.Header, err error) { 171 | hdr, err = r.readHeader(r.Source) 172 | if err == io.EOF && r.fragmented() { 173 | // If we are in fragmented state EOF means that is was totally 174 | // unexpected. 175 | // 176 | // NOTE: This is necessary to prevent callers such that 177 | // ioutil.ReadAll to receive some amount of bytes without an error. 178 | // ReadAll() ignores an io.EOF error, thus caller may think that 179 | // whole message fetched, but actually only part of it. 180 | err = io.ErrUnexpectedEOF 181 | } 182 | if err == nil && !r.SkipHeaderCheck { 183 | err = ws.CheckHeader(hdr, r.State) 184 | } 185 | if err != nil { 186 | return hdr, err 187 | } 188 | 189 | if n := r.MaxFrameSize; n > 0 && hdr.Length > n { 190 | return hdr, ErrFrameTooLarge 191 | } 192 | 193 | // Save raw reader to use it on discarding frame without ciphering and 194 | // other streaming checks. 195 | r.raw = io.LimitedReader{ 196 | R: r.Source, 197 | N: hdr.Length, 198 | } 199 | 200 | frame := io.Reader(&r.raw) 201 | if hdr.Masked { 202 | if r.cr == nil { 203 | r.cr = NewCipherReader(frame, hdr.Mask) 204 | } else { 205 | r.cr.Reset(frame, hdr.Mask) 206 | } 207 | frame = r.cr 208 | } 209 | 210 | for _, x := range r.Extensions { 211 | hdr, err = x.UnsetBits(hdr) 212 | if err != nil { 213 | return hdr, err 214 | } 215 | } 216 | 217 | if r.fragmented() { 218 | if hdr.OpCode.IsControl() { 219 | if cb := r.OnIntermediate; cb != nil { 220 | err = cb(hdr, frame) 221 | } 222 | if err == nil { 223 | // Ensure that src is empty. 224 | _, err = io.Copy(ioutil.Discard, &r.raw) 225 | } 226 | return hdr, err 227 | } 228 | } else { 229 | r.opCode = hdr.OpCode 230 | } 231 | if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) { 232 | r.utf8.Source = frame 233 | frame = &r.utf8 234 | } 235 | 236 | // Save reader with ciphering and other streaming checks. 237 | r.frame = frame 238 | 239 | if hdr.OpCode == ws.OpContinuation { 240 | if cb := r.OnContinuation; cb != nil { 241 | err = cb(hdr, frame) 242 | } 243 | } 244 | 245 | if hdr.Fin { 246 | r.State = r.State.Clear(ws.StateFragmented) 247 | } else { 248 | r.State = r.State.Set(ws.StateFragmented) 249 | } 250 | 251 | return hdr, err 252 | } 253 | 254 | func (r *Reader) fragmented() bool { 255 | return r.State.Fragmented() 256 | } 257 | 258 | func (r *Reader) resetFragment() { 259 | r.raw = io.LimitedReader{} 260 | r.frame = nil 261 | // Reset source of the UTF8Reader, but not the state. 262 | r.utf8.Source = nil 263 | } 264 | 265 | func (r *Reader) reset() { 266 | r.raw = io.LimitedReader{} 267 | r.frame = nil 268 | r.utf8 = UTF8Reader{} 269 | r.opCode = 0 270 | } 271 | 272 | // readHeader reads a frame header from in. 273 | func (r *Reader) readHeader(in io.Reader) (h ws.Header, err error) { 274 | // Make slice of bytes with capacity 12 that could hold any header. 275 | // 276 | // The maximum header size is 14, but due to the 2 hop reads, 277 | // after first hop that reads first 2 constant bytes, we could reuse 2 bytes. 278 | // So 14 - 2 = 12. 279 | bts := r.tmp[:2] 280 | 281 | // Prepare to hold first 2 bytes to choose size of next read. 282 | _, err = io.ReadFull(in, bts) 283 | if err != nil { 284 | return h, err 285 | } 286 | const bit0 = 0x80 287 | 288 | h.Fin = bts[0]&bit0 != 0 289 | h.Rsv = (bts[0] & 0x70) >> 4 290 | h.OpCode = ws.OpCode(bts[0] & 0x0f) 291 | 292 | var extra int 293 | 294 | if bts[1]&bit0 != 0 { 295 | h.Masked = true 296 | extra += 4 297 | } 298 | 299 | length := bts[1] & 0x7f 300 | switch { 301 | case length < 126: 302 | h.Length = int64(length) 303 | 304 | case length == 126: 305 | extra += 2 306 | 307 | case length == 127: 308 | extra += 8 309 | 310 | default: 311 | err = ws.ErrHeaderLengthUnexpected 312 | return h, err 313 | } 314 | 315 | if extra == 0 { 316 | return h, err 317 | } 318 | 319 | // Increase len of bts to extra bytes need to read. 320 | // Overwrite first 2 bytes that was read before. 321 | bts = bts[:extra] 322 | _, err = io.ReadFull(in, bts) 323 | if err != nil { 324 | return h, err 325 | } 326 | 327 | switch { 328 | case length == 126: 329 | h.Length = int64(binary.BigEndian.Uint16(bts[:2])) 330 | bts = bts[2:] 331 | 332 | case length == 127: 333 | if bts[0]&0x80 != 0 { 334 | err = ws.ErrHeaderLengthMSB 335 | return h, err 336 | } 337 | h.Length = int64(binary.BigEndian.Uint64(bts[:8])) 338 | bts = bts[8:] 339 | } 340 | 341 | if h.Masked { 342 | copy(h.Mask[:], bts) 343 | } 344 | 345 | return h, nil 346 | } 347 | 348 | // NextReader prepares next message read from r. It returns header that 349 | // describes the message and io.Reader to read message's payload. It returns 350 | // non-nil error when it is not possible to read message's initial frame. 351 | // 352 | // Note that next NextReader() on the same r should be done after reading all 353 | // bytes from previously returned io.Reader. For more performant way to discard 354 | // message use Reader and its Discard() method. 355 | // 356 | // Note that it will not handle any "intermediate" frames, that possibly could 357 | // be received between text/binary continuation frames. That is, if peer sent 358 | // text/binary frame with fin flag "false", then it could send ping frame, and 359 | // eventually remaining part of text/binary frame with fin "true" – with 360 | // NextReader() the ping frame will be dropped without any notice. To handle 361 | // this rare, but possible situation (and if you do not know exactly which 362 | // frames peer could send), you could use Reader with OnIntermediate field set. 363 | func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) { 364 | rd := &Reader{ 365 | Source: r, 366 | State: s, 367 | } 368 | header, err := rd.NextFrame() 369 | if err != nil { 370 | return header, nil, err 371 | } 372 | return header, rd, nil 373 | } 374 | -------------------------------------------------------------------------------- /wsutil/reader_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "io/ioutil" 7 | "testing" 8 | "unicode/utf8" 9 | 10 | "github.com/gobwas/ws" 11 | ) 12 | 13 | // TODO(gobwas): test continuation discard. 14 | // test discard when NextFrame(). 15 | 16 | var eofReader = bytes.NewReader(nil) 17 | 18 | func TestReadFromWithIntermediateControl(t *testing.T) { 19 | var buf bytes.Buffer 20 | 21 | ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpText, false, []byte("foo"))) 22 | ws.MustWriteFrame(&buf, ws.NewPingFrame([]byte("ping"))) 23 | ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpContinuation, false, []byte("bar"))) 24 | ws.MustWriteFrame(&buf, ws.NewPongFrame([]byte("pong"))) 25 | ws.MustWriteFrame(&buf, ws.NewFrame(ws.OpContinuation, true, []byte("baz"))) 26 | 27 | var intermediate [][]byte 28 | r := Reader{ 29 | Source: &buf, 30 | OnIntermediate: func(h ws.Header, r io.Reader) error { 31 | bts, err := ioutil.ReadAll(r) 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | intermediate = append( 36 | intermediate, 37 | append([]byte(nil), bts...), 38 | ) 39 | return nil 40 | }, 41 | } 42 | 43 | h, err := r.NextFrame() 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | exp := ws.Header{ 48 | Length: 3, 49 | Fin: false, 50 | OpCode: ws.OpText, 51 | } 52 | if act := h; act != exp { 53 | t.Fatalf("unexpected NextFrame() header: %+v; want %+v", act, exp) 54 | } 55 | 56 | act, err := ioutil.ReadAll(&r) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | if exp := []byte("foobarbaz"); !bytes.Equal(act, exp) { 61 | t.Errorf("unexpected all bytes: %q; want %q", act, exp) 62 | } 63 | if act, exp := len(intermediate), 2; act != exp { 64 | t.Errorf("unexpected intermediate payload: %d; want %d", act, exp) 65 | } else { 66 | for i, exp := range [][]byte{ 67 | []byte("ping"), 68 | []byte("pong"), 69 | } { 70 | if act := intermediate[i]; !bytes.Equal(act, exp) { 71 | t.Errorf( 72 | "unexpected #%d intermediate payload: %q; want %q", 73 | i, act, exp, 74 | ) 75 | } 76 | } 77 | } 78 | } 79 | 80 | func TestReaderNoFrameAdvance(t *testing.T) { 81 | r := Reader{ 82 | Source: eofReader, 83 | } 84 | if _, err := r.Read(make([]byte, 10)); err != ErrNoFrameAdvance { 85 | t.Errorf("Read() returned %v; want %v", err, ErrNoFrameAdvance) 86 | } 87 | } 88 | 89 | func TestReaderNextFrameAndReadEOF(t *testing.T) { 90 | for _, test := range []struct { 91 | source func() io.Reader 92 | nextFrameErr error 93 | readErr error 94 | }{ 95 | { 96 | source: func() io.Reader { return eofReader }, 97 | nextFrameErr: io.EOF, 98 | readErr: ErrNoFrameAdvance, 99 | }, 100 | { 101 | source: func() io.Reader { 102 | // This case tests that ReadMessage still fails after 103 | // successfully reading header bytes frame via ws.ReadHeader() 104 | // and non-successfully read of the body. 105 | var buf bytes.Buffer 106 | f := ws.NewTextFrame([]byte("this part will be lost")) 107 | if err := ws.WriteHeader(&buf, f.Header); err != nil { 108 | t.Fatal(err) 109 | } 110 | return &buf 111 | }, 112 | nextFrameErr: nil, 113 | readErr: io.ErrUnexpectedEOF, 114 | }, 115 | { 116 | source: func() io.Reader { 117 | var buf bytes.Buffer 118 | f := ws.NewTextFrame([]byte("foobar")) 119 | if err := ws.WriteHeader(&buf, f.Header); err != nil { 120 | t.Fatal(err) 121 | } 122 | buf.WriteString("foo") 123 | return &buf 124 | }, 125 | nextFrameErr: nil, 126 | readErr: io.ErrUnexpectedEOF, 127 | }, 128 | { 129 | source: func() io.Reader { 130 | var buf bytes.Buffer 131 | f := ws.NewFrame(ws.OpText, false, []byte("payload")) 132 | if err := ws.WriteFrame(&buf, f); err != nil { 133 | t.Fatal(err) 134 | } 135 | return &buf 136 | }, 137 | nextFrameErr: nil, 138 | readErr: io.ErrUnexpectedEOF, 139 | }, 140 | } { 141 | t.Run("", func(t *testing.T) { 142 | r := Reader{ 143 | Source: test.source(), 144 | } 145 | _, err := r.NextFrame() 146 | if err != test.nextFrameErr { 147 | t.Errorf("NextFrame() = %v; want %v", err, test.nextFrameErr) 148 | } 149 | var ( 150 | p = make([]byte, 4096) 151 | i = 0 152 | ) 153 | for { 154 | if i == 100 { 155 | t.Fatal(io.ErrNoProgress) 156 | } 157 | _, err := r.Read(p) 158 | if err == nil { 159 | continue 160 | } 161 | if err != test.readErr { 162 | t.Errorf("Read() = %v; want %v", err, test.readErr) 163 | } 164 | break 165 | } 166 | }) 167 | } 168 | } 169 | 170 | func TestMaxFrameSize(t *testing.T) { 171 | var buf bytes.Buffer 172 | msg := []byte("small frame") 173 | f := ws.NewTextFrame(msg) 174 | if err := ws.WriteFrame(&buf, f); err != nil { 175 | t.Fatal(err) 176 | } 177 | r := Reader{ 178 | Source: &buf, 179 | MaxFrameSize: int64(len(msg)) - 1, 180 | } 181 | 182 | _, err := r.NextFrame() 183 | if got, want := err, ErrFrameTooLarge; got != want { 184 | t.Errorf("NextFrame() error = %v; want %v", got, want) 185 | } 186 | 187 | p := make([]byte, 100) 188 | n, err := r.Read(p) 189 | if got, want := err, ErrNoFrameAdvance; got != want { 190 | t.Errorf("Read() error = %v; want %v", got, want) 191 | } 192 | if got, want := n, 0; got != want { 193 | t.Errorf("Read() bytes returned = %v; want %v", got, want) 194 | } 195 | } 196 | 197 | func TestReaderUTF8(t *testing.T) { 198 | yo := []byte("Ё") 199 | if !utf8.ValidString(string(yo)) { 200 | t.Fatal("bad fixture") 201 | } 202 | 203 | var buf bytes.Buffer 204 | ws.WriteFrame(&buf, 205 | ws.NewFrame(ws.OpText, false, yo[:1]), 206 | ) 207 | ws.WriteFrame(&buf, 208 | ws.NewFrame(ws.OpContinuation, true, yo[1:]), 209 | ) 210 | 211 | r := Reader{ 212 | Source: &buf, 213 | CheckUTF8: true, 214 | } 215 | if _, err := r.NextFrame(); err != nil { 216 | t.Fatal(err) 217 | } 218 | bts, err := ioutil.ReadAll(&r) 219 | if err != nil { 220 | t.Errorf("unexpected error: %v", err) 221 | } 222 | if !bytes.Equal(bts, yo) { 223 | t.Errorf("ReadAll(r) = %v; want %v", bts, yo) 224 | } 225 | } 226 | 227 | func TestNextReader(t *testing.T) { 228 | for _, test := range []struct { 229 | name string 230 | seq []ws.Frame 231 | chop int 232 | exp []byte 233 | err error 234 | }{ 235 | { 236 | name: "empty", 237 | seq: []ws.Frame{}, 238 | err: io.EOF, 239 | }, 240 | { 241 | name: "single", 242 | seq: []ws.Frame{ 243 | ws.NewTextFrame([]byte("Привет, Мир!")), 244 | }, 245 | exp: []byte("Привет, Мир!"), 246 | }, 247 | { 248 | name: "single_masked", 249 | seq: []ws.Frame{ 250 | ws.MaskFrame(ws.NewTextFrame([]byte("Привет, Мир!"))), 251 | }, 252 | exp: []byte("Привет, Мир!"), 253 | }, 254 | { 255 | name: "fragmented", 256 | seq: []ws.Frame{ 257 | ws.NewFrame(ws.OpText, false, []byte("Привет,")), 258 | ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,")), 259 | ws.NewFrame(ws.OpContinuation, false, []byte(" новый ")), 260 | ws.NewFrame(ws.OpContinuation, true, []byte("Мир!")), 261 | 262 | ws.NewTextFrame([]byte("Hello, Brave New World!")), 263 | }, 264 | exp: []byte("Привет, о дивный, новый Мир!"), 265 | }, 266 | { 267 | name: "fragmented_masked", 268 | seq: []ws.Frame{ 269 | ws.MaskFrame(ws.NewFrame(ws.OpText, false, []byte("Привет,"))), 270 | ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,"))), 271 | ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" новый "))), 272 | ws.MaskFrame(ws.NewFrame(ws.OpContinuation, true, []byte("Мир!"))), 273 | 274 | ws.MaskFrame(ws.NewTextFrame([]byte("Hello, Brave New World!"))), 275 | }, 276 | exp: []byte("Привет, о дивный, новый Мир!"), 277 | }, 278 | { 279 | name: "fragmented_and_control", 280 | seq: []ws.Frame{ 281 | ws.NewFrame(ws.OpText, false, []byte("Привет,")), 282 | ws.NewFrame(ws.OpPing, true, nil), 283 | ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,")), 284 | ws.NewFrame(ws.OpPing, true, nil), 285 | ws.NewFrame(ws.OpContinuation, false, []byte(" новый ")), 286 | ws.NewFrame(ws.OpPing, true, nil), 287 | ws.NewFrame(ws.OpPing, true, []byte("ping info")), 288 | ws.NewFrame(ws.OpContinuation, true, []byte("Мир!")), 289 | }, 290 | exp: []byte("Привет, о дивный, новый Мир!"), 291 | }, 292 | { 293 | name: "fragmented_and_control_mask", 294 | seq: []ws.Frame{ 295 | ws.MaskFrame(ws.NewFrame(ws.OpText, false, []byte("Привет,"))), 296 | ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)), 297 | ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" о дивный,"))), 298 | ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)), 299 | ws.MaskFrame(ws.NewFrame(ws.OpContinuation, false, []byte(" новый "))), 300 | ws.MaskFrame(ws.NewFrame(ws.OpPing, true, nil)), 301 | ws.MaskFrame(ws.NewFrame(ws.OpPing, true, []byte("ping info"))), 302 | ws.MaskFrame(ws.NewFrame(ws.OpContinuation, true, []byte("Мир!"))), 303 | }, 304 | exp: []byte("Привет, о дивный, новый Мир!"), 305 | }, 306 | } { 307 | t.Run(test.name, func(t *testing.T) { 308 | // Prepare input. 309 | buf := &bytes.Buffer{} 310 | for _, f := range test.seq { 311 | if err := ws.WriteFrame(buf, f); err != nil { 312 | t.Fatal(err) 313 | } 314 | } 315 | 316 | conn := &chopReader{ 317 | src: bytes.NewReader(buf.Bytes()), 318 | sz: test.chop, 319 | } 320 | 321 | var bts []byte 322 | _, reader, err := NextReader(conn, 0) 323 | if err == nil { 324 | bts, err = ioutil.ReadAll(reader) 325 | } 326 | if err != test.err { 327 | t.Fatalf("unexpected error; got %v; want %v", err, test.err) 328 | } 329 | if test.err == nil && !bytes.Equal(bts, test.exp) { 330 | t.Errorf( 331 | "ReadAll from reader:\nact:\t%#x\nexp:\t%#x\nact:\t%s\nexp:\t%s\n", 332 | bts, test.exp, string(bts), string(test.exp), 333 | ) 334 | } 335 | }) 336 | } 337 | } 338 | 339 | type chopReader struct { 340 | src io.Reader 341 | sz int 342 | } 343 | 344 | func (c chopReader) Read(p []byte) (n int, err error) { 345 | sz := c.sz 346 | if sz == 0 { 347 | sz = 1 348 | } 349 | if sz > len(p) { 350 | sz = len(p) 351 | } 352 | return c.src.Read(p[:sz]) 353 | } 354 | -------------------------------------------------------------------------------- /wsutil/upgrader.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "io" 7 | "io/ioutil" 8 | "net/http" 9 | 10 | "github.com/gobwas/ws" 11 | ) 12 | 13 | // DebugUpgrader is a wrapper around ws.Upgrader. It tracks I/O of a 14 | // WebSocket handshake. 15 | // 16 | // Note that it must not be used in production applications that requires 17 | // Upgrade() to be efficient. 18 | type DebugUpgrader struct { 19 | // Upgrader contains upgrade to WebSocket options. 20 | Upgrader ws.Upgrader 21 | 22 | // OnRequest and OnResponse are the callbacks that will be called with the 23 | // HTTP request and response respectively. 24 | OnRequest, OnResponse func([]byte) 25 | } 26 | 27 | // Upgrade calls Upgrade() on underlying ws.Upgrader and tracks I/O on conn. 28 | func (d *DebugUpgrader) Upgrade(conn io.ReadWriter) (hs ws.Handshake, err error) { 29 | var ( 30 | // Take the Reader and Writer parts from conn to be probably replaced 31 | // below. 32 | r io.Reader = conn 33 | w io.Writer = conn 34 | ) 35 | if onRequest := d.OnRequest; onRequest != nil { 36 | var buf bytes.Buffer 37 | // First, we must read the entire request. 38 | req, err := http.ReadRequest(bufio.NewReader( 39 | io.TeeReader(conn, &buf), 40 | )) 41 | if err == nil { 42 | // Fulfill the buffer with the response body. 43 | io.Copy(ioutil.Discard, req.Body) 44 | req.Body.Close() 45 | } 46 | onRequest(buf.Bytes()) 47 | 48 | r = io.MultiReader( 49 | &buf, conn, 50 | ) 51 | } 52 | 53 | if onResponse := d.OnResponse; onResponse != nil { 54 | var buf bytes.Buffer 55 | // Intercept the response stream written by the Upgrade(). 56 | w = io.MultiWriter( 57 | conn, &buf, 58 | ) 59 | defer func() { 60 | onResponse(buf.Bytes()) 61 | }() 62 | } 63 | 64 | return d.Upgrader.Upgrade(struct { 65 | io.Reader 66 | io.Writer 67 | }{r, w}) 68 | } 69 | -------------------------------------------------------------------------------- /wsutil/upgrader_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/url" 8 | "testing" 9 | 10 | "github.com/gobwas/ws" 11 | ) 12 | 13 | func TestDebugUpgrader(t *testing.T) { 14 | for _, test := range []struct { 15 | name string 16 | upgrader ws.Upgrader 17 | req []byte 18 | }{ 19 | { 20 | // Base case. 21 | }, 22 | { 23 | req: []byte("" + 24 | "GET /test HTTP/1.1\r\n" + 25 | "Host: example.org\r\n" + 26 | "\r\n", 27 | ), 28 | }, 29 | { 30 | req: []byte("PUT /fail HTTP/1.1\r\n\r\n"), 31 | }, 32 | { 33 | req: []byte("GET /fail HTTP/1.0\r\n\r\n"), 34 | }, 35 | } { 36 | t.Run(test.name, func(t *testing.T) { 37 | var ( 38 | reqBuf bytes.Buffer 39 | resBuf bytes.Buffer 40 | 41 | expReq, expRes []byte 42 | actReq, actRes []byte 43 | ) 44 | if test.req == nil { 45 | var dialer ws.Dialer 46 | dialer.Upgrade(struct { 47 | io.Reader 48 | io.Writer 49 | }{ 50 | new(falseReader), 51 | &reqBuf, 52 | }, makeURL("wss://example.org")) 53 | } else { 54 | reqBuf.Write(test.req) 55 | } 56 | 57 | // Need to save bytes before they will be read by Upgrade(). 58 | expReq = reqBuf.Bytes() 59 | 60 | du := DebugUpgrader{ 61 | Upgrader: test.upgrader, 62 | OnRequest: func(p []byte) { actReq = p }, 63 | OnResponse: func(p []byte) { actRes = p }, 64 | } 65 | du.Upgrade(struct { 66 | io.Reader 67 | io.Writer 68 | }{ 69 | &reqBuf, 70 | &resBuf, 71 | }) 72 | 73 | expRes = resBuf.Bytes() 74 | 75 | if !bytes.Equal(actReq, expReq) { 76 | t.Errorf( 77 | "unexpected request bytes:\nact:\n%s\nwant:\n%s\n", 78 | actReq, expReq, 79 | ) 80 | } 81 | if !bytes.Equal(actRes, expRes) { 82 | t.Errorf( 83 | "unexpected response bytes:\nact:\n%s\nwant:\n%s\n", 84 | actRes, expRes, 85 | ) 86 | } 87 | }) 88 | } 89 | } 90 | 91 | type falseReader struct{} 92 | 93 | func (f falseReader) Read([]byte) (int, error) { 94 | return 0, fmt.Errorf("falsy read") 95 | } 96 | 97 | func makeURL(s string) *url.URL { 98 | u, err := url.Parse(s) 99 | if err != nil { 100 | panic(err) 101 | } 102 | return u 103 | } 104 | -------------------------------------------------------------------------------- /wsutil/utf8.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // ErrInvalidUTF8 is returned by UTF8 reader on invalid utf8 sequence. 9 | var ErrInvalidUTF8 = fmt.Errorf("invalid utf8") 10 | 11 | // UTF8Reader implements io.Reader that calculates utf8 validity state after 12 | // every read byte from Source. 13 | // 14 | // Note that in some cases client must call r.Valid() after all bytes are read 15 | // to ensure that all of them are valid utf8 sequences. That is, some io helper 16 | // functions such io.ReadAtLeast or io.ReadFull could discard the error 17 | // information returned by the reader when they receive all of requested bytes. 18 | // For example, the last read sequence is invalid and UTF8Reader returns number 19 | // of bytes read and an error. But helper function decides to discard received 20 | // error due to all requested bytes are completely read from the source. 21 | // 22 | // Another possible case is when some valid sequence become split by the read 23 | // bound. Then UTF8Reader can not make decision about validity of the last 24 | // sequence cause it is not fully read yet. And if the read stops, Valid() will 25 | // return false, even if Read() by itself dit not. 26 | type UTF8Reader struct { 27 | Source io.Reader 28 | 29 | accepted int 30 | 31 | state uint32 32 | codep uint32 33 | } 34 | 35 | // NewUTF8Reader creates utf8 reader that reads from r. 36 | func NewUTF8Reader(r io.Reader) *UTF8Reader { 37 | return &UTF8Reader{ 38 | Source: r, 39 | } 40 | } 41 | 42 | // Reset resets utf8 reader to read from r. 43 | func (u *UTF8Reader) Reset(r io.Reader) { 44 | u.Source = r 45 | u.state = 0 46 | u.codep = 0 47 | } 48 | 49 | // Read implements io.Reader. 50 | func (u *UTF8Reader) Read(p []byte) (n int, err error) { 51 | n, err = u.Source.Read(p) 52 | 53 | accepted := 0 54 | s, c := u.state, u.codep 55 | for i := 0; i < n; i++ { 56 | c, s = decode(s, c, p[i]) 57 | if s == utf8Reject { 58 | u.state = s 59 | return accepted, ErrInvalidUTF8 60 | } 61 | if s == utf8Accept { 62 | accepted = i + 1 63 | } 64 | } 65 | u.state, u.codep = s, c 66 | u.accepted = accepted 67 | 68 | return n, err 69 | } 70 | 71 | // Valid checks current reader state. It returns true if all read bytes are 72 | // valid UTF-8 sequences, and false if not. 73 | func (u *UTF8Reader) Valid() bool { 74 | return u.state == utf8Accept 75 | } 76 | 77 | // Accepted returns number of valid bytes in last Read(). 78 | func (u *UTF8Reader) Accepted() int { 79 | return u.accepted 80 | } 81 | 82 | // Below is port of UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ 83 | // 84 | // Copyright (c) 2008-2009 Bjoern Hoehrmann 85 | // 86 | // Permission is hereby granted, free of charge, to any person obtaining a copy 87 | // of this software and associated documentation files (the "Software"), to 88 | // deal in the Software without restriction, including without limitation the 89 | // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 90 | // sell copies of the Software, and to permit persons to whom the Software is 91 | // furnished to do so, subject to the following conditions: 92 | // 93 | // The above copyright notice and this permission notice shall be included in 94 | // all copies or substantial portions of the Software. 95 | // 96 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 97 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 98 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 99 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 100 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 101 | // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 102 | // IN THE SOFTWARE. 103 | 104 | const ( 105 | utf8Accept = 0 106 | utf8Reject = 12 107 | ) 108 | 109 | var utf8d = [...]byte{ 110 | // The first part of the table maps bytes to character classes that 111 | // to reduce the size of the transition table and create bitmasks. 112 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 113 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 114 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 115 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 116 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 117 | 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 118 | 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 119 | 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 120 | 121 | // The second part is a transition table that maps a combination 122 | // of a state of the automaton and a character class to a state. 123 | 0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 124 | 12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12, 125 | 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 126 | 12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 127 | 12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 128 | } 129 | 130 | func decode(state, codep uint32, b byte) (uint32, uint32) { 131 | t := uint32(utf8d[b]) 132 | 133 | if state != utf8Accept { 134 | codep = (uint32(b) & 0x3f) | (codep << 6) 135 | } else { 136 | codep = (0xff >> t) & uint32(b) 137 | } 138 | 139 | return codep, uint32(utf8d[256+state+t]) 140 | } 141 | -------------------------------------------------------------------------------- /wsutil/utf8_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "fmt" 7 | "io" 8 | "io/ioutil" 9 | "testing" 10 | "unicode/utf8" 11 | ) 12 | 13 | func TestUTF8ReaderReadFull(t *testing.T) { 14 | for _, test := range []struct { 15 | hex string 16 | err bool 17 | valid bool 18 | n int 19 | }{ 20 | { 21 | hex: "cebae1bdb9cf83cebcceb5eda080656469746564", 22 | err: true, 23 | valid: false, 24 | n: 11, 25 | }, 26 | { 27 | hex: "cebae1bdb9cf83cebcceb5eda080656469746564", 28 | valid: false, 29 | err: true, 30 | n: 11, 31 | }, 32 | { 33 | hex: "7f7f7fdf", 34 | valid: false, 35 | err: false, 36 | n: 4, 37 | }, 38 | { 39 | hex: "dfbf", 40 | n: 2, 41 | valid: true, 42 | err: false, 43 | }, 44 | } { 45 | t.Run("", func(t *testing.T) { 46 | bts, err := hex.DecodeString(test.hex) 47 | if err != nil { 48 | t.Fatal(err) 49 | } 50 | 51 | src := bytes.NewReader(bts) 52 | r := NewUTF8Reader(src) 53 | 54 | p := make([]byte, src.Len()) 55 | n, err := io.ReadFull(r, p) 56 | 57 | if err != nil && !utf8.Valid(bts[:n]) { 58 | // Should return only number of valid bytes read. 59 | t.Errorf("read n bytes is actually invalid utf8 sequence") 60 | } 61 | if n := r.Accepted(); err == nil && !utf8.Valid(bts[:n]) { 62 | // Should return only number of valid bytes read. 63 | t.Errorf("read n bytes is actually invalid utf8 sequence") 64 | } 65 | if test.err && err == nil { 66 | t.Errorf("expected read error; got nil") 67 | } 68 | if !test.err && err != nil { 69 | t.Errorf("unexpected read error: %s", err) 70 | } 71 | if n != test.n { 72 | t.Errorf("ReadFull() read %d; want %d", n, test.n) 73 | } 74 | if act, exp := r.Valid(), test.valid; act != exp { 75 | t.Errorf("Valid() = %v; want %v", act, exp) 76 | } 77 | }) 78 | } 79 | } 80 | 81 | func TestUTF8Reader(t *testing.T) { 82 | for i, test := range []struct { 83 | label string 84 | 85 | data []byte 86 | // or 87 | hex string 88 | 89 | chop int 90 | 91 | err bool 92 | valid bool 93 | at int 94 | }{ 95 | { 96 | data: []byte("hello, world!"), 97 | valid: true, 98 | chop: 2, 99 | }, 100 | { 101 | data: []byte{0x7f, 0xf0, 0x00}, 102 | valid: false, 103 | err: true, 104 | at: 2, 105 | chop: 1, 106 | }, 107 | { 108 | hex: "48656c6c6f2dc2b540c39fc3b6c3a4c3bcc3a0c3a12d5554462d382121", 109 | valid: true, 110 | chop: 1, 111 | }, 112 | { 113 | hex: "cebae1bdb9cf83cebcceb5eda080656469746564", 114 | valid: false, 115 | err: true, 116 | at: 12, 117 | chop: 1, 118 | }, 119 | } { 120 | t.Run(fmt.Sprintf("%s#%d", test.label, i), func(t *testing.T) { 121 | data := test.data 122 | if h := test.hex; h != "" { 123 | var err error 124 | if data, err = hex.DecodeString(h); err != nil { 125 | t.Fatal(err) 126 | } 127 | } 128 | 129 | cr := &chopReader{ 130 | src: bytes.NewReader(data), 131 | sz: test.chop, 132 | } 133 | 134 | r := NewUTF8Reader(cr) 135 | 136 | bts := make([]byte, 2*len(data)) 137 | 138 | var ( 139 | i, n int 140 | err error 141 | ) 142 | for { 143 | n, err = r.Read(bts[i:]) 144 | i += n 145 | if err != nil { 146 | if err == io.EOF { 147 | err = nil 148 | } 149 | bts = bts[:i] 150 | break 151 | } 152 | } 153 | if test.err && err == nil { 154 | t.Fatalf("want error; got nil") 155 | } 156 | if !test.err && err != nil { 157 | t.Fatalf("unexpected error: %s", err) 158 | } 159 | if test.err && err == ErrInvalidUTF8 && i != test.at { 160 | t.Fatalf("received error at %d; want at %d", i, test.at) 161 | } 162 | if act, exp := r.Valid(), test.valid; act != exp { 163 | t.Fatalf("Valid() = %v; want %v", act, exp) 164 | } 165 | if !test.err && !bytes.Equal(bts, data) { 166 | t.Errorf("bytes are not equal") 167 | } 168 | }) 169 | } 170 | } 171 | 172 | func BenchmarkUTF8Reader(b *testing.B) { 173 | for i, bench := range []struct { 174 | label string 175 | data []byte 176 | chop int 177 | err bool 178 | }{ 179 | { 180 | data: bytes.Repeat([]byte("x"), 1024), 181 | chop: 128, 182 | }, 183 | { 184 | data: append( 185 | bytes.Repeat([]byte("x"), 1024), 186 | append( 187 | []byte{0x7f, 0xf0}, 188 | bytes.Repeat([]byte("x"), 128)..., 189 | )..., 190 | ), 191 | err: true, 192 | chop: 7, 193 | }, 194 | } { 195 | b.Run(fmt.Sprintf("%s#%d", bench.label, i), func(b *testing.B) { 196 | for i := 0; i < b.N; i++ { 197 | cr := &chopReader{ 198 | src: bytes.NewReader(bench.data), 199 | sz: bench.chop, 200 | } 201 | r := NewUTF8Reader(cr) 202 | _, err := ioutil.ReadAll(r) 203 | if !bench.err && err != nil { 204 | b.Fatal(err) 205 | } 206 | } 207 | }) 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /wsutil/wsutil.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package wsutil provides utilities for working with WebSocket protocol. 3 | 4 | Overview: 5 | 6 | // Read masked text message from peer and check utf8 encoding. 7 | header, err := ws.ReadHeader(conn) 8 | if err != nil { 9 | // handle err 10 | } 11 | 12 | // Prepare to read payload. 13 | r := io.LimitReader(conn, header.Length) 14 | r = wsutil.NewCipherReader(r, header.Mask) 15 | r = wsutil.NewUTF8Reader(r) 16 | 17 | payload, err := ioutil.ReadAll(r) 18 | if err != nil { 19 | // handle err 20 | } 21 | 22 | You could get the same behavior using just `wsutil.Reader`: 23 | 24 | r := wsutil.Reader{ 25 | Source: conn, 26 | CheckUTF8: true, 27 | } 28 | 29 | payload, err := ioutil.ReadAll(r) 30 | if err != nil { 31 | // handle err 32 | } 33 | 34 | Or even simplest: 35 | 36 | payload, err := wsutil.ReadClientText(conn) 37 | if err != nil { 38 | // handle err 39 | } 40 | 41 | Package is also exports tools for buffered writing: 42 | 43 | // Create buffered writer, that will buffer output bytes and send them as 44 | // 128-length fragments (with exception on large writes, see the doc). 45 | writer := wsutil.NewWriterSize(conn, ws.StateServerSide, ws.OpText, 128) 46 | 47 | _, err := io.CopyN(writer, rand.Reader, 100) 48 | if err == nil { 49 | err = writer.Flush() 50 | } 51 | if err != nil { 52 | // handle error 53 | } 54 | 55 | For more utils and helpers see the documentation. 56 | */ 57 | package wsutil 58 | --------------------------------------------------------------------------------