├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── pr-check.yml │ └── tests.yml ├── .gitignore ├── .golangci.yaml ├── .licenserc.yaml ├── AUTHORS ├── LICENSE ├── README.md ├── _typos.toml ├── client.go ├── client_test.go ├── compression.go ├── compression_test.go ├── conn.go ├── conn_broadcast_test.go ├── conn_test.go ├── examples ├── autobahn │ ├── README.md │ ├── config │ │ └── fuzzingclient.json │ └── server.go ├── chat │ ├── README.md │ ├── client.go │ ├── home.html │ ├── hub.go │ └── main.go ├── command │ ├── README.md │ ├── home.html │ └── main.go ├── echo │ ├── README.md │ ├── client.go │ └── server.go └── filewatch │ ├── README.md │ └── main.go ├── go.mod ├── go.sum ├── json.go ├── json_test.go ├── licenses ├── LICENSE-gotils └── LICENSE-websocket ├── mask.go ├── mask_safe.go ├── mask_test.go ├── prepared.go ├── prepared_test.go ├── server.go └── util.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | **To Reproduce** 15 | 16 | Steps to reproduce the behavior: 17 | 1. Go to '...' 18 | 2. Click on '....' 19 | 3. Scroll down to '....' 20 | 4. See error 21 | 22 | **Expected behavior** 23 | 24 | A clear and concise description of what you expected to happen. 25 | 26 | **Screenshots** 27 | 28 | If applicable, add screenshots to help explain your problem. 29 | 30 | **Hertz version:** 31 | 32 | Please provide the version of Hertz you are using. 33 | 34 | **Environment:** 35 | 36 | The output of `go env`. 37 | 38 | **Additional context** 39 | 40 | Add any other context about the problem here. 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | 12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 13 | 14 | **Describe the solution you'd like** 15 | 16 | A clear and concise description of what you want to happen. 17 | 18 | **Describe alternatives you've considered** 19 | 20 | A clear and concise description of any alternative solutions or features you've considered. 21 | 22 | **Additional context** 23 | 24 | Add any other context or screenshots about the feature request here. 25 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | #### What type of PR is this? 2 | 3 | 18 | 19 | #### What this PR does / why we need it (English/Chinese): 20 | 21 | 25 | 26 | #### Which issue(s) this PR fixes: 27 | 28 | 32 | -------------------------------------------------------------------------------- /.github/workflows/pr-check.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request Check 2 | 3 | on: [ pull_request ] 4 | 5 | jobs: 6 | compliant: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | 11 | - name: Check License Header 12 | uses: apache/skywalking-eyes/header@v0.4.0 13 | env: 14 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 15 | 16 | - name: Check Spell 17 | uses: crate-ci/typos@master 18 | 19 | golangci-lint: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Go 24 | uses: actions/setup-go@v5 25 | with: 26 | go-version: stable 27 | # for self-hosted, the cache path is shared across projects 28 | # and it works well without the cache of github actions 29 | # Enable it if we're going to use Github only 30 | cache: true 31 | 32 | - name: Golangci Lint 33 | # https://golangci-lint.run/ 34 | uses: golangci/golangci-lint-action@v6 35 | with: 36 | version: latest 37 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | benchmark: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Go 11 | uses: actions/setup-go@v5 12 | with: 13 | go-version: stable 14 | 15 | - name: Benchmark 16 | run: go test -bench=. -benchmem ./... 17 | 18 | uinttest: 19 | strategy: 20 | matrix: 21 | go: [ "1.18", "1.19", "1.20", "1.21", "1.22", "1.23" ] 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Go 26 | uses: actions/setup-go@v5 27 | with: 28 | go-version: ${{ matrix.go }} 29 | cache: true # set false for self-hosted 30 | - name: Unit Test 31 | run: go test -race ./... 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # goland 15 | .idea 16 | # vscode 17 | .vscode 18 | 19 | # Go workspace file 20 | go.work 21 | go.work.sum 22 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linters: # https://golangci-lint.run/usage/linters/ 2 | disable-all: true 3 | enable: 4 | # - errcheck # can not skip _test.go ? 5 | - gosimple 6 | - govet 7 | - ineffassign 8 | - staticcheck 9 | - unused 10 | - unconvert 11 | -------------------------------------------------------------------------------- /.licenserc.yaml: -------------------------------------------------------------------------------- 1 | header: 2 | license: 3 | spdx-id: Apache-2.0 4 | copyright-owner: CloudWeGo Authors 5 | content: | 6 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 7 | // Use of this source code is governed by a BSD-style 8 | // license that can be found in the LICENSE file. 9 | // 10 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 11 | // Modifications are Copyright 2022 CloudWeGo Authors. 12 | 13 | paths: 14 | - '**/*.go' 15 | - '**/*.s' 16 | 17 | comment: on-failure 18 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of Gorilla WebSocket authors for copyright 2 | # purposes. 3 | # 4 | # Please keep the list sorted. 5 | 6 | Gary Burd 7 | Google LLC (https://opensource.google.com/) 8 | Joachim Bauch 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hertz-WebSocket(This is a community driven project) 2 | 3 | 4 | This repo is forked from [Gorilla WebSocket](https://github.com/gorilla/websocket/) and adapted to Hertz. 5 | 6 | ### How to use 7 | ```go 8 | package main 9 | 10 | import ( 11 | "context" 12 | "log" 13 | 14 | "github.com/cloudwego/hertz/pkg/app" 15 | "github.com/cloudwego/hertz/pkg/app/server" 16 | "github.com/hertz-contrib/websocket" 17 | ) 18 | 19 | var upgrader = websocket.HertzUpgrader{} // use default options 20 | 21 | func echo(_ context.Context, c *app.RequestContext) { 22 | err := upgrader.Upgrade(c, func(conn *websocket.Conn) { 23 | for { 24 | mt, message, err := conn.ReadMessage() 25 | if err != nil { 26 | log.Println("read:", err) 27 | break 28 | } 29 | log.Printf("recv: %s", message) 30 | err = conn.WriteMessage(mt, message) 31 | if err != nil { 32 | log.Println("write:", err) 33 | break 34 | } 35 | } 36 | }) 37 | if err != nil { 38 | log.Print("upgrade:", err) 39 | return 40 | } 41 | } 42 | 43 | 44 | func main() { 45 | h := server.Default(server.WithHostPorts(addr)) 46 | // https://github.com/cloudwego/hertz/issues/121 47 | h.NoHijackConnPool = true 48 | h.GET("/echo", echo) 49 | h.Spin() 50 | } 51 | 52 | ``` 53 | 54 | ### More info 55 | 56 | See [examples](examples/) 57 | 58 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | # Typo check: https://github.com/crate-ci/typos 2 | 3 | [files] 4 | extend-exclude = ["go.sum"] 5 | 6 | [default.extend-identifiers] 7 | # *sigh* this just isn't worth the cost of fixing 8 | ConnTLSer = "ConnTLSer" 9 | flate = "flate" 10 | TestCompressFlateSerial = "TestCompressFlateSerial" 11 | testCompressFlate = "testCompressFlate" 12 | TestCompressFlateConcurrent = "TestCompressFlateConcurrent" 13 | trUe = "trUe" 14 | OPTIO = "OPTIO" 15 | contant = "contant" 16 | referer = "referer" 17 | HeaderReferer = "HeaderReferer" 18 | expectedReferer = "expectedReferer" 19 | Referer = "Referer" 20 | flateWriterPools = "flateWriterPools" 21 | flateReaderPool = "flateReaderPool" 22 | flateWriteWrapper = "flateWriteWrapper" 23 | flateReadWrapper = "flateReadWrapper" 24 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bytes" 12 | "errors" 13 | "fmt" 14 | "time" 15 | 16 | "github.com/cloudwego/hertz/pkg/protocol" 17 | ) 18 | 19 | // ErrBadHandshake is returned when the server response to opening handshake is 20 | // invalid. 21 | var ErrBadHandshake = errors.New("websocket: bad handshake") 22 | 23 | // ClientUpgrader is a helper for upgrading hertz http response to websocket conn. 24 | // See ExampleClient for usage 25 | type ClientUpgrader struct { 26 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 27 | // size is zero, then buffers allocated by the HTTP server are used. The 28 | // I/O buffer sizes do not limit the size of the messages that can be sent 29 | // or received. 30 | ReadBufferSize, WriteBufferSize int 31 | 32 | // WriteBufferPool is a pool of buffers for write operations. If the value 33 | // is not set, then write buffers are allocated to the connection for the 34 | // lifetime of the connection. 35 | // 36 | // A pool is most useful when the application has a modest volume of writes 37 | // across a large number of connections. 38 | // 39 | // Applications should use a single pool for each unique value of 40 | // WriteBufferSize. 41 | WriteBufferPool BufferPool 42 | 43 | // EnableCompression specify if the server should attempt to negotiate per 44 | // message compression (RFC 7692). Setting this value to true does not 45 | // guarantee that compression will be supported. Currently only "no context 46 | // takeover" modes are supported. 47 | EnableCompression bool 48 | } 49 | 50 | // PrepareRequest prepares request for websocket 51 | // 52 | // It adds headers for websocket, 53 | // and it must be called BEFORE sending http request via cli.DoXXX 54 | func (p *ClientUpgrader) PrepareRequest(req *protocol.Request) { 55 | req.Header.Set("Upgrade", "websocket") 56 | req.Header.Set("Connection", "Upgrade") 57 | req.Header.Set("Sec-WebSocket-Version", "13") 58 | req.Header.Set("Sec-WebSocket-Key", generateChallengeKey()) 59 | if p.EnableCompression { 60 | req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") 61 | } 62 | } 63 | 64 | // UpgradeResponse upgrades a response to websocket conn 65 | // 66 | // It returns Conn if success. ErrBadHandshake is returned if headers go wrong. 67 | // This method must be called after PrepareRequest and (*.Client).DoXXX 68 | func (p *ClientUpgrader) UpgradeResponse(req *protocol.Request, resp *protocol.Response) (*Conn, error) { 69 | if resp.StatusCode() != 101 || 70 | !tokenContainsValue(resp.Header.Get("Upgrade"), "websocket") || 71 | !tokenContainsValue(resp.Header.Get("Connection"), "Upgrade") || 72 | resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKeyBytes(req.Header.Peek("Sec-Websocket-Key")) { 73 | return nil, ErrBadHandshake 74 | } 75 | 76 | c, err := resp.Hijack() 77 | if err != nil { 78 | return nil, fmt.Errorf("Hijack response connection err: %w", err) 79 | } 80 | 81 | c.SetDeadline(time.Time{}) 82 | conn := newConn(c, false, p.ReadBufferSize, p.WriteBufferSize, p.WriteBufferPool, nil, nil) 83 | 84 | // can not use p.EnableCompression, always follow ext returned from server 85 | compress := false 86 | extensions := parseDataHeader(resp.Header.Peek("Sec-WebSocket-Extensions")) 87 | for _, ext := range extensions { 88 | if bytes.HasPrefix(ext, strPermessageDeflate) { 89 | compress = true 90 | } 91 | } 92 | if compress { 93 | conn.newCompressionWriter = compressNoContextTakeover 94 | conn.newDecompressionReader = decompressNoContextTakeover 95 | } 96 | conn.resp = resp 97 | return conn, nil 98 | } 99 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "context" 12 | "fmt" 13 | "log" 14 | "time" 15 | 16 | "github.com/cloudwego/hertz/pkg/app" 17 | "github.com/cloudwego/hertz/pkg/app/client" 18 | "github.com/cloudwego/hertz/pkg/app/server" 19 | "github.com/cloudwego/hertz/pkg/network/standard" 20 | "github.com/cloudwego/hertz/pkg/protocol" 21 | ) 22 | 23 | const ( 24 | testaddr = "localhost:10012" 25 | testpath = "/echo" 26 | ) 27 | 28 | func ExampleClient() { 29 | runServer(testaddr) 30 | time.Sleep(50 * time.Millisecond) // await server running 31 | 32 | c, err := client.NewClient(client.WithDialer(standard.NewDialer())) 33 | if err != nil { 34 | panic(err) 35 | } 36 | 37 | req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() 38 | req.SetRequestURI("http://" + testaddr + testpath) 39 | req.SetMethod("GET") 40 | 41 | u := &ClientUpgrader{} 42 | u.PrepareRequest(req) 43 | err = c.Do(context.Background(), req, resp) 44 | if err != nil { 45 | panic(err) 46 | } 47 | conn, err := u.UpgradeResponse(req, resp) 48 | if err != nil { 49 | panic(err) 50 | } 51 | 52 | conn.WriteMessage(TextMessage, []byte("hello")) 53 | m, b, err := conn.ReadMessage() 54 | if err != nil { 55 | panic(err) 56 | } 57 | fmt.Println(m, string(b)) 58 | // Output: 1 hello 59 | } 60 | 61 | func runServer(addr string) { 62 | upgrader := HertzUpgrader{} // use default options 63 | h := server.Default(server.WithHostPorts(addr)) 64 | // https://github.com/cloudwego/hertz/issues/121 65 | h.NoHijackConnPool = true 66 | h.GET(testpath, func(_ context.Context, c *app.RequestContext) { 67 | err := upgrader.Upgrade(c, func(conn *Conn) { 68 | for { 69 | mt, message, err := conn.ReadMessage() 70 | if err != nil { 71 | log.Println("read:", err) 72 | break 73 | } 74 | log.Printf("[server] recv: %v %s", mt, message) 75 | err = conn.WriteMessage(mt, message) 76 | if err != nil { 77 | log.Println("write:", err) 78 | break 79 | } 80 | } 81 | }) 82 | if err != nil { 83 | log.Print("upgrade:", err) 84 | return 85 | } 86 | }) 87 | go func() { 88 | if err := h.Run(); err != nil { 89 | log.Fatal(err) 90 | } 91 | }() 92 | } 93 | -------------------------------------------------------------------------------- /compression.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "compress/flate" 12 | "errors" 13 | "io" 14 | "strings" 15 | "sync" 16 | ) 17 | 18 | const ( 19 | minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 20 | maxCompressionLevel = flate.BestCompression 21 | defaultCompressionLevel = 1 22 | ) 23 | 24 | var ( 25 | flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool 26 | flateReaderPool = sync.Pool{New: func() interface{} { 27 | return flate.NewReader(nil) 28 | }} 29 | ) 30 | 31 | func decompressNoContextTakeover(r io.Reader) io.ReadCloser { 32 | const tail = 33 | // Add four bytes as specified in RFC 34 | "\x00\x00\xff\xff" + 35 | // Add final block to squelch unexpected EOF error from flate reader. 36 | "\x01\x00\x00\xff\xff" 37 | 38 | fr, _ := flateReaderPool.Get().(io.ReadCloser) 39 | fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) 40 | return &flateReadWrapper{fr} 41 | } 42 | 43 | func isValidCompressionLevel(level int) bool { 44 | return minCompressionLevel <= level && level <= maxCompressionLevel 45 | } 46 | 47 | func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { 48 | p := &flateWriterPools[level-minCompressionLevel] 49 | tw := &truncWriter{w: w} 50 | fw, _ := p.Get().(*flate.Writer) 51 | if fw == nil { 52 | fw, _ = flate.NewWriter(tw, level) 53 | } else { 54 | fw.Reset(tw) 55 | } 56 | return &flateWriteWrapper{fw: fw, tw: tw, p: p} 57 | } 58 | 59 | // truncWriter is an io.Writer that writes all but the last four bytes of the 60 | // stream to another io.Writer. 61 | type truncWriter struct { 62 | w io.WriteCloser 63 | n int 64 | p [4]byte 65 | } 66 | 67 | func (w *truncWriter) Write(p []byte) (int, error) { 68 | n := 0 69 | 70 | // fill buffer first for simplicity. 71 | if w.n < len(w.p) { 72 | n = copy(w.p[w.n:], p) 73 | p = p[n:] 74 | w.n += n 75 | if len(p) == 0 { 76 | return n, nil 77 | } 78 | } 79 | 80 | m := len(p) 81 | if m > len(w.p) { 82 | m = len(w.p) 83 | } 84 | 85 | if nn, err := w.w.Write(w.p[:m]); err != nil { 86 | return n + nn, err 87 | } 88 | 89 | copy(w.p[:], w.p[m:]) 90 | copy(w.p[len(w.p)-m:], p[len(p)-m:]) 91 | nn, err := w.w.Write(p[:len(p)-m]) 92 | return n + nn, err 93 | } 94 | 95 | type flateWriteWrapper struct { 96 | fw *flate.Writer 97 | tw *truncWriter 98 | p *sync.Pool 99 | } 100 | 101 | func (w *flateWriteWrapper) Write(p []byte) (int, error) { 102 | if w.fw == nil { 103 | return 0, errWriteClosed 104 | } 105 | return w.fw.Write(p) 106 | } 107 | 108 | func (w *flateWriteWrapper) Close() error { 109 | if w.fw == nil { 110 | return errWriteClosed 111 | } 112 | err1 := w.fw.Flush() 113 | w.p.Put(w.fw) 114 | w.fw = nil 115 | if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { 116 | return errors.New("websocket: internal error, unexpected bytes at end of flate stream") 117 | } 118 | err2 := w.tw.w.Close() 119 | if err1 != nil { 120 | return err1 121 | } 122 | return err2 123 | } 124 | 125 | type flateReadWrapper struct { 126 | fr io.ReadCloser 127 | } 128 | 129 | func (r *flateReadWrapper) Read(p []byte) (int, error) { 130 | if r.fr == nil { 131 | return 0, io.ErrClosedPipe 132 | } 133 | n, err := r.fr.Read(p) 134 | if err == io.EOF { 135 | // Preemptively place the reader back in the pool. This helps with 136 | // scenarios where the application does not call NextReader() soon after 137 | // this final read. 138 | r.Close() 139 | } 140 | return n, err 141 | } 142 | 143 | func (r *flateReadWrapper) Close() error { 144 | if r.fr == nil { 145 | return io.ErrClosedPipe 146 | } 147 | err := r.fr.Close() 148 | flateReaderPool.Put(r.fr) 149 | r.fr = nil 150 | return err 151 | } 152 | -------------------------------------------------------------------------------- /compression_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bytes" 12 | "fmt" 13 | "io" 14 | "io/ioutil" 15 | "testing" 16 | ) 17 | 18 | type nopCloser struct{ io.Writer } 19 | 20 | func (nopCloser) Close() error { return nil } 21 | 22 | func TestTruncWriter(t *testing.T) { 23 | const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" 24 | for n := 1; n <= 10; n++ { 25 | var b bytes.Buffer 26 | w := &truncWriter{w: nopCloser{&b}} 27 | p := []byte(data) 28 | for len(p) > 0 { 29 | m := len(p) 30 | if m > n { 31 | m = n 32 | } 33 | w.Write(p[:m]) 34 | p = p[m:] 35 | } 36 | if b.String() != data[:len(data)-len(w.p)] { 37 | t.Errorf("%d: %q", n, b.String()) 38 | } 39 | } 40 | } 41 | 42 | func textMessages(num int) [][]byte { 43 | messages := make([][]byte, num) 44 | for i := 0; i < num; i++ { 45 | msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) 46 | messages[i] = []byte(msg) 47 | } 48 | return messages 49 | } 50 | 51 | func BenchmarkWriteNoCompression(b *testing.B) { 52 | w := ioutil.Discard 53 | c := newTestConn(nil, w, false) 54 | messages := textMessages(100) 55 | b.ResetTimer() 56 | for i := 0; i < b.N; i++ { 57 | c.WriteMessage(TextMessage, messages[i%len(messages)]) 58 | } 59 | b.ReportAllocs() 60 | } 61 | 62 | func BenchmarkWriteWithCompression(b *testing.B) { 63 | w := ioutil.Discard 64 | c := newTestConn(nil, w, false) 65 | messages := textMessages(100) 66 | c.enableWriteCompression = true 67 | c.newCompressionWriter = compressNoContextTakeover 68 | b.ResetTimer() 69 | for i := 0; i < b.N; i++ { 70 | c.WriteMessage(TextMessage, messages[i%len(messages)]) 71 | } 72 | b.ReportAllocs() 73 | } 74 | 75 | func TestValidCompressionLevel(t *testing.T) { 76 | c := newTestConn(nil, nil, false) 77 | for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { 78 | if err := c.SetCompressionLevel(level); err == nil { 79 | t.Errorf("no error for level %d", level) 80 | } 81 | } 82 | for _, level := range []int{minCompressionLevel, maxCompressionLevel} { 83 | if err := c.SetCompressionLevel(level); err != nil { 84 | t.Errorf("error for level %d", level) 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bufio" 12 | "encoding/binary" 13 | "errors" 14 | "io" 15 | "io/ioutil" 16 | "math/rand" 17 | "net" 18 | "strconv" 19 | "strings" 20 | "sync" 21 | "time" 22 | "unicode/utf8" 23 | ) 24 | 25 | const ( 26 | // Frame header byte 0 bits from Section 5.2 of RFC 6455 27 | finalBit = 1 << 7 28 | rsv1Bit = 1 << 6 29 | rsv2Bit = 1 << 5 30 | rsv3Bit = 1 << 4 31 | 32 | // Frame header byte 1 bits from Section 5.2 of RFC 6455 33 | maskBit = 1 << 7 34 | 35 | maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask 36 | maxControlFramePayloadSize = 125 37 | 38 | writeWait = time.Second 39 | 40 | defaultReadBufferSize = 4096 41 | defaultWriteBufferSize = 4096 42 | 43 | continuationFrame = 0 44 | noFrame = -1 45 | ) 46 | 47 | // Close codes defined in RFC 6455, section 11.7. 48 | const ( 49 | CloseNormalClosure = 1000 50 | CloseGoingAway = 1001 51 | CloseProtocolError = 1002 52 | CloseUnsupportedData = 1003 53 | CloseNoStatusReceived = 1005 54 | CloseAbnormalClosure = 1006 55 | CloseInvalidFramePayloadData = 1007 56 | ClosePolicyViolation = 1008 57 | CloseMessageTooBig = 1009 58 | CloseMandatoryExtension = 1010 59 | CloseInternalServerErr = 1011 60 | CloseServiceRestart = 1012 61 | CloseTryAgainLater = 1013 62 | CloseTLSHandshake = 1015 63 | ) 64 | 65 | // The message types are defined in RFC 6455, section 11.8. 66 | const ( 67 | // TextMessage denotes a text data message. The text message payload is 68 | // interpreted as UTF-8 encoded text data. 69 | TextMessage = 1 70 | 71 | // BinaryMessage denotes a binary data message. 72 | BinaryMessage = 2 73 | 74 | // CloseMessage denotes a close control message. The optional message 75 | // payload contains a numeric code and text. Use the FormatCloseMessage 76 | // function to format a close message payload. 77 | CloseMessage = 8 78 | 79 | // PingMessage denotes a ping control message. The optional message payload 80 | // is UTF-8 encoded text. 81 | PingMessage = 9 82 | 83 | // PongMessage denotes a pong control message. The optional message payload 84 | // is UTF-8 encoded text. 85 | PongMessage = 10 86 | ) 87 | 88 | // ErrCloseSent is returned when the application writes a message to the 89 | // connection after sending a close message. 90 | var ErrCloseSent = errors.New("websocket: close sent") 91 | 92 | // ErrReadLimit is returned when reading a message that is larger than the 93 | // read limit set for the connection. 94 | var ErrReadLimit = errors.New("websocket: read limit exceeded") 95 | 96 | // netError satisfies the net Error interface. 97 | type netError struct { 98 | msg string 99 | temporary bool 100 | timeout bool 101 | } 102 | 103 | func (e *netError) Error() string { return e.msg } 104 | func (e *netError) Temporary() bool { return e.temporary } 105 | func (e *netError) Timeout() bool { return e.timeout } 106 | 107 | // CloseError represents a close message. 108 | type CloseError struct { 109 | // Code is defined in RFC 6455, section 11.7. 110 | Code int 111 | 112 | // Text is the optional text payload. 113 | Text string 114 | } 115 | 116 | func (e *CloseError) Error() string { 117 | s := []byte("websocket: close ") 118 | s = strconv.AppendInt(s, int64(e.Code), 10) 119 | switch e.Code { 120 | case CloseNormalClosure: 121 | s = append(s, " (normal)"...) 122 | case CloseGoingAway: 123 | s = append(s, " (going away)"...) 124 | case CloseProtocolError: 125 | s = append(s, " (protocol error)"...) 126 | case CloseUnsupportedData: 127 | s = append(s, " (unsupported data)"...) 128 | case CloseNoStatusReceived: 129 | s = append(s, " (no status)"...) 130 | case CloseAbnormalClosure: 131 | s = append(s, " (abnormal closure)"...) 132 | case CloseInvalidFramePayloadData: 133 | s = append(s, " (invalid payload data)"...) 134 | case ClosePolicyViolation: 135 | s = append(s, " (policy violation)"...) 136 | case CloseMessageTooBig: 137 | s = append(s, " (message too big)"...) 138 | case CloseMandatoryExtension: 139 | s = append(s, " (mandatory extension missing)"...) 140 | case CloseInternalServerErr: 141 | s = append(s, " (internal server error)"...) 142 | case CloseTLSHandshake: 143 | s = append(s, " (TLS handshake error)"...) 144 | } 145 | if e.Text != "" { 146 | s = append(s, ": "...) 147 | s = append(s, e.Text...) 148 | } 149 | return string(s) 150 | } 151 | 152 | // IsCloseError returns boolean indicating whether the error is a *CloseError 153 | // with one of the specified codes. 154 | func IsCloseError(err error, codes ...int) bool { 155 | if e, ok := err.(*CloseError); ok { 156 | for _, code := range codes { 157 | if e.Code == code { 158 | return true 159 | } 160 | } 161 | } 162 | return false 163 | } 164 | 165 | // IsUnexpectedCloseError returns boolean indicating whether the error is a 166 | // *CloseError with a code not in the list of expected codes. 167 | func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { 168 | if e, ok := err.(*CloseError); ok { 169 | for _, code := range expectedCodes { 170 | if e.Code == code { 171 | return false 172 | } 173 | } 174 | return true 175 | } 176 | return false 177 | } 178 | 179 | var ( 180 | errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} 181 | errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} 182 | errBadWriteOpCode = errors.New("websocket: bad write message type") 183 | errWriteClosed = errors.New("websocket: write closed") 184 | errInvalidControlFrame = errors.New("websocket: invalid control frame") 185 | ) 186 | 187 | func newMaskKey() [4]byte { 188 | n := rand.Uint32() 189 | return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} 190 | } 191 | 192 | func hideTempErr(err error) error { 193 | if e, ok := err.(net.Error); ok && e.Temporary() { 194 | err = &netError{msg: e.Error(), timeout: e.Timeout()} 195 | } 196 | return err 197 | } 198 | 199 | func isControl(frameType int) bool { 200 | return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage 201 | } 202 | 203 | func isData(frameType int) bool { 204 | return frameType == TextMessage || frameType == BinaryMessage 205 | } 206 | 207 | var validReceivedCloseCodes = map[int]bool{ 208 | // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number 209 | CloseNormalClosure: true, 210 | CloseGoingAway: true, 211 | CloseProtocolError: true, 212 | CloseUnsupportedData: true, 213 | CloseNoStatusReceived: false, 214 | CloseAbnormalClosure: false, 215 | CloseInvalidFramePayloadData: true, 216 | ClosePolicyViolation: true, 217 | CloseMessageTooBig: true, 218 | CloseMandatoryExtension: true, 219 | CloseInternalServerErr: true, 220 | CloseServiceRestart: true, 221 | CloseTryAgainLater: true, 222 | CloseTLSHandshake: false, 223 | } 224 | 225 | func isValidReceivedCloseCode(code int) bool { 226 | return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) 227 | } 228 | 229 | // BufferPool represents a pool of buffers. The *sync.Pool type satisfies this 230 | // interface. The type of the value stored in a pool is not specified. 231 | type BufferPool interface { 232 | // Get gets a value from the pool or returns nil if the pool is empty. 233 | Get() interface{} 234 | // Put adds a value to the pool. 235 | Put(interface{}) 236 | } 237 | 238 | // writePoolData is the type added to the write buffer pool. This wrapper is 239 | // used to prevent applications from peeking at and depending on the values 240 | // added to the pool. 241 | type writePoolData struct{ buf []byte } 242 | 243 | // The Conn type represents a WebSocket connection. 244 | type Conn struct { 245 | conn net.Conn 246 | isServer bool 247 | subprotocol string 248 | 249 | // Write fields 250 | mu chan struct{} // used as mutex to protect write to conn 251 | writeBuf []byte // frame is constructed in this buffer. 252 | writePool BufferPool 253 | writeBufSize int 254 | writeDeadline time.Time 255 | writer io.WriteCloser // the current writer returned to the application 256 | isWriting bool // for best-effort concurrent write detection 257 | 258 | writeErrMu sync.Mutex 259 | writeErr error 260 | 261 | enableWriteCompression bool 262 | compressionLevel int 263 | newCompressionWriter func(io.WriteCloser, int) io.WriteCloser 264 | 265 | // Read fields 266 | reader io.ReadCloser // the current reader returned to the application 267 | readErr error 268 | br *bufio.Reader 269 | // bytes remaining in current frame. 270 | // set setReadRemaining to safely update this value and prevent overflow 271 | readRemaining int64 272 | readFinal bool // true the current message has more frames. 273 | readLength int64 // Message size. 274 | readLimit int64 // Maximum message size. 275 | readMaskPos int 276 | readMaskKey [4]byte 277 | handlePong func(string) error 278 | handlePing func(string) error 279 | handleClose func(int, string) error 280 | readErrCount int 281 | messageReader *messageReader // the current low-level reader 282 | 283 | readDecompress bool // whether last read frame had RSV1 set 284 | newDecompressionReader func(io.Reader) io.ReadCloser 285 | 286 | // keep reference to the resp to make sure the underlying conn will not be closed. 287 | // see: https://github.com/cloudwego/hertz/pull/1214 for the details. 288 | resp interface{} // *protocol.Response 289 | } 290 | 291 | func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { 292 | if br == nil { 293 | if readBufferSize == 0 { 294 | readBufferSize = defaultReadBufferSize 295 | } else if readBufferSize < maxControlFramePayloadSize { 296 | // must be large enough for control frame 297 | readBufferSize = maxControlFramePayloadSize 298 | } 299 | br = bufio.NewReaderSize(conn, readBufferSize) 300 | } 301 | 302 | if writeBufferSize <= 0 { 303 | writeBufferSize = defaultWriteBufferSize 304 | } 305 | writeBufferSize += maxFrameHeaderSize 306 | 307 | if writeBuf == nil && writeBufferPool == nil { 308 | writeBuf = make([]byte, writeBufferSize) 309 | } 310 | 311 | mu := make(chan struct{}, 1) 312 | mu <- struct{}{} 313 | c := &Conn{ 314 | isServer: isServer, 315 | br: br, 316 | conn: conn, 317 | mu: mu, 318 | readFinal: true, 319 | writeBuf: writeBuf, 320 | writePool: writeBufferPool, 321 | writeBufSize: writeBufferSize, 322 | enableWriteCompression: true, 323 | compressionLevel: defaultCompressionLevel, 324 | } 325 | c.SetCloseHandler(nil) 326 | c.SetPingHandler(nil) 327 | c.SetPongHandler(nil) 328 | return c 329 | } 330 | 331 | // setReadRemaining tracks the number of bytes remaining on the connection. If n 332 | // overflows, an ErrReadLimit is returned. 333 | func (c *Conn) setReadRemaining(n int64) error { 334 | if n < 0 { 335 | return ErrReadLimit 336 | } 337 | 338 | c.readRemaining = n 339 | return nil 340 | } 341 | 342 | // Subprotocol returns the negotiated protocol for the connection. 343 | func (c *Conn) Subprotocol() string { 344 | return c.subprotocol 345 | } 346 | 347 | // Close closes the underlying network connection without sending or waiting 348 | // for a close message. 349 | func (c *Conn) Close() error { 350 | return c.conn.Close() 351 | } 352 | 353 | // LocalAddr returns the local network address. 354 | func (c *Conn) LocalAddr() net.Addr { 355 | return c.conn.LocalAddr() 356 | } 357 | 358 | // RemoteAddr returns the remote network address. 359 | func (c *Conn) RemoteAddr() net.Addr { 360 | return c.conn.RemoteAddr() 361 | } 362 | 363 | // Write methods 364 | 365 | func (c *Conn) writeFatal(err error) error { 366 | err = hideTempErr(err) 367 | c.writeErrMu.Lock() 368 | if c.writeErr == nil { 369 | c.writeErr = err 370 | } 371 | c.writeErrMu.Unlock() 372 | return err 373 | } 374 | 375 | func (c *Conn) read(n int) ([]byte, error) { 376 | p, err := c.br.Peek(n) 377 | if err == io.EOF { 378 | err = errUnexpectedEOF 379 | } 380 | c.br.Discard(len(p)) 381 | return p, err 382 | } 383 | 384 | func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { 385 | <-c.mu 386 | defer func() { c.mu <- struct{}{} }() 387 | 388 | c.writeErrMu.Lock() 389 | err := c.writeErr 390 | c.writeErrMu.Unlock() 391 | if err != nil { 392 | return err 393 | } 394 | 395 | c.conn.SetWriteDeadline(deadline) 396 | if len(buf1) == 0 { 397 | _, err = c.conn.Write(buf0) 398 | } else { 399 | err = c.writeBufs(buf0, buf1) 400 | } 401 | if err != nil { 402 | return c.writeFatal(err) 403 | } 404 | if frameType == CloseMessage { 405 | c.writeFatal(ErrCloseSent) 406 | } 407 | return nil 408 | } 409 | 410 | func (c *Conn) writeBufs(bufs ...[]byte) error { 411 | b := net.Buffers(bufs) 412 | _, err := b.WriteTo(c.conn) 413 | return err 414 | } 415 | 416 | // WriteControl writes a control message with the given deadline. The allowed 417 | // message types are CloseMessage, PingMessage and PongMessage. 418 | func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { 419 | if !isControl(messageType) { 420 | return errBadWriteOpCode 421 | } 422 | if len(data) > maxControlFramePayloadSize { 423 | return errInvalidControlFrame 424 | } 425 | 426 | b0 := byte(messageType) | finalBit 427 | b1 := byte(len(data)) 428 | if !c.isServer { 429 | b1 |= maskBit 430 | } 431 | 432 | buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) 433 | buf = append(buf, b0, b1) 434 | 435 | if c.isServer { 436 | buf = append(buf, data...) 437 | } else { 438 | key := newMaskKey() 439 | buf = append(buf, key[:]...) 440 | buf = append(buf, data...) 441 | maskBytes(key, 0, buf[6:]) 442 | } 443 | 444 | d := 1000 * time.Hour 445 | if !deadline.IsZero() { 446 | d = time.Until(deadline) 447 | if d < 0 { 448 | return errWriteTimeout 449 | } 450 | } 451 | 452 | timer := time.NewTimer(d) 453 | select { 454 | case <-c.mu: 455 | timer.Stop() 456 | case <-timer.C: 457 | return errWriteTimeout 458 | } 459 | defer func() { c.mu <- struct{}{} }() 460 | 461 | c.writeErrMu.Lock() 462 | err := c.writeErr 463 | c.writeErrMu.Unlock() 464 | if err != nil { 465 | return err 466 | } 467 | 468 | c.conn.SetWriteDeadline(deadline) 469 | _, err = c.conn.Write(buf) 470 | if err != nil { 471 | return c.writeFatal(err) 472 | } 473 | if messageType == CloseMessage { 474 | c.writeFatal(ErrCloseSent) 475 | } 476 | return err 477 | } 478 | 479 | // beginMessage prepares a connection and message writer for a new message. 480 | func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { 481 | // Close previous writer if not already closed by the application. It's 482 | // probably better to return an error in this situation, but we cannot 483 | // change this without breaking existing applications. 484 | if c.writer != nil { 485 | c.writer.Close() 486 | c.writer = nil 487 | } 488 | 489 | if !isControl(messageType) && !isData(messageType) { 490 | return errBadWriteOpCode 491 | } 492 | 493 | c.writeErrMu.Lock() 494 | err := c.writeErr 495 | c.writeErrMu.Unlock() 496 | if err != nil { 497 | return err 498 | } 499 | 500 | mw.c = c 501 | mw.frameType = messageType 502 | mw.pos = maxFrameHeaderSize 503 | 504 | if c.writeBuf == nil { 505 | wpd, ok := c.writePool.Get().(writePoolData) 506 | if ok { 507 | c.writeBuf = wpd.buf 508 | } else { 509 | c.writeBuf = make([]byte, c.writeBufSize) 510 | } 511 | } 512 | return nil 513 | } 514 | 515 | // NextWriter returns a writer for the next message to send. The writer's Close 516 | // method flushes the complete message to the network. 517 | // 518 | // There can be at most one open writer on a connection. NextWriter closes the 519 | // previous writer if the application has not already done so. 520 | // 521 | // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and 522 | // PongMessage) are supported. 523 | func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { 524 | var mw messageWriter 525 | if err := c.beginMessage(&mw, messageType); err != nil { 526 | return nil, err 527 | } 528 | c.writer = &mw 529 | if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { 530 | w := c.newCompressionWriter(c.writer, c.compressionLevel) 531 | mw.compress = true 532 | c.writer = w 533 | } 534 | return c.writer, nil 535 | } 536 | 537 | type messageWriter struct { 538 | c *Conn 539 | compress bool // whether next call to flushFrame should set RSV1 540 | pos int // end of data in writeBuf. 541 | frameType int // type of the current frame. 542 | err error 543 | } 544 | 545 | func (w *messageWriter) endMessage(err error) error { 546 | if w.err != nil { 547 | return err 548 | } 549 | c := w.c 550 | w.err = err 551 | c.writer = nil 552 | if c.writePool != nil { 553 | c.writePool.Put(writePoolData{buf: c.writeBuf}) 554 | c.writeBuf = nil 555 | } 556 | return err 557 | } 558 | 559 | // flushFrame writes buffered data and extra as a frame to the network. The 560 | // final argument indicates that this is the last frame in the message. 561 | func (w *messageWriter) flushFrame(final bool, extra []byte) error { 562 | c := w.c 563 | length := w.pos - maxFrameHeaderSize + len(extra) 564 | 565 | // Check for invalid control frames. 566 | if isControl(w.frameType) && 567 | (!final || length > maxControlFramePayloadSize) { 568 | return w.endMessage(errInvalidControlFrame) 569 | } 570 | 571 | b0 := byte(w.frameType) 572 | if final { 573 | b0 |= finalBit 574 | } 575 | if w.compress { 576 | b0 |= rsv1Bit 577 | } 578 | w.compress = false 579 | 580 | b1 := byte(0) 581 | if !c.isServer { 582 | b1 |= maskBit 583 | } 584 | 585 | // Assume that the frame starts at beginning of c.writeBuf. 586 | framePos := 0 587 | if c.isServer { 588 | // Adjust up if mask not included in the header. 589 | framePos = 4 590 | } 591 | 592 | switch { 593 | case length >= 65536: 594 | c.writeBuf[framePos] = b0 595 | c.writeBuf[framePos+1] = b1 | 127 596 | binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) 597 | case length > 125: 598 | framePos += 6 599 | c.writeBuf[framePos] = b0 600 | c.writeBuf[framePos+1] = b1 | 126 601 | binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) 602 | default: 603 | framePos += 8 604 | c.writeBuf[framePos] = b0 605 | c.writeBuf[framePos+1] = b1 | byte(length) 606 | } 607 | 608 | if !c.isServer { 609 | key := newMaskKey() 610 | copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) 611 | maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) 612 | if len(extra) > 0 { 613 | return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) 614 | } 615 | } 616 | 617 | // Write the buffers to the connection with best-effort detection of 618 | // concurrent writes. See the concurrency section in the package 619 | // documentation for more info. 620 | 621 | if c.isWriting { 622 | panic("concurrent write to websocket connection") 623 | } 624 | c.isWriting = true 625 | 626 | err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) 627 | 628 | if !c.isWriting { 629 | panic("concurrent write to websocket connection") 630 | } 631 | c.isWriting = false 632 | 633 | if err != nil { 634 | return w.endMessage(err) 635 | } 636 | 637 | if final { 638 | w.endMessage(errWriteClosed) 639 | return nil 640 | } 641 | 642 | // Setup for next frame. 643 | w.pos = maxFrameHeaderSize 644 | w.frameType = continuationFrame 645 | return nil 646 | } 647 | 648 | func (w *messageWriter) ncopy(max int) (int, error) { 649 | n := len(w.c.writeBuf) - w.pos 650 | if n <= 0 { 651 | if err := w.flushFrame(false, nil); err != nil { 652 | return 0, err 653 | } 654 | n = len(w.c.writeBuf) - w.pos 655 | } 656 | if n > max { 657 | n = max 658 | } 659 | return n, nil 660 | } 661 | 662 | func (w *messageWriter) Write(p []byte) (int, error) { 663 | if w.err != nil { 664 | return 0, w.err 665 | } 666 | 667 | if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { 668 | // Don't buffer large messages. 669 | err := w.flushFrame(false, p) 670 | if err != nil { 671 | return 0, err 672 | } 673 | return len(p), nil 674 | } 675 | 676 | nn := len(p) 677 | for len(p) > 0 { 678 | n, err := w.ncopy(len(p)) 679 | if err != nil { 680 | return 0, err 681 | } 682 | copy(w.c.writeBuf[w.pos:], p[:n]) 683 | w.pos += n 684 | p = p[n:] 685 | } 686 | return nn, nil 687 | } 688 | 689 | func (w *messageWriter) WriteString(p string) (int, error) { 690 | if w.err != nil { 691 | return 0, w.err 692 | } 693 | 694 | nn := len(p) 695 | for len(p) > 0 { 696 | n, err := w.ncopy(len(p)) 697 | if err != nil { 698 | return 0, err 699 | } 700 | copy(w.c.writeBuf[w.pos:], p[:n]) 701 | w.pos += n 702 | p = p[n:] 703 | } 704 | return nn, nil 705 | } 706 | 707 | func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { 708 | if w.err != nil { 709 | return 0, w.err 710 | } 711 | for { 712 | if w.pos == len(w.c.writeBuf) { 713 | err = w.flushFrame(false, nil) 714 | if err != nil { 715 | break 716 | } 717 | } 718 | var n int 719 | n, err = r.Read(w.c.writeBuf[w.pos:]) 720 | w.pos += n 721 | nn += int64(n) 722 | if err != nil { 723 | if err == io.EOF { 724 | err = nil 725 | } 726 | break 727 | } 728 | } 729 | return nn, err 730 | } 731 | 732 | func (w *messageWriter) Close() error { 733 | if w.err != nil { 734 | return w.err 735 | } 736 | return w.flushFrame(true, nil) 737 | } 738 | 739 | // WritePreparedMessage writes prepared message into connection. 740 | func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { 741 | frameType, frameData, err := pm.frame(prepareKey{ 742 | isServer: c.isServer, 743 | compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), 744 | compressionLevel: c.compressionLevel, 745 | }) 746 | if err != nil { 747 | return err 748 | } 749 | if c.isWriting { 750 | panic("concurrent write to websocket connection") 751 | } 752 | c.isWriting = true 753 | err = c.write(frameType, c.writeDeadline, frameData, nil) 754 | if !c.isWriting { 755 | panic("concurrent write to websocket connection") 756 | } 757 | c.isWriting = false 758 | return err 759 | } 760 | 761 | // WriteMessage is a helper method for getting a writer using NextWriter, 762 | // writing the message and closing the writer. 763 | func (c *Conn) WriteMessage(messageType int, data []byte) error { 764 | if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { 765 | // Fast path with no allocations and single frame. 766 | 767 | var mw messageWriter 768 | if err := c.beginMessage(&mw, messageType); err != nil { 769 | return err 770 | } 771 | n := copy(c.writeBuf[mw.pos:], data) 772 | mw.pos += n 773 | data = data[n:] 774 | return mw.flushFrame(true, data) 775 | } 776 | 777 | w, err := c.NextWriter(messageType) 778 | if err != nil { 779 | return err 780 | } 781 | if _, err = w.Write(data); err != nil { 782 | return err 783 | } 784 | return w.Close() 785 | } 786 | 787 | // SetWriteDeadline sets the write deadline on the underlying network 788 | // connection. After a write has timed out, the websocket state is corrupt and 789 | // all future writes will return an error. A zero value for t means writes will 790 | // not time out. 791 | func (c *Conn) SetWriteDeadline(t time.Time) error { 792 | c.writeDeadline = t 793 | return nil 794 | } 795 | 796 | // Read methods 797 | 798 | func (c *Conn) advanceFrame() (int, error) { 799 | // 1. Skip remainder of previous frame. 800 | 801 | if c.readRemaining > 0 { 802 | if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { 803 | return noFrame, err 804 | } 805 | } 806 | 807 | // 2. Read and parse first two bytes of frame header. 808 | // To aid debugging, collect and report all errors in the first two bytes 809 | // of the header. 810 | 811 | var errors []string 812 | 813 | p, err := c.read(2) 814 | if err != nil { 815 | return noFrame, err 816 | } 817 | 818 | frameType := int(p[0] & 0xf) 819 | final := p[0]&finalBit != 0 820 | rsv1 := p[0]&rsv1Bit != 0 821 | rsv2 := p[0]&rsv2Bit != 0 822 | rsv3 := p[0]&rsv3Bit != 0 823 | mask := p[1]&maskBit != 0 824 | c.setReadRemaining(int64(p[1] & 0x7f)) 825 | 826 | c.readDecompress = false 827 | if rsv1 { 828 | if c.newDecompressionReader != nil { 829 | c.readDecompress = true 830 | } else { 831 | errors = append(errors, "RSV1 set") 832 | } 833 | } 834 | 835 | if rsv2 { 836 | errors = append(errors, "RSV2 set") 837 | } 838 | 839 | if rsv3 { 840 | errors = append(errors, "RSV3 set") 841 | } 842 | 843 | switch frameType { 844 | case CloseMessage, PingMessage, PongMessage: 845 | if c.readRemaining > maxControlFramePayloadSize { 846 | errors = append(errors, "len > 125 for control") 847 | } 848 | if !final { 849 | errors = append(errors, "FIN not set on control") 850 | } 851 | case TextMessage, BinaryMessage: 852 | if !c.readFinal { 853 | errors = append(errors, "data before FIN") 854 | } 855 | c.readFinal = final 856 | case continuationFrame: 857 | if c.readFinal { 858 | errors = append(errors, "continuation after FIN") 859 | } 860 | c.readFinal = final 861 | default: 862 | errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) 863 | } 864 | 865 | if mask != c.isServer { 866 | errors = append(errors, "bad MASK") 867 | } 868 | 869 | if len(errors) > 0 { 870 | return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) 871 | } 872 | 873 | // 3. Read and parse frame length as per 874 | // https://tools.ietf.org/html/rfc6455#section-5.2 875 | // 876 | // The length of the "Payload data", in bytes: if 0-125, that is the payload 877 | // length. 878 | // - If 126, the following 2 bytes interpreted as a 16-bit unsigned 879 | // integer are the payload length. 880 | // - If 127, the following 8 bytes interpreted as 881 | // a 64-bit unsigned integer (the most significant bit MUST be 0) are the 882 | // payload length. Multibyte length quantities are expressed in network byte 883 | // order. 884 | 885 | switch c.readRemaining { 886 | case 126: 887 | p, err := c.read(2) 888 | if err != nil { 889 | return noFrame, err 890 | } 891 | 892 | if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { 893 | return noFrame, err 894 | } 895 | case 127: 896 | p, err := c.read(8) 897 | if err != nil { 898 | return noFrame, err 899 | } 900 | 901 | if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { 902 | return noFrame, err 903 | } 904 | } 905 | 906 | // 4. Handle frame masking. 907 | 908 | if mask { 909 | c.readMaskPos = 0 910 | p, err := c.read(len(c.readMaskKey)) 911 | if err != nil { 912 | return noFrame, err 913 | } 914 | copy(c.readMaskKey[:], p) 915 | } 916 | 917 | // 5. For text and binary messages, enforce read limit and return. 918 | 919 | if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { 920 | 921 | c.readLength += c.readRemaining 922 | // Don't allow readLength to overflow in the presence of a large readRemaining 923 | // counter. 924 | if c.readLength < 0 { 925 | return noFrame, ErrReadLimit 926 | } 927 | 928 | if c.readLimit > 0 && c.readLength > c.readLimit { 929 | c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) 930 | return noFrame, ErrReadLimit 931 | } 932 | 933 | return frameType, nil 934 | } 935 | 936 | // 6. Read control frame payload. 937 | 938 | var payload []byte 939 | if c.readRemaining > 0 { 940 | payload, err = c.read(int(c.readRemaining)) 941 | c.setReadRemaining(0) 942 | if err != nil { 943 | return noFrame, err 944 | } 945 | if c.isServer { 946 | maskBytes(c.readMaskKey, 0, payload) 947 | } 948 | } 949 | 950 | // 7. Process control frame payload. 951 | 952 | switch frameType { 953 | case PongMessage: 954 | if err := c.handlePong(string(payload)); err != nil { 955 | return noFrame, err 956 | } 957 | case PingMessage: 958 | if err := c.handlePing(string(payload)); err != nil { 959 | return noFrame, err 960 | } 961 | case CloseMessage: 962 | closeCode := CloseNoStatusReceived 963 | closeText := "" 964 | if len(payload) >= 2 { 965 | closeCode = int(binary.BigEndian.Uint16(payload)) 966 | if !isValidReceivedCloseCode(closeCode) { 967 | return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) 968 | } 969 | closeText = string(payload[2:]) 970 | if !utf8.ValidString(closeText) { 971 | return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") 972 | } 973 | } 974 | if err := c.handleClose(closeCode, closeText); err != nil { 975 | return noFrame, err 976 | } 977 | return noFrame, &CloseError{Code: closeCode, Text: closeText} 978 | } 979 | 980 | return frameType, nil 981 | } 982 | 983 | func (c *Conn) handleProtocolError(message string) error { 984 | data := FormatCloseMessage(CloseProtocolError, message) 985 | if len(data) > maxControlFramePayloadSize { 986 | data = data[:maxControlFramePayloadSize] 987 | } 988 | c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) 989 | return errors.New("websocket: " + message) 990 | } 991 | 992 | // NextReader returns the next data message received from the peer. The 993 | // returned messageType is either TextMessage or BinaryMessage. 994 | // 995 | // There can be at most one open reader on a connection. NextReader discards 996 | // the previous message if the application has not already consumed it. 997 | // 998 | // Applications must break out of the application's read loop when this method 999 | // returns a non-nil error value. Errors returned from this method are 1000 | // permanent. Once this method returns a non-nil error, all subsequent calls to 1001 | // this method return the same error. 1002 | func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { 1003 | // Close previous reader, only relevant for decompression. 1004 | if c.reader != nil { 1005 | c.reader.Close() 1006 | c.reader = nil 1007 | } 1008 | 1009 | c.messageReader = nil 1010 | c.readLength = 0 1011 | 1012 | for c.readErr == nil { 1013 | frameType, err := c.advanceFrame() 1014 | if err != nil { 1015 | c.readErr = hideTempErr(err) 1016 | break 1017 | } 1018 | 1019 | if frameType == TextMessage || frameType == BinaryMessage { 1020 | c.messageReader = &messageReader{c} 1021 | c.reader = c.messageReader 1022 | if c.readDecompress { 1023 | c.reader = c.newDecompressionReader(c.reader) 1024 | } 1025 | return frameType, c.reader, nil 1026 | } 1027 | } 1028 | 1029 | // Applications that do handle the error returned from this method spin in 1030 | // tight loop on connection failure. To help application developers detect 1031 | // this error, panic on repeated reads to the failed connection. 1032 | c.readErrCount++ 1033 | if c.readErrCount >= 1000 { 1034 | panic("repeated read on failed websocket connection") 1035 | } 1036 | 1037 | return noFrame, nil, c.readErr 1038 | } 1039 | 1040 | type messageReader struct{ c *Conn } 1041 | 1042 | func (r *messageReader) Read(b []byte) (int, error) { 1043 | c := r.c 1044 | if c.messageReader != r { 1045 | return 0, io.EOF 1046 | } 1047 | 1048 | for c.readErr == nil { 1049 | 1050 | if c.readRemaining > 0 { 1051 | if int64(len(b)) > c.readRemaining { 1052 | b = b[:c.readRemaining] 1053 | } 1054 | n, err := c.br.Read(b) 1055 | c.readErr = hideTempErr(err) 1056 | if c.isServer { 1057 | c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) 1058 | } 1059 | rem := c.readRemaining 1060 | rem -= int64(n) 1061 | c.setReadRemaining(rem) 1062 | if c.readRemaining > 0 && c.readErr == io.EOF { 1063 | c.readErr = errUnexpectedEOF 1064 | } 1065 | return n, c.readErr 1066 | } 1067 | 1068 | if c.readFinal { 1069 | c.messageReader = nil 1070 | return 0, io.EOF 1071 | } 1072 | 1073 | frameType, err := c.advanceFrame() 1074 | switch { 1075 | case err != nil: 1076 | c.readErr = hideTempErr(err) 1077 | case frameType == TextMessage || frameType == BinaryMessage: 1078 | c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") 1079 | } 1080 | } 1081 | 1082 | err := c.readErr 1083 | if err == io.EOF && c.messageReader == r { 1084 | err = errUnexpectedEOF 1085 | } 1086 | return 0, err 1087 | } 1088 | 1089 | func (r *messageReader) Close() error { 1090 | return nil 1091 | } 1092 | 1093 | // ReadMessage is a helper method for getting a reader using NextReader and 1094 | // reading from that reader to a buffer. 1095 | func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { 1096 | var r io.Reader 1097 | messageType, r, err = c.NextReader() 1098 | if err != nil { 1099 | return messageType, nil, err 1100 | } 1101 | p, err = ioutil.ReadAll(r) 1102 | return messageType, p, err 1103 | } 1104 | 1105 | // SetReadDeadline sets the read deadline on the underlying network connection. 1106 | // After a read has timed out, the websocket connection state is corrupt and 1107 | // all future reads will return an error. A zero value for t means reads will 1108 | // not time out. 1109 | func (c *Conn) SetReadDeadline(t time.Time) error { 1110 | return c.conn.SetReadDeadline(t) 1111 | } 1112 | 1113 | // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a 1114 | // message exceeds the limit, the connection sends a close message to the peer 1115 | // and returns ErrReadLimit to the application. 1116 | func (c *Conn) SetReadLimit(limit int64) { 1117 | c.readLimit = limit 1118 | } 1119 | 1120 | // CloseHandler returns the current close handler 1121 | func (c *Conn) CloseHandler() func(code int, text string) error { 1122 | return c.handleClose 1123 | } 1124 | 1125 | // SetCloseHandler sets the handler for close messages received from the peer. 1126 | // The code argument to h is the received close code or CloseNoStatusReceived 1127 | // if the close message is empty. The default close handler sends a close 1128 | // message back to the peer. 1129 | // 1130 | // The handler function is called from the NextReader, ReadMessage and message 1131 | // reader Read methods. The application must read the connection to process 1132 | // close messages as described in the section on Control Messages above. 1133 | // 1134 | // The connection read methods return a CloseError when a close message is 1135 | // received. Most applications should handle close messages as part of their 1136 | // normal error handling. Applications should only set a close handler when the 1137 | // application must perform some action before sending a close message back to 1138 | // the peer. 1139 | func (c *Conn) SetCloseHandler(h func(code int, text string) error) { 1140 | if h == nil { 1141 | h = func(code int, text string) error { 1142 | message := FormatCloseMessage(code, "") 1143 | c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) 1144 | return nil 1145 | } 1146 | } 1147 | c.handleClose = h 1148 | } 1149 | 1150 | // PingHandler returns the current ping handler 1151 | func (c *Conn) PingHandler() func(appData string) error { 1152 | return c.handlePing 1153 | } 1154 | 1155 | // SetPingHandler sets the handler for ping messages received from the peer. 1156 | // The appData argument to h is the PING message application data. The default 1157 | // ping handler sends a pong to the peer. 1158 | // 1159 | // The handler function is called from the NextReader, ReadMessage and message 1160 | // reader Read methods. The application must read the connection to process 1161 | // ping messages as described in the section on Control Messages above. 1162 | func (c *Conn) SetPingHandler(h func(appData string) error) { 1163 | if h == nil { 1164 | h = func(message string) error { 1165 | err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) 1166 | if err == ErrCloseSent { 1167 | return nil 1168 | } else if e, ok := err.(net.Error); ok && e.Temporary() { 1169 | return nil 1170 | } 1171 | return err 1172 | } 1173 | } 1174 | c.handlePing = h 1175 | } 1176 | 1177 | // PongHandler returns the current pong handler 1178 | func (c *Conn) PongHandler() func(appData string) error { 1179 | return c.handlePong 1180 | } 1181 | 1182 | // SetPongHandler sets the handler for pong messages received from the peer. 1183 | // The appData argument to h is the PONG message application data. The default 1184 | // pong handler does nothing. 1185 | // 1186 | // The handler function is called from the NextReader, ReadMessage and message 1187 | // reader Read methods. The application must read the connection to process 1188 | // pong messages as described in the section on Control Messages above. 1189 | func (c *Conn) SetPongHandler(h func(appData string) error) { 1190 | if h == nil { 1191 | h = func(string) error { return nil } 1192 | } 1193 | c.handlePong = h 1194 | } 1195 | 1196 | // NetConn returns the underlying connection that is wrapped by c. 1197 | // Note that writing to or reading from this connection directly will corrupt the 1198 | // WebSocket connection. 1199 | func (c *Conn) NetConn() net.Conn { 1200 | return c.conn 1201 | } 1202 | 1203 | // UnderlyingConn returns the internal net.Conn. This can be used to further 1204 | // modifications to connection specific flags. 1205 | // Deprecated: Use the NetConn method. 1206 | func (c *Conn) UnderlyingConn() net.Conn { 1207 | return c.conn 1208 | } 1209 | 1210 | // EnableWriteCompression enables and disables write compression of 1211 | // subsequent text and binary messages. This function is a noop if 1212 | // compression was not negotiated with the peer. 1213 | func (c *Conn) EnableWriteCompression(enable bool) { 1214 | c.enableWriteCompression = enable 1215 | } 1216 | 1217 | // SetCompressionLevel sets the flate compression level for subsequent text and 1218 | // binary messages. This function is a noop if compression was not negotiated 1219 | // with the peer. See the compress/flate package for a description of 1220 | // compression levels. 1221 | func (c *Conn) SetCompressionLevel(level int) error { 1222 | if !isValidCompressionLevel(level) { 1223 | return errors.New("websocket: invalid compression level") 1224 | } 1225 | c.compressionLevel = level 1226 | return nil 1227 | } 1228 | 1229 | // FormatCloseMessage formats closeCode and text as a WebSocket close message. 1230 | // An empty message is returned for code CloseNoStatusReceived. 1231 | func FormatCloseMessage(closeCode int, text string) []byte { 1232 | if closeCode == CloseNoStatusReceived { 1233 | // Return empty message because it's illegal to send 1234 | // CloseNoStatusReceived. Return non-nil value in case application 1235 | // checks for nil. 1236 | return []byte{} 1237 | } 1238 | buf := make([]byte, 2+len(text)) 1239 | binary.BigEndian.PutUint16(buf, uint16(closeCode)) 1240 | copy(buf[2:], text) 1241 | return buf 1242 | } 1243 | -------------------------------------------------------------------------------- /conn_broadcast_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "io" 12 | "io/ioutil" 13 | "sync/atomic" 14 | "testing" 15 | ) 16 | 17 | // broadcastBench allows to run broadcast benchmarks. 18 | // In every broadcast benchmark we create many connections, then send the same 19 | // message into every connection and wait for all writes complete. This emulates 20 | // an application where many connections listen to the same data - i.e. PUB/SUB 21 | // scenarios with many subscribers in one channel. 22 | type broadcastBench struct { 23 | w io.Writer 24 | closeCh chan struct{} 25 | doneCh chan struct{} 26 | count int32 27 | conns []*broadcastConn 28 | compression bool 29 | usePrepared bool 30 | } 31 | 32 | type broadcastMessage struct { 33 | payload []byte 34 | prepared *PreparedMessage 35 | } 36 | 37 | type broadcastConn struct { 38 | conn *Conn 39 | msgCh chan *broadcastMessage 40 | } 41 | 42 | func newBroadcastConn(c *Conn) *broadcastConn { 43 | return &broadcastConn{ 44 | conn: c, 45 | msgCh: make(chan *broadcastMessage, 1), 46 | } 47 | } 48 | 49 | func newBroadcastBench(usePrepared, compression bool) *broadcastBench { 50 | bench := &broadcastBench{ 51 | w: ioutil.Discard, 52 | doneCh: make(chan struct{}), 53 | closeCh: make(chan struct{}), 54 | usePrepared: usePrepared, 55 | compression: compression, 56 | } 57 | bench.makeConns(10000) 58 | return bench 59 | } 60 | 61 | func (b *broadcastBench) makeConns(numConns int) { 62 | conns := make([]*broadcastConn, numConns) 63 | 64 | for i := 0; i < numConns; i++ { 65 | c := newTestConn(nil, b.w, true) 66 | if b.compression { 67 | c.enableWriteCompression = true 68 | c.newCompressionWriter = compressNoContextTakeover 69 | } 70 | conns[i] = newBroadcastConn(c) 71 | go func(c *broadcastConn) { 72 | for { 73 | select { 74 | case msg := <-c.msgCh: 75 | if msg.prepared != nil { 76 | c.conn.WritePreparedMessage(msg.prepared) 77 | } else { 78 | c.conn.WriteMessage(TextMessage, msg.payload) 79 | } 80 | val := atomic.AddInt32(&b.count, 1) 81 | if val%int32(numConns) == 0 { 82 | b.doneCh <- struct{}{} 83 | } 84 | case <-b.closeCh: 85 | return 86 | } 87 | } 88 | }(conns[i]) 89 | } 90 | b.conns = conns 91 | } 92 | 93 | func (b *broadcastBench) close() { 94 | close(b.closeCh) 95 | } 96 | 97 | func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { 98 | for _, c := range b.conns { 99 | c.msgCh <- msg 100 | } 101 | <-b.doneCh 102 | } 103 | 104 | func BenchmarkBroadcast(b *testing.B) { 105 | benchmarks := []struct { 106 | name string 107 | usePrepared bool 108 | compression bool 109 | }{ 110 | {"NoCompression", false, false}, 111 | {"Compression", false, true}, 112 | {"NoCompressionPrepared", true, false}, 113 | {"CompressionPrepared", true, true}, 114 | } 115 | payload := textMessages(1)[0] 116 | for _, bm := range benchmarks { 117 | b.Run(bm.name, func(b *testing.B) { 118 | bench := newBroadcastBench(bm.usePrepared, bm.compression) 119 | defer bench.close() 120 | b.ResetTimer() 121 | for i := 0; i < b.N; i++ { 122 | message := &broadcastMessage{ 123 | payload: payload, 124 | } 125 | if bench.usePrepared { 126 | pm, _ := NewPreparedMessage(TextMessage, message.payload) 127 | message.prepared = pm 128 | } 129 | bench.broadcastOnce(message) 130 | } 131 | b.ReportAllocs() 132 | }) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /conn_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bufio" 12 | "bytes" 13 | "errors" 14 | "fmt" 15 | "io" 16 | "io/ioutil" 17 | "net" 18 | "reflect" 19 | "sync" 20 | "testing" 21 | "testing/iotest" 22 | "time" 23 | ) 24 | 25 | var _ net.Error = errWriteTimeout 26 | 27 | type fakeNetConn struct { 28 | io.Reader 29 | io.Writer 30 | } 31 | 32 | func (c fakeNetConn) Close() error { return nil } 33 | func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } 34 | func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } 35 | func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } 36 | func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } 37 | func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } 38 | 39 | type fakeAddr int 40 | 41 | var ( 42 | localAddr = fakeAddr(1) 43 | remoteAddr = fakeAddr(2) 44 | ) 45 | 46 | func (a fakeAddr) Network() string { 47 | return "net" 48 | } 49 | 50 | func (a fakeAddr) String() string { 51 | return "str" 52 | } 53 | 54 | // newTestConn creates a connection backed by a fake network connection using 55 | // default values for buffering. 56 | func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { 57 | return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) 58 | } 59 | 60 | func TestFraming(t *testing.T) { 61 | frameSizes := []int{ 62 | 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 63 | // 65536, 65537 64 | } 65 | readChunkers := []struct { 66 | name string 67 | f func(io.Reader) io.Reader 68 | }{ 69 | {"half", iotest.HalfReader}, 70 | {"one", iotest.OneByteReader}, 71 | {"asis", func(r io.Reader) io.Reader { return r }}, 72 | } 73 | writeBuf := make([]byte, 65537) 74 | for i := range writeBuf { 75 | writeBuf[i] = byte(i) 76 | } 77 | writers := []struct { 78 | name string 79 | f func(w io.Writer, n int) (int, error) 80 | }{ 81 | {"iocopy", func(w io.Writer, n int) (int, error) { 82 | nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) 83 | return int(nn), err 84 | }}, 85 | {"write", func(w io.Writer, n int) (int, error) { 86 | return w.Write(writeBuf[:n]) 87 | }}, 88 | {"string", func(w io.Writer, n int) (int, error) { 89 | return io.WriteString(w, string(writeBuf[:n])) // nolint: staticcheck 90 | }}, 91 | } 92 | 93 | for _, compress := range []bool{false, true} { 94 | for _, isServer := range []bool{true, false} { 95 | for _, chunker := range readChunkers { 96 | 97 | var connBuf bytes.Buffer 98 | wc := newTestConn(nil, &connBuf, isServer) 99 | rc := newTestConn(chunker.f(&connBuf), nil, !isServer) 100 | if compress { 101 | wc.newCompressionWriter = compressNoContextTakeover 102 | rc.newDecompressionReader = decompressNoContextTakeover 103 | } 104 | for _, n := range frameSizes { 105 | for _, writer := range writers { 106 | name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) 107 | 108 | w, err := wc.NextWriter(TextMessage) 109 | if err != nil { 110 | t.Errorf("%s: wc.NextWriter() returned %v", name, err) 111 | continue 112 | } 113 | nn, err := writer.f(w, n) 114 | if err != nil || nn != n { 115 | t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) 116 | continue 117 | } 118 | err = w.Close() 119 | if err != nil { 120 | t.Errorf("%s: w.Close() returned %v", name, err) 121 | continue 122 | } 123 | 124 | opCode, r, err := rc.NextReader() 125 | if err != nil || opCode != TextMessage { 126 | t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) 127 | continue 128 | } 129 | 130 | t.Logf("frame size: %d", n) 131 | rbuf, err := ioutil.ReadAll(r) 132 | if err != nil { 133 | t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) 134 | continue 135 | } 136 | 137 | if len(rbuf) != n { 138 | t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) 139 | continue 140 | } 141 | 142 | for i, b := range rbuf { 143 | if byte(i) != b { 144 | t.Errorf("%s: bad byte at offset %d", name, i) 145 | break 146 | } 147 | } 148 | } 149 | } 150 | } 151 | } 152 | } 153 | } 154 | 155 | func TestControl(t *testing.T) { 156 | const message = "this is a ping/pong message" 157 | for _, isServer := range []bool{true, false} { 158 | for _, isWriteControl := range []bool{true, false} { 159 | name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) 160 | var connBuf bytes.Buffer 161 | wc := newTestConn(nil, &connBuf, isServer) 162 | rc := newTestConn(&connBuf, nil, !isServer) 163 | if isWriteControl { 164 | wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) 165 | } else { 166 | w, err := wc.NextWriter(PongMessage) 167 | if err != nil { 168 | t.Errorf("%s: wc.NextWriter() returned %v", name, err) 169 | continue 170 | } 171 | if _, err := w.Write([]byte(message)); err != nil { 172 | t.Errorf("%s: w.Write() returned %v", name, err) 173 | continue 174 | } 175 | if err := w.Close(); err != nil { 176 | t.Errorf("%s: w.Close() returned %v", name, err) 177 | continue 178 | } 179 | var actualMessage string 180 | rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) 181 | rc.NextReader() 182 | if actualMessage != message { 183 | t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) 184 | continue 185 | } 186 | } 187 | } 188 | } 189 | } 190 | 191 | // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. 192 | type simpleBufferPool struct { 193 | v interface{} 194 | } 195 | 196 | func (p *simpleBufferPool) Get() interface{} { 197 | v := p.v 198 | p.v = nil 199 | return v 200 | } 201 | 202 | func (p *simpleBufferPool) Put(v interface{}) { 203 | p.v = v 204 | } 205 | 206 | func TestWriteBufferPool(t *testing.T) { 207 | const message = "Now is the time for all good people to come to the aid of the party." 208 | 209 | var buf bytes.Buffer 210 | var pool simpleBufferPool 211 | rc := newTestConn(&buf, nil, false) 212 | 213 | // Specify writeBufferSize smaller than message size to ensure that pooling 214 | // works with fragmented messages. 215 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) 216 | 217 | if wc.writeBuf != nil { 218 | t.Fatal("writeBuf not nil after create") 219 | } 220 | 221 | // Part 1: test NextWriter/Write/Close 222 | 223 | w, err := wc.NextWriter(TextMessage) 224 | if err != nil { 225 | t.Fatalf("wc.NextWriter() returned %v", err) 226 | } 227 | 228 | if wc.writeBuf == nil { 229 | t.Fatal("writeBuf is nil after NextWriter") 230 | } 231 | 232 | writeBufAddr := &wc.writeBuf[0] 233 | 234 | if _, err := io.WriteString(w, message); err != nil { 235 | t.Fatalf("io.WriteString(w, message) returned %v", err) 236 | } 237 | 238 | if err := w.Close(); err != nil { 239 | t.Fatalf("w.Close() returned %v", err) 240 | } 241 | 242 | if wc.writeBuf != nil { 243 | t.Fatal("writeBuf not nil after w.Close()") 244 | } 245 | 246 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 247 | t.Fatal("writeBuf not returned to pool") 248 | } 249 | 250 | opCode, p, err := rc.ReadMessage() 251 | if opCode != TextMessage || err != nil { 252 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 253 | } 254 | 255 | if s := string(p); s != message { 256 | t.Fatalf("message is %s, want %s", s, message) 257 | } 258 | 259 | // Part 2: Test WriteMessage. 260 | 261 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 262 | t.Fatalf("wc.WriteMessage() returned %v", err) 263 | } 264 | 265 | if wc.writeBuf != nil { 266 | t.Fatal("writeBuf not nil after wc.WriteMessage()") 267 | } 268 | 269 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 270 | t.Fatal("writeBuf not returned to pool after WriteMessage") 271 | } 272 | 273 | opCode, p, err = rc.ReadMessage() 274 | if opCode != TextMessage || err != nil { 275 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 276 | } 277 | 278 | if s := string(p); s != message { 279 | t.Fatalf("message is %s, want %s", s, message) 280 | } 281 | } 282 | 283 | // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. 284 | func TestWriteBufferPoolSync(t *testing.T) { 285 | var buf bytes.Buffer 286 | var pool sync.Pool 287 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) 288 | rc := newTestConn(&buf, nil, false) 289 | 290 | const message = "Hello World!" 291 | for i := 0; i < 3; i++ { 292 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 293 | t.Fatalf("wc.WriteMessage() returned %v", err) 294 | } 295 | opCode, p, err := rc.ReadMessage() 296 | if opCode != TextMessage || err != nil { 297 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 298 | } 299 | if s := string(p); s != message { 300 | t.Fatalf("message is %s, want %s", s, message) 301 | } 302 | } 303 | } 304 | 305 | // errorWriter is an io.Writer than returns an error on all writes. 306 | type errorWriter struct{} 307 | 308 | func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } 309 | 310 | // TestWriteBufferPoolError ensures that buffer is returned to pool after error 311 | // on write. 312 | func TestWriteBufferPoolError(t *testing.T) { 313 | // Part 1: Test NextWriter/Write/Close 314 | 315 | var pool simpleBufferPool 316 | wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 317 | 318 | w, err := wc.NextWriter(TextMessage) 319 | if err != nil { 320 | t.Fatalf("wc.NextWriter() returned %v", err) 321 | } 322 | 323 | if wc.writeBuf == nil { 324 | t.Fatal("writeBuf is nil after NextWriter") 325 | } 326 | 327 | writeBufAddr := &wc.writeBuf[0] 328 | 329 | if _, err := io.WriteString(w, "Hello"); err != nil { 330 | t.Fatalf("io.WriteString(w, message) returned %v", err) 331 | } 332 | 333 | if err := w.Close(); err == nil { 334 | t.Fatalf("w.Close() did not return error") 335 | } 336 | 337 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 338 | t.Fatal("writeBuf not returned to pool") 339 | } 340 | 341 | // Part 2: Test WriteMessage 342 | 343 | wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 344 | 345 | if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { 346 | t.Fatalf("wc.WriteMessage did not return error") 347 | } 348 | 349 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 350 | t.Fatal("writeBuf not returned to pool") 351 | } 352 | } 353 | 354 | func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { 355 | const bufSize = 512 356 | 357 | expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} 358 | 359 | var b1, b2 bytes.Buffer 360 | wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 361 | rc := newTestConn(&b1, &b2, true) 362 | 363 | w, _ := wc.NextWriter(BinaryMessage) 364 | w.Write(make([]byte, bufSize+bufSize/2)) 365 | wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) 366 | w.Close() 367 | 368 | op, r, err := rc.NextReader() 369 | if op != BinaryMessage || err != nil { 370 | t.Fatalf("NextReader() returned %d, %v", op, err) 371 | } 372 | _, err = io.Copy(ioutil.Discard, r) 373 | if !reflect.DeepEqual(err, expectedErr) { 374 | t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) 375 | } 376 | _, _, err = rc.NextReader() 377 | if !reflect.DeepEqual(err, expectedErr) { 378 | t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) 379 | } 380 | } 381 | 382 | func TestEOFWithinFrame(t *testing.T) { 383 | const bufSize = 64 384 | 385 | for n := 0; ; n++ { 386 | var b bytes.Buffer 387 | wc := newTestConn(nil, &b, false) 388 | rc := newTestConn(&b, nil, true) 389 | 390 | w, _ := wc.NextWriter(BinaryMessage) 391 | w.Write(make([]byte, bufSize)) 392 | w.Close() 393 | 394 | if n >= b.Len() { 395 | break 396 | } 397 | b.Truncate(n) 398 | 399 | op, r, err := rc.NextReader() 400 | if err == errUnexpectedEOF { 401 | continue 402 | } 403 | if op != BinaryMessage || err != nil { 404 | t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) 405 | } 406 | _, err = io.Copy(ioutil.Discard, r) 407 | if err != errUnexpectedEOF { 408 | t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) 409 | } 410 | _, _, err = rc.NextReader() 411 | if err != errUnexpectedEOF { 412 | t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) 413 | } 414 | } 415 | } 416 | 417 | func TestEOFBeforeFinalFrame(t *testing.T) { 418 | const bufSize = 512 419 | 420 | var b1, b2 bytes.Buffer 421 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 422 | rc := newTestConn(&b1, &b2, true) 423 | 424 | w, _ := wc.NextWriter(BinaryMessage) 425 | w.Write(make([]byte, bufSize+bufSize/2)) 426 | 427 | op, r, err := rc.NextReader() 428 | if op != BinaryMessage || err != nil { 429 | t.Fatalf("NextReader() returned %d, %v", op, err) 430 | } 431 | _, err = io.Copy(ioutil.Discard, r) 432 | if err != errUnexpectedEOF { 433 | t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) 434 | } 435 | _, _, err = rc.NextReader() 436 | if err != errUnexpectedEOF { 437 | t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) 438 | } 439 | } 440 | 441 | func TestWriteAfterMessageWriterClose(t *testing.T) { 442 | wc := newTestConn(nil, &bytes.Buffer{}, false) 443 | w, _ := wc.NextWriter(BinaryMessage) 444 | io.WriteString(w, "hello") 445 | if err := w.Close(); err != nil { 446 | t.Fatalf("unexpected error closing message writer, %v", err) 447 | } 448 | 449 | if _, err := io.WriteString(w, "world"); err == nil { 450 | t.Fatalf("no error writing after close") 451 | } 452 | 453 | w, _ = wc.NextWriter(BinaryMessage) 454 | io.WriteString(w, "hello") 455 | 456 | // close w by getting next writer 457 | _, err := wc.NextWriter(BinaryMessage) 458 | if err != nil { 459 | t.Fatalf("unexpected error getting next writer, %v", err) 460 | } 461 | 462 | if _, err := io.WriteString(w, "world"); err == nil { 463 | t.Fatalf("no error writing after close") 464 | } 465 | } 466 | 467 | func TestReadLimit(t *testing.T) { 468 | t.Run("Test ReadLimit is enforced", func(t *testing.T) { 469 | const readLimit = 512 470 | message := make([]byte, readLimit+1) 471 | 472 | var b1, b2 bytes.Buffer 473 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) 474 | rc := newTestConn(&b1, &b2, true) 475 | rc.SetReadLimit(readLimit) 476 | 477 | // Send message at the limit with interleaved pong. 478 | w, _ := wc.NextWriter(BinaryMessage) 479 | w.Write(message[:readLimit-1]) 480 | wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) 481 | w.Write(message[:1]) 482 | w.Close() 483 | 484 | // Send message larger than the limit. 485 | wc.WriteMessage(BinaryMessage, message[:readLimit+1]) 486 | 487 | op, _, err := rc.NextReader() 488 | if op != BinaryMessage || err != nil { 489 | t.Fatalf("1: NextReader() returned %d, %v", op, err) 490 | } 491 | op, r, err := rc.NextReader() 492 | if op != BinaryMessage || err != nil { 493 | t.Fatalf("2: NextReader() returned %d, %v", op, err) 494 | } 495 | _, err = io.Copy(ioutil.Discard, r) 496 | if err != ErrReadLimit { 497 | t.Fatalf("io.Copy() returned %v", err) 498 | } 499 | }) 500 | 501 | t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { 502 | const readLimit = 1 503 | 504 | var b1, b2 bytes.Buffer 505 | rc := newTestConn(&b1, &b2, true) 506 | rc.SetReadLimit(readLimit) 507 | 508 | // First, send a non-final binary message 509 | b1.Write([]byte("\x02\x81")) 510 | 511 | // Mask key 512 | b1.Write([]byte("\x00\x00\x00\x00")) 513 | 514 | // First payload 515 | b1.Write([]byte("A")) 516 | 517 | // Next, send a negative-length, non-final continuation frame 518 | b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) 519 | 520 | // Mask key 521 | b1.Write([]byte("\x00\x00\x00\x00")) 522 | 523 | // Next, send a too long, final continuation frame 524 | b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) 525 | 526 | // Mask key 527 | b1.Write([]byte("\x00\x00\x00\x00")) 528 | 529 | // Too-long payload 530 | b1.Write([]byte("BCDEF")) 531 | 532 | op, r, err := rc.NextReader() 533 | if op != BinaryMessage || err != nil { 534 | t.Fatalf("1: NextReader() returned %d, %v", op, err) 535 | } 536 | 537 | var buf [10]byte 538 | var read int 539 | n, err := r.Read(buf[:]) 540 | if err != nil && err != ErrReadLimit { 541 | t.Fatalf("unexpected error testing read limit: %v", err) 542 | } 543 | read += n 544 | 545 | n, err = r.Read(buf[:]) 546 | if err != nil && err != ErrReadLimit { 547 | t.Fatalf("unexpected error testing read limit: %v", err) 548 | } 549 | read += n 550 | 551 | if err == nil && read > readLimit { 552 | t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) 553 | } 554 | }) 555 | } 556 | 557 | func TestAddrs(t *testing.T) { 558 | c := newTestConn(nil, nil, true) 559 | if c.LocalAddr() != localAddr { 560 | t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) 561 | } 562 | if c.RemoteAddr() != remoteAddr { 563 | t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) 564 | } 565 | } 566 | 567 | func TestDeprecatedUnderlyingConn(t *testing.T) { 568 | var b1, b2 bytes.Buffer 569 | fc := fakeNetConn{Reader: &b1, Writer: &b2} 570 | c := newConn(fc, true, 1024, 1024, nil, nil, nil) 571 | ul := c.UnderlyingConn() 572 | if ul != fc { 573 | t.Fatalf("Underlying conn is not what it should be.") 574 | } 575 | } 576 | 577 | func TestNetConn(t *testing.T) { 578 | var b1, b2 bytes.Buffer 579 | fc := fakeNetConn{Reader: &b1, Writer: &b2} 580 | c := newConn(fc, true, 1024, 1024, nil, nil, nil) 581 | ul := c.NetConn() 582 | if ul != fc { 583 | t.Fatalf("Underlying conn is not what it should be.") 584 | } 585 | } 586 | 587 | func TestBufioReadBytes(t *testing.T) { 588 | // Test calling bufio.ReadBytes for value longer than read buffer size. 589 | 590 | m := make([]byte, 512) 591 | m[len(m)-1] = '\n' 592 | 593 | var b1, b2 bytes.Buffer 594 | wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) 595 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) 596 | 597 | w, _ := wc.NextWriter(BinaryMessage) 598 | w.Write(m) 599 | w.Close() 600 | 601 | op, r, err := rc.NextReader() 602 | if op != BinaryMessage || err != nil { 603 | t.Fatalf("NextReader() returned %d, %v", op, err) 604 | } 605 | 606 | br := bufio.NewReader(r) 607 | p, err := br.ReadBytes('\n') 608 | if err != nil { 609 | t.Fatalf("ReadBytes() returned %v", err) 610 | } 611 | if len(p) != len(m) { 612 | t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) 613 | } 614 | } 615 | 616 | var closeErrorTests = []struct { 617 | err error 618 | codes []int 619 | ok bool 620 | }{ 621 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, 622 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, 623 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, 624 | {errors.New("hello"), []int{CloseNormalClosure}, false}, 625 | } 626 | 627 | func TestCloseError(t *testing.T) { 628 | for _, tt := range closeErrorTests { 629 | ok := IsCloseError(tt.err, tt.codes...) 630 | if ok != tt.ok { 631 | t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 632 | } 633 | } 634 | } 635 | 636 | var unexpectedCloseErrorTests = []struct { 637 | err error 638 | codes []int 639 | ok bool 640 | }{ 641 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, 642 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, 643 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, 644 | {errors.New("hello"), []int{CloseNormalClosure}, false}, 645 | } 646 | 647 | func TestUnexpectedCloseErrors(t *testing.T) { 648 | for _, tt := range unexpectedCloseErrorTests { 649 | ok := IsUnexpectedCloseError(tt.err, tt.codes...) 650 | if ok != tt.ok { 651 | t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 652 | } 653 | } 654 | } 655 | 656 | type blockingWriter struct { 657 | c1, c2 chan struct{} 658 | } 659 | 660 | func (w blockingWriter) Write(p []byte) (int, error) { 661 | // Allow main to continue 662 | close(w.c1) 663 | // Wait for panic in main 664 | <-w.c2 665 | return len(p), nil 666 | } 667 | 668 | func TestConcurrentWritePanic(t *testing.T) { 669 | w := blockingWriter{make(chan struct{}), make(chan struct{})} 670 | c := newTestConn(nil, w, false) 671 | go func() { 672 | c.WriteMessage(TextMessage, []byte{}) 673 | }() 674 | 675 | // wait for goroutine to block in write. 676 | <-w.c1 677 | 678 | defer func() { 679 | close(w.c2) 680 | if v := recover(); v != nil { 681 | return 682 | } 683 | }() 684 | 685 | c.WriteMessage(TextMessage, []byte{}) 686 | t.Fatal("should not get here") 687 | } 688 | 689 | type failingReader struct{} 690 | 691 | func (r failingReader) Read(p []byte) (int, error) { 692 | return 0, io.EOF 693 | } 694 | 695 | func TestFailedConnectionReadPanic(t *testing.T) { 696 | c := newTestConn(failingReader{}, nil, false) 697 | 698 | defer func() { 699 | if v := recover(); v != nil { 700 | return 701 | } 702 | }() 703 | 704 | for i := 0; i < 20000; i++ { 705 | c.ReadMessage() 706 | } 707 | t.Fatal("should not get here") 708 | } 709 | -------------------------------------------------------------------------------- /examples/autobahn/README.md: -------------------------------------------------------------------------------- 1 | # Test Server 2 | 3 | This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite). 4 | 5 | To test the server, run 6 | 7 | go run server.go 8 | 9 | and start the client test driver 10 | 11 | mkdir -p reports 12 | docker run -it --rm \ 13 | -v ${PWD}/config:/config \ 14 | -v ${PWD}/reports:/reports \ 15 | crossbario/autobahn-testsuite \ 16 | wstest -m fuzzingclient -s /config/fuzzingclient.json 17 | 18 | When the client completes, it writes a report to reports/index.html. 19 | -------------------------------------------------------------------------------- /examples/autobahn/config/fuzzingclient.json: -------------------------------------------------------------------------------- 1 | { 2 | "cases": ["*"], 3 | "exclude-cases": [], 4 | "exclude-agent-cases": {}, 5 | "outdir": "/reports", 6 | "options": {"failByDrop": false}, 7 | "servers": [ 8 | { 9 | "agent": "ReadAllWriteMessage", 10 | "url": "ws://host.docker.internal:9000/m" 11 | }, 12 | { 13 | "agent": "ReadAllWritePreparedMessage", 14 | "url": "ws://host.docker.internal:9000/p" 15 | }, 16 | { 17 | "agent": "CopyFull", 18 | "url": "ws://host.docker.internal:9000/f" 19 | }, 20 | { 21 | "agent": "ReadAllWrite", 22 | "url": "ws://host.docker.internal:9000/r" 23 | }, 24 | { 25 | "agent": "CopyWriterOnly", 26 | "url": "ws://host.docker.internal:9000/c" 27 | } 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /examples/autobahn/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | // Command server is a test server for the Autobahn WebSockets Test Suite. 9 | package main 10 | 11 | import ( 12 | "context" 13 | "errors" 14 | "io" 15 | "log" 16 | "net/http" 17 | "time" 18 | "unicode/utf8" 19 | 20 | "github.com/cloudwego/hertz/pkg/app" 21 | "github.com/cloudwego/hertz/pkg/app/server" 22 | "github.com/cloudwego/hertz/pkg/protocol/consts" 23 | "github.com/hertz-contrib/websocket" 24 | ) 25 | 26 | var upgrader = websocket.HertzUpgrader{ 27 | ReadBufferSize: 4096, 28 | WriteBufferSize: 4096, 29 | EnableCompression: true, 30 | CheckOrigin: func(ctx *app.RequestContext) bool { 31 | return true 32 | }, 33 | } 34 | 35 | // echoCopy echoes messages from the client using io.Copy. 36 | func echoCopy(ctx *app.RequestContext, writerOnly bool) { 37 | err := upgrader.Upgrade(ctx, func(conn *websocket.Conn) { 38 | defer conn.Close() 39 | for { 40 | mt, r, err := conn.NextReader() 41 | if err != nil { 42 | if err != io.EOF { 43 | log.Println("NextReader:", err) 44 | } 45 | return 46 | } 47 | if mt == websocket.TextMessage { 48 | r = &validator{r: r} 49 | } 50 | w, err := conn.NextWriter(mt) 51 | if err != nil { 52 | log.Println("NextWriter:", err) 53 | return 54 | } 55 | if mt == websocket.TextMessage { 56 | r = &validator{r: r} 57 | } 58 | if writerOnly { 59 | _, err = io.Copy(struct{ io.Writer }{w}, r) 60 | } else { 61 | _, err = io.Copy(w, r) 62 | } 63 | if err != nil { 64 | if err == errInvalidUTF8 { 65 | conn.WriteControl(websocket.CloseMessage, 66 | websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), 67 | time.Time{}) 68 | } 69 | log.Println("Copy:", err) 70 | return 71 | } 72 | err = w.Close() 73 | if err != nil { 74 | log.Println("Close:", err) 75 | return 76 | } 77 | } 78 | }) 79 | if err != nil { 80 | log.Println("Upgrade:", err) 81 | return 82 | } 83 | } 84 | 85 | func echoCopyWriterOnly(_ context.Context, c *app.RequestContext) { 86 | echoCopy(c, true) 87 | } 88 | 89 | func echoCopyFull(_ context.Context, c *app.RequestContext) { 90 | echoCopy(c, false) 91 | } 92 | 93 | // echoReadAll echoes messages from the client by reading the entire message 94 | // with ioutil.ReadAll. 95 | func echoReadAll(ctx *app.RequestContext, writeMessage, writePrepared bool) { 96 | err := upgrader.Upgrade(ctx, func(conn *websocket.Conn) { 97 | defer conn.Close() 98 | for { 99 | mt, b, err := conn.ReadMessage() 100 | if err != nil { 101 | if err != io.EOF { 102 | log.Println("NextReader:", err) 103 | } 104 | return 105 | } 106 | if mt == websocket.TextMessage { 107 | if !utf8.Valid(b) { 108 | conn.WriteControl(websocket.CloseMessage, 109 | websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), 110 | time.Time{}) 111 | log.Println("ReadAll: invalid utf8") 112 | } 113 | } 114 | if writeMessage { 115 | if !writePrepared { 116 | err = conn.WriteMessage(mt, b) 117 | if err != nil { 118 | log.Println("WriteMessage:", err) 119 | } 120 | } else { 121 | pm, err := websocket.NewPreparedMessage(mt, b) 122 | if err != nil { 123 | log.Println("NewPreparedMessage:", err) 124 | return 125 | } 126 | err = conn.WritePreparedMessage(pm) 127 | if err != nil { 128 | log.Println("WritePreparedMessage:", err) 129 | } 130 | } 131 | } else { 132 | w, err := conn.NextWriter(mt) 133 | if err != nil { 134 | log.Println("NextWriter:", err) 135 | return 136 | } 137 | if _, err := w.Write(b); err != nil { 138 | log.Println("Writer:", err) 139 | return 140 | } 141 | if err := w.Close(); err != nil { 142 | log.Println("Close:", err) 143 | return 144 | } 145 | } 146 | } 147 | }) 148 | if err != nil { 149 | log.Println("Upgrade:", err) 150 | return 151 | } 152 | } 153 | 154 | func echoReadAllWriter(_ context.Context, c *app.RequestContext) { 155 | echoReadAll(c, false, false) 156 | } 157 | 158 | func echoReadAllWriteMessage(_ context.Context, c *app.RequestContext) { 159 | echoReadAll(c, true, false) 160 | } 161 | 162 | func echoReadAllWritePreparedMessage(_ context.Context, c *app.RequestContext) { 163 | echoReadAll(c, true, true) 164 | } 165 | 166 | func serveHome(_ context.Context, c *app.RequestContext) { 167 | if string(c.URI().Path()) != "/" { 168 | _ = c.AbortWithError(http.StatusNotFound, errors.New("not found")) 169 | return 170 | } 171 | if !c.IsGet() { 172 | _ = c.AbortWithError(http.StatusMethodNotAllowed, errors.New("method not allowed")) 173 | return 174 | } 175 | c.Response.Header.Set("Content-Type", "text/html; charset=utf-8") 176 | c.String(consts.StatusOK, "Echo Server") 177 | } 178 | 179 | var addr = ":9000" 180 | 181 | func main() { 182 | // server.Default() creates a Hertz with recovery middleware. 183 | // If you need a pure hertz, you can use server.New() 184 | h := server.Default(server.WithHostPorts(addr)) 185 | h.GET("/", serveHome) 186 | h.GET("/c", echoCopyWriterOnly) 187 | h.GET("/f", echoCopyFull) 188 | h.GET("/r", echoReadAllWriter) 189 | h.GET("/m", echoReadAllWriteMessage) 190 | h.GET("/p", echoReadAllWritePreparedMessage) 191 | 192 | h.NoRoute(func(c context.Context, ctx *app.RequestContext) { 193 | ctx.AbortWithMsg("Unsupported path", consts.StatusNotFound) 194 | }) 195 | 196 | h.Spin() 197 | } 198 | 199 | type validator struct { 200 | state int 201 | x rune 202 | r io.Reader 203 | } 204 | 205 | var errInvalidUTF8 = errors.New("invalid utf8") 206 | 207 | func (r *validator) Read(p []byte) (int, error) { 208 | n, err := r.r.Read(p) 209 | state := r.state 210 | x := r.x 211 | for _, b := range p[:n] { 212 | state, x = decode(state, x, b) 213 | if state == utf8Reject { 214 | break 215 | } 216 | } 217 | r.state = state 218 | r.x = x 219 | if state == utf8Reject || (err == io.EOF && state != utf8Accept) { 220 | return n, errInvalidUTF8 221 | } 222 | return n, err 223 | } 224 | 225 | // UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ 226 | // 227 | // Copyright (c) 2008-2009 Bjoern Hoehrmann 228 | // 229 | // Permission is hereby granted, free of charge, to any person obtaining a copy 230 | // of this software and associated documentation files (the "Software"), to 231 | // deal in the Software without restriction, including without limitation the 232 | // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 233 | // sell copies of the Software, and to permit persons to whom the Software is 234 | // furnished to do so, subject to the following conditions: 235 | // 236 | // The above copyright notice and this permission notice shall be included in 237 | // all copies or substantial portions of the Software. 238 | // 239 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 240 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 241 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 242 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 243 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 244 | // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 245 | // IN THE SOFTWARE. 246 | var utf8d = [...]byte{ 247 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1f 248 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3f 249 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5f 250 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7f 251 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9f 252 | 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // a0..bf 253 | 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // c0..df 254 | 0xa, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // e0..ef 255 | 0xb, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // f0..ff 256 | 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 257 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 258 | 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 259 | 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 260 | 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // s7..s8 261 | } 262 | 263 | const ( 264 | utf8Accept = 0 265 | utf8Reject = 1 266 | ) 267 | 268 | func decode(state int, x rune, b byte) (int, rune) { 269 | t := utf8d[b] 270 | if state != utf8Accept { 271 | x = rune(b&0x3f) | (x << 6) 272 | } else { 273 | x = rune((0xff >> t) & b) 274 | } 275 | state = int(utf8d[256+state*16+int(t)]) 276 | return state, x 277 | } 278 | -------------------------------------------------------------------------------- /examples/chat/README.md: -------------------------------------------------------------------------------- 1 | # Chat Example 2 | 3 | This application shows how to use the 4 | [websocket](https://github.com/hertz-contrib/websocket) package to implement a simple 5 | web chat application. 6 | 7 | ## Running the example 8 | 9 | The example requires a working Go development environment. The [Getting 10 | Started](http://golang.org/doc/install) page describes how to install the 11 | development environment. 12 | 13 | Once you have Go up and running, you can download, build and run the example 14 | using the following commands. 15 | 16 | $ go get https://github.com/hertz-contrib/websocket 17 | $ cd `go list -f '{{.Dir}}' https://github.com/hertz-contrib/websocket/examples/chat` 18 | $ go run *.go 19 | 20 | To use the chat example, open http://localhost:8080/ in your browser. 21 | 22 | ## Server 23 | 24 | The server application defines two types, `Client` and `Hub`. The server 25 | creates an instance of the `Client` type for each websocket connection. A 26 | `Client` acts as an intermediary between the websocket connection and a single 27 | instance of the `Hub` type. The `Hub` maintains a set of registered clients and 28 | broadcasts messages to the clients. 29 | 30 | The application runs one goroutine for the `Hub` and two goroutines for each 31 | `Client`. The goroutines communicate with each other using channels. The `Hub` 32 | has channels for registering clients, unregistering clients and broadcasting 33 | messages. A `Client` has a buffered channel of outbound messages. One of the 34 | client's goroutines reads messages from this channel and writes the messages to 35 | the websocket. The other client goroutine reads messages from the websocket and 36 | sends them to the hub. 37 | 38 | ### Hub 39 | 40 | The code for the `Hub` type is in 41 | [hub.go](https://https://github.com/hertz-contrib/websocket/blob/master/examples/chat/hub.go). 42 | The application's `main` function starts the hub's `run` method as a goroutine. 43 | Clients send requests to the hub using the `register`, `unregister` and 44 | `broadcast` channels. 45 | 46 | The hub registers clients by adding the client pointer as a key in the 47 | `clients` map. The map value is always true. 48 | 49 | The unregister code is a little more complicated. In addition to deleting the 50 | client pointer from the `clients` map, the hub closes the clients's `send` 51 | channel to signal the client that no more messages will be sent to the client. 52 | 53 | The hub handles messages by looping over the registered clients and sending the 54 | message to the client's `send` channel. If the client's `send` buffer is full, 55 | then the hub assumes that the client is dead or stuck. In this case, the hub 56 | unregisters the client and closes the websocket. 57 | 58 | ### Client 59 | 60 | The code for the `Client` type is in [client.go](https://https://github.com/hertz-contrib/websocket/blob/master/examples/chat/client.go). 61 | 62 | The `serveWs` function is registered by the application's `main` function as 63 | an HTTP handler. The handler upgrades the HTTP connection to the WebSocket 64 | protocol, creates a client, registers the client with the hub and schedules the 65 | client to be unregistered using a defer statement. 66 | 67 | Next, the HTTP handler starts the client's `writePump` method as a goroutine. 68 | This method transfers messages from the client's send channel to the websocket 69 | connection. The writer method exits when the channel is closed by the hub or 70 | there's an error writing to the websocket connection. 71 | 72 | Finally, the HTTP handler calls the client's `readPump` method. This method 73 | transfers inbound messages from the websocket to the hub. 74 | 75 | WebSocket connections [support one concurrent reader and one concurrent 76 | writer](https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency). The 77 | application ensures that these concurrency requirements are met by executing 78 | all reads from the `readPump` goroutine and all writes from the `writePump` 79 | goroutine. 80 | 81 | To improve efficiency under high load, the `writePump` function coalesces 82 | pending chat messages in the `send` channel to a single WebSocket message. This 83 | reduces the number of system calls and the amount of data sent over the 84 | network. 85 | 86 | ## Frontend 87 | 88 | The frontend code is in [home.html](https://https://github.com/hertz-contrib/websocket/blob/master/examples/chat/home.html). 89 | 90 | On document load, the script checks for websocket functionality in the browser. 91 | If websocket functionality is available, then the script opens a connection to 92 | the server and registers a callback to handle messages from the server. The 93 | callback appends the message to the chat log using the appendLog function. 94 | 95 | To allow the user to manually scroll through the chat log without interruption 96 | from new messages, the `appendLog` function checks the scroll position before 97 | adding new content. If the chat log is scrolled to the bottom, then the 98 | function scrolls new content into view after adding the content. Otherwise, the 99 | scroll position is not changed. 100 | 101 | The form handler writes the user input to the websocket and clears the input 102 | field. 103 | -------------------------------------------------------------------------------- /examples/chat/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package main 9 | 10 | import ( 11 | "bytes" 12 | "log" 13 | "time" 14 | 15 | "github.com/cloudwego/hertz/pkg/app" 16 | "github.com/hertz-contrib/websocket" 17 | ) 18 | 19 | const ( 20 | // Time allowed to write a message to the peer. 21 | writeWait = 10 * time.Second 22 | 23 | // Time allowed to read the next pong message from the peer. 24 | pongWait = 60 * time.Second 25 | 26 | // Send pings to peer with this period. Must be less than pongWait. 27 | pingPeriod = (pongWait * 9) / 10 28 | 29 | // Maximum message size allowed from peer. 30 | maxMessageSize = 512 31 | ) 32 | 33 | var ( 34 | newline = []byte{'\n'} 35 | space = []byte{' '} 36 | ) 37 | 38 | // Client is a middleman between the websocket connection and the hub. 39 | type Client struct { 40 | // The websocket connection. 41 | conn *websocket.Conn 42 | 43 | // Buffered channel of outbound messages. 44 | send chan []byte 45 | } 46 | 47 | // readPump pumps messages from the websocket connection to the hub. 48 | // 49 | // The application runs readPump in a per-connection goroutine. The application 50 | // ensures that there is at most one reader on a connection by executing all 51 | // reads from this goroutine. 52 | func (c *Client) readPump() { 53 | defer func() { 54 | hub.unregister <- c 55 | c.conn.Close() 56 | }() 57 | c.conn.SetReadLimit(maxMessageSize) 58 | c.conn.SetReadDeadline(time.Now().Add(pongWait)) 59 | c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 60 | for { 61 | _, message, err := c.conn.ReadMessage() 62 | if err != nil { 63 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 64 | log.Printf("error: %v", err) 65 | } 66 | break 67 | } 68 | message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) 69 | hub.broadcast <- message 70 | } 71 | } 72 | 73 | // writePump pumps messages from the hub to the websocket connection. 74 | // 75 | // A goroutine running writePump is started for each connection. The 76 | // application ensures that there is at most one writer to a connection by 77 | // executing all writes from this goroutine. 78 | func (c *Client) writePump() { 79 | ticker := time.NewTicker(pingPeriod) 80 | defer func() { 81 | ticker.Stop() 82 | c.conn.Close() 83 | }() 84 | for { 85 | select { 86 | case message, ok := <-c.send: 87 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 88 | if !ok { 89 | // The hub closed the channel. 90 | c.conn.WriteMessage(websocket.CloseMessage, []byte{}) 91 | return 92 | } 93 | 94 | w, err := c.conn.NextWriter(websocket.TextMessage) 95 | if err != nil { 96 | return 97 | } 98 | w.Write(message) 99 | 100 | // Add queued chat messages to the current websocket message. 101 | n := len(c.send) 102 | for i := 0; i < n; i++ { 103 | w.Write(newline) 104 | w.Write(<-c.send) 105 | } 106 | 107 | if err := w.Close(); err != nil { 108 | return 109 | } 110 | case <-ticker.C: 111 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 112 | if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 113 | return 114 | } 115 | } 116 | } 117 | } 118 | 119 | // serveWs handles websocket requests from the peer. 120 | func serveWs(ctx *app.RequestContext) { 121 | err := upgrader.Upgrade(ctx, func(conn *websocket.Conn) { 122 | client := &Client{conn: conn, send: make(chan []byte, 256)} 123 | hub.register <- client 124 | 125 | go client.writePump() 126 | client.readPump() 127 | }) 128 | if err != nil { 129 | log.Println(err) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /examples/chat/home.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Chat Example 5 | 53 | 90 | 91 | 92 |
93 |
94 | 95 | 96 |
97 | 98 | 99 | -------------------------------------------------------------------------------- /examples/chat/hub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package main 9 | 10 | import ( 11 | "sync" 12 | ) 13 | 14 | // Hub maintains the set of active clients and broadcasts messages to the 15 | // clients. 16 | type Hub struct { 17 | // Registered clients. 18 | clients map[*Client]struct{} 19 | clientsLock sync.RWMutex 20 | 21 | // Inbound messages from the clients. 22 | broadcast chan []byte 23 | 24 | // Register requests from the clients. 25 | register chan *Client 26 | 27 | // Unregister requests from clients. 28 | unregister chan *Client 29 | } 30 | 31 | var hub = newHub() 32 | 33 | func newHub() *Hub { 34 | return &Hub{ 35 | broadcast: make(chan []byte), 36 | register: make(chan *Client), 37 | unregister: make(chan *Client), 38 | clients: make(map[*Client]struct{}), 39 | } 40 | } 41 | 42 | func (h *Hub) run() { 43 | for { 44 | select { 45 | case client := <-h.register: 46 | h.Register(client) 47 | case client := <-h.unregister: 48 | h.Unregister(client) 49 | case message := <-h.broadcast: 50 | for client := range h.clients { 51 | select { 52 | case client.send <- message: 53 | default: 54 | close(client.send) 55 | delete(h.clients, client) 56 | } 57 | } 58 | } 59 | } 60 | } 61 | 62 | func (h *Hub) Register(client *Client) { 63 | h.AddClient(client) 64 | } 65 | 66 | func (h *Hub) AddClient(client *Client) { 67 | h.clientsLock.Lock() 68 | defer h.clientsLock.Unlock() 69 | h.clients[client] = struct{}{} 70 | } 71 | 72 | func (h *Hub) Unregister(client *Client) { 73 | h.DelClient(client) 74 | } 75 | 76 | func (h *Hub) DelClient(client *Client) { 77 | h.clientsLock.Lock() 78 | defer h.clientsLock.Unlock() 79 | if _, ok := h.clients[client]; ok { 80 | delete(h.clients, client) 81 | close(client.send) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /examples/chat/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package main 9 | 10 | import ( 11 | "context" 12 | "net/http" 13 | 14 | "github.com/cloudwego/hertz/pkg/app" 15 | "github.com/cloudwego/hertz/pkg/app/server" 16 | "github.com/cloudwego/hertz/pkg/common/hlog" 17 | "github.com/hertz-contrib/websocket" 18 | ) 19 | 20 | var upgrader = websocket.HertzUpgrader{ 21 | ReadBufferSize: 1024, 22 | WriteBufferSize: 1024, 23 | CheckOrigin: func(ctx *app.RequestContext) bool { 24 | return true 25 | }, 26 | } 27 | 28 | var addr = ":8080" 29 | 30 | func serveHome(_ context.Context, c *app.RequestContext) { 31 | if string(c.URI().Path()) != "/" { 32 | hlog.Error("Not found", http.StatusNotFound) 33 | return 34 | } 35 | if !c.IsGet() { 36 | hlog.Error("Method not allowed", http.StatusMethodNotAllowed) 37 | return 38 | } 39 | c.HTML(http.StatusOK, "home.html", nil) 40 | } 41 | 42 | func main() { 43 | go hub.run() 44 | // server.Default() creates a Hertz with recovery middleware. 45 | // If you need a pure hertz, you can use server.New() 46 | h := server.Default(server.WithHostPorts(addr)) 47 | h.LoadHTMLGlob("home.html") 48 | 49 | h.GET("/", serveHome) 50 | h.GET("/ws", func(c context.Context, ctx *app.RequestContext) { 51 | serveWs(ctx) 52 | }) 53 | 54 | h.Spin() 55 | } 56 | -------------------------------------------------------------------------------- /examples/command/README.md: -------------------------------------------------------------------------------- 1 | # Command example 2 | 3 | This example connects a websocket connection to stdin and stdout of a command. 4 | Received messages are written to stdin followed by a `\n`. Each line read from 5 | standard out is sent as a message to the client. 6 | 7 | $ go run main.go 8 | # Open http://localhost:8080/ . 9 | 10 | Try the following commands. 11 | 12 | # Echo sent messages to the output area. 13 | $ go run main.go cat 14 | 15 | # Run a shell.Try sending "ls" and "cat main.go". 16 | $ go run main.go sh 17 | 18 | -------------------------------------------------------------------------------- /examples/command/home.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Command Example 5 | 53 | 94 | 95 | 96 |
97 |
98 | 99 | 100 |
101 | 102 | 103 | -------------------------------------------------------------------------------- /examples/command/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package main 9 | 10 | import ( 11 | "bufio" 12 | "context" 13 | "flag" 14 | "io" 15 | "log" 16 | "net/http" 17 | "os" 18 | "os/exec" 19 | "time" 20 | 21 | "github.com/cloudwego/hertz/pkg/app" 22 | "github.com/cloudwego/hertz/pkg/app/server" 23 | "github.com/cloudwego/hertz/pkg/common/hlog" 24 | "github.com/hertz-contrib/websocket" 25 | ) 26 | 27 | var ( 28 | addr = flag.String("addr", "127.0.0.1:8080", "http service address") 29 | cmdPath string 30 | ) 31 | 32 | const ( 33 | // Time allowed to write a message to the peer. 34 | writeWait = 10 * time.Second 35 | 36 | // Maximum message size allowed from peer. 37 | maxMessageSize = 8192 38 | 39 | // Time allowed to read the next pong message from the peer. 40 | pongWait = 60 * time.Second 41 | 42 | // Send pings to peer with this period. Must be less than pongWait. 43 | pingPeriod = (pongWait * 9) / 10 44 | 45 | // Time to wait before force close on connection. 46 | closeGracePeriod = 10 * time.Second 47 | ) 48 | 49 | func pumpStdin(ws *websocket.Conn, w io.Writer) { 50 | defer ws.Close() 51 | ws.SetReadLimit(maxMessageSize) 52 | ws.SetReadDeadline(time.Now().Add(pongWait)) 53 | ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 54 | for { 55 | _, message, err := ws.ReadMessage() 56 | if err != nil { 57 | break 58 | } 59 | message = append(message, '\n') 60 | if _, err := w.Write(message); err != nil { 61 | break 62 | } 63 | } 64 | } 65 | 66 | func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { 67 | defer func() { 68 | }() 69 | s := bufio.NewScanner(r) 70 | for s.Scan() { 71 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 72 | if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil { 73 | ws.Close() 74 | break 75 | } 76 | } 77 | if s.Err() != nil { 78 | log.Println("scan:", s.Err()) 79 | } 80 | close(done) 81 | 82 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 83 | ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 84 | time.Sleep(closeGracePeriod) 85 | ws.Close() 86 | } 87 | 88 | func ping(ws *websocket.Conn, done chan struct{}) { 89 | ticker := time.NewTicker(pingPeriod) 90 | defer ticker.Stop() 91 | for { 92 | select { 93 | case <-ticker.C: 94 | if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { 95 | log.Println("ping:", err) 96 | } 97 | case <-done: 98 | return 99 | } 100 | } 101 | } 102 | 103 | func internalError(ws *websocket.Conn, msg string, err error) { 104 | log.Println(msg, err) 105 | ws.WriteMessage(websocket.TextMessage, []byte("Internal server error.")) 106 | } 107 | 108 | var upgrader = websocket.HertzUpgrader{} 109 | 110 | func serveWs(c context.Context, ctx *app.RequestContext) { 111 | err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) { 112 | defer ws.Close() 113 | 114 | outr, outw, err := os.Pipe() 115 | if err != nil { 116 | internalError(ws, "stdout:", err) 117 | return 118 | } 119 | defer outr.Close() 120 | defer outw.Close() 121 | 122 | inr, inw, err := os.Pipe() 123 | if err != nil { 124 | internalError(ws, "stdin:", err) 125 | return 126 | } 127 | defer inr.Close() 128 | defer inw.Close() 129 | 130 | proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{ 131 | Files: []*os.File{inr, outw, outw}, 132 | }) 133 | if err != nil { 134 | internalError(ws, "start:", err) 135 | return 136 | } 137 | 138 | inr.Close() 139 | outw.Close() 140 | 141 | stdoutDone := make(chan struct{}) 142 | go pumpStdout(ws, outr, stdoutDone) 143 | go ping(ws, stdoutDone) 144 | 145 | pumpStdin(ws, inw) 146 | 147 | // Some commands will exit when stdin is closed. 148 | inw.Close() 149 | 150 | // Other commands need a bonk on the head. 151 | if err := proc.Signal(os.Interrupt); err != nil { 152 | log.Println("inter:", err) 153 | } 154 | 155 | select { 156 | case <-stdoutDone: 157 | case <-time.After(time.Second): 158 | // A bigger bonk on the head. 159 | if err := proc.Signal(os.Kill); err != nil { 160 | log.Println("term:", err) 161 | } 162 | <-stdoutDone 163 | } 164 | 165 | if _, err := proc.Wait(); err != nil { 166 | log.Println("wait:", err) 167 | } 168 | }) 169 | if err != nil { 170 | log.Println("upgrade:", err) 171 | return 172 | } 173 | } 174 | 175 | func serveHome(_ context.Context, c *app.RequestContext) { 176 | log.Println(string(c.URI().FullURI())) 177 | if string(c.URI().Path()) != "/" { 178 | hlog.Error("Not found", http.StatusNotFound) 179 | return 180 | } 181 | if !c.IsGet() { 182 | hlog.Error("Method not allowed", http.StatusMethodNotAllowed) 183 | return 184 | } 185 | c.HTML(http.StatusOK, "home.html", nil) 186 | } 187 | 188 | func main() { 189 | flag.Parse() 190 | if len(flag.Args()) < 1 { 191 | log.Fatal("must specify at least one argument") 192 | } 193 | var err error 194 | cmdPath, err = exec.LookPath(flag.Args()[0]) 195 | if err != nil { 196 | log.Fatal(err) 197 | } 198 | h := server.Default(server.WithHostPorts(*addr)) 199 | h.LoadHTMLGlob("home.html") 200 | h.GET("/", serveHome) 201 | h.GET("/ws", serveWs) 202 | h.Spin() 203 | } 204 | -------------------------------------------------------------------------------- /examples/echo/README.md: -------------------------------------------------------------------------------- 1 | # Client and server example 2 | 3 | This example shows a simple client and server. 4 | 5 | The server echoes messages sent to it. The client sends a message every second 6 | and prints all messages received. 7 | 8 | To run the example, start the server: 9 | 10 | $ go run server.go 11 | 12 | Next, start the client: 13 | 14 | $ go run client.go 15 | 16 | The server includes a simple web client. To use the client, open 17 | http://127.0.0.1:8080 in the browser and follow the instructions on the page. 18 | -------------------------------------------------------------------------------- /examples/echo/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | //go:build ignore 9 | // +build ignore 10 | 11 | package main 12 | 13 | import ( 14 | "flag" 15 | "log" 16 | "net/url" 17 | "os" 18 | "os/signal" 19 | "time" 20 | 21 | "github.com/gorilla/websocket" 22 | ) 23 | 24 | var addr = flag.String("addr", "localhost:8080", "http service address") 25 | 26 | func main() { 27 | flag.Parse() 28 | log.SetFlags(0) 29 | 30 | interrupt := make(chan os.Signal, 1) 31 | signal.Notify(interrupt, os.Interrupt) 32 | 33 | u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"} 34 | log.Printf("connecting to %s", u.String()) 35 | 36 | c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) 37 | if err != nil { 38 | log.Fatal("dial:", err) 39 | } 40 | defer c.Close() 41 | 42 | done := make(chan struct{}) 43 | 44 | go func() { 45 | defer close(done) 46 | for { 47 | _, message, err := c.ReadMessage() 48 | if err != nil { 49 | log.Println("read:", err) 50 | return 51 | } 52 | log.Printf("recv: %s", message) 53 | } 54 | }() 55 | 56 | ticker := time.NewTicker(time.Second) 57 | defer ticker.Stop() 58 | 59 | for { 60 | select { 61 | case <-done: 62 | return 63 | case t := <-ticker.C: 64 | err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) 65 | if err != nil { 66 | log.Println("write:", err) 67 | return 68 | } 69 | case <-interrupt: 70 | log.Println("interrupt") 71 | 72 | // Cleanly close the connection by sending a close message and then 73 | // waiting (with timeout) for the server to close the connection. 74 | err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 75 | if err != nil { 76 | log.Println("write close:", err) 77 | return 78 | } 79 | select { 80 | case <-done: 81 | case <-time.After(time.Second): 82 | } 83 | return 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /examples/echo/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | //go:build ignore 9 | // +build ignore 10 | 11 | package main 12 | 13 | import ( 14 | "context" 15 | "flag" 16 | "html/template" 17 | "log" 18 | 19 | "github.com/cloudwego/hertz/pkg/app" 20 | "github.com/cloudwego/hertz/pkg/app/server" 21 | "github.com/hertz-contrib/websocket" 22 | ) 23 | 24 | var addr = flag.String("addr", "localhost:8080", "http service address") 25 | 26 | var upgrader = websocket.HertzUpgrader{} // use default options 27 | 28 | func echo(_ context.Context, c *app.RequestContext) { 29 | err := upgrader.Upgrade(c, func(conn *websocket.Conn) { 30 | for { 31 | mt, message, err := conn.ReadMessage() 32 | if err != nil { 33 | log.Println("read:", err) 34 | break 35 | } 36 | log.Printf("recv: %s", message) 37 | err = conn.WriteMessage(mt, message) 38 | if err != nil { 39 | log.Println("write:", err) 40 | break 41 | } 42 | } 43 | }) 44 | if err != nil { 45 | log.Print("upgrade:", err) 46 | return 47 | } 48 | } 49 | 50 | func home(_ context.Context, c *app.RequestContext) { 51 | c.SetContentType("text/html; charset=utf-8") 52 | homeTemplate.Execute(c, "ws://"+string(c.Host())+"/echo") 53 | } 54 | 55 | func main() { 56 | flag.Parse() 57 | h := server.Default(server.WithHostPorts(*addr)) 58 | // https://github.com/cloudwego/hertz/issues/121 59 | h.NoHijackConnPool = true 60 | h.GET("/", home) 61 | h.GET("/echo", echo) 62 | h.Spin() 63 | } 64 | 65 | var homeTemplate = template.Must(template.New("").Parse(` 66 | 67 | 68 | 69 | 70 | 124 | 125 | 126 | 127 |
128 |

Click "Open" to create a connection to the server, 129 | "Send" to send a message to the server and "Close" to close the connection. 130 | You can change the message and send multiple times. 131 |

132 |

133 | 134 | 135 |

136 | 137 |

138 |
139 |
140 |
141 | 142 | 143 | `)) 144 | -------------------------------------------------------------------------------- /examples/filewatch/README.md: -------------------------------------------------------------------------------- 1 | # File Watch example. 2 | 3 | This example sends a file to the browser client for display whenever the file is modified. 4 | 5 | $ go run main.go 6 | # Open http://localhost:8080/ . 7 | # Modify the file to see it update in the browser. 8 | -------------------------------------------------------------------------------- /examples/filewatch/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package main 9 | 10 | import ( 11 | "context" 12 | "flag" 13 | "html/template" 14 | "io/ioutil" 15 | "log" 16 | "os" 17 | "strconv" 18 | "time" 19 | 20 | "github.com/cloudwego/hertz/pkg/app" 21 | "github.com/cloudwego/hertz/pkg/app/server" 22 | "github.com/cloudwego/hertz/pkg/protocol/consts" 23 | "github.com/hertz-contrib/websocket" 24 | ) 25 | 26 | const ( 27 | // Time allowed to write the file to the client. 28 | writeWait = 10 * time.Second 29 | 30 | // Time allowed to read the next pong message from the client. 31 | pongWait = 60 * time.Second 32 | 33 | // Send pings to client with this period. Must be less than pongWait. 34 | pingPeriod = (pongWait * 9) / 10 35 | 36 | // Poll file for changes with this period. 37 | filePeriod = 10 * time.Second 38 | ) 39 | 40 | var ( 41 | addr = flag.String("addr", ":8080", "http service address") 42 | homeTempl = template.Must(template.New("").Parse(homeHTML)) 43 | filename string 44 | upgrader = websocket.HertzUpgrader{ 45 | ReadBufferSize: 1024, 46 | WriteBufferSize: 1024, 47 | } 48 | ) 49 | 50 | func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) { 51 | fi, err := os.Stat(filename) 52 | if err != nil { 53 | return nil, lastMod, err 54 | } 55 | if !fi.ModTime().After(lastMod) { 56 | return nil, lastMod, nil 57 | } 58 | p, err := ioutil.ReadFile(filename) 59 | if err != nil { 60 | return nil, fi.ModTime(), err 61 | } 62 | return p, fi.ModTime(), nil 63 | } 64 | 65 | func reader(ws *websocket.Conn) { 66 | defer ws.Close() 67 | ws.SetReadLimit(512) 68 | ws.SetReadDeadline(time.Now().Add(pongWait)) 69 | ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 70 | for { 71 | _, _, err := ws.ReadMessage() 72 | if err != nil { 73 | break 74 | } 75 | } 76 | } 77 | 78 | func writer(ws *websocket.Conn, lastMod time.Time) { 79 | lastError := "" 80 | pingTicker := time.NewTicker(pingPeriod) 81 | fileTicker := time.NewTicker(filePeriod) 82 | defer func() { 83 | pingTicker.Stop() 84 | fileTicker.Stop() 85 | ws.Close() 86 | }() 87 | for { 88 | select { 89 | case <-fileTicker.C: 90 | var p []byte 91 | var err error 92 | 93 | p, lastMod, err = readFileIfModified(lastMod) 94 | 95 | if err != nil { 96 | if s := err.Error(); s != lastError { 97 | lastError = s 98 | p = []byte(lastError) 99 | } 100 | } else { 101 | lastError = "" 102 | } 103 | 104 | if p != nil { 105 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 106 | if err := ws.WriteMessage(websocket.TextMessage, p); err != nil { 107 | return 108 | } 109 | } 110 | case <-pingTicker.C: 111 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 112 | if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { 113 | return 114 | } 115 | } 116 | } 117 | } 118 | 119 | func serveWs(c context.Context, ctx *app.RequestContext) { 120 | err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) { 121 | var lastMod time.Time 122 | if n, err := strconv.ParseInt(string(ctx.FormValue("lastMod")), 16, 64); err == nil { 123 | lastMod = time.Unix(0, n) 124 | } 125 | 126 | go writer(ws, lastMod) 127 | reader(ws) 128 | }) 129 | if err != nil { 130 | if _, ok := err.(websocket.HandshakeError); ok { 131 | log.Println(err) 132 | } 133 | return 134 | } 135 | } 136 | 137 | func serveHome(c context.Context, ctx *app.RequestContext) { 138 | if !ctx.IsGet() { 139 | ctx.AbortWithMsg("Method not allowed", consts.StatusMethodNotAllowed) 140 | return 141 | } 142 | 143 | ctx.SetContentType("text/html; charset=utf-8") 144 | 145 | p, lastMod, err := readFileIfModified(time.Time{}) 146 | if err != nil { 147 | p = []byte(err.Error()) 148 | lastMod = time.Unix(0, 0) 149 | } 150 | v := struct { 151 | Host string 152 | Data string 153 | LastMod string 154 | }{ 155 | string(ctx.Host()), 156 | string(p), 157 | strconv.FormatInt(lastMod.UnixNano(), 16), 158 | } 159 | homeTempl.Execute(ctx, &v) 160 | } 161 | 162 | func main() { 163 | flag.Parse() 164 | if flag.NArg() != 1 { 165 | log.Fatal("filename not specified") 166 | } 167 | filename = flag.Args()[0] 168 | 169 | h := server.New(server.WithHostPorts(*addr)) 170 | 171 | h.GET("/", serveHome) 172 | 173 | h.GET("/ws", serveWs) 174 | 175 | h.NoRoute(func(c context.Context, ctx *app.RequestContext) { 176 | ctx.AbortWithMsg("Unsupported path", consts.StatusNotFound) 177 | }) 178 | 179 | h.Spin() 180 | } 181 | 182 | const homeHTML = ` 183 | 184 | 185 | WebSocket Example 186 | 187 | 188 |
{{.Data}}
189 | 202 | 203 | 204 | ` 205 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hertz-contrib/websocket 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/bytedance/sonic v1.13.2 7 | github.com/cloudwego/hertz v0.9.7 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= 2 | github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= 3 | github.com/bytedance/mockey v1.2.12 h1:aeszOmGw8CPX8CRx1DZ/Glzb1yXvhjDh6jdFBNZjsU4= 4 | github.com/bytedance/mockey v1.2.12/go.mod h1:3ZA4MQasmqC87Tw0w7Ygdy7eHIc2xgpZ8Pona5rsYIk= 5 | github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ= 6 | github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4= 7 | github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= 8 | github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY= 9 | github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= 10 | github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= 11 | github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= 12 | github.com/cloudwego/hertz v0.9.7 h1:tAVaiO+vTf+ZkQhvNhKbDJ0hmC4oJ7bzwDi1KhvhHy4= 13 | github.com/cloudwego/hertz v0.9.7/go.mod h1:t6d7NcoQxPmETvzPMMIVPHMn5C5QzpqIiFsaavoLJYQ= 14 | github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= 15 | github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= 16 | github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= 17 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 18 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 19 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 20 | github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= 21 | github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= 22 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 23 | github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= 24 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 25 | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= 26 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 27 | github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= 28 | github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= 29 | github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= 30 | github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= 31 | github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= 32 | github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= 33 | github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= 34 | github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= 35 | github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= 36 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 37 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 38 | github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= 39 | github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= 40 | github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= 41 | github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= 42 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 43 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 44 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 45 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 46 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 47 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 48 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 49 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 50 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 51 | github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= 52 | github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 53 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 54 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 55 | github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= 56 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 57 | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= 58 | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= 59 | golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= 60 | golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= 61 | golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 62 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 63 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 64 | golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= 65 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 66 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 67 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 68 | golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 69 | golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 70 | golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= 71 | golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 72 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 73 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 74 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 75 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 76 | golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 77 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 78 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 79 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 80 | google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= 81 | google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 82 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 83 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 84 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 85 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 86 | nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= 87 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= 88 | -------------------------------------------------------------------------------- /json.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "io" 12 | 13 | "github.com/bytedance/sonic" 14 | ) 15 | 16 | // WriteJSON writes the JSON encoding of v as a message. 17 | // 18 | // See the documentation for encoding/json Marshal for details about the 19 | // conversion of Go values to JSON. 20 | func (c *Conn) WriteJSON(v interface{}) error { 21 | w, err := c.NextWriter(TextMessage) 22 | if err != nil { 23 | return err 24 | } 25 | err1 := sonic.ConfigDefault.NewEncoder(w).Encode(v) 26 | err2 := w.Close() 27 | if err1 != nil { 28 | return err1 29 | } 30 | return err2 31 | } 32 | 33 | // ReadJSON reads the next JSON-encoded message from the connection and stores 34 | // it in the value pointed to by v. 35 | // 36 | // See the documentation for the encoding/json Unmarshal function for details 37 | // about the conversion of JSON to a Go value. 38 | func (c *Conn) ReadJSON(v interface{}) error { 39 | _, r, err := c.NextReader() 40 | if err != nil { 41 | return err 42 | } 43 | err = sonic.ConfigDefault.NewDecoder(r).Decode(v) 44 | if err == io.EOF { 45 | // One value is expected in the message. 46 | err = io.ErrUnexpectedEOF 47 | } 48 | return err 49 | } 50 | -------------------------------------------------------------------------------- /json_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bytes" 12 | "errors" 13 | "io" 14 | "reflect" 15 | "testing" 16 | 17 | "github.com/bytedance/sonic" 18 | "github.com/bytedance/sonic/decoder" 19 | ) 20 | 21 | func TestJSON(t *testing.T) { 22 | var buf bytes.Buffer 23 | wc := newTestConn(nil, &buf, true) 24 | rc := newTestConn(&buf, nil, false) 25 | 26 | var actual, expect struct { 27 | A int 28 | B string 29 | } 30 | expect.A = 1 31 | expect.B = "hello" 32 | 33 | if err := wc.WriteJSON(&expect); err != nil { 34 | t.Fatal("write", err) 35 | } 36 | 37 | if err := rc.ReadJSON(&actual); err != nil { 38 | t.Fatal("read", err) 39 | } 40 | 41 | if !reflect.DeepEqual(&actual, &expect) { 42 | t.Fatal("equal", actual, expect) 43 | } 44 | } 45 | 46 | func TestPartialJSONRead(t *testing.T) { 47 | var buf0, buf1 bytes.Buffer 48 | wc := newTestConn(nil, &buf0, true) 49 | rc := newTestConn(&buf0, &buf1, false) 50 | 51 | var v struct { 52 | A int 53 | B string 54 | } 55 | v.A = 1 56 | v.B = "hello" 57 | 58 | messageCount := 0 59 | 60 | // Partial JSON values. 61 | 62 | data, err := sonic.Marshal(v) 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | for i := len(data) - 1; i >= 0; i-- { 67 | if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { 68 | t.Fatal(err) 69 | } 70 | messageCount++ 71 | } 72 | 73 | // Whitespace. 74 | 75 | if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { 76 | t.Fatal(err) 77 | } 78 | messageCount++ 79 | 80 | // Close. 81 | 82 | if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { 83 | t.Fatal(err) 84 | } 85 | 86 | for i := 0; i < messageCount; i++ { 87 | err := rc.ReadJSON(&v) 88 | if err != io.ErrUnexpectedEOF && !errors.As(err, &decoder.SyntaxError{}) { 89 | t.Error("read", i, err) 90 | } 91 | } 92 | 93 | err = rc.ReadJSON(&v) 94 | if _, ok := err.(*CloseError); !ok { 95 | t.Error("final", err) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /licenses/LICENSE-gotils: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020-present Sergio Andres Virviescas Santana 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /licenses/LICENSE-websocket: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 17 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 18 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 19 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 20 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 21 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /mask.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | //go:build !appengine 9 | // +build !appengine 10 | 11 | package websocket 12 | 13 | import "unsafe" 14 | 15 | const wordSize = int(unsafe.Sizeof(uintptr(0))) 16 | 17 | func maskBytes(key [4]byte, pos int, b []byte) int { 18 | // Mask one byte at a time for small buffers. 19 | if len(b) < 2*wordSize { 20 | for i := range b { 21 | b[i] ^= key[pos&3] 22 | pos++ 23 | } 24 | return pos & 3 25 | } 26 | 27 | // Mask one byte at a time to word boundary. 28 | if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { 29 | n = wordSize - n 30 | for i := range b[:n] { 31 | b[i] ^= key[pos&3] 32 | pos++ 33 | } 34 | b = b[n:] 35 | } 36 | 37 | // Create aligned word size key. 38 | var k [wordSize]byte 39 | for i := range k { 40 | k[i] = key[(pos+i)&3] 41 | } 42 | kw := *(*uintptr)(unsafe.Pointer(&k)) 43 | 44 | // Mask one word at a time. 45 | n := (len(b) / wordSize) * wordSize 46 | for i := 0; i < n; i += wordSize { 47 | *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw 48 | } 49 | 50 | // Mask one byte at a time for remaining bytes. 51 | b = b[n:] 52 | for i := range b { 53 | b[i] ^= key[pos&3] 54 | pos++ 55 | } 56 | 57 | return pos & 3 58 | } 59 | -------------------------------------------------------------------------------- /mask_safe.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | //go:build appengine 9 | // +build appengine 10 | 11 | package websocket 12 | 13 | func maskBytes(key [4]byte, pos int, b []byte) int { 14 | for i := range b { 15 | b[i] ^= key[pos&3] 16 | pos++ 17 | } 18 | return pos & 3 19 | } 20 | -------------------------------------------------------------------------------- /mask_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | // !appengine 9 | 10 | package websocket 11 | 12 | import ( 13 | "fmt" 14 | "testing" 15 | ) 16 | 17 | func maskBytesByByte(key [4]byte, pos int, b []byte) int { 18 | for i := range b { 19 | b[i] ^= key[pos&3] 20 | pos++ 21 | } 22 | return pos & 3 23 | } 24 | 25 | func notzero(b []byte) int { 26 | for i := range b { 27 | if b[i] != 0 { 28 | return i 29 | } 30 | } 31 | return -1 32 | } 33 | 34 | func TestMaskBytes(t *testing.T) { 35 | key := [4]byte{1, 2, 3, 4} 36 | for size := 1; size <= 1024; size++ { 37 | for align := 0; align < wordSize; align++ { 38 | for pos := 0; pos < 4; pos++ { 39 | b := make([]byte, size+align)[align:] 40 | maskBytes(key, pos, b) 41 | maskBytesByByte(key, pos, b) 42 | if i := notzero(b); i >= 0 { 43 | t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) 44 | } 45 | } 46 | } 47 | } 48 | } 49 | 50 | func BenchmarkMaskBytes(b *testing.B) { 51 | for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { 52 | b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { 53 | for _, align := range []int{wordSize / 2} { 54 | b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { 55 | for _, fn := range []struct { 56 | name string 57 | fn func(key [4]byte, pos int, b []byte) int 58 | }{ 59 | {"byte", maskBytesByByte}, 60 | {"word", maskBytes}, 61 | } { 62 | b.Run(fn.name, func(b *testing.B) { 63 | key := newMaskKey() 64 | data := make([]byte, size+align)[align:] 65 | for i := 0; i < b.N; i++ { 66 | fn.fn(key, 0, data) 67 | } 68 | b.SetBytes(int64(len(data))) 69 | }) 70 | } 71 | }) 72 | } 73 | }) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /prepared.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bytes" 12 | "net" 13 | "sync" 14 | "time" 15 | ) 16 | 17 | // PreparedMessage caches on the wire representations of a message payload. 18 | // Use PreparedMessage to efficiently send a message payload to multiple 19 | // connections. PreparedMessage is especially useful when compression is used 20 | // because the CPU and memory expensive compression operation can be executed 21 | // once for a given set of compression options. 22 | type PreparedMessage struct { 23 | messageType int 24 | data []byte 25 | mu sync.Mutex 26 | frames map[prepareKey]*preparedFrame 27 | } 28 | 29 | // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. 30 | type prepareKey struct { 31 | isServer bool 32 | compress bool 33 | compressionLevel int 34 | } 35 | 36 | // preparedFrame contains data in wire representation. 37 | type preparedFrame struct { 38 | once sync.Once 39 | data []byte 40 | } 41 | 42 | // NewPreparedMessage returns an initialized PreparedMessage. You can then send 43 | // it to connection using WritePreparedMessage method. Valid wire 44 | // representation will be calculated lazily only once for a set of current 45 | // connection options. 46 | func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { 47 | pm := &PreparedMessage{ 48 | messageType: messageType, 49 | frames: make(map[prepareKey]*preparedFrame), 50 | data: data, 51 | } 52 | 53 | // Prepare a plain server frame. 54 | _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | // To protect against caller modifying the data argument, remember the data 60 | // copied to the plain server frame. 61 | pm.data = frameData[len(frameData)-len(data):] 62 | return pm, nil 63 | } 64 | 65 | func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { 66 | pm.mu.Lock() 67 | frame, ok := pm.frames[key] 68 | if !ok { 69 | frame = &preparedFrame{} 70 | pm.frames[key] = frame 71 | } 72 | pm.mu.Unlock() 73 | 74 | var err error 75 | frame.once.Do(func() { 76 | // Prepare a frame using a 'fake' connection. 77 | // TODO: Refactor code in conn.go to allow more direct construction of 78 | // the frame. 79 | mu := make(chan struct{}, 1) 80 | mu <- struct{}{} 81 | var nc prepareConn 82 | c := &Conn{ 83 | conn: &nc, 84 | mu: mu, 85 | isServer: key.isServer, 86 | compressionLevel: key.compressionLevel, 87 | enableWriteCompression: true, 88 | writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), 89 | } 90 | if key.compress { 91 | c.newCompressionWriter = compressNoContextTakeover 92 | } 93 | err = c.WriteMessage(pm.messageType, pm.data) 94 | frame.data = nc.buf.Bytes() 95 | }) 96 | return pm.messageType, frame.data, err 97 | } 98 | 99 | type prepareConn struct { 100 | buf bytes.Buffer 101 | net.Conn 102 | } 103 | 104 | func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } 105 | func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } 106 | -------------------------------------------------------------------------------- /prepared_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bytes" 12 | "compress/flate" 13 | "math/rand" 14 | "testing" 15 | ) 16 | 17 | var preparedMessageTests = []struct { 18 | messageType int 19 | isServer bool 20 | enableWriteCompression bool 21 | compressionLevel int 22 | }{ 23 | // Server 24 | {TextMessage, true, false, flate.BestSpeed}, 25 | {TextMessage, true, true, flate.BestSpeed}, 26 | {TextMessage, true, true, flate.BestCompression}, 27 | {PingMessage, true, false, flate.BestSpeed}, 28 | {PingMessage, true, true, flate.BestSpeed}, 29 | 30 | // Client 31 | {TextMessage, false, false, flate.BestSpeed}, 32 | {TextMessage, false, true, flate.BestSpeed}, 33 | {TextMessage, false, true, flate.BestCompression}, 34 | {PingMessage, false, false, flate.BestSpeed}, 35 | {PingMessage, false, true, flate.BestSpeed}, 36 | } 37 | 38 | func TestPreparedMessage(t *testing.T) { 39 | for _, tt := range preparedMessageTests { 40 | data := []byte("this is a test") 41 | var buf bytes.Buffer 42 | c := newTestConn(nil, &buf, tt.isServer) 43 | if tt.enableWriteCompression { 44 | c.newCompressionWriter = compressNoContextTakeover 45 | } 46 | c.SetCompressionLevel(tt.compressionLevel) 47 | 48 | // Seed random number generator for consistent frame mask. 49 | rand.Seed(1234) 50 | 51 | if err := c.WriteMessage(tt.messageType, data); err != nil { 52 | t.Fatal(err) 53 | } 54 | want := buf.String() 55 | 56 | pm, err := NewPreparedMessage(tt.messageType, data) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | 61 | // Scribble on data to ensure that NewPreparedMessage takes a snapshot. 62 | copy(data, "hello world") 63 | 64 | // Seed random number generator for consistent frame mask. 65 | rand.Seed(1234) 66 | 67 | buf.Reset() 68 | if err := c.WritePreparedMessage(pm); err != nil { 69 | t.Fatal(err) 70 | } 71 | got := buf.String() 72 | 73 | if got != want { 74 | t.Errorf("write message != prepared message for %+v", tt) 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bytes" 12 | "fmt" 13 | "net/url" 14 | "sync" 15 | "time" 16 | 17 | "github.com/cloudwego/hertz/pkg/app" 18 | "github.com/cloudwego/hertz/pkg/network" 19 | "github.com/cloudwego/hertz/pkg/protocol/consts" 20 | ) 21 | 22 | const badHandshake = "websocket: the client is not using the websocket protocol: " 23 | 24 | var strPermessageDeflate = []byte("permessage-deflate") 25 | 26 | // HandshakeError describes an error with the handshake from the peer. 27 | type HandshakeError struct { 28 | message string 29 | } 30 | 31 | func (e HandshakeError) Error() string { return e.message } 32 | 33 | var poolWriteBuffer = sync.Pool{ 34 | New: func() interface{} { 35 | var buf []byte 36 | return buf 37 | }, 38 | } 39 | 40 | // HertzHandler receives a websocket connection after the handshake has been 41 | // completed. This must be provided. 42 | type HertzHandler func(*Conn) 43 | 44 | // HertzUpgrader specifies parameters for upgrading an HTTP connection to a 45 | // WebSocket connection. 46 | type HertzUpgrader struct { 47 | // HandshakeTimeout specifies the duration for the handshake to complete. 48 | HandshakeTimeout time.Duration 49 | 50 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 51 | // size is zero, then buffers allocated by the HTTP server are used. The 52 | // I/O buffer sizes do not limit the size of the messages that can be sent 53 | // or received. 54 | ReadBufferSize, WriteBufferSize int 55 | 56 | // WriteBufferPool is a pool of buffers for write operations. If the value 57 | // is not set, then write buffers are allocated to the connection for the 58 | // lifetime of the connection. 59 | // 60 | // A pool is most useful when the application has a modest volume of writes 61 | // across a large number of connections. 62 | // 63 | // Applications should use a single pool for each unique value of 64 | // WriteBufferSize. 65 | WriteBufferPool BufferPool 66 | 67 | // Subprotocols specifies the server's supported protocols in order of 68 | // preference. If this field is not nil, then the Upgrade method negotiates a 69 | // subprotocol by selecting the first match in this list with a protocol 70 | // requested by the client. If there's no match, then no protocol is 71 | // negotiated (the Sec-Websocket-Protocol header is not included in the 72 | // handshake response). 73 | Subprotocols []string 74 | 75 | // Error specifies the function for generating HTTP error responses. If Error 76 | // is nil, then http.Error is used to generate the HTTP response. 77 | Error func(ctx *app.RequestContext, status int, reason error) 78 | 79 | // CheckOrigin returns true if the request Origin header is acceptable. If 80 | // CheckOrigin is nil, then a safe default is used: return false if the 81 | // Origin request header is present and the origin host is not equal to 82 | // request Host header. 83 | // 84 | // A CheckOrigin function should carefully validate the request origin to 85 | // prevent cross-site request forgery. 86 | CheckOrigin func(ctx *app.RequestContext) bool 87 | 88 | // EnableCompression specify if the server should attempt to negotiate per 89 | // message compression (RFC 7692). Setting this value to true does not 90 | // guarantee that compression will be supported. Currently only "no context 91 | // takeover" modes are supported. 92 | EnableCompression bool 93 | } 94 | 95 | func (u *HertzUpgrader) returnError(ctx *app.RequestContext, status int, reason string) error { 96 | err := HandshakeError{reason} 97 | if u.Error != nil { 98 | u.Error(ctx, status, err) 99 | } else { 100 | ctx.Response.Header.Set("Sec-Websocket-Version", "13") 101 | ctx.AbortWithMsg(consts.StatusMessage(status), status) 102 | } 103 | 104 | return err 105 | } 106 | 107 | func (u *HertzUpgrader) selectSubprotocol(ctx *app.RequestContext) []byte { 108 | if u.Subprotocols != nil { 109 | clientProtocols := parseDataHeader(ctx.Request.Header.Peek("Sec-Websocket-Protocol")) 110 | 111 | for _, serverProtocol := range u.Subprotocols { 112 | for _, clientProtocol := range clientProtocols { 113 | if b2s(clientProtocol) == serverProtocol { 114 | return clientProtocol 115 | } 116 | } 117 | } 118 | } else if ctx.Response.Header.Len() > 0 { 119 | return ctx.Response.Header.Peek("Sec-Websocket-Protocol") 120 | } 121 | 122 | return nil 123 | } 124 | 125 | func (u *HertzUpgrader) isCompressionEnable(ctx *app.RequestContext) bool { 126 | extensions := parseDataHeader(ctx.Request.Header.Peek("Sec-WebSocket-Extensions")) 127 | 128 | // Negotiate PMCE 129 | if u.EnableCompression { 130 | for _, ext := range extensions { 131 | if bytes.HasPrefix(ext, strPermessageDeflate) { 132 | return true 133 | } 134 | } 135 | } 136 | 137 | return false 138 | } 139 | 140 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 141 | // 142 | // The responseHeader is included in the response to the client's upgrade 143 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the 144 | // application negotiated subprotocol (Sec-WebSocket-Protocol). 145 | // 146 | // If the upgrade fails, then Upgrade replies to the client with an HTTP error 147 | // response. 148 | func (u *HertzUpgrader) Upgrade(ctx *app.RequestContext, handler HertzHandler) error { 149 | if !ctx.IsGet() { 150 | return u.returnError(ctx, consts.StatusMethodNotAllowed, fmt.Sprintf("%s request method is not GET", badHandshake)) 151 | } 152 | 153 | if !tokenContainsValue(b2s(ctx.Request.Header.Peek("Connection")), "Upgrade") { 154 | return u.returnError(ctx, consts.StatusBadRequest, fmt.Sprintf("%s 'upgrade' token not found in 'Connection' header", badHandshake)) 155 | } 156 | 157 | if !tokenContainsValue(b2s(ctx.Request.Header.Peek("Upgrade")), "Websocket") { 158 | return u.returnError(ctx, consts.StatusBadRequest, fmt.Sprintf("%s 'websocket' token not found in 'Upgrade' header", badHandshake)) 159 | } 160 | 161 | if !tokenContainsValue(b2s(ctx.Request.Header.Peek("Sec-Websocket-Version")), "13") { 162 | return u.returnError(ctx, consts.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") 163 | } 164 | 165 | if len(ctx.Response.Header.Peek("Sec-Websocket-Extensions")) > 0 { 166 | return u.returnError(ctx, consts.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") 167 | } 168 | 169 | checkOrigin := u.CheckOrigin 170 | if checkOrigin == nil { 171 | checkOrigin = fastHTTPCheckSameOrigin 172 | } 173 | if !checkOrigin(ctx) { 174 | return u.returnError(ctx, consts.StatusForbidden, "websocket: request origin not allowed by HertzUpgrader.CheckOrigin") 175 | } 176 | 177 | challengeKey := ctx.Request.Header.Peek("Sec-Websocket-Key") 178 | if len(challengeKey) == 0 { 179 | return u.returnError(ctx, consts.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank") 180 | } 181 | 182 | subprotocol := u.selectSubprotocol(ctx) 183 | compress := u.isCompressionEnable(ctx) 184 | 185 | ctx.SetStatusCode(consts.StatusSwitchingProtocols) 186 | ctx.Response.Header.Set("Upgrade", "websocket") 187 | ctx.Response.Header.Set("Connection", "Upgrade") 188 | ctx.Response.Header.Set("Sec-WebSocket-Accept", computeAcceptKeyBytes(challengeKey)) 189 | if compress { 190 | ctx.Response.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") 191 | } 192 | if subprotocol != nil { 193 | ctx.Response.Header.SetBytesV("Sec-WebSocket-Protocol", subprotocol) 194 | } 195 | 196 | ctx.Hijack(func(netConn network.Conn) { 197 | writeBuf := poolWriteBuffer.Get().([]byte) 198 | c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, nil, writeBuf) 199 | if subprotocol != nil { 200 | c.subprotocol = b2s(subprotocol) 201 | } 202 | 203 | if compress { 204 | c.newCompressionWriter = compressNoContextTakeover 205 | c.newDecompressionReader = decompressNoContextTakeover 206 | } 207 | 208 | // Clear deadlines set by HTTP server. 209 | netConn.SetDeadline(time.Time{}) 210 | 211 | handler(c) 212 | 213 | writeBuf = writeBuf[0:0] 214 | 215 | // FIXME: argument should be pointer-like to avoid allocations (staticcheck) 216 | poolWriteBuffer.Put(writeBuf) // nolint: staticcheck 217 | }) 218 | 219 | return nil 220 | } 221 | 222 | // fastHTTPCheckSameOrigin returns true if the origin is not set or is equal to the request host. 223 | func fastHTTPCheckSameOrigin(ctx *app.RequestContext) bool { 224 | origin := ctx.Request.Header.Peek("Origin") 225 | if len(origin) == 0 { 226 | return true 227 | } 228 | u, err := url.Parse(b2s(origin)) 229 | if err != nil { 230 | return false 231 | } 232 | return equalASCIIFold(u.Host, b2s(ctx.Host())) 233 | } 234 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // 5 | // This file may have been modified by CloudWeGo authors. All CloudWeGo 6 | // Modifications are Copyright 2022 CloudWeGo Authors. 7 | 8 | package websocket 9 | 10 | import ( 11 | "bytes" 12 | "crypto/sha1" 13 | "encoding/base64" 14 | "encoding/binary" 15 | "math/rand" 16 | "unicode/utf8" 17 | "unsafe" 18 | ) 19 | 20 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 21 | 22 | func generateChallengeKey() string { 23 | b := make([]byte, 16) 24 | binary.BigEndian.PutUint64(b, rand.Uint64()) 25 | binary.BigEndian.PutUint64(b[8:], rand.Uint64()) 26 | return base64.StdEncoding.EncodeToString(b) 27 | } 28 | 29 | // Token octets per RFC 2616. 30 | var isTokenOctet = [256]bool{ 31 | '!': true, 32 | '#': true, 33 | '$': true, 34 | '%': true, 35 | '&': true, 36 | '\'': true, 37 | '*': true, 38 | '+': true, 39 | '-': true, 40 | '.': true, 41 | '0': true, 42 | '1': true, 43 | '2': true, 44 | '3': true, 45 | '4': true, 46 | '5': true, 47 | '6': true, 48 | '7': true, 49 | '8': true, 50 | '9': true, 51 | 'A': true, 52 | 'B': true, 53 | 'C': true, 54 | 'D': true, 55 | 'E': true, 56 | 'F': true, 57 | 'G': true, 58 | 'H': true, 59 | 'I': true, 60 | 'J': true, 61 | 'K': true, 62 | 'L': true, 63 | 'M': true, 64 | 'N': true, 65 | 'O': true, 66 | 'P': true, 67 | 'Q': true, 68 | 'R': true, 69 | 'S': true, 70 | 'T': true, 71 | 'U': true, 72 | 'W': true, 73 | 'V': true, 74 | 'X': true, 75 | 'Y': true, 76 | 'Z': true, 77 | '^': true, 78 | '_': true, 79 | '`': true, 80 | 'a': true, 81 | 'b': true, 82 | 'c': true, 83 | 'd': true, 84 | 'e': true, 85 | 'f': true, 86 | 'g': true, 87 | 'h': true, 88 | 'i': true, 89 | 'j': true, 90 | 'k': true, 91 | 'l': true, 92 | 'm': true, 93 | 'n': true, 94 | 'o': true, 95 | 'p': true, 96 | 'q': true, 97 | 'r': true, 98 | 's': true, 99 | 't': true, 100 | 'u': true, 101 | 'v': true, 102 | 'w': true, 103 | 'x': true, 104 | 'y': true, 105 | 'z': true, 106 | '|': true, 107 | '~': true, 108 | } 109 | 110 | // skipSpace returns a slice of the string s with all leading RFC 2616 linear 111 | // whitespace removed. 112 | func skipSpace(s string) (rest string) { 113 | i := 0 114 | for ; i < len(s); i++ { 115 | if b := s[i]; b != ' ' && b != '\t' { 116 | break 117 | } 118 | } 119 | return s[i:] 120 | } 121 | 122 | // nextToken returns the leading RFC 2616 token of s and the string following 123 | // the token. 124 | func nextToken(s string) (token, rest string) { 125 | i := 0 126 | for ; i < len(s); i++ { 127 | if !isTokenOctet[s[i]] { 128 | break 129 | } 130 | } 131 | return s[:i], s[i:] 132 | } 133 | 134 | // equalASCIIFold returns true if s is equal to t with ASCII case folding as 135 | // defined in RFC 4790. 136 | func equalASCIIFold(s, t string) bool { 137 | for s != "" && t != "" { 138 | sr, size := utf8.DecodeRuneInString(s) 139 | s = s[size:] 140 | tr, size := utf8.DecodeRuneInString(t) 141 | t = t[size:] 142 | if sr == tr { 143 | continue 144 | } 145 | if 'A' <= sr && sr <= 'Z' { 146 | sr = sr + 'a' - 'A' 147 | } 148 | if 'A' <= tr && tr <= 'Z' { 149 | tr = tr + 'a' - 'A' 150 | } 151 | if sr != tr { 152 | return false 153 | } 154 | } 155 | return s == t 156 | } 157 | 158 | // parseDataHeader returns a list with values if header value is comma-separated 159 | func parseDataHeader(headerValue []byte) [][]byte { 160 | h := bytes.TrimSpace(headerValue) 161 | if bytes.Equal(h, []byte("")) { 162 | return nil 163 | } 164 | 165 | values := bytes.Split(h, []byte(",")) 166 | for i := range values { 167 | values[i] = bytes.TrimSpace(values[i]) 168 | } 169 | return values 170 | } 171 | 172 | // tokenContainsValue returns true if the 1#token header with the given 173 | // name contains a token equal to value with ASCII case folding. 174 | func tokenContainsValue(s, value string) bool { 175 | for { 176 | var t string 177 | t, s = nextToken(skipSpace(s)) 178 | if t == "" { 179 | return false 180 | } 181 | s = skipSpace(s) 182 | if s != "" && s[0] != ',' { 183 | return false 184 | } 185 | if equalASCIIFold(t, value) { 186 | return true 187 | } 188 | if s == "" { 189 | return false 190 | } 191 | 192 | s = s[1:] 193 | } 194 | } 195 | 196 | func computeAcceptKeyBytes(challengeKey []byte) string { 197 | h := sha1.New() 198 | h.Write(challengeKey) 199 | h.Write(keyGUID) 200 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) 201 | } 202 | 203 | // b2s converts byte slice to a string without memory allocation. 204 | // See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ . 205 | // 206 | // Note it may break if string and/or slice header will change 207 | // in the future go versions. 208 | func b2s(b []byte) string { 209 | return *(*string)(unsafe.Pointer(&b)) 210 | } 211 | --------------------------------------------------------------------------------