├── autobahn ├── run_server.sh ├── run_client.sh ├── config │ ├── fuzzingserver.json │ └── fuzzingclient.json ├── server │ ├── csr.pem │ ├── public.crt │ ├── privatekey.pem │ └── autobahn-server.go ├── script │ └── run.sh └── client │ └── autobahn-client.go ├── go.mod ├── go.sum ├── .gitignore ├── Makefile ├── parse_mode.go ├── server_options.go ├── benchmark_rand_test.go ├── opcode.go ├── .github └── workflows │ ├── autobahn.yml │ └── go.yml ├── client_options.go ├── utils.go ├── callback.go ├── err.go ├── proxy.go ├── permessage_deflate.go ├── utils_test.go ├── server_option_test.go ├── config.go ├── server_profile_test.go ├── callback_test.go ├── config_test.go ├── status_codes.go ├── server_handshake_test.go ├── upgrade.go ├── server_handshake.go ├── benchmark_read_write_message_test.go ├── client_test.go ├── proxy_test.go ├── client.go ├── common_options.go ├── LICENSE ├── README.md ├── client_option_test.go ├── server_test.go ├── conn.go └── conn_test.go /autobahn/run_server.sh: -------------------------------------------------------------------------------- 1 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json -------------------------------------------------------------------------------- /autobahn/run_client.sh: -------------------------------------------------------------------------------- 1 | docker run -it --rm --network host -v ${PWD}/config:/config -v ${PWD}/report2:/report2 crossbario/autobahn-testsuite wstest -m fuzzingserver -s /config/fuzzingserver.json -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/antlabs/quickws 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/antlabs/wsutil v0.1.11 7 | golang.org/x/net v0.23.0 8 | ) 9 | 10 | require github.com/klauspost/compress v1.17.8 // indirect 11 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingserver.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "ws://127.0.0.1:9003", 3 | "outdir": "./report2", 4 | "cases": [ 5 | "*" 6 | ], 7 | "exclude-cases": [], 8 | "exclude-agent-cases": {} 9 | } 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/antlabs/wsutil v0.1.11 h1:bIVZ3Hxdq5ByZKu5OXL/cMtanEw6YlxdtUDiySI77Q0= 2 | github.com/antlabs/wsutil v0.1.11/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A= 3 | github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= 4 | github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= 5 | golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= 6 | golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *swp 2 | /.idea 3 | /coverage.out 4 | /cover.cov 5 | /autobahn/report 6 | autobahn-testsuite 7 | autobahn-client-testsuite-darwin-arm64 8 | autobahn-client-testsuite-linux-amd64 9 | autobahn-server-testsuite-darwin-arm64 10 | autobahn-server-testsuite-linux-amd64 11 | autobahn/autobahn-server-testsuite-darwin-arm64-arena 12 | autobahn-client-darwin-arm64 13 | autobahn-client-linux-amd64 14 | autobahn-server-darwin-arm64 15 | autobahn-server-darwin-arm64-arena 16 | autobahn-server-linux-amd64 17 | /cpu.profile 18 | /mem.profile 19 | /autobahn/autobahn-server 20 | /autobahn-server 21 | /.vscode 22 | /autobahn/report2/ -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # dump: 2 | #go build -o /dev/null -gcflags -S * 3 | 4 | all: 5 | # mac, arm64 6 | GOOS=darwin GOARCH=arm64 go build -o autobahn-server-darwin-arm64 ./autobahn/server/autobahn-server.go 7 | # linux amd64 8 | GOOS=linux GOARCH=amd64 go build -o autobahn-server-linux-amd64 ./autobahn/server/autobahn-server.go 9 | 10 | # mac, arm64 11 | GOOS=darwin GOARCH=arm64 go build -o autobahn-client-darwin-arm64 ./autobahn/client/autobahn-client.go 12 | # linux amd64 13 | GOOS=linux GOARCH=amd64 go build -o autobahn-client-linux-amd64 ./autobahn/client/autobahn-client.go 14 | 15 | key: 16 | openssl genrsa 2048 > privatekey.pem 17 | openssl req -new -key privatekey.pem -out csr.pem 18 | openssl x509 -req -days 36500 -in csr.pem -signkey privatekey.pem -out public.crt -------------------------------------------------------------------------------- /parse_mode.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | type parseMode int32 17 | 18 | const ( 19 | ParseModeBufio parseMode = iota 20 | ParseModeWindows 21 | ) 22 | -------------------------------------------------------------------------------- /server_options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | type ServerOption func(*ConnOption) 18 | 19 | type ConnOption struct { 20 | Config 21 | } 22 | -------------------------------------------------------------------------------- /benchmark_rand_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "math/rand" 18 | "testing" 19 | ) 20 | 21 | func Benchmark_Rand_Uint32(t *testing.B) { 22 | for i := 0; i < t.N; i++ { 23 | _ = rand.Uint32() 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /autobahn/server/csr.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE REQUEST----- 2 | MIICeDCCAWACAQAwMzELMAkGA1UEBhMCY24xETAPBgNVBAgMCHNoYW5naGFpMREw 3 | DwYDVQQHDAhzaGFuZ2hhaTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB 4 | AL+UkOo23qO2W9r1fHcxk0wtXvmnbx1Arf5IMaoJXN6prCr8NDKjt9Y2/7PyPBwB 5 | hRewGhGut7EmCr/WnHT9rKa4pSYwk+fAEMY3KQbVyjA6v6rID87K5JEN132vqN6g 6 | 3VrTwV7yS5vhx43HmASUutpwtf2f9GM4ptQcmzdUvMALiPx515J7szCAUVh1R+4v 7 | my5eSZ+fqhcemLE1xAzNsUn0EMa56eF7zwfVpuQn6OWBUC5+WpR6rVnSgMABLhY/ 8 | Bct3OAk0jsftG7SCZPb6v0YTMMpHD8PT/uHXe/VNbSMJJ+s+4L8zAjJW8zQXtkOh 9 | //0rba4khzqF++ffPsK3kLcCAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4IBAQCn6WhR 10 | sj9NI6XD0S8PwYCmSDfDYq+DKFUP3UQujm6Wlaj5Cd0qURmaj/Zonh+fDOGke9AK 11 | iKwIBom/3BWviAlIeRUQpdvBR8nCcQImlan0ttiFzNn72GqlAcYAVo9VZzLZAnxS 12 | jVV7+RQlttQ4zmCPJ1P3xz1sC81c3Pt6f89N1MZbtk/EfDFplLaEKuCOHn8CasRp 13 | KSLeidChYB5uw2ZzEN473huN0iVSZmFVF4gXB7OWYrCahsltJdBTBLCfNgeTuxY0 14 | br0wfFi/TbQCOxsvEe07HbX58NDB0zwL2PnNGpzpbaNVUa09/JJ1v9mv8xkoFAe9 15 | dzrWlhVI1/YLJ8oG 16 | -----END CERTIFICATE REQUEST----- 17 | -------------------------------------------------------------------------------- /opcode.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import "github.com/antlabs/wsutil/opcode" 18 | 19 | type Opcode = opcode.Opcode 20 | 21 | const ( 22 | Continuation = opcode.Continuation 23 | Text = opcode.Text 24 | Binary = opcode.Binary 25 | // 3 - 7保留 26 | _ // 3 27 | _ 28 | _ // 5 29 | _ 30 | _ // 7 31 | Close = opcode.Close 32 | Ping = opcode.Ping 33 | Pong = opcode.Pong 34 | ) 35 | -------------------------------------------------------------------------------- /autobahn/script/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ./autobahn/bin 4 | go build -o ./autobahn/bin/autobahn_server ./autobahn/autobahn-server.go 5 | # go build -o ./autobahn/bin/autobahn_reporter ./autobahn/reporter/ 6 | 7 | echo "pwd:" $(pwd) 8 | ./autobahn/bin/autobahn_server & 9 | 10 | rm -rf ${PWD}/autobahn/report 11 | mkdir -p ${PWD}/autobahn/report/ 12 | 13 | docker pull crossbario/autobahn-testsuite 14 | 15 | docker run -i --rm \ 16 | -v ${PWD}/autobahn/config:/config \ 17 | -v ${PWD}/autobahn/report:/report \ 18 | --network host \ 19 | --name=autobahn \ 20 | crossbario/autobahn-testsuite \ 21 | wstest -m fuzzingclient -s /config/fuzzingclient.json 22 | 23 | trap ctrl_c INT 24 | ctrl_c() { 25 | echo "SIGINT received; cleaning up" 26 | docker kill --signal INT "autobahn" >/dev/null 27 | rm -rf ${PWD}/autobahn/bin 28 | rm -rf ${PWD}/autobahn/report 29 | cleanup 30 | exit 130 31 | } 32 | 33 | cleanup() { 34 | killall autobahn_server 35 | } 36 | 37 | # TODO 38 | # ./autobahn/bin/autobahn_reporter ${PWD}/autobahn/report/index.json 39 | 40 | cleanup 41 | 42 | -------------------------------------------------------------------------------- /.github/workflows/autobahn.yml: -------------------------------------------------------------------------------- 1 | name: Autobahn 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - dev 8 | pull_request: 9 | branches: 10 | - master 11 | - dev 12 | 13 | jobs: 14 | Autobahn: 15 | strategy: 16 | matrix: 17 | os: [ ubuntu-latest ] 18 | go: [ 1.18.x ] 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v2 23 | - name: Setup Go 24 | uses: actions/setup-go@v2 25 | with: 26 | go-version: ${{ matrix.go }} 27 | - name: Autobahn Test 28 | env: 29 | CRYPTOGRAPHY_ALLOW_OPENSSL_102: yes 30 | run: | 31 | chmod +x ./autobahn/script/run.sh && ./autobahn/script/run.sh 32 | - name: Autobahn Report Artifact 33 | if: >- 34 | startsWith(matrix.os, 'ubuntu') 35 | uses: actions/upload-artifact@v2 36 | 37 | with: 38 | name: autobahn report ${{ matrix.go }} ${{ matrix.os }} 39 | path: autobahn/report 40 | retention-days: 7 41 | -------------------------------------------------------------------------------- /autobahn/server/public.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC5DCCAcwCCQDjmz+Ai7+gajANBgkqhkiG9w0BAQUFADAzMQswCQYDVQQGEwJj 3 | bjERMA8GA1UECAwIc2hhbmdoYWkxETAPBgNVBAcMCHNoYW5naGFpMCAXDTIzMDYy 4 | NzA1MTYxMloYDzIxMjMwNjAzMDUxNjEyWjAzMQswCQYDVQQGEwJjbjERMA8GA1UE 5 | CAwIc2hhbmdoYWkxETAPBgNVBAcMCHNoYW5naGFpMIIBIjANBgkqhkiG9w0BAQEF 6 | AAOCAQ8AMIIBCgKCAQEAv5SQ6jbeo7Zb2vV8dzGTTC1e+advHUCt/kgxqglc3qms 7 | Kvw0MqO31jb/s/I8HAGFF7AaEa63sSYKv9acdP2sprilJjCT58AQxjcpBtXKMDq/ 8 | qsgPzsrkkQ3Xfa+o3qDdWtPBXvJLm+HHjceYBJS62nC1/Z/0Yzim1BybN1S8wAuI 9 | /HnXknuzMIBRWHVH7i+bLl5Jn5+qFx6YsTXEDM2xSfQQxrnp4XvPB9Wm5Cfo5YFQ 10 | Ln5alHqtWdKAwAEuFj8Fy3c4CTSOx+0btIJk9vq/RhMwykcPw9P+4dd79U1tIwkn 11 | 6z7gvzMCMlbzNBe2Q6H//SttriSHOoX7598+wreQtwIDAQABMA0GCSqGSIb3DQEB 12 | BQUAA4IBAQCZpP2FJwSM/BlahiptIUJPtXY3cjq25v2fU3KRDyA46gW0rOLBKFkV 13 | IdpqTsp6YGiz3ELFkxqJ548PmzJLqSuYm4gsLBx8AKDDC9TQ9f7w9URj5Am+rbW8 14 | pQYCrdm/IE6KhmxajlO3Ef8DIRWq6vrkCx/Au2HQIo4P2ZaRo2Ts6st1aKK2/qzB 15 | 56Iex6iuCL/5sn7gwBggXH+1FDjqbmBYDmzPsE0wjOQrjN6QMwBZFvcE4XuRNRTr 16 | AwPpnADRLl/HVFmGKwHOHljgQDsVdDv1m3k3Fm2HrDu47ooSXXK8sIy5WA/iqJdR 17 | Op2HFiQI8Q/mg9F08Xaqh5BscUG7aiJV 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | go: [ '1.20', '1.21'] 14 | name: Go ${{ matrix.go }} sample 15 | 16 | steps: 17 | 18 | - name: Set up Go 1.20 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version: ${{ matrix.go }} 22 | id: go 23 | 24 | - name: Check out code into the Go module directory 25 | uses: actions/checkout@main 26 | 27 | - name: Get dependencies 28 | run: | 29 | go get -v -t -d ./... 30 | - name: Build 31 | run: go build -v . 32 | 33 | - name: Test-386 34 | run: env GOARCH=386 go test -test.run=Test_Retry_sleep -v 35 | #run: env GOARCH=386 go test -v -coverprofile='coverage.out' -covermode=count ./... 36 | 37 | - name: Test-amd64 38 | run: env GOARCH=amd64 go test -race -v -coverprofile='coverage.out' -covermode=atomic ./... 39 | 40 | - name: Upload Coverage report 41 | uses: codecov/codecov-action@v3 42 | with: 43 | token: ${{secrets.CODECOV_TOKEN}} 44 | file: ./coverage.out 45 | verbose: true 46 | -------------------------------------------------------------------------------- /client_options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "crypto/tls" 19 | "net/http" 20 | "time" 21 | ) 22 | 23 | type ClientOption func(*DialOption) 24 | 25 | // 1.配置tls.config 26 | func WithClientTLSConfig(tls *tls.Config) ClientOption { 27 | return func(o *DialOption) { 28 | o.tlsConfig = tls 29 | } 30 | } 31 | 32 | // 2.配置http.Header 33 | func WithClientHTTPHeader(h http.Header) ClientOption { 34 | return func(o *DialOption) { 35 | o.Header = h 36 | } 37 | } 38 | 39 | // 3.配置握手时的timeout 40 | func WithClientDialTimeout(t time.Duration) ClientOption { 41 | return func(o *DialOption) { 42 | o.dialTimeout = t 43 | } 44 | } 45 | 46 | // 4.配置压缩 47 | func WithClientCompression() ClientOption { 48 | return func(o *DialOption) { 49 | o.Compression = true 50 | } 51 | } 52 | 53 | // 6.获取http header 54 | func WithClientBindHTTPHeader(h *http.Header) ClientOption { 55 | return func(o *DialOption) { 56 | o.bindClientHttpHeader = h 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /autobahn/server/privatekey.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEAv5SQ6jbeo7Zb2vV8dzGTTC1e+advHUCt/kgxqglc3qmsKvw0 3 | MqO31jb/s/I8HAGFF7AaEa63sSYKv9acdP2sprilJjCT58AQxjcpBtXKMDq/qsgP 4 | zsrkkQ3Xfa+o3qDdWtPBXvJLm+HHjceYBJS62nC1/Z/0Yzim1BybN1S8wAuI/HnX 5 | knuzMIBRWHVH7i+bLl5Jn5+qFx6YsTXEDM2xSfQQxrnp4XvPB9Wm5Cfo5YFQLn5a 6 | lHqtWdKAwAEuFj8Fy3c4CTSOx+0btIJk9vq/RhMwykcPw9P+4dd79U1tIwkn6z7g 7 | vzMCMlbzNBe2Q6H//SttriSHOoX7598+wreQtwIDAQABAoIBAQCKVKLCizX9Ldpr 8 | YqApjIFYGtaeG1iu3ZoEpmo95Z7KI+dt7kdeXTqLkZDWhM0ER9CrBvv70pVOczKF 9 | zFeSXezBQUf2KFNTnio+hWu5RLtGUdU9YlGPto6NclorpZ+giLTsNURF41vWxZMK 10 | e5j3jdDRk1rFNC8JScmkFLe6nxPe8fPvCjLGQx2yuMYHlnvgmrIEuWiPHaziP4yU 11 | tBKQpCxmruljJvQ503BZtW3g/XlktwR4lexuIOn624nLWla8+ryn+cquEV3FTiKp 12 | WwHJYAVlVaaJbsefmiz9tDkUV2PHDgov3u/Eth/C3HdW0qJC0X+/uiLY0/3Z0O3w 13 | h0M7Y8wpAoGBAPRNymCrG0DKKJLIBB0B+++SNRsR5D1rE74Ig043ax836lVy+2wS 14 | 5aJu7wSbfIAEei2iGhvcV6ePHK0LyDXBaGUOF2/QnhCz7nI7D++E9K7hYRNzUllk 15 | L8JgEj9BrIoruED/Y81TQ3HNi9JXp82Oqnpa9/dvd3tGWT1UqIFKGNRrAoGBAMjA 16 | l2F6cGEOgBOEyrbt60h2ErlXPMIOaFUAaLePOdmv3RQQkWoYOjouPcXUbQ22Bc2Z 17 | edSLip2OTm4hCSbcIYpKViH+NxiMMQnvWfTtO/rdlvPQKYBi9BaNTvvxTcIONMT6 18 | 9v+D5u1mbnIPv/PuDvW5hKUCArrUHfzXDsuFv+flAoGAcwx7RODveah6SP12qm53 19 | xY+WAMSBNsdJSdHafCgvA0miylDWxEN17vPNDd9nVyZEn17aaspuYRNNTtTgmSgW 20 | 0Jg9Q0P8XCNQJG1aCNMVI5Ix1CYX3s8GisQRc8aqyXrjT4C18EjI1zwUH591/6Cy 21 | +eIDKnxMyToM5owKurA5VzcCgYBPnlJrjqvTUnTpSNk9A880xd9XMooeTKiETc06 22 | P8up0l3T/14svb8aJAzL0RwPPAnBKQVwjodDRZVFiESg7N1Ag4r1oGUpjzBDyHHc 23 | +dm3/PpJaF2NVbGI4DJbKbC1Lf0vwnkDjcSgkudqxWRT0i6Mti8tYkbC4i2igYiU 24 | n08lIQKBgQCt4vxZ6nHEd/uIhYal/+dla75rob44/itn+ASELpyyZ3FNYlcnbkLu 25 | JsIRiBLPerUX70HHn49eA3ilGTT3YKmy52rhfg/bTy+elDPyVcevXtLzeNLc/xYX 26 | Mw2fsZ/vXG9rQcP3oPfHKijMywLahXgaVshMHM9Rtm8+cG1OO75k7g== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "global", 6 | "url": "ws://localhost:9001/global", 7 | "options": { 8 | "version": 18 9 | } 10 | }, 11 | { 12 | "agent": "no-context-takeover-decompression-and-compression-no-tls", 13 | "url": "ws://localhost:9001/no-context-takeover-decompression-and-compression", 14 | "options": { 15 | "version": 18 16 | } 17 | }, 18 | { 19 | "agent": "no-context-takeover-decompression-no-tls", 20 | "url": "ws://localhost:9001/no-context-takeover-decompression", 21 | "options": { 22 | "version": 18 23 | } 24 | }, 25 | { 26 | "agent": "context-takeover-decompression-and-compression-no-tls", 27 | "url": "ws://localhost:9001/context-takeover-decompression-and-compression", 28 | "options": { 29 | "version": 18 30 | } 31 | }, 32 | { 33 | "agent": "context-takeover-decompression-no-tls", 34 | "url": "ws://localhost:9001/context-takeover-decompression", 35 | "options": { 36 | "version": 18 37 | } 38 | } 39 | ], 40 | "cases": [ 41 | "*" 42 | ], 43 | "exclude-cases": [ 44 | "" 45 | ], 46 | "exclude-agent-cases": {} 47 | } -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "crypto/sha1" 19 | "encoding/base64" 20 | "errors" 21 | "fmt" 22 | "math/rand" 23 | "net/http" 24 | "sync" 25 | "time" 26 | "unsafe" 27 | ) 28 | 29 | var rng = rand.New(rand.NewSource(time.Now().UnixNano())) 30 | 31 | var mu sync.Mutex 32 | var uuid = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 33 | 34 | func StringToBytes(s string) []byte { 35 | return unsafe.Slice(unsafe.StringData(s), len(s)) 36 | } 37 | 38 | // // StringToBytes 没有内存开销的转换 39 | // func StringToBytes(s string) (b []byte) { 40 | // bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 41 | // sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) 42 | // bh.Data = sh.Data 43 | // bh.Len = sh.Len 44 | // bh.Cap = sh.Len 45 | // return b 46 | // } 47 | 48 | // func BytesToString(b []byte) string { 49 | // return *(*string)(unsafe.Pointer(&b)) 50 | // } 51 | 52 | func secWebSocketAccept() string { 53 | // rfc规定是16字节 54 | var key [16]byte 55 | mu.Lock() 56 | rng.Read(key[:]) 57 | mu.Unlock() 58 | return base64.StdEncoding.EncodeToString(key[:]) 59 | } 60 | 61 | func secWebSocketAcceptVal(val string) string { 62 | s := sha1.New() 63 | s.Write(StringToBytes(val)) 64 | s.Write(uuid) 65 | r := s.Sum(nil) 66 | return base64.StdEncoding.EncodeToString(r) 67 | } 68 | 69 | func getHttpErrMsg(statusCode int) error { 70 | errMsg := http.StatusText(statusCode) 71 | if errMsg != "" { 72 | return errors.New(errMsg) 73 | } 74 | 75 | return fmt.Errorf("status code:%d", statusCode) 76 | } 77 | -------------------------------------------------------------------------------- /callback.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | type ( 17 | Callback interface { 18 | OnOpen(*Conn) 19 | OnMessage(*Conn, Opcode, []byte) 20 | OnClose(*Conn, error) 21 | } 22 | ) 23 | 24 | type ( 25 | OnOpenFunc func(*Conn) 26 | ) 27 | 28 | type DefCallback struct{} 29 | 30 | func (defcallback *DefCallback) OnOpen(_ *Conn) { 31 | } 32 | 33 | func (defcallback *DefCallback) OnMessage(_ *Conn, _ Opcode, _ []byte) { 34 | } 35 | 36 | func (defcallback *DefCallback) OnClose(_ *Conn, _ error) { 37 | } 38 | 39 | // 只设置OnMessage, 和OnClose互斥 40 | type OnMessageFunc func(*Conn, Opcode, []byte) 41 | 42 | func (o OnMessageFunc) OnOpen(_ *Conn) { 43 | } 44 | 45 | func (o OnMessageFunc) OnMessage(c *Conn, op Opcode, data []byte) { 46 | o(c, op, data) 47 | } 48 | 49 | func (o OnMessageFunc) OnClose(_ *Conn, _ error) { 50 | } 51 | 52 | // 只设置OnClose, 和OnMessage互斥 53 | type OnCloseFunc func(*Conn, error) 54 | 55 | func (o OnCloseFunc) OnOpen(_ *Conn) { 56 | } 57 | 58 | func (o OnCloseFunc) OnMessage(_ *Conn, _ Opcode, _ []byte) { 59 | } 60 | 61 | func (o OnCloseFunc) OnClose(c *Conn, err error) { 62 | o(c, err) 63 | } 64 | 65 | type funcToCallback struct { 66 | onOpen func(*Conn) 67 | onMessage func(*Conn, Opcode, []byte) 68 | onClose func(*Conn, error) 69 | } 70 | 71 | func (f *funcToCallback) OnOpen(c *Conn) { 72 | if f.onOpen != nil { 73 | f.onOpen(c) 74 | } 75 | } 76 | 77 | func (f *funcToCallback) OnMessage(c *Conn, op Opcode, data []byte) { 78 | if f.onMessage != nil { 79 | f.onMessage(c, op, data) 80 | } 81 | } 82 | 83 | func (f *funcToCallback) OnClose(c *Conn, err error) { 84 | if f.onClose != nil { 85 | f.onClose(c, err) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /err.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import "errors" 18 | 19 | var ( 20 | // conn已经被关闭 21 | ErrClosed = errors.New("closed") 22 | 23 | ErrWrongStatusCode = errors.New("Wrong status code") 24 | ErrUpgradeFieldValue = errors.New("The value of the upgrade field is not 'websocket'") 25 | ErrConnectionFieldValue = errors.New("The value of the connection field is not 'upgrade'") 26 | ErrSecWebSocketAccept = errors.New("The value of Sec-WebSocketAaccept field is invalid") 27 | 28 | ErrHostCannotBeEmpty = errors.New("Host cannot be empty") 29 | ErrSecWebSocketKey = errors.New("The value of SEC websocket key field is wrong") 30 | ErrSecWebSocketVersion = errors.New("The value of SEC websocket version field is wrong, not 13") 31 | 32 | ErrHTTPProtocolNotSupported = errors.New("HTTP protocol not supported") 33 | 34 | ErrOnlyGETSupported = errors.New("error:Only get methods are supported") 35 | ErrMaxControlFrameSize = errors.New("error:max control frame size > 125, need <= 125") 36 | ErrRsv123 = errors.New("error:rsv1 or rsv2 or rsv3 has a value") 37 | ErrOpcode = errors.New("error:wrong opcode") 38 | ErrNOTBeFragmented = errors.New("error:since control message MUST NOT be fragmented") 39 | ErrFrameOpcode = errors.New("error:since all data frames after the initial data frame must have opcode 0.") 40 | ErrTextNotUTF8 = errors.New("error:text is not utf8 data") 41 | ErrClosePayloadTooSmall = errors.New("error:close payload too small") 42 | ErrCloseValue = errors.New("error:close value is wrong") // close值不对 43 | ErrEmptyClose = errors.New("error:close value is empty") // close的值是空的 44 | ErrWriteClosed = errors.New("write close") 45 | ) 46 | 47 | var ( 48 | ErrUnexpectedFlateStream = errors.New("quickws: internal error, unexpected bytes at end of flate stream") 49 | ) 50 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "bufio" 19 | "encoding/base64" 20 | "net" 21 | "net/http" 22 | "net/url" 23 | "time" 24 | 25 | "github.com/antlabs/wsutil/hostname" 26 | ) 27 | 28 | type ( 29 | dialFunc func(network, addr string, timeout time.Duration) (c net.Conn, err error) 30 | httpProxy struct { 31 | proxyAddr *url.URL 32 | dialTimeout func(network, addr string, timeout time.Duration) (c net.Conn, err error) 33 | timeout time.Duration 34 | } 35 | ) 36 | 37 | var _ DialerTimeout = (*httpProxy)(nil) 38 | 39 | func newhttpProxy(u *url.URL, dial dialFunc) *httpProxy { 40 | return &httpProxy{proxyAddr: u, dialTimeout: dial} 41 | } 42 | 43 | func (h *httpProxy) DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error) { 44 | if h.proxyAddr == nil { 45 | return h.dialTimeout(network, addr, h.timeout) 46 | } 47 | 48 | hostName := hostname.GetHostName(h.proxyAddr) 49 | c, err = h.dialTimeout(network, hostName, h.timeout) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | header := make(http.Header) 55 | 56 | if u := h.proxyAddr.User; u != nil { 57 | user := u.Username() 58 | if pass, ok := u.Password(); ok { 59 | credential := base64.StdEncoding.EncodeToString([]byte(user + ":" + pass)) 60 | header.Set("Proxy-Authorization", "Basic "+credential) 61 | } 62 | } 63 | 64 | req := &http.Request{ 65 | Method: http.MethodConnect, 66 | URL: &url.URL{Opaque: hostName}, 67 | Host: hostName, 68 | Header: header, 69 | } 70 | 71 | if err := req.Write(c); err != nil { 72 | c.Close() 73 | return nil, err 74 | } 75 | 76 | br := bufio.NewReader(c) 77 | resp, err := http.ReadResponse(br, req) 78 | if err != nil { 79 | c.Close() 80 | return nil, err 81 | } 82 | 83 | if resp.StatusCode != 200 { 84 | c.Close() 85 | return nil, getHttpErrMsg(resp.StatusCode) 86 | } 87 | return c, nil 88 | } 89 | -------------------------------------------------------------------------------- /permessage_deflate.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "sync/atomic" 19 | "unsafe" 20 | 21 | "github.com/antlabs/wsutil/deflate" 22 | ) 23 | 24 | // 压缩的入口函数 25 | func (c *Conn) encoode(payload *[]byte) (encodePayload *[]byte, err error) { 26 | 27 | ct := (c.pd.ClientContextTakeover && c.client || !c.client && c.pd.ServerContextTakeover) && c.pd.Compression 28 | // 上下文接管 29 | bit := uint8(0) 30 | if c.client { 31 | bit = c.pd.ClientMaxWindowBits 32 | } else { 33 | bit = c.pd.ServerMaxWindowBits 34 | } 35 | if ct { 36 | // 这里的读取是单go程的。所以不用加锁 37 | if atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&c.enCtx))) == nil { 38 | 39 | c.wmu.Lock() //加锁 40 | if atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&c.enCtx))) != nil { 41 | goto decode 42 | } 43 | enCtx, err := deflate.NewCompressContextTakeover(bit) 44 | if err != nil { 45 | c.wmu.Unlock() 46 | return nil, err 47 | } 48 | atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&c.enCtx)), unsafe.Pointer(enCtx)) 49 | c.wmu.Unlock() 50 | } 51 | 52 | } 53 | c.wmu.Lock() 54 | decode: 55 | defer c.wmu.Unlock() 56 | // 处理上下文接管和非上下文接管两种情况 57 | // bit 为啥放在参数里面传递, 因为非上下文接管的时候,也需要正确处理bit 58 | return c.enCtx.Compress(payload, bit) 59 | } 60 | 61 | // 解压缩入口函数 62 | // 解压目前只在一个go程里面按序列处理,所以不需要加锁 63 | func (c *Conn) decode(payload *[]byte) (decodePayload *[]byte, err error) { 64 | ct := (c.pd.ClientContextTakeover && c.client || !c.client && c.pd.ServerContextTakeover) && c.pd.Decompression 65 | // 上下文接管 66 | if ct { 67 | if c.deCtx == nil { 68 | 69 | bit := uint8(0) 70 | 71 | if c.client { 72 | bit = c.pd.ClientMaxWindowBits 73 | } else { 74 | bit = c.pd.ServerMaxWindowBits 75 | } 76 | c.deCtx, err = deflate.NewDecompressContextTakeover(bit) 77 | if err != nil { 78 | return nil, err 79 | } 80 | } 81 | 82 | } 83 | 84 | // 上下文接管, deCtx是nil 85 | // 非上下文接管, deCtx是非nil 86 | 87 | return c.deCtx.Decompress(payload, c.readMaxMessage) 88 | } 89 | -------------------------------------------------------------------------------- /autobahn/client/autobahn-client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/antlabs/quickws" 9 | ) 10 | 11 | // https://github.com/snapview/tokio-tungstenite/blob/master/examples/autobahn-client.rs 12 | 13 | const ( 14 | // host = "ws://192.168.128.44:9003" 15 | host = "ws://127.0.0.1:9003" 16 | agent = "quickws" 17 | ) 18 | 19 | type echoHandler struct { 20 | done chan struct{} 21 | } 22 | 23 | func (e *echoHandler) OnOpen(c *quickws.Conn) { 24 | fmt.Printf("OnOpen::%p\n", c) 25 | } 26 | 27 | func (e *echoHandler) OnMessage(c *quickws.Conn, op quickws.Opcode, msg []byte) { 28 | fmt.Printf("OnMessage: opcode:%s, msg.size:%d\n", op, len(msg)) 29 | if op == quickws.Text || op == quickws.Binary { 30 | // os.WriteFile("./debug.dat", msg, 0o644) 31 | // if err := c.WriteMessage(op, msg); err != nil { 32 | // fmt.Println("write fail:", err) 33 | // } 34 | if err := c.WriteTimeout(op, msg, 1*time.Minute); err != nil { 35 | fmt.Println("write fail:", err) 36 | } 37 | } 38 | } 39 | 40 | func (e *echoHandler) OnClose(c *quickws.Conn, err error) { 41 | fmt.Println("OnClose:", c, err) 42 | close(e.done) 43 | } 44 | 45 | func getCaseCount() int { 46 | var count int 47 | c, err := quickws.Dial(fmt.Sprintf("%s/getCaseCount", host), quickws.WithClientOnMessageFunc(func() quickws.OnMessageFunc { 48 | return func(c *quickws.Conn, op quickws.Opcode, msg []byte) { 49 | var err error 50 | fmt.Printf("msg(%s)\n", msg) 51 | count, err = strconv.Atoi(string(msg)) 52 | if err != nil { 53 | panic(err) 54 | } 55 | c.Close() 56 | } 57 | }())) 58 | if err != nil { 59 | panic(err) 60 | } 61 | defer c.Close() 62 | 63 | err = c.ReadLoop() 64 | fmt.Printf("readloop rv:%s\n", err) 65 | return count 66 | } 67 | 68 | func runTest(caseNo int) { 69 | done := make(chan struct{}) 70 | c, err := quickws.Dial(fmt.Sprintf("%s/runCase?case=%d&agent=%s", host, caseNo, agent), 71 | quickws.WithClientReplyPing(), 72 | quickws.WithClientEnableUTF8Check(), 73 | quickws.WithClientDecompressAndCompress(), 74 | quickws.WithClientContextTakeover(), 75 | quickws.WithClientMaxWindowsBits(10), 76 | quickws.WithClientCallback(&echoHandler{done: done}), 77 | ) 78 | if err != nil { 79 | fmt.Println("Dial fail:", err) 80 | return 81 | } 82 | 83 | go func() { 84 | _ = c.ReadLoop() 85 | }() 86 | <-done 87 | } 88 | 89 | func updateReports() { 90 | c, err := quickws.Dial(fmt.Sprintf("%s/updateReports?agent=%s", host, agent)) 91 | if err != nil { 92 | fmt.Println("Dial fail:", err) 93 | return 94 | } 95 | 96 | c.Close() 97 | } 98 | 99 | // 1.先通过接口获取case的总个数 100 | // 2.运行测试客户端client 101 | func main() { 102 | total := getCaseCount() 103 | fmt.Println("total case:", total) 104 | for i := 1; i <= total; i++ { 105 | runTest(i) 106 | } 107 | updateReports() 108 | } 109 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "reflect" 19 | "testing" 20 | ) 21 | 22 | func Test_SecWebSocketAcceptVal(t *testing.T) { 23 | need := "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" 24 | got := secWebSocketAcceptVal("dGhlIHNhbXBsZSBub25jZQ==") 25 | if got != need { 26 | t.Errorf("need %s, got %s", need, got) 27 | } 28 | } 29 | 30 | func Test_getHttpErrMsg(t *testing.T) { 31 | t.Run("test 1", func(t *testing.T) { 32 | err := getHttpErrMsg(111) 33 | if err == nil { 34 | t.Errorf("err should not be nil") 35 | } 36 | }) 37 | 38 | t.Run("test 2", func(t *testing.T) { 39 | err := getHttpErrMsg(400) 40 | if err == nil { 41 | t.Errorf("err should not be nil") 42 | } 43 | }) 44 | } 45 | 46 | func TestStringToBytes(t *testing.T) { 47 | type args struct { 48 | s string 49 | } 50 | tests := []struct { 51 | name string 52 | args args 53 | wantB []byte 54 | }{ 55 | { 56 | name: "test1", 57 | args: args{s: "test1"}, 58 | wantB: []byte("test1"), 59 | }, 60 | { 61 | name: "test2", 62 | args: args{s: "test2"}, 63 | wantB: []byte("test2"), 64 | }, 65 | } 66 | for _, tt := range tests { 67 | t.Run(tt.name, func(t *testing.T) { 68 | if gotB := StringToBytes(tt.args.s); !reflect.DeepEqual(gotB, tt.wantB) { 69 | t.Errorf("StringToBytes() = %v, want %v", gotB, tt.wantB) 70 | } 71 | }) 72 | } 73 | } 74 | 75 | func Test_secWebSocketAccept(t *testing.T) { 76 | tests := []struct { 77 | name string 78 | want string 79 | }{ 80 | {name: ">0"}, 81 | } 82 | for _, tt := range tests { 83 | t.Run(tt.name, func(t *testing.T) { 84 | if got := secWebSocketAccept(); len(got) == 0 { 85 | t.Errorf("secWebSocketAccept() = %v, want %v", got, tt.want) 86 | } 87 | }) 88 | } 89 | } 90 | 91 | func Test_secWebSocketAcceptVal(t *testing.T) { 92 | type args struct { 93 | val string 94 | } 95 | tests := []struct { 96 | name string 97 | args args 98 | want string 99 | }{ 100 | {name: "test1", args: args{val: "test1"}}, 101 | } 102 | for _, tt := range tests { 103 | t.Run(tt.name, func(t *testing.T) { 104 | if got := secWebSocketAcceptVal(tt.args.val); len(got) == 0 { 105 | t.Errorf("secWebSocketAcceptVal() = %v, want %v", got, tt.want) 106 | } 107 | }) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /server_option_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "net/http" 18 | "net/http/httptest" 19 | "strings" 20 | "testing" 21 | ) 22 | 23 | func Test_ServerOption(t *testing.T) { 24 | t.Run("2.1 Subprotocol", func(t *testing.T) { 25 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 | _, err := Upgrade(w, r, WithServerSubprotocols([]string{"crud", "im"})) 27 | if err != nil { 28 | t.Error(err) 29 | } 30 | })) 31 | 32 | defer ts.Close() 33 | 34 | url := strings.ReplaceAll(ts.URL, "http", "ws") 35 | h := make(http.Header) 36 | con, err := DialConf(url, ClientOptionToConf(WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{ 37 | "Sec-WebSocket-Protocol": []string{"crud"}, 38 | }))) 39 | if err != nil { 40 | t.Error(err) 41 | } 42 | defer con.Close() 43 | 44 | if h["Sec-Websocket-Protocol"][0] != "crud" { 45 | t.Error("header fail") 46 | } 47 | }) 48 | 49 | t.Run("2.2 Subprotocol", func(t *testing.T) { 50 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 51 | _, err := Upgrade(w, r, WithServerSubprotocols([]string{"crud", "im"})) 52 | if err != nil { 53 | t.Error(err) 54 | } 55 | })) 56 | 57 | defer ts.Close() 58 | 59 | url := strings.ReplaceAll(ts.URL, "http", "ws") 60 | h := make(http.Header) 61 | con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{ 62 | "Sec-WebSocket-Protocol": []string{"crud"}, 63 | })) 64 | if err != nil { 65 | t.Error(err) 66 | } 67 | defer con.Close() 68 | 69 | if h["Sec-Websocket-Protocol"][0] != "crud" { 70 | t.Error("header fail") 71 | } 72 | }) 73 | 74 | t.Run("2.3 Subprotocol", func(t *testing.T) { 75 | upgrade := NewUpgrade(WithServerSubprotocols([]string{"crud", "im"})) 76 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 77 | _, err := upgrade.Upgrade(w, r) 78 | if err != nil { 79 | t.Error(err) 80 | } 81 | })) 82 | 83 | defer ts.Close() 84 | 85 | url := strings.ReplaceAll(ts.URL, "http", "ws") 86 | h := make(http.Header) 87 | con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{ 88 | "Sec-WebSocket-Protocol": []string{"crud"}, 89 | })) 90 | if err != nil { 91 | t.Error(err) 92 | } 93 | defer con.Close() 94 | 95 | if h["Sec-Websocket-Protocol"][0] != "crud" { 96 | t.Error("header fail") 97 | } 98 | }) 99 | } 100 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "errors" 19 | "net" 20 | "net/http" 21 | "net/url" 22 | "time" 23 | 24 | "github.com/antlabs/wsutil/deflate" 25 | "github.com/antlabs/wsutil/enum" 26 | ) 27 | 28 | var ErrDialFuncAndProxyFunc = errors.New("dialFunc and proxyFunc can't be set at the same time") 29 | 30 | // 握手 31 | type Dialer interface { 32 | Dial(network, addr string) (c net.Conn, err error) 33 | } 34 | 35 | // 带超时时间的握手 36 | type DialerTimeout interface { 37 | DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error) 38 | } 39 | 40 | // Config的配置,有两个种用法 41 | // 一种是声明一个全局的配置,后面不停使用。 42 | // 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置 43 | type Config struct { 44 | cb Callback 45 | deflate.PermessageDeflateConf // 静态配置, 从WithXXX函数中获取 46 | tcpNoDelay bool 47 | replyPing bool // 开启自动回复 48 | ignorePong bool // 忽略pong消息 49 | disableBufioClearHack bool // 关闭bufio的clear hack优化 50 | utf8Check func([]byte) bool // utf8检查 51 | readTimeout time.Duration // 读超时时间 52 | windowsMultipleTimesPayloadSize float32 // 设置几倍(1024+14)的payload大小 53 | bufioMultipleTimesPayloadSize float32 // 设置几倍(1024)的payload大小 54 | parseMode parseMode // 解析模式 55 | maxDelayWriteNum int32 // 最大延迟包的个数, 默认值为10 56 | delayWriteInitBufferSize int32 // 延迟写入的初始缓冲区大小, 默认值是8k 57 | maxDelayWriteDuration time.Duration // 最大延迟时间, 默认值是10ms 58 | subProtocols []string // 设置支持的子协议 59 | readMaxMessage int64 //最大消息大小 60 | dialFunc func() (Dialer, error) 61 | proxyFunc func(*http.Request) (*url.URL, error) // 62 | } 63 | 64 | func (c *Config) initPayloadSize() int { 65 | return int((1024.0 + float32(enum.MaxFrameHeaderSize)) * c.windowsMultipleTimesPayloadSize) 66 | } 67 | 68 | // 默认设置 69 | func (c *Config) defaultSetting() error { 70 | c.cb = &DefCallback{} 71 | c.maxDelayWriteNum = 10 72 | c.windowsMultipleTimesPayloadSize = 1.0 73 | c.delayWriteInitBufferSize = 8 * 1024 74 | c.maxDelayWriteDuration = 10 * time.Millisecond 75 | c.tcpNoDelay = true 76 | c.parseMode = ParseModeWindows 77 | // 对于text消息,默认不检查text是utf8字符 78 | c.utf8Check = func(b []byte) bool { return true } 79 | 80 | if c.dialFunc != nil && c.proxyFunc != nil { 81 | return ErrDialFuncAndProxyFunc 82 | } 83 | return nil 84 | } 85 | -------------------------------------------------------------------------------- /server_profile_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "bytes" 18 | "fmt" 19 | "net/http" 20 | "net/http/httptest" 21 | "strings" 22 | "sync" 23 | "testing" 24 | 25 | "github.com/antlabs/wsutil/deflate" 26 | //"os" 27 | ) 28 | 29 | type echoHandler struct{} 30 | 31 | func (e *echoHandler) OnOpen(c *Conn) { 32 | } 33 | 34 | func (e *echoHandler) OnMessage(c *Conn, op Opcode, msg []byte) { 35 | // if err := c.WriteTimeout(op, msg, 3*time.Second); err != nil { 36 | // fmt.Println("write fail:", err) 37 | // } 38 | if err := c.WriteMessage(op, msg); err != nil { 39 | fmt.Println("write fail:", err) 40 | } 41 | } 42 | 43 | func (e *echoHandler) OnClose(c *Conn, err error) { 44 | } 45 | 46 | type echoClientHandler struct { 47 | DefCallback 48 | Count int 49 | } 50 | 51 | func (e *echoClientHandler) OnMessage(c *Conn, op Opcode, msg []byte) { 52 | e.Count++ 53 | if e.Count == 100 { 54 | c.Close() 55 | } 56 | _ = c.WriteMessage(op, msg) 57 | } 58 | 59 | func newProfileServrEcho(t *testing.T, data []byte) *httptest.Server { 60 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 | c, err := Upgrade(w, r, WithServerCallback(&echoHandler{})) 62 | if err != nil { 63 | t.Error(err) 64 | return 65 | } 66 | 67 | c.StartReadLoop() 68 | })) 69 | 70 | ts.URL = "ws" + strings.TrimPrefix(ts.URL, "http") 71 | return ts 72 | } 73 | 74 | // 跑profile的echo服务 75 | func Test_ServerProfile(t *testing.T) { 76 | payload := make([]byte, 1024) 77 | for i := 0; i < len(payload); i++ { 78 | payload[i] = byte(i) 79 | } 80 | 81 | maxGo := 10 82 | ts := newProfileServrEcho(t, payload) 83 | var wg sync.WaitGroup 84 | defer wg.Wait() 85 | 86 | wg.Add(maxGo) 87 | for i := 0; i < maxGo; i++ { 88 | go func() { 89 | defer wg.Done() 90 | c, err := Dial(ts.URL, WithClientCallback(&echoClientHandler{})) 91 | if err != nil { 92 | t.Error(err) 93 | return 94 | } 95 | err = c.WriteMessage(Binary, payload) 96 | if err != nil { 97 | t.Error(err) 98 | return 99 | } 100 | c.StartReadLoop() 101 | }() 102 | } 103 | } 104 | 105 | func Test_Upgrade(t *testing.T) { 106 | r, err := http.NewRequest("GET", "http://test.com", nil) 107 | if err != nil { 108 | t.Error(err) 109 | return 110 | } 111 | 112 | var out bytes.Buffer 113 | err = prepareWriteResponse(r, &out, &Config{}, deflate.PermessageDeflateConf{}) 114 | if err != nil { 115 | t.Error(err) 116 | return 117 | } 118 | fmt.Printf("%s\n %d", out.Bytes(), out.Len()) 119 | } 120 | -------------------------------------------------------------------------------- /callback_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "net/http" 18 | "net/http/httptest" 19 | "strings" 20 | "sync/atomic" 21 | "testing" 22 | "time" 23 | ) 24 | 25 | type testDefaultCallback struct { 26 | DefCallback 27 | } 28 | 29 | func Test_DefaultCallback(t *testing.T) { 30 | t.Run("local: default callback", func(t *testing.T) { 31 | run := int32(0) 32 | done := make(chan bool, 1) 33 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 | c, err := Upgrade(w, r, WithServerCallback(&testDefaultCallback{})) 35 | if err != nil { 36 | t.Error(err) 37 | return 38 | } 39 | if c == nil { 40 | t.Error("conn is nil") 41 | return 42 | } 43 | defer c.Close() 44 | c.StartReadLoop() 45 | 46 | err = c.WriteMessage(Binary, []byte("hello")) 47 | if err != nil { 48 | t.Error(err) 49 | return 50 | } 51 | 52 | atomic.AddInt32(&run, int32(1)) 53 | done <- true 54 | })) 55 | 56 | defer ts.Close() 57 | 58 | url := strings.ReplaceAll(ts.URL, "http", "ws") 59 | con, err := Dial(url, WithClientCallback(&testDefaultCallback{})) 60 | if err != nil { 61 | t.Error(err) 62 | return 63 | } 64 | if con.client != true { 65 | t.Error("con.client must be true") 66 | } 67 | defer con.Close() 68 | 69 | err = con.WriteMessage(Binary, []byte("hello")) 70 | 71 | if err != nil { 72 | t.Error(err) 73 | return 74 | } 75 | select { 76 | case <-done: 77 | case <-time.After(1000 * time.Millisecond): 78 | } 79 | if atomic.LoadInt32(&run) != 1 { 80 | t.Error("not run server:method fail") 81 | } 82 | }) 83 | 84 | t.Run("global: default callback", func(t *testing.T) { 85 | run := int32(0) 86 | done := make(chan bool, 1) 87 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 88 | c, err := Upgrade(w, r, WithServerCallback(&testDefaultCallback{})) 89 | if err != nil { 90 | t.Error(err) 91 | } 92 | c.StartReadLoop() 93 | err = c.WriteMessage(Binary, []byte("hello")) 94 | if err != nil { 95 | t.Error(err) 96 | return 97 | } 98 | atomic.AddInt32(&run, int32(1)) 99 | done <- true 100 | })) 101 | 102 | defer ts.Close() 103 | 104 | url := strings.ReplaceAll(ts.URL, "http", "ws") 105 | con, err := Dial(url, WithClientCallback(&testDefaultCallback{})) 106 | if err != nil { 107 | t.Error(err) 108 | } 109 | defer con.Close() 110 | 111 | err = con.WriteMessage(Binary, []byte("hello")) 112 | if err != nil { 113 | t.Error(err) 114 | return 115 | } 116 | select { 117 | case <-done: 118 | case <-time.After(100 * time.Millisecond): 119 | } 120 | if atomic.LoadInt32(&run) != 1 { 121 | t.Error("not run server:method fail") 122 | } 123 | }) 124 | } 125 | -------------------------------------------------------------------------------- /config_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "net/http" 18 | "net/url" 19 | "testing" 20 | "time" 21 | ) 22 | 23 | func Test_InitPayloadSize(t *testing.T) { 24 | t.Run("InitPayload", func(t *testing.T) { 25 | var c Config 26 | for i := 1; i < 32; i++ { 27 | c.windowsMultipleTimesPayloadSize = float32(i) 28 | if c.initPayloadSize() != i*(1024+14) { 29 | t.Errorf("initPayloadSize() = %d, want %d", c.initPayloadSize(), i*(1024+14)) 30 | } 31 | } 32 | }) 33 | } 34 | 35 | func TestConfig_defaultSetting(t *testing.T) { 36 | type fields struct { 37 | Callback Callback 38 | tcpNoDelay bool 39 | replyPing bool 40 | // decompression bool 41 | // compression bool 42 | ignorePong bool 43 | disableBufioClearHack bool 44 | utf8Check func([]byte) bool 45 | readTimeout time.Duration 46 | windowsMultipleTimesPayloadSize float32 47 | bufioMultipleTimesPayloadSize float32 48 | parseMode parseMode 49 | maxDelayWriteNum int32 50 | delayWriteInitBufferSize int32 51 | maxDelayWriteDuration time.Duration 52 | subProtocols []string 53 | dialFunc func() (Dialer, error) 54 | proxyFunc func(*http.Request) (*url.URL, error) 55 | } 56 | tests := []struct { 57 | name string 58 | fields fields 59 | wantErr bool 60 | }{ 61 | // TODO: Add test cases. 62 | {name: "fail", fields: fields{ 63 | dialFunc: func() (Dialer, error) { return nil, nil }, 64 | proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, 65 | }, wantErr: true}, 66 | } 67 | for _, tt := range tests { 68 | t.Run(tt.name, func(t *testing.T) { 69 | c := &Config{ 70 | cb: tt.fields.Callback, 71 | tcpNoDelay: tt.fields.tcpNoDelay, 72 | replyPing: tt.fields.replyPing, 73 | ignorePong: tt.fields.ignorePong, 74 | disableBufioClearHack: tt.fields.disableBufioClearHack, 75 | utf8Check: tt.fields.utf8Check, 76 | readTimeout: tt.fields.readTimeout, 77 | windowsMultipleTimesPayloadSize: tt.fields.windowsMultipleTimesPayloadSize, 78 | bufioMultipleTimesPayloadSize: tt.fields.bufioMultipleTimesPayloadSize, 79 | parseMode: tt.fields.parseMode, 80 | maxDelayWriteNum: tt.fields.maxDelayWriteNum, 81 | delayWriteInitBufferSize: tt.fields.delayWriteInitBufferSize, 82 | maxDelayWriteDuration: tt.fields.maxDelayWriteDuration, 83 | subProtocols: tt.fields.subProtocols, 84 | dialFunc: tt.fields.dialFunc, 85 | proxyFunc: tt.fields.proxyFunc, 86 | } 87 | if err := c.defaultSetting(); (err != nil) != tt.wantErr { 88 | t.Errorf("Config.defaultSetting() error = %v, wantErr %v", err, tt.wantErr) 89 | } 90 | }) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /status_codes.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "encoding/binary" 19 | "strconv" 20 | "strings" 21 | ) 22 | 23 | // https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 24 | // 这里记录了各种状态码的含义 25 | type StatusCode int16 26 | 27 | const ( 28 | // NormalClosure 正常关闭 29 | NormalClosure StatusCode = 1000 30 | // EndpointGoingAway 对端正在消失 31 | EndpointGoingAway StatusCode = 1001 32 | // ProtocolError 表示对端由于协议错误正在终止连接 33 | ProtocolError StatusCode = 1002 34 | // DataCannotAccept 收到一个不能接受的数据类型 35 | DataCannotAccept StatusCode = 1003 36 | // NotConsistentMessageType 表示对端正在终止连接, 消息类型不一致 37 | NotConsistentMessageType StatusCode = 1007 38 | // TerminatingConnection 表示对端正在终止连接, 没有好用的错误, 可以用这个错误码表示 39 | TerminatingConnection StatusCode = 1008 40 | // TooBigMessage 消息太大, 不能处理, 关闭连接 41 | TooBigMessage StatusCode = 1009 42 | // NoExtensions 只用于客户端, 服务端返回扩展消息 43 | NoExtensions StatusCode = 1010 44 | // ServerTerminating 服务端遇到意外情况, 中止请求 45 | ServerTerminating StatusCode = 1011 46 | ) 47 | 48 | func (s StatusCode) String() string { 49 | switch s { 50 | case NormalClosure: 51 | return "NormalClosure" 52 | case EndpointGoingAway: 53 | return "EndpointGoingAway" 54 | case ProtocolError: 55 | return "ProtocolError" 56 | case DataCannotAccept: 57 | return "DataCannotAccept" 58 | case NotConsistentMessageType: 59 | return "NotConsistentMessageType" 60 | case TerminatingConnection: 61 | return "TerminatingConnection" 62 | case TooBigMessage: 63 | return "TooBigMessage" 64 | case NoExtensions: 65 | return "NoExtensions" 66 | case ServerTerminating: 67 | return "ServerTerminating" 68 | } 69 | 70 | return "unknown" 71 | } 72 | 73 | func (s StatusCode) Error() string { 74 | return s.String() 75 | } 76 | 77 | func (s StatusCode) toBytes() (rv []byte) { 78 | rv = make([]byte, 2+len(s.String())) 79 | binary.BigEndian.PutUint16(rv, uint16(s)) 80 | copy(rv[2:], s.String()) 81 | return 82 | } 83 | 84 | type CloseErrMsg struct { 85 | Code StatusCode 86 | Msg string 87 | } 88 | 89 | func (c CloseErrMsg) Error() string { 90 | var out strings.Builder 91 | 92 | out.WriteString(" 0 { 101 | out.WriteString(c.Msg) 102 | } 103 | 104 | out.WriteString(">") 105 | return out.String() 106 | } 107 | 108 | func bytesToCloseErrMsg(payload []byte) *CloseErrMsg { 109 | var ce CloseErrMsg 110 | if len(payload) >= 2 { 111 | ce.Code = StatusCode(binary.BigEndian.Uint16(payload)) 112 | } 113 | 114 | if len(payload) >= 3 { 115 | ce.Msg = string(payload[3:]) 116 | } 117 | return &ce 118 | } 119 | 120 | func validCode(code uint16) bool { 121 | switch code { 122 | case 1004, 1005, 1006, 1015: 123 | return false 124 | } 125 | 126 | if code >= 1000 && code <= 1015 { 127 | return true 128 | } 129 | 130 | if code >= 3000 && code <= 4999 { 131 | return true 132 | } 133 | 134 | return false 135 | } 136 | -------------------------------------------------------------------------------- /server_handshake_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "errors" 19 | "fmt" 20 | "io" 21 | "net/http" 22 | "testing" 23 | 24 | "github.com/antlabs/wsutil/deflate" 25 | ) 26 | 27 | type failWriter struct { 28 | count int 29 | failCount int 30 | } 31 | 32 | func (f *failWriter) Write(p []byte) (n int, err error) { 33 | f.count++ 34 | if f.count == f.failCount { 35 | return 0, errors.New("fail") 36 | } 37 | return len(p), nil 38 | } 39 | func Test_writeHeaderVal(t *testing.T) { 40 | type args struct { 41 | val []byte 42 | } 43 | tests := []struct { 44 | w io.Writer 45 | name string 46 | args args 47 | wantW string 48 | wantErr bool 49 | }{ 50 | {w: &failWriter{failCount: 1}, wantErr: true}, 51 | {w: &failWriter{failCount: 2}, wantErr: true}, 52 | } 53 | for _, tt := range tests { 54 | t.Run(tt.name, func(t *testing.T) { 55 | if err := writeHeaderVal(tt.w, tt.args.val); (err != nil) != tt.wantErr { 56 | t.Errorf("writeHeaderVal() error = %v, wantErr %v", err, tt.wantErr) 57 | return 58 | } 59 | }) 60 | } 61 | } 62 | 63 | func Test_HttpGet(t *testing.T) { 64 | h := make(http.Header) 65 | h.Set(strGetSecWebSocketProtocolKey, "token") 66 | fmt.Printf("%#v\n", h) 67 | if h.Get(strGetSecWebSocketProtocolKey) == "" { 68 | panic("error") 69 | } 70 | } 71 | func Test_prepareWriteResponse(t *testing.T) { 72 | type args struct { 73 | r *http.Request 74 | cnf *Config 75 | } 76 | tests := []struct { 77 | w io.Writer 78 | name string 79 | args args 80 | wantW string 81 | wantErr bool 82 | }{ 83 | {w: &failWriter{failCount: 1}, wantErr: true, args: args{r: &http.Request{Header: http.Header{}}, cnf: &Config{}}}, 84 | {w: &failWriter{failCount: 2}, wantErr: true, args: args{r: &http.Request{Header: http.Header{}}, cnf: &Config{}}}, 85 | {w: &failWriter{failCount: 3}, wantErr: true, args: args{r: &http.Request{Header: http.Header{}}, cnf: &Config{}}}, 86 | {w: &failWriter{failCount: 4}, wantErr: true, args: args{r: &http.Request{Header: http.Header{}}, cnf: &Config{PermessageDeflateConf: deflate.PermessageDeflateConf{Decompression: true, Compression: true}}}}, 87 | // {w: &failWriter{failCount: 5}, wantErr: true, args: args{r: &http.Request{Header: http.Header{}}, cnf: &Config{PermessageDeflateConf: deflate.PermessageDeflateConf{Decompression: true, Compression: true}}}}, 88 | // {w: &failWriter{failCount: 6}, wantErr: true, args: args{r: &http.Request{Header: http.Header{}}, cnf: &Config{PermessageDeflateConf: deflate.PermessageDeflateConf{Decompression: true, Compression: true}}}}, 89 | // {w: &failWriter{failCount: 7}, wantErr: true, args: args{r: &http.Request{Header: http.Header{"Sec-Websocket-Protocol": []string{"token"}}}, cnf: &Config{PermessageDeflateConf: deflate.PermessageDeflateConf{Decompression: true, Compression: true}}}}, 90 | // {w: &failWriter{failCount: 8}, wantErr: true, args: args{r: &http.Request{Header: http.Header{"Sec-Websocket-Protocol": []string{"token"}}}, cnf: &Config{PermessageDeflateConf: deflate.PermessageDeflateConf{Decompression: true, Compression: true}}}}, 91 | } 92 | for i, tt := range tests { 93 | t.Run(tt.name, func(t *testing.T) { 94 | if err := prepareWriteResponse(tt.args.r, tt.w, tt.args.cnf, deflate.PermessageDeflateConf{}); (err != nil) != tt.wantErr { 95 | t.Errorf("index:%d, prepareWriteResponse() error = %v, wantErr %v, count= %d", i, err, tt.wantErr, tt.w.(*failWriter).count) 96 | return 97 | } 98 | }) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /upgrade.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "bufio" 19 | "bytes" 20 | "net" 21 | "net/http" 22 | "time" 23 | 24 | "github.com/antlabs/wsutil/bufio2" 25 | "github.com/antlabs/wsutil/bytespool" 26 | "github.com/antlabs/wsutil/deflate" 27 | "github.com/antlabs/wsutil/fixedreader" 28 | ) 29 | 30 | type UpgradeServer struct { 31 | config Config 32 | } 33 | 34 | func NewUpgrade(opts ...ServerOption) *UpgradeServer { 35 | var conf ConnOption 36 | if err := conf.defaultSetting(); err != nil { 37 | panic(err.Error()) 38 | } 39 | for _, o := range opts { 40 | o(&conf) 41 | } 42 | return &UpgradeServer{config: conf.Config} 43 | } 44 | 45 | func (u *UpgradeServer) Upgrade(w http.ResponseWriter, r *http.Request) (c *Conn, err error) { 46 | return upgradeInner(w, r, &u.config, nil) 47 | } 48 | 49 | func (u *UpgradeServer) UpgradeV2(w http.ResponseWriter, r *http.Request, cb Callback) (c *Conn, err error) { 50 | return upgradeInner(w, r, &u.config, cb) 51 | } 52 | 53 | func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) { 54 | var conf ConnOption 55 | if err := conf.defaultSetting(); err != nil { 56 | return nil, err 57 | } 58 | for _, o := range opts { 59 | o(&conf) 60 | } 61 | return upgradeInner(w, r, &conf.Config, nil) 62 | } 63 | 64 | func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config, cb Callback) (wsCon *Conn, err error) { 65 | if ecode, err := checkRequest(r); err != nil { 66 | http.Error(w, err.Error(), ecode) 67 | return nil, err 68 | } 69 | 70 | hi, ok := w.(http.Hijacker) 71 | if !ok { 72 | return nil, ErrNotFoundHijacker 73 | } 74 | 75 | var br *bufio.Reader 76 | var conn net.Conn 77 | var rw *bufio.ReadWriter 78 | if conf.parseMode == ParseModeWindows { 79 | // 这里不需要rw,直接使用conn 80 | conn, rw, err = hi.Hijack() 81 | if !conf.disableBufioClearHack { 82 | bufio2.ClearReadWriter(rw) 83 | } 84 | // TODO 85 | // rsp.ClearRsp(w) 86 | // rw = nil 87 | } else { 88 | var rw *bufio.ReadWriter 89 | conn, rw, err = hi.Hijack() 90 | br = rw.Reader 91 | rw = nil 92 | } 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | // 是否打开解压缩 98 | // 外层接收压缩, 并且客户端发送扩展过来 99 | var pd deflate.PermessageDeflateConf 100 | if conf.Decompression { 101 | pd, err = deflate.GetConnPermessageDeflate(r.Header) 102 | if err != nil { 103 | return nil, err 104 | } 105 | } 106 | 107 | buf := bytespool.GetUpgradeRespBytes() 108 | 109 | tmpWriter := bytes.NewBuffer((*buf)[:0]) 110 | defer func() { 111 | bytespool.PutUpgradeRespBytes(buf) 112 | tmpWriter = nil 113 | }() 114 | 115 | resetPermessageDeflate(&pd, conf) 116 | if err = prepareWriteResponse(r, tmpWriter, conf, pd); err != nil { 117 | return 118 | } 119 | 120 | if _, err := conn.Write(tmpWriter.Bytes()); err != nil { 121 | return nil, err 122 | } 123 | 124 | var fr fixedreader.FixedReader 125 | if conf.parseMode == ParseModeWindows { 126 | fr.Init(conn, bytespool.GetBytes(conf.initPayloadSize())) 127 | } 128 | 129 | if err := conn.SetDeadline(time.Time{}); err != nil { 130 | return nil, err 131 | } 132 | if wsCon, err = newConn(conn, false, conf, fr, br); err != nil { 133 | return nil, err 134 | } 135 | 136 | wsCon.pd = pd 137 | wsCon.Callback = cb 138 | if cb == nil { 139 | wsCon.Callback = conf.cb 140 | } 141 | return wsCon, nil 142 | } 143 | 144 | func resetPermessageDeflate(pd *deflate.PermessageDeflateConf, conf *Config) { 145 | pd.Decompression = pd.Enable && conf.Decompression 146 | pd.Compression = pd.Enable && conf.Compression 147 | pd.ServerContextTakeover = pd.Enable && pd.ServerContextTakeover && conf.ServerContextTakeover 148 | pd.ClientContextTakeover = pd.Enable && pd.ClientContextTakeover && conf.ClientContextTakeover 149 | } 150 | -------------------------------------------------------------------------------- /server_handshake.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "errors" 19 | "fmt" 20 | "io" 21 | "net/http" 22 | "strings" 23 | 24 | "github.com/antlabs/wsutil/deflate" 25 | ) 26 | 27 | var ( 28 | ErrNotFoundHijacker = errors.New("not found Hijacker") 29 | bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") 30 | // bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n") 31 | bytesSecWebSocketExtensionsKey = []byte("Sec-WebSocket-Extensions: ") 32 | bytesCRLF = []byte("\r\n") 33 | bytesPutSecWebSocketProtocolKey = []byte("Sec-WebSocket-Protocol: ") 34 | strGetSecWebSocketProtocolKey = "Sec-WebSocket-Protocol" 35 | strWebSocketKey = "Sec-WebSocket-Key" 36 | ) 37 | 38 | func writeHeaderVal(w io.Writer, val []byte) (err error) { 39 | if _, err = w.Write(val); err != nil { 40 | return 41 | } 42 | 43 | _, err = w.Write(bytesCRLF) 44 | return err 45 | } 46 | 47 | func subProtocol(subProtocol string, cnf *Config) string { 48 | if subProtocol == "" { 49 | return "" 50 | } 51 | 52 | subProtocols := strings.Split(subProtocol, ",") 53 | // 如果配置了subProtocols, 则检查客户端的subProtocols是否在配置的subProtocols中 54 | // 为什么要这么做,可以看下这个issue 55 | // https://github.com/antlabs/quickws/issues/12 56 | if len(cnf.subProtocols) > 0 { 57 | for _, clientSubProtocols := range subProtocols { 58 | clientSubProtocols = strings.TrimSpace(clientSubProtocols) 59 | for _, serverSubProtocols := range cnf.subProtocols { 60 | if clientSubProtocols == serverSubProtocols { 61 | return clientSubProtocols 62 | } 63 | } 64 | } 65 | } 66 | // echo Secf-WebSocket-Protocol 的值 67 | return subProtocol 68 | } 69 | 70 | // https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.2 71 | // 第5小点 72 | func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config, pd deflate.PermessageDeflateConf) (err error) { 73 | // 写入响应头 74 | // 写入Sec-WebSocket-Accept key 75 | if _, err = w.Write(bytesHeaderUpgrade); err != nil { 76 | return 77 | } 78 | 79 | v := secWebSocketAcceptVal(r.Header.Get(strWebSocketKey)) 80 | // 写入Sec-WebSocket-Accept 81 | if err = writeHeaderVal(w, StringToBytes(v)); err != nil { 82 | return err 83 | } 84 | 85 | // 给客户端回个信, 表示支持解压缩模式 86 | if pd.Decompression { 87 | if _, err = w.Write(bytesSecWebSocketExtensionsKey); err != nil { 88 | return err 89 | } 90 | if _, err = w.Write([]byte(deflate.GenSecWebSocketExtensions(pd))); err != nil { 91 | return err 92 | } 93 | if _, err = w.Write(bytesCRLF); err != nil { 94 | return err 95 | } 96 | } 97 | 98 | v = r.Header.Get(strGetSecWebSocketProtocolKey) 99 | v = subProtocol(v, cnf) 100 | if len(v) > 0 { 101 | if _, err = w.Write(bytesPutSecWebSocketProtocolKey); err != nil { 102 | return 103 | } 104 | 105 | if err = writeHeaderVal(w, StringToBytes(v)); err != nil { 106 | return err 107 | } 108 | } 109 | 110 | _, err = w.Write(bytesCRLF) 111 | return err 112 | } 113 | 114 | // https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.1 115 | // 按rfc标准, 先来一顿if else判断, 检查发的request是否满足标准 116 | func checkRequest(r *http.Request) (ecode int, err error) { 117 | // 不是get方法的 118 | if r.Method != http.MethodGet { 119 | // TODO错误消息 120 | return http.StatusMethodNotAllowed, fmt.Errorf("%w :%s", ErrOnlyGETSupported, r.Method) 121 | } 122 | // http版本低于1.1 123 | if !r.ProtoAtLeast(1, 1) { 124 | // TODO错误消息 125 | return http.StatusHTTPVersionNotSupported, ErrHTTPProtocolNotSupported 126 | } 127 | 128 | // 没有host字段的 129 | if r.Host == "" { 130 | return http.StatusBadRequest, ErrHostCannotBeEmpty 131 | } 132 | 133 | // Upgrade值不等于websocket的 134 | if upgrade := r.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") { 135 | return http.StatusBadRequest, ErrUpgradeFieldValue 136 | } 137 | 138 | // Connection值不是Upgrade 139 | if conn := r.Header.Get("Connection"); !strings.EqualFold(conn, "Upgrade") { 140 | return http.StatusBadRequest, ErrConnectionFieldValue 141 | } 142 | 143 | // Sec-WebSocket-Key解码之后是16字节长度 144 | // TODO后续优化 145 | if len(r.Header.Get("Sec-WebSocket-Key")) == 0 { 146 | return http.StatusBadRequest, ErrSecWebSocketKey 147 | } 148 | 149 | // Sec-WebSocket-Version的版本不是13的 150 | if r.Header.Get("Sec-WebSocket-Version") != "13" { 151 | return http.StatusUpgradeRequired, ErrSecWebSocketVersion 152 | } 153 | 154 | // TODO Sec-WebSocket-Extensions 155 | return 0, nil 156 | } 157 | -------------------------------------------------------------------------------- /benchmark_read_write_message_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "bytes" 18 | "net" 19 | "testing" 20 | "time" 21 | 22 | "github.com/antlabs/wsutil/enum" 23 | "github.com/antlabs/wsutil/frame" 24 | "github.com/antlabs/wsutil/opcode" 25 | ) 26 | 27 | var noMaskData = []byte{0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f} 28 | 29 | // Read reads data from the connection. 30 | // Read can be made to time out and return an error after a fixed 31 | // time limit; see SetDeadline and SetReadDeadline. 32 | func (testconn *testConn) Read(b []byte) (n int, err error) { 33 | return testconn.buf.Read(b) 34 | } 35 | 36 | // Write writes data to the connection. 37 | // Write can be made to time out and return an error after a fixed 38 | // time limit; see SetDeadline and SetWriteDeadline. 39 | func (testconn *testConn) Write(b []byte) (n int, err error) { 40 | return testconn.buf.Write(b) 41 | } 42 | 43 | // Close closes the connection. 44 | // Any blocked Read or Write operations will be unblocked and return errors. 45 | func (testconn *testConn) Close() error { 46 | return nil 47 | } 48 | 49 | // LocalAddr returns the local network address, if known. 50 | func (testconn *testConn) LocalAddr() net.Addr { 51 | panic("not implemented") // TODO: Implement 52 | } 53 | 54 | // RemoteAddr returns the remote network address, if known. 55 | func (testconn *testConn) RemoteAddr() net.Addr { 56 | panic("not implemented") // TODO: Implement 57 | } 58 | 59 | // SetDeadline sets the read and write deadlines associated 60 | // with the connection. It is equivalent to calling both 61 | // SetReadDeadline and SetWriteDeadline. 62 | // 63 | // A deadline is an absolute time after which I/O operations 64 | // fail instead of blocking. The deadline applies to all future 65 | // and pending I/O, not just the immediately following call to 66 | // Read or Write. After a deadline has been exceeded, the 67 | // connection can be refreshed by setting a deadline in the future. 68 | // 69 | // If the deadline is exceeded a call to Read or Write or to other 70 | // I/O methods will return an error that wraps os.ErrDeadlineExceeded. 71 | // This can be tested using errors.Is(err, os.ErrDeadlineExceeded). 72 | // The error's Timeout method will return true, but note that there 73 | // are other possible errors for which the Timeout method will 74 | // return true even if the deadline has not been exceeded. 75 | // 76 | // An idle timeout can be implemented by repeatedly extending 77 | // the deadline after successful Read or Write calls. 78 | // 79 | // A zero value for t means I/O operations will not time out. 80 | func (testconn *testConn) SetDeadline(t time.Time) error { 81 | panic("not implemented") // TODO: Implement 82 | } 83 | 84 | // SetReadDeadline sets the deadline for future Read calls 85 | // and any currently-blocked Read call. 86 | // A zero value for t means Read will not time out. 87 | func (testconn *testConn) SetReadDeadline(t time.Time) error { 88 | panic("not implemented") // TODO: Implement 89 | } 90 | 91 | // SetWriteDeadline sets the deadline for future Write calls 92 | // and any currently-blocked Write call. 93 | // Even if write times out, it may return n > 0, indicating that 94 | // some of the data was successfully written. 95 | // A zero value for t means Write will not time out. 96 | func (testconn *testConn) SetWriteDeadline(t time.Time) error { 97 | panic("not implemented") // TODO: Implement 98 | } 99 | 100 | type testConn struct { 101 | buf *bytes.Buffer 102 | } 103 | 104 | func Benchmark_WriteMessage(b *testing.B) { 105 | b.Run("1.case", func(b *testing.B) { 106 | var c Conn 107 | buf2 := bytes.NewBuffer(make([]byte, 0, 1024)) 108 | c.c = &testConn{buf: buf2} 109 | buf := make([]byte, 1024) 110 | for i := range buf { 111 | buf[i] = 1 112 | } 113 | 114 | b.ResetTimer() 115 | for i := 0; i < b.N; i++ { 116 | _ = c.WriteMessage(opcode.Binary, buf) 117 | buf2.Reset() 118 | } 119 | }) 120 | } 121 | 122 | func Benchmark_ReadMessage(b *testing.B) { 123 | b.Run("bufio-TODO", func(b *testing.B) { 124 | }) 125 | 126 | b.Run("windows", func(b *testing.B) { 127 | var c Conn 128 | buf2 := bytes.NewBuffer(make([]byte, 0, 1024+enum.MaxFrameHeaderSize)) 129 | 130 | c.c = &testConn{buf: buf2} 131 | 132 | windows := make([]byte, 0, 1024) 133 | 134 | c.fr.Init(c.c, &windows) 135 | c.Callback = &DefCallback{} 136 | 137 | wbuf := make([]byte, 1024) 138 | for i := range wbuf { 139 | wbuf[i] = 1 140 | } 141 | 142 | b.ResetTimer() 143 | for i := 0; i < b.N; i++ { 144 | _ = c.WriteMessage(opcode.Binary, wbuf) 145 | _ = c.ReadLoop() 146 | buf2.Reset() 147 | } 148 | }) 149 | } 150 | 151 | func Benchmark_ReadFrame(b *testing.B) { 152 | r := bytes.NewReader(noMaskData) 153 | var headArray [enum.MaxFrameHeaderSize]byte 154 | for i := 0; i < b.N; i++ { 155 | 156 | r.Reset(noMaskData) 157 | _, _, err := frame.ReadHeader(r, &headArray) 158 | if err != nil { 159 | b.Fatal(err) 160 | } 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "fmt" 18 | "net/http" 19 | "net/http/httptest" 20 | "strings" 21 | "sync/atomic" 22 | "testing" 23 | ) 24 | 25 | // 测试客户端Dial, 返回的http.Header 26 | func Test_Client_Dial_Check_Header(t *testing.T) { 27 | t.Run("Dial: valid resp: status code fail", func(t *testing.T) { 28 | done := make(chan bool, 1) 29 | run := int32(0) 30 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 31 | atomic.AddInt32(&run, int32(1)) 32 | done <- true 33 | })) 34 | 35 | defer ts.Close() 36 | 37 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 38 | _, err := Dial(rawURL) 39 | if err == nil { 40 | t.Fatal("should be error") 41 | } 42 | }) 43 | 44 | t.Run("DialConf: valid resp : status code fail", func(t *testing.T) { 45 | done := make(chan bool, 1) 46 | run := int32(0) 47 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 48 | atomic.AddInt32(&run, int32(1)) 49 | done <- true 50 | })) 51 | 52 | defer ts.Close() 53 | 54 | cnf := ClientOptionToConf() 55 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 56 | _, err := DialConf(rawURL, cnf) 57 | if err == nil { 58 | t.Fatal("should be error") 59 | } 60 | }) 61 | 62 | t.Run("Dial: valid resp: Upgrade field fail", func(t *testing.T) { 63 | done := make(chan bool, 1) 64 | run := int32(0) 65 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 66 | atomic.AddInt32(&run, int32(1)) 67 | w.WriteHeader(101) 68 | w.Header().Set("Upgrade", "xx") 69 | done <- true 70 | })) 71 | 72 | defer ts.Close() 73 | 74 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 75 | _, err := Dial(rawURL) 76 | if err == nil { 77 | t.Fatal("should be error") 78 | } 79 | }) 80 | 81 | t.Run("DialConf: valid resp: Upgrade field fail", func(t *testing.T) { 82 | done := make(chan bool, 1) 83 | run := int32(0) 84 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 85 | atomic.AddInt32(&run, int32(1)) 86 | w.WriteHeader(101) 87 | w.Header().Set("Upgrade", "xx") 88 | done <- true 89 | })) 90 | 91 | defer ts.Close() 92 | 93 | cnf := ClientOptionToConf() 94 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 95 | _, err := DialConf(rawURL, cnf) 96 | if err == nil { 97 | t.Fatal("should be error") 98 | } 99 | }) 100 | 101 | t.Run("Dial: valid resp: Connection fail", func(t *testing.T) { 102 | done := make(chan bool, 1) 103 | run := int32(0) 104 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 105 | atomic.AddInt32(&run, int32(1)) 106 | w.Header().Set("Upgrade", "websocket") 107 | w.Header().Set("Connection", "xx") 108 | w.WriteHeader(101) 109 | done <- true 110 | })) 111 | 112 | defer ts.Close() 113 | 114 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 115 | _, err := Dial(rawURL) 116 | if err == nil { 117 | t.Fatal("should be error") 118 | } 119 | }) 120 | 121 | t.Run("DialConf: valid resp: Connection fail", func(t *testing.T) { 122 | done := make(chan bool, 1) 123 | run := int32(0) 124 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 125 | atomic.AddInt32(&run, int32(1)) 126 | w.Header().Set("Upgrade", "websocket") 127 | w.Header().Set("Connection", "xx") 128 | w.WriteHeader(101) 129 | done <- true 130 | })) 131 | 132 | defer ts.Close() 133 | 134 | cnf := ClientOptionToConf() 135 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 136 | _, err := DialConf(rawURL, cnf) 137 | if err == nil { 138 | t.Fatal("should be error") 139 | } else { 140 | fmt.Printf("err: %v\n", err) 141 | } 142 | }) 143 | 144 | t.Run("Dial: valid resp: Sec-WebSocket-Accept fail", func(t *testing.T) { 145 | done := make(chan bool, 1) 146 | run := int32(0) 147 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 148 | atomic.AddInt32(&run, int32(1)) 149 | w.Header().Set("Upgrade", "websocket") 150 | w.Header().Set("Connection", "Upgrade") 151 | w.WriteHeader(101) 152 | done <- true 153 | })) 154 | 155 | defer ts.Close() 156 | 157 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 158 | _, err := Dial(rawURL) 159 | if err == nil { 160 | t.Fatal("should be error") 161 | } 162 | }) 163 | 164 | t.Run("DialConf: valid resp: Sec-WebSocket-Accept fail", func(t *testing.T) { 165 | done := make(chan bool, 1) 166 | run := int32(0) 167 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 168 | atomic.AddInt32(&run, int32(1)) 169 | w.Header().Set("Upgrade", "websocket") 170 | w.Header().Set("Connection", "Upgrade") 171 | w.WriteHeader(101) 172 | done <- true 173 | })) 174 | 175 | defer ts.Close() 176 | 177 | cnf := ClientOptionToConf() 178 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 179 | _, err := DialConf(rawURL, cnf) 180 | if err == nil { 181 | t.Fatal("should be error") 182 | } else { 183 | fmt.Printf("err: %v\n", err) 184 | } 185 | }) 186 | } 187 | -------------------------------------------------------------------------------- /autobahn/server/autobahn-server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | _ "embed" 6 | "fmt" 7 | "log" 8 | "net" 9 | "net/http" 10 | "sync" 11 | "time" 12 | 13 | "github.com/antlabs/quickws" 14 | //"os" 15 | ) 16 | 17 | //go:embed public.crt 18 | var certPEMBlock []byte 19 | 20 | //go:embed privatekey.pem 21 | var keyPEMBlock []byte 22 | 23 | type echoHandler struct { 24 | openWriteTimeout bool 25 | } 26 | 27 | func (e *echoHandler) OnOpen(c *quickws.Conn) { 28 | fmt.Printf("OnOpen: %p\n", c) 29 | } 30 | 31 | func (e *echoHandler) OnMessage(c *quickws.Conn, op quickws.Opcode, msg []byte) { 32 | // fmt.Println("OnMessage:", c, msg, op) 33 | // if err := c.WriteTimeout(op, msg, 3*time.Second); err != nil { 34 | // fmt.Println("write fail:", err) 35 | // } 36 | 37 | if e.openWriteTimeout { 38 | if err := c.WriteTimeout(op, msg, 50*time.Second); err != nil { 39 | fmt.Println("write fail:", err) 40 | } 41 | return 42 | } 43 | if err := c.WriteMessage(op, msg); err != nil { 44 | fmt.Println("write fail:", err) 45 | } 46 | } 47 | 48 | func (e *echoHandler) OnClose(c *quickws.Conn, err error) { 49 | fmt.Printf("OnClose:%p, %v\n", c, err) 50 | } 51 | 52 | // 1.测试不接管上下文,只解压 53 | func echoNoContextDecompression(w http.ResponseWriter, r *http.Request) { 54 | c, err := quickws.Upgrade(w, r, 55 | quickws.WithServerReplyPing(), 56 | quickws.WithServerDecompression(), 57 | quickws.WithServerIgnorePong(), 58 | quickws.WithServerCallback(&echoHandler{}), 59 | quickws.WithServerEnableUTF8Check(), 60 | // quickws.WithServerReadTimeout(5*time.Second), 61 | ) 62 | if err != nil { 63 | fmt.Println("Upgrade fail:", err) 64 | return 65 | } 66 | 67 | _ = c.ReadLoop() 68 | } 69 | 70 | // 2.测试不接管上下文,压缩和解压 71 | func echoNoContextDecompressionAndCompression(w http.ResponseWriter, r *http.Request) { 72 | c, err := quickws.Upgrade(w, r, 73 | quickws.WithServerReplyPing(), 74 | quickws.WithServerDecompressAndCompress(), 75 | quickws.WithServerIgnorePong(), 76 | quickws.WithServerCallback(&echoHandler{}), 77 | quickws.WithServerEnableUTF8Check(), 78 | ) 79 | if err != nil { 80 | fmt.Println("Upgrade fail:", err) 81 | return 82 | } 83 | 84 | _ = c.ReadLoop() 85 | } 86 | 87 | // 3.测试接管上下文,解压 88 | func echoContextTakeoverDecompression(w http.ResponseWriter, r *http.Request) { 89 | c, err := quickws.Upgrade(w, r, 90 | quickws.WithServerReplyPing(), 91 | quickws.WithServerDecompression(), 92 | quickws.WithServerIgnorePong(), 93 | quickws.WithServerContextTakeover(), 94 | quickws.WithServerCallback(&echoHandler{}), 95 | quickws.WithServerEnableUTF8Check(), 96 | ) 97 | if err != nil { 98 | fmt.Println("Upgrade fail:", err) 99 | return 100 | } 101 | 102 | _ = c.ReadLoop() 103 | } 104 | 105 | // 4.测试接管上下文,压缩/解压缩 106 | func echoContextTakeoverDecompressionAndCompression(w http.ResponseWriter, r *http.Request) { 107 | c, err := quickws.Upgrade(w, r, 108 | quickws.WithServerReplyPing(), 109 | quickws.WithServerDecompressAndCompress(), 110 | quickws.WithServerIgnorePong(), 111 | quickws.WithServerContextTakeover(), 112 | quickws.WithServerCallback(&echoHandler{}), 113 | quickws.WithServerEnableUTF8Check(), 114 | ) 115 | if err != nil { 116 | fmt.Println("Upgrade fail:", err) 117 | return 118 | } 119 | 120 | _ = c.ReadLoop() 121 | } 122 | func echoReadTime(w http.ResponseWriter, r *http.Request) { 123 | c, err := quickws.Upgrade(w, r, 124 | quickws.WithServerReplyPing(), 125 | quickws.WithServerDecompression(), 126 | quickws.WithServerIgnorePong(), 127 | quickws.WithServerCallback(&echoHandler{openWriteTimeout: true}), 128 | quickws.WithServerEnableUTF8Check(), 129 | quickws.WithServerReadTimeout(5*time.Second), 130 | ) 131 | if err != nil { 132 | fmt.Println("Upgrade fail:", err) 133 | return 134 | } 135 | 136 | _ = c.ReadLoop() 137 | } 138 | 139 | var upgrade = quickws.NewUpgrade( 140 | quickws.WithServerReplyPing(), 141 | quickws.WithServerDecompression(), 142 | quickws.WithServerIgnorePong(), 143 | quickws.WithServerEnableUTF8Check(), 144 | quickws.WithServerReadTimeout(5*time.Second), 145 | ) 146 | 147 | func global(w http.ResponseWriter, r *http.Request) { 148 | c, err := upgrade.UpgradeV2(w, r, &echoHandler{openWriteTimeout: true}) 149 | if err != nil { 150 | fmt.Println("Upgrade fail:", err) 151 | return 152 | } 153 | 154 | _ = c.ReadLoop() 155 | } 156 | 157 | func startTLSServer(mux *http.ServeMux) { 158 | 159 | cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) 160 | if err != nil { 161 | log.Fatalf("tls.X509KeyPair failed: %v", err) 162 | } 163 | tlsConfig := &tls.Config{ 164 | Certificates: []tls.Certificate{cert}, 165 | InsecureSkipVerify: true, 166 | } 167 | lnTLS, err := tls.Listen("tcp", "localhost:9002", tlsConfig) 168 | if err != nil { 169 | panic(err) 170 | } 171 | log.Println("tls server exit:", http.Serve(lnTLS, mux)) 172 | } 173 | 174 | func startServer(mux *http.ServeMux) { 175 | 176 | rawTCP, err := net.Listen("tcp", ":9001") 177 | if err != nil { 178 | fmt.Println("Listen fail:", err) 179 | return 180 | } 181 | 182 | log.Println("non-tls server exit:", http.Serve(rawTCP, mux)) 183 | } 184 | 185 | func main() { 186 | mux := &http.ServeMux{} 187 | mux.HandleFunc("/timeout", echoReadTime) 188 | mux.HandleFunc("/global", global) 189 | mux.HandleFunc("/no-context-takeover-decompression", echoNoContextDecompression) 190 | mux.HandleFunc("/no-context-takeover-decompression-and-compression", echoNoContextDecompressionAndCompression) 191 | mux.HandleFunc("/context-takeover-decompression", echoContextTakeoverDecompression) 192 | mux.HandleFunc("/context-takeover-decompression-and-compression", echoContextTakeoverDecompressionAndCompression) 193 | 194 | var wg sync.WaitGroup 195 | wg.Add(2) 196 | 197 | defer wg.Wait() 198 | 199 | go func() { 200 | defer wg.Done() 201 | startServer(mux) 202 | }() 203 | 204 | go func() { 205 | startTLSServer(mux) 206 | }() 207 | 208 | } 209 | -------------------------------------------------------------------------------- /proxy_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "errors" 19 | "net" 20 | "net/http" 21 | "net/http/httptest" 22 | "net/url" 23 | "strings" 24 | "testing" 25 | "time" 26 | ) 27 | 28 | type testServer struct { 29 | path string 30 | rawQuery string 31 | requestURL string 32 | subprotos []string 33 | *testing.T 34 | } 35 | 36 | func newTestServer(t *testing.T) *testServer { 37 | return &testServer{path: "/test", rawQuery: "a=1&b=2", requestURL: "/test?a=1&b=2", T: t, subprotos: []string{"proto1", "proto2"}} 38 | } 39 | 40 | func (t *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 41 | if r.URL.Path != t.path { 42 | t.Errorf("path error: %s", r.URL.Path) 43 | return 44 | } 45 | 46 | if r.URL.RawQuery != t.rawQuery { 47 | t.Errorf("raw query error: %s", r.URL.RawQuery) 48 | return 49 | } 50 | 51 | sub := subProtocol(r.Header.Get("Sec-Websocket-Protocol"), &Config{subProtocols: t.subprotos}) 52 | if sub != "proto1" { 53 | t.Errorf("sub protocol error: (%s)", sub) 54 | return 55 | } 56 | 57 | conn, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 58 | err := c.WriteMessage(o, b) 59 | if err != nil { 60 | t.Error(err) 61 | return 62 | } 63 | })) 64 | if err != nil { 65 | t.Error(err) 66 | return 67 | } 68 | _ = conn.ReadLoop() 69 | } 70 | 71 | func (t *testServer) clientSend(c *Conn) { 72 | _ = c.WriteMessage(Text, []byte("hello world")) 73 | } 74 | 75 | func HTTPToWS(u string) string { 76 | return strings.ReplaceAll(u, "http://", "ws://") 77 | } 78 | 79 | func WsToHTTP(u string) string { 80 | return strings.ReplaceAll(u, "ws://", "http://") 81 | } 82 | 83 | func Test_Proxy(t *testing.T) { 84 | t.Run("test proxy dial.1", func(t *testing.T) { 85 | connect := false 86 | s := newTestServer(t) 87 | 88 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 89 | t.Logf("method: %s, url: %s", r.Method, r.URL.String()) 90 | if r.Method == http.MethodConnect { 91 | connect = true 92 | w.WriteHeader(http.StatusOK) 93 | return 94 | } 95 | 96 | if !connect { 97 | t.Error("test proxy dial fail: not connect") 98 | http.Error(w, "not connect", http.StatusMethodNotAllowed) 99 | return 100 | } 101 | s.ServeHTTP(w, r) 102 | })) 103 | 104 | defer ts.Close() 105 | 106 | proxy := func(*http.Request) (*url.URL, error) { 107 | return url.Parse(HTTPToWS(ts.URL)) 108 | } 109 | 110 | got := make(chan string, 1) 111 | dstURL := HTTPToWS(ts.URL + s.requestURL) 112 | con, err := Dial(dstURL, 113 | WithClientProxyFunc(proxy), 114 | WithClientSubprotocols(s.subprotos), 115 | WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 116 | got <- string(b) 117 | })) 118 | if err != nil { 119 | t.Error(err) 120 | return 121 | } 122 | con.StartReadLoop() 123 | s.clientSend(con) 124 | 125 | defer con.Close() 126 | gotValue := <-got 127 | if gotValue != "hello world" { 128 | t.Errorf("got: %s, want: %s", gotValue, "hello world") 129 | return 130 | } 131 | }) 132 | } 133 | 134 | func Test_httpProxy_Dial(t *testing.T) { 135 | type fields struct { 136 | proxyAddr *url.URL 137 | dial func(network, addr string, timeout time.Duration) (c net.Conn, err error) 138 | } 139 | type args struct { 140 | network string 141 | addr string 142 | } 143 | tests := []struct { 144 | name string 145 | fields fields 146 | args args 147 | wantC net.Conn 148 | wantErr bool 149 | }{ 150 | // 0 151 | { 152 | name: "No proxy address", 153 | fields: fields{ 154 | proxyAddr: nil, 155 | dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) { 156 | // Simulate successful dialing 157 | return &net.TCPConn{}, errors.New("fail") 158 | }, 159 | }, 160 | args: args{ 161 | network: "tcp", 162 | addr: "example.com:80", 163 | }, 164 | wantC: &net.TCPConn{}, 165 | wantErr: true, 166 | }, 167 | // 1 168 | { 169 | name: "Proxy address", 170 | fields: fields{ 171 | proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")}, 172 | dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) { 173 | // Simulate successful dialing 174 | return &net.TCPConn{}, errors.New("fail") 175 | }, 176 | }, 177 | args: args{ 178 | network: "tcp", 179 | addr: "a.b.c:80", 180 | }, 181 | wantC: &net.TCPConn{}, 182 | wantErr: true, 183 | }, 184 | // 2 185 | { 186 | name: "Proxy address", 187 | fields: fields{ 188 | proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")}, 189 | dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) { 190 | // Simulate successful dialing 191 | return &net.TCPConn{}, nil 192 | }, 193 | }, 194 | args: args{ 195 | network: "tcp", 196 | addr: "a.b.c:80", 197 | }, 198 | wantC: &net.TCPConn{}, 199 | wantErr: false, 200 | }, 201 | } 202 | for i, tt := range tests { 203 | t.Run(tt.name, func(t *testing.T) { 204 | h := &httpProxy{ 205 | proxyAddr: tt.fields.proxyAddr, 206 | dialTimeout: tt.fields.dial, 207 | } 208 | _, err := h.dialTimeout(tt.args.network, tt.args.addr, 0) 209 | if (err != nil) != tt.wantErr { 210 | t.Errorf("index:%d, httpProxy.Dial() error = %v, wantErr %v", i, err, tt.wantErr) 211 | return 212 | } 213 | // if !reflect.DeepEqual(gotC, tt.wantC) { 214 | // t.Errorf("httpProxy.Dial() = %v, want %v", gotC, tt.wantC) 215 | // } 216 | }) 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "bufio" 19 | "crypto/tls" 20 | "fmt" 21 | "net" 22 | "net/http" 23 | "net/url" 24 | "strings" 25 | "time" 26 | 27 | "github.com/antlabs/wsutil/bufio2" 28 | "github.com/antlabs/wsutil/bytespool" 29 | "github.com/antlabs/wsutil/deflate" 30 | "github.com/antlabs/wsutil/enum" 31 | "github.com/antlabs/wsutil/fixedreader" 32 | "github.com/antlabs/wsutil/hostname" 33 | ) 34 | 35 | var ( 36 | defaultTimeout = time.Minute * 30 37 | ) 38 | 39 | type DialOption struct { 40 | Header http.Header 41 | u *url.URL 42 | tlsConfig *tls.Config 43 | dialTimeout time.Duration 44 | bindClientHttpHeader *http.Header // 握手成功之后, 客户端获取http.Header, 45 | Config 46 | } 47 | 48 | func ClientOptionToConf(opts ...ClientOption) *DialOption { 49 | var dial DialOption 50 | if err := dial.defaultSetting(); err != nil { 51 | panic(err.Error()) 52 | } 53 | for _, o := range opts { 54 | o(&dial) 55 | } 56 | return &dial 57 | } 58 | 59 | func DialConf(rawUrl string, conf *DialOption) (*Conn, error) { 60 | u, err := url.Parse(rawUrl) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | conf.u = u 66 | conf.dialTimeout = defaultTimeout 67 | if conf.Header == nil { 68 | conf.Header = make(http.Header) 69 | } 70 | return conf.Dial() 71 | } 72 | 73 | // https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 74 | // 又是一顿if else, 咬文嚼字 75 | func Dial(rawUrl string, opts ...ClientOption) (*Conn, error) { 76 | var dial DialOption 77 | u, err := url.Parse(rawUrl) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | dial.u = u 83 | dial.dialTimeout = defaultTimeout 84 | if dial.Header == nil { 85 | dial.Header = make(http.Header) 86 | } 87 | 88 | if err := dial.defaultSetting(); err != nil { 89 | return nil, err 90 | } 91 | 92 | for _, o := range opts { 93 | o(&dial) 94 | } 95 | 96 | return dial.Dial() 97 | } 98 | 99 | // 准备握手的数据 100 | func (d *DialOption) handshake() (*http.Request, string, error) { 101 | switch { 102 | case d.u.Scheme == "wss": 103 | d.u.Scheme = "https" 104 | case d.u.Scheme == "ws": 105 | d.u.Scheme = "http" 106 | default: 107 | return nil, "", fmt.Errorf("Unknown scheme, only supports ws:// or wss://: got %s", d.u.Scheme) 108 | } 109 | 110 | // 满足4.1 111 | // 第2点 GET约束http 1.1版本约束 112 | req, err := http.NewRequest("GET", d.u.String(), nil) 113 | if err != nil { 114 | return nil, "", err 115 | } 116 | // 第5点 117 | d.Header.Add("Upgrade", "websocket") 118 | // 第6点 119 | d.Header.Add("Connection", "Upgrade") 120 | // 第7点 121 | secWebSocket := secWebSocketAccept() 122 | d.Header.Add("Sec-WebSocket-Key", secWebSocket) 123 | // TODO 第8点 124 | // 第9点 125 | d.Header.Add("Sec-WebSocket-Version", "13") 126 | if d.Decompression && d.Compression { 127 | // d.Header.Add("Sec-WebSocket-Extensions", genSecWebSocketExtensions(d.Pd)) 128 | d.Header.Add("Sec-WebSocket-Extensions", deflate.GenSecWebSocketExtensions(d.PermessageDeflateConf)) 129 | } 130 | 131 | if len(d.subProtocols) > 0 { 132 | d.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.subProtocols, ", ")} 133 | } 134 | 135 | req.Header = d.Header 136 | return req, secWebSocket, nil 137 | } 138 | 139 | // 检查服务端响应的数据 140 | // 4.2.2.5 141 | func (d *DialOption) validateRsp(rsp *http.Response, secWebSocket string) error { 142 | if rsp.StatusCode != 101 { 143 | return fmt.Errorf("%w %d", ErrWrongStatusCode, rsp.StatusCode) 144 | } 145 | 146 | // 第2点 147 | if !strings.EqualFold(rsp.Header.Get("Upgrade"), "websocket") { 148 | return ErrUpgradeFieldValue 149 | } 150 | 151 | // 第3点 152 | if !strings.EqualFold(rsp.Header.Get("Connection"), "Upgrade") { 153 | return ErrConnectionFieldValue 154 | } 155 | 156 | // 第4点 157 | if !strings.EqualFold(rsp.Header.Get("Sec-WebSocket-Accept"), secWebSocketAcceptVal(secWebSocket)) { 158 | return ErrSecWebSocketAccept 159 | } 160 | 161 | // TODO 5点 162 | 163 | // TODO 6点 164 | return nil 165 | } 166 | 167 | // wss已经修改为https 168 | func (d *DialOption) tlsConn(c net.Conn) net.Conn { 169 | if d.u.Scheme == "https" { 170 | cfg := d.tlsConfig 171 | if cfg == nil { 172 | cfg = &tls.Config{} 173 | } else { 174 | cfg = cfg.Clone() 175 | } 176 | 177 | if cfg.ServerName == "" { 178 | host := d.u.Host 179 | if pos := strings.Index(host, ":"); pos != -1 { 180 | host = host[:pos] 181 | } 182 | cfg.ServerName = host 183 | } 184 | return tls.Client(c, cfg) 185 | } 186 | 187 | return c 188 | } 189 | 190 | func (d *DialOption) Dial() (wsCon *Conn, err error) { 191 | // scheme ws -> http 192 | // scheme wss -> https 193 | req, secWebSocket, err := d.handshake() 194 | if err != nil { 195 | return nil, err 196 | } 197 | 198 | var conn net.Conn 199 | 200 | hostName := hostname.GetHostName(d.u) 201 | dialFunc := net.DialTimeout 202 | if d.dialFunc != nil { 203 | dialInterface, err := d.dialFunc() 204 | if err != nil { 205 | return nil, err 206 | } 207 | dialFunc = func(network, address string, timeout time.Duration) (net.Conn, error) { 208 | return dialInterface.Dial(network, address) 209 | } 210 | } 211 | 212 | if d.proxyFunc != nil { 213 | proxyURL, err := d.proxyFunc(req) 214 | if err != nil { 215 | return nil, err 216 | } 217 | dialFunc = newhttpProxy(proxyURL, dialFunc).DialTimeout 218 | } 219 | 220 | conn, err = dialFunc("tcp", hostName, d.dialTimeout) 221 | if err != nil { 222 | return nil, err 223 | } 224 | 225 | conn = d.tlsConn(conn) 226 | defer func() { 227 | if err != nil && conn != nil { 228 | conn.Close() 229 | conn = nil 230 | } 231 | }() 232 | 233 | err = conn.SetDeadline(time.Time{}) 234 | if err = req.Write(conn); err != nil { 235 | return 236 | } 237 | 238 | br := bufio.NewReader(bufio.NewReader(conn)) 239 | rsp, err := http.ReadResponse(br, req) 240 | if err != nil { 241 | return nil, err 242 | } 243 | 244 | if d.bindClientHttpHeader != nil { 245 | *d.bindClientHttpHeader = rsp.Header.Clone() 246 | } 247 | 248 | pd, err := deflate.GetConnPermessageDeflate(rsp.Header) 249 | if err != nil { 250 | return nil, err 251 | } 252 | if d.Decompression { 253 | pd.Decompression = pd.Enable && d.Decompression 254 | } 255 | if d.Compression { 256 | pd.Compression = pd.Enable && d.Compression 257 | } 258 | 259 | if err = d.validateRsp(rsp, secWebSocket); err != nil { 260 | return 261 | } 262 | 263 | // 处理下已经在bufio里面的数据,后面都是直接操作net.Conn,所以需要取出bufio里面已读取的数据 264 | var fr fixedreader.FixedReader 265 | if d.parseMode == ParseModeWindows { 266 | fr.Init(conn, bytespool.GetBytes(1024+enum.MaxFrameHeaderSize)) 267 | if br.Buffered() > 0 { 268 | b, err := br.Peek(br.Buffered()) 269 | if err != nil { 270 | return nil, err 271 | } 272 | 273 | buf := fr.BufPtr() 274 | if len(b) > 1024+enum.MaxFrameHeaderSize { 275 | bytespool.PutBytes(buf) 276 | buf = bytespool.GetBytes(len(b) + enum.MaxFrameHeaderSize) 277 | 278 | fr.Reset(buf) 279 | } 280 | 281 | copy(*buf, b) 282 | fr.W = len(b) 283 | } 284 | bufio2.ClearReader(br) 285 | br = nil 286 | } 287 | if err := conn.SetDeadline(time.Time{}); err != nil { 288 | return nil, err 289 | } 290 | if wsCon, err = newConn(conn, true /* client is true*/, &d.Config, fr, br); err != nil { 291 | return nil, err 292 | } 293 | wsCon.pd = pd 294 | wsCon.Callback = d.cb 295 | return wsCon, nil 296 | } 297 | -------------------------------------------------------------------------------- /common_options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "net/http" 18 | "net/url" 19 | "time" 20 | "unicode/utf8" 21 | ) 22 | 23 | // 0. CallbackFunc 24 | func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ClientOption { 25 | return func(o *DialOption) { 26 | o.cb = &funcToCallback{ 27 | onOpen: open, 28 | onMessage: m, 29 | onClose: c, 30 | } 31 | } 32 | } 33 | 34 | // 配置服务端回调函数 35 | func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ServerOption { 36 | return func(o *ConnOption) { 37 | o.cb = &funcToCallback{ 38 | onOpen: open, 39 | onMessage: m, 40 | onClose: c, 41 | } 42 | } 43 | } 44 | 45 | // 1. callback 46 | // 配置客户端callback 47 | func WithClientCallback(cb Callback) ClientOption { 48 | return func(o *DialOption) { 49 | o.cb = cb 50 | } 51 | } 52 | 53 | // 配置服务端回调函数 54 | func WithServerCallback(cb Callback) ServerOption { 55 | return func(o *ConnOption) { 56 | o.cb = cb 57 | } 58 | } 59 | 60 | // 2. 设置TCP_NODELAY 61 | // 设置客户端TCP_NODELAY 62 | func WithClientTCPDelay() ClientOption { 63 | return func(o *DialOption) { 64 | o.tcpNoDelay = false 65 | } 66 | } 67 | 68 | // 设置TCP_NODELAY 为false, 开启nagle算法 69 | // 设置服务端TCP_NODELAY 70 | func WithServerTCPDelay() ServerOption { 71 | return func(o *ConnOption) { 72 | o.tcpNoDelay = false 73 | } 74 | } 75 | 76 | // 3.关闭utf8检查 77 | func WithServerEnableUTF8Check() ServerOption { 78 | return func(o *ConnOption) { 79 | o.utf8Check = utf8.Valid 80 | } 81 | } 82 | 83 | func WithClientEnableUTF8Check() ClientOption { 84 | return func(o *DialOption) { 85 | o.utf8Check = utf8.Valid 86 | } 87 | } 88 | 89 | // 4.仅仅配置OnMessae函数 90 | // 仅仅配置OnMessae函数 91 | func WithServerOnMessageFunc(cb OnMessageFunc) ServerOption { 92 | return func(o *ConnOption) { 93 | o.cb = OnMessageFunc(cb) 94 | } 95 | } 96 | 97 | // 仅仅配置OnMessae函数 98 | func WithClientOnMessageFunc(cb OnMessageFunc) ClientOption { 99 | return func(o *DialOption) { 100 | o.cb = OnMessageFunc(cb) 101 | } 102 | } 103 | 104 | // 5. 105 | // 配置自动回应ping frame, 当收到ping, 回一个pong 106 | func WithServerReplyPing() ServerOption { 107 | return func(o *ConnOption) { 108 | o.replyPing = true 109 | } 110 | } 111 | 112 | // 配置自动回应ping frame, 当收到ping, 回一个pong 113 | func WithClientReplyPing() ClientOption { 114 | return func(o *DialOption) { 115 | o.replyPing = true 116 | } 117 | } 118 | 119 | // 6 配置忽略pong消息 120 | func WithClientIgnorePong() ClientOption { 121 | return func(o *DialOption) { 122 | o.ignorePong = true 123 | } 124 | } 125 | 126 | func WithServerIgnorePong() ServerOption { 127 | return func(o *ConnOption) { 128 | o.ignorePong = true 129 | } 130 | } 131 | 132 | // 7. 133 | // 设置几倍payload的缓冲区 134 | // 只有解析方式是窗口的时候才有效 135 | // 如果为1.0就是1024 + 14, 如果是2.0就是2048 + 14 136 | func WithServerWindowsMultipleTimesPayloadSize(mt float32) ServerOption { 137 | return func(o *ConnOption) { 138 | if mt < 1.0 { 139 | mt = 1.0 140 | } 141 | o.windowsMultipleTimesPayloadSize = mt 142 | } 143 | } 144 | 145 | func WithClientWindowsMultipleTimesPayloadSize(mt float32) ClientOption { 146 | return func(o *DialOption) { 147 | if mt < 1.0 { 148 | mt = 1.0 149 | } 150 | o.windowsMultipleTimesPayloadSize = mt 151 | } 152 | } 153 | 154 | // 8 配置windows解析方式 155 | // 默认使用窗口解析方式, 以后以后默认解析方式改变过,才有必要使用这个选项 156 | func WithServerWindowsParseMode() ServerOption { 157 | return func(o *ConnOption) { 158 | o.parseMode = ParseModeWindows 159 | } 160 | } 161 | 162 | // 默认使用窗口解析方式, 以后以后默认解析方式改变过,才有必要使用这个选项 163 | func WithClientWindowsParseMode() ClientOption { 164 | return func(o *DialOption) { 165 | o.parseMode = ParseModeWindows 166 | } 167 | } 168 | 169 | // 9. 170 | // 171 | // 使用基于bufio的解析方式 172 | func WithServerBufioParseMode() ServerOption { 173 | return func(o *ConnOption) { 174 | o.parseMode = ParseModeBufio 175 | } 176 | } 177 | 178 | func WithClientBufioParseMode() ClientOption { 179 | return func(o *DialOption) { 180 | o.parseMode = ParseModeBufio 181 | } 182 | } 183 | 184 | // 10 配置解压缩 185 | func WithClientDecompression() ClientOption { 186 | return func(o *DialOption) { 187 | o.Decompression = true 188 | } 189 | } 190 | 191 | func WithServerDecompression() ServerOption { 192 | return func(o *ConnOption) { 193 | o.Decompression = true 194 | } 195 | } 196 | 197 | // 11 关闭bufio clear hack优化 198 | func WithServerDisableBufioClearHack() ServerOption { 199 | return func(o *ConnOption) { 200 | o.disableBufioClearHack = true 201 | } 202 | } 203 | 204 | func WithClientDisableBufioClearHack() ClientOption { 205 | return func(o *DialOption) { 206 | o.disableBufioClearHack = true 207 | } 208 | } 209 | 210 | // 12 配置多倍payload缓冲区, 1.是1024 2。是2048 211 | // 为何不让用户自己配置呢,可以和底层的buffer池结合起来,/1024就知道命中哪个缓冲区了, 不需要维护index命中的哪个sync.Pool 212 | // 如果用户传些奇奇怪怪的数字,就不好办了 213 | func WithServerBufioMultipleTimesPayloadSize(mt float32) ServerOption { 214 | return func(o *ConnOption) { 215 | if mt <= 0 { 216 | mt = 1.0 217 | } 218 | o.bufioMultipleTimesPayloadSize = mt 219 | } 220 | } 221 | 222 | func WithClientBufioMultipleTimesPayloadSize(mt float32) ClientOption { 223 | return func(o *DialOption) { 224 | if mt <= 0 { 225 | mt = 1.0 226 | } 227 | o.bufioMultipleTimesPayloadSize = mt 228 | } 229 | } 230 | 231 | // 13. 配置延迟发送 232 | // 配置延迟最大发送时间 233 | func WithServerMaxDelayWriteDuration(d time.Duration) ServerOption { 234 | return func(o *ConnOption) { 235 | o.maxDelayWriteDuration = d 236 | } 237 | } 238 | 239 | // 13. 配置延迟发送 240 | // 配置延迟最大发送时间 241 | func WithClientMaxDelayWriteDuration(d time.Duration) ClientOption { 242 | return func(o *DialOption) { 243 | o.maxDelayWriteDuration = d 244 | } 245 | } 246 | 247 | // 14.1 配置最大延迟个数.server 248 | func WithServerMaxDelayWriteNum(n int32) ServerOption { 249 | return func(o *ConnOption) { 250 | o.maxDelayWriteNum = n 251 | } 252 | } 253 | 254 | // 14.2 配置最大延迟个数.client 255 | func WithClientMaxDelayWriteNum(n int32) ClientOption { 256 | return func(o *DialOption) { 257 | o.maxDelayWriteNum = n 258 | } 259 | } 260 | 261 | // 15.1 配置延迟包的初始化buffer大小 262 | func WithServerDelayWriteInitBufferSize(n int32) ServerOption { 263 | return func(o *ConnOption) { 264 | o.delayWriteInitBufferSize = n 265 | } 266 | } 267 | 268 | // 15.2 配置延迟包的初始化buffer大小 269 | func WithClientDelayWriteInitBufferSize(n int32) ClientOption { 270 | return func(o *DialOption) { 271 | o.delayWriteInitBufferSize = n 272 | } 273 | } 274 | 275 | // 16. 配置读超时时间 276 | // 277 | // 16.1 .设置服务端读超时时间 278 | func WithServerReadTimeout(t time.Duration) ServerOption { 279 | return func(o *ConnOption) { 280 | o.readTimeout = t 281 | } 282 | } 283 | 284 | // 16.2 .设置客户端读超时时间 285 | func WithClientReadTimeout(t time.Duration) ClientOption { 286 | return func(o *DialOption) { 287 | o.readTimeout = t 288 | } 289 | } 290 | 291 | // 17。 只配置OnClose 292 | // 17.1 配置服务端OnClose 293 | func WithServerOnCloseFunc(onClose func(c *Conn, err error)) ServerOption { 294 | return func(o *ConnOption) { 295 | o.cb = OnCloseFunc(onClose) 296 | } 297 | } 298 | 299 | // 17.2 配置客户端OnClose 300 | func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption { 301 | return func(o *DialOption) { 302 | o.cb = OnCloseFunc(onClose) 303 | } 304 | } 305 | 306 | // 18. 配置新的dial函数, 这里可以配置socks5代理地址 307 | func WithClientDialFunc(dialFunc func() (Dialer, error)) ClientOption { 308 | return func(o *DialOption) { 309 | o.dialFunc = dialFunc 310 | } 311 | } 312 | 313 | // 19. 配置proxy地址 314 | func WithClientProxyFunc(proxyFunc func(*http.Request) (*url.URL, error)) ClientOption { 315 | return func(o *DialOption) { 316 | o.proxyFunc = proxyFunc 317 | } 318 | } 319 | 320 | // 20. 设置支持的子协议 321 | // 20.1 设置客户端支持的子协议 322 | func WithClientSubprotocols(subprotocols []string) ClientOption { 323 | return func(o *DialOption) { 324 | o.subProtocols = subprotocols 325 | } 326 | } 327 | 328 | // 20.2 设置服务端支持的子协议 329 | func WithServerSubprotocols(subprotocols []string) ServerOption { 330 | return func(o *ConnOption) { 331 | o.subProtocols = subprotocols 332 | } 333 | } 334 | 335 | // 21.1 设置客户端支持上下文接管, 默认不支持上下文接管 336 | func WithClientContextTakeover() ClientOption { 337 | return func(o *DialOption) { 338 | o.ClientContextTakeover = true 339 | } 340 | } 341 | 342 | // 21.2 设置服务端支持上下文接管, 默认不支持上下文接管 343 | func WithServerContextTakeover() ServerOption { 344 | return func(o *ConnOption) { 345 | o.ServerContextTakeover = true 346 | } 347 | } 348 | 349 | // 21.1 设置客户端最大窗口位数,使用上下文接管时,这个参数才有效 350 | func WithClientMaxWindowsBits(bits uint8) ClientOption { 351 | return func(o *DialOption) { 352 | if bits < 8 || bits > 15 { 353 | return 354 | } 355 | o.ClientMaxWindowBits = bits 356 | } 357 | } 358 | 359 | // 22.2 设置服务端最大窗口位数, 使用上下文接管时,这个参数才有效 360 | func WithServerMaxWindowBits(bits uint8) ServerOption { 361 | return func(o *ConnOption) { 362 | if bits < 8 || bits > 15 { 363 | return 364 | } 365 | o.ServerMaxWindowBits = bits 366 | } 367 | } 368 | 369 | // 22.1 设置客户端最大可以读取的message的大小, 默认没有限制 370 | func WithClientReadMaxMessage(size int64) ClientOption { 371 | return func(o *DialOption) { 372 | o.readMaxMessage = size 373 | } 374 | } 375 | 376 | // 22.2 设置服务端最大可以读取的message的大小,默认没有限制 377 | func WithServerReadMaxMessage(size int64) ServerOption { 378 | return func(o *ConnOption) { 379 | o.readMaxMessage = size 380 | } 381 | } 382 | 383 | // 22.1配置客户端压缩和解压缩 384 | func WithClientDecompressAndCompress() ClientOption { 385 | return func(o *DialOption) { 386 | o.Compression = true 387 | o.Decompression = true 388 | } 389 | } 390 | 391 | // 22.2配置服务端压缩和解压缩 392 | func WithServerDecompressAndCompress() ServerOption { 393 | return func(o *ConnOption) { 394 | o.Compression = true 395 | o.Decompression = true 396 | } 397 | } 398 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 简介 2 | 3 | quickws是一个高性能的websocket库 4 | 5 | ![Go](https://github.com/antlabs/quickws/workflows/Go/badge.svg) 6 | [![codecov](https://codecov.io/gh/antlabs/quickws/branch/main/graph/badge.svg)](https://codecov.io/gh/antlabs/quickws) 7 | [![Go Report Card](https://goreportcard.com/badge/github.com/antlabs/quickws)](https://goreportcard.com/report/github.com/antlabs/quickws) 8 | 9 | ## 特性 10 | 11 | * 完整实现rfc6455 12 | * 完整实现rfc7692 13 | * 高tps 14 | * 低内存占用 15 | * 池化管理所有buffer 16 | 17 | ## 内容 18 | 19 | * [安装](#installation) 20 | * [例子](#example) 21 | * [net/http升级到websocket服务端](#net-http升级到websocket服务端) 22 | * [gin升级到websocket服务端](#gin升级到websocket服务端) 23 | * [客户端](#客户端) 24 | * [配置函数](#配置函数) 25 | * [客户端配置参数](#客户端配置) 26 | * [配置header](#配置header) 27 | * [配置握手时的超时时间](#配置握手时的超时时间) 28 | * [配置自动回复ping消息](#配置自动回复ping消息) 29 | * [配置socks5代理](#配置socks5代理) 30 | * [配置proxy代理](#配置proxy代理) 31 | * [配置客户端最大读取message](#配置客户端最大读message) 32 | * [配置客户端压缩和解压消息](#配置客户端压缩和解压消息) 33 | * [配置客户端上下文接管](#配置客户端上下文接管) 34 | * [服务配置参数](#服务端配置) 35 | * [配置服务自动回复ping消息](#配置服务自动回复ping消息) 36 | * [配置服务端最大读取message](#配置服务端最大读message) 37 | * [配置服务端解压消息](#配置服务端解压消息) 38 | * [配置服务端压缩和解压消息](#配置服务端压缩和解压消息) 39 | * [配置服务端上下文接管](#配置服务端上下文接管) 40 | 41 | * [综合例子](#综合例子) 42 | 43 | ## 注意⚠️ 44 | 45 | quickws默认返回read buffer的浅引用,如果生命周期超过OnMessage的,需要clone一份再使用 46 | 47 | ## Installation 48 | 49 | ```console 50 | go get github.com/antlabs/quickws 51 | ``` 52 | 53 | ## example 54 | 55 | ### net http升级到websocket服务端 56 | 57 | ```go 58 | 59 | package main 60 | 61 | import ( 62 | "fmt" 63 | "net/http" 64 | "time" 65 | 66 | "github.com/antlabs/quickws" 67 | ) 68 | 69 | type echoHandler struct{} 70 | 71 | func (e *echoHandler) OnOpen(c *quickws.Conn) { 72 | fmt.Println("OnOpen:\n") 73 | } 74 | 75 | func (e *echoHandler) OnMessage(c *quickws.Conn, op quickws.Opcode, msg []byte) { 76 | fmt.Printf("OnMessage: %s, %v\n", msg, op) 77 | if err := c.WriteTimeout(op, msg, 3*time.Second); err != nil { 78 | fmt.Println("write fail:", err) 79 | } 80 | } 81 | 82 | func (e *echoHandler) OnClose(c *quickws.Conn, err error) { 83 | fmt.Println("OnClose: %v", err) 84 | } 85 | 86 | // echo测试服务 87 | func echo(w http.ResponseWriter, r *http.Request) { 88 | c, err := quickws.Upgrade(w, r, quickws.WithServerReplyPing(), 89 | // quickws.WithServerDecompression(), 90 | // quickws.WithServerIgnorePong(), 91 | quickws.WithServerCallback(&echoHandler{}), 92 | quickws.WithServerReadTimeout(5*time.Second), 93 | ) 94 | if err != nil { 95 | fmt.Println("Upgrade fail:", err) 96 | return 97 | } 98 | 99 | c.StartReadLoop() 100 | } 101 | 102 | func main() { 103 | http.HandleFunc("/", echo) 104 | 105 | http.ListenAndServe(":8080", nil) 106 | } 107 | 108 | ``` 109 | 110 | [返回](#内容) 111 | 112 | ### gin升级到websocket服务端 113 | 114 | ```go 115 | package main 116 | 117 | import ( 118 | "fmt" 119 | 120 | "github.com/antlabs/quickws" 121 | "github.com/gin-gonic/gin" 122 | ) 123 | 124 | type handler struct{} 125 | 126 | func (h *handler) OnOpen(c *quickws.Conn) { 127 | fmt.Printf("服务端收到一个新的连接") 128 | } 129 | 130 | func (h *handler) OnMessage(c *quickws.Conn, op quickws.Opcode, msg []byte) { 131 | // 如果msg的生命周期不是在OnMessage中结束,需要拷贝一份 132 | // newMsg := make([]byte, len(msg)) 133 | // copy(newMsg, msg) 134 | 135 | fmt.Printf("收到客户端消息:%s\n", msg) 136 | c.WriteMessage(op, msg) 137 | // os.Stdout.Write(msg) 138 | } 139 | 140 | func (h *handler) OnClose(c *quickws.Conn, err error) { 141 | fmt.Printf("服务端连接关闭:%v\n", err) 142 | } 143 | 144 | func main() { 145 | r := gin.Default() 146 | r.GET("/", func(c *gin.Context) { 147 | con, err := quickws.Upgrade(c.Writer, c.Request, quickws.WithServerCallback(&handler{})) 148 | if err != nil { 149 | return 150 | } 151 | con.StartReadLoop() 152 | }) 153 | r.Run() 154 | } 155 | ``` 156 | 157 | [返回](#内容) 158 | 159 | ### 客户端 160 | 161 | ```go 162 | package main 163 | 164 | import ( 165 | "fmt" 166 | "time" 167 | 168 | "github.com/antlabs/quickws" 169 | ) 170 | 171 | type handler struct{} 172 | 173 | func (h *handler) OnOpen(c *quickws.Conn) { 174 | fmt.Printf("客户端连接成功\n") 175 | } 176 | 177 | func (h *handler) OnMessage(c *quickws.Conn, op quickws.Opcode, msg []byte) { 178 | // 如果msg的生命周期不是在OnMessage中结束,需要拷贝一份 179 | // newMsg := make([]byte, len(msg)) 180 | // copy(newMsg, msg) 181 | 182 | fmt.Printf("收到服务端消息:%s\n", msg) 183 | c.WriteMessage(op, msg) 184 | time.Sleep(time.Second) 185 | } 186 | 187 | func (h *handler) OnClose(c *quickws.Conn, err error) { 188 | fmt.Printf("客户端端连接关闭:%v\n", err) 189 | } 190 | 191 | func main() { 192 | c, err := quickws.Dial("ws://127.0.0.1:8080/", quickws.WithClientCallback(&handler{})) 193 | if err != nil { 194 | fmt.Printf("连接失败:%v\n", err) 195 | return 196 | } 197 | 198 | c.WriteMessage(opcode.Text, []byte("hello")) 199 | c.ReadLoop() 200 | } 201 | ``` 202 | 203 | [返回](#内容) 204 | 205 | ## 配置函数 206 | 207 | ### 客户端配置参数 208 | 209 | #### 配置header 210 | 211 | ```go 212 | func main() { 213 | quickws.Dial("ws://127.0.0.1:12345/test", quickws.WithClientHTTPHeader(http.Header{ 214 | "h1": "v1", 215 | "h2":"v2", 216 | })) 217 | } 218 | ``` 219 | 220 | [返回](#内容) 221 | 222 | #### 配置握手时的超时时间 223 | 224 | ```go 225 | func main() { 226 | quickws.Dial("ws://127.0.0.1:12345/test", quickws.WithClientDialTimeout(2 * time.Second)) 227 | } 228 | ``` 229 | 230 | [返回](#内容) 231 | 232 | #### 配置自动回复ping消息 233 | 234 | ```go 235 | func main() { 236 | quickws.Dial("ws://127.0.0.1:12345/test", quickws.WithClientReplyPing()) 237 | } 238 | ``` 239 | 240 | [返回](#内容) 241 | 242 | #### 配置socks5代理 243 | 244 | ```go 245 | import( 246 | "github.com/antlabs/quickws" 247 | "golang.org/x/net/proxy" 248 | ) 249 | 250 | func main() { 251 | quickws.Dial("ws://127.0.0.1:12345", quickws.WithClientDialFunc(func() (quickws.Dialer, error) { 252 | return proxy.SOCKS5("tcp", "socks5代理服务地址", nil, nil) 253 | })) 254 | } 255 | ``` 256 | 257 | [返回](#内容) 258 | 259 | #### 配置proxy代理 260 | 261 | ```go 262 | import( 263 | "github.com/antlabs/quickws" 264 | ) 265 | 266 | func main() { 267 | 268 | proxy := func(*http.Request) (*url.URL, error) { 269 | return url.Parse("http://127.0.0.1:1007") 270 | } 271 | 272 | quickws.Dial("ws://127.0.0.1:12345", quickws.WithClientProxyFunc(proxy)) 273 | } 274 | ``` 275 | 276 | [返回](#内容) 277 | 278 | #### 配置客户端最大读message 279 | 280 | ```go 281 | func main() { 282 | // 限制客户端最大服务返回返回的最大包是1024,如果超过这个大小报错 283 | quickws.Dial("ws://127.0.0.1:12345/test", quickws.WithClientReadMaxMessage(1024)) 284 | } 285 | ``` 286 | 287 | [返回](#内容) 288 | 289 | #### 配置客户端压缩和解压消息 290 | 291 | ```go 292 | func main() { 293 | quickws.Dial("ws://127.0.0.1:12345/test", quickws.WithClientDecompressAndCompress()) 294 | } 295 | ``` 296 | 297 | [返回](#内容) 298 | 299 | #### 配置客户端上下文接管 300 | 301 | ```go 302 | func main() { 303 | quickws.Dial("ws://127.0.0.1:12345/test", quickws.WithClientContextTakeover()) 304 | } 305 | ``` 306 | 307 | [返回](#内容) 308 | 309 | ### 服务端配置参数 310 | 311 | #### 配置服务自动回复ping消息 312 | 313 | ```go 314 | func main() { 315 | c, err := quickws.Upgrade(w, r, quickws.WithServerReplyPing()) 316 | if err != nil { 317 | fmt.Println("Upgrade fail:", err) 318 | return 319 | } 320 | } 321 | ``` 322 | 323 | [返回](#内容) 324 | 325 | #### 配置服务端最大读message 326 | 327 | ```go 328 | func main() { 329 | // 配置服务端读取客户端最大的包是1024大小, 超过该值报错 330 | c, err := quickws.Upgrade(w, r, quickws.WithServerReadMaxMessage(1024)) 331 | if err != nil { 332 | fmt.Println("Upgrade fail:", err) 333 | return 334 | } 335 | } 336 | ``` 337 | 338 | [返回](#内容) 339 | 340 | #### 配置服务端解压消息 341 | 342 | ```go 343 | func main() { 344 | // 配置服务端读取客户端最大的包是1024大小, 超过该值报错 345 | c, err := quickws.Upgrade(w, r, quickws.WithServerDecompression()) 346 | if err != nil { 347 | fmt.Println("Upgrade fail:", err) 348 | return 349 | } 350 | } 351 | ``` 352 | 353 | [返回](#内容) 354 | 355 | #### 配置服务端压缩和解压消息 356 | 357 | ```go 358 | func main() { 359 | c, err := quickws.Upgrade(w, r, quickws.WithServerDecompressAndCompress()) 360 | if err != nil { 361 | fmt.Println("Upgrade fail:", err) 362 | return 363 | } 364 | } 365 | ``` 366 | 367 | [返回](#内容) 368 | 369 | #### 配置服务端上下文接管 370 | 371 | ```go 372 | func main() { 373 | // 配置服务端读取客户端最大的包是1024大小, 超过该值报错 374 | c, err := quickws.Upgrade(w, r, quickws.WithServerContextTakeover) 375 | if err != nil { 376 | fmt.Println("Upgrade fail:", err) 377 | return 378 | } 379 | } 380 | ``` 381 | 382 | [返回](#内容) 383 | 384 | ## 综合例子 385 | 386 | 387 | 388 | ## 常见问题 389 | 390 | ### 1.为什么quickws不标榜zero upgrade? 391 | 392 | 第一:quickws 是基于 std 的方案实现的 websocket 协议。 393 | 394 | 第二:原因是 zero upgrade 对 websocket 的性能提升几乎没有影响(同步方式),所以 quickws 就没有选择花时间优化 upgrade 过程, 395 | 396 | 直接基于 net/http, websocket 的协议是整体符合大数定律,一个存活几秒的websocket协议由 upgrade(握手) frame(数据包) frame frame 。。。组成。 397 | 398 | 所以随着时间的增长, upgrade 对整体的影响接近于0,我们用数字代入下。 399 | 400 | A: 代表 upgrade 可能会慢点,但是 frame 的过程比较快,比如基于 net/http 方案的 websocket 401 | 402 | upgrade (100ms) frame(10ms) frame(10ms) frame(10ms) avg = 32.5ms 403 | 404 | B: 代表主打zero upgrade的库,假如frame的过程处理慢点, 405 | 406 | upgrade (90ms) frame(15ms) frame(15ms) frame(15ms) avg = 33.75ms 407 | 408 | 简单代入下已经证明了,决定 websocket 差距的是 frame 的处理过程,无论是tps还是内存占用 quickws 在实战中也会证明这个点。所以没有必须也不需要在 upgrade 下功夫,常规优化就够了。 409 | 410 | ### 2.quickws tps如何 411 | 412 | 在5800h的cpu上面,tps稳定在47w/s,接近48w/s。比gorilla使用ReadMessage的38.9w/s,快了近9w/s 413 | 414 | ``` 415 | quickws.1: 416 | 1s:357999/s 2s:418860/s 3s:440650/s 4s:453360/s 5s:461108/s 6s:465898/s 7s:469211/s 8s:470780/s 9s:472923/s 10s:473821/s 11s:474525/s 12s:475463/s 13s:476021/s 14s:476410/s 15s:477593/s 16s:477943/s 17s:478038/s 417 | gorilla-linux-ReadMessage.4.1 418 | 1s:271126/s 2s:329367/s 3s:353468/s 4s:364842/s 5s:371908/s 6s:377633/s 7s:380870/s 8s:383271/s 9s:384646/s 10s:385986/s 11s:386448/s 12s:386554/s 13s:387573/s 14s:388263/s 15s:388701/s 16s:388867/s 17s:389383/s 419 | gorilla-linux-UseReader.4.2: 420 | 1s:293888/s 2s:377628/s 3s:399744/s 4s:413150/s 5s:421092/s 6s:426666/s 7s:430239/s 8s:432801/s 9s:434977/s 10s:436058/s 11s:436805/s 12s:437865/s 13s:438421/s 14s:438901/s 15s:439133/s 16s:439409/s 17s:439578/s 421 | gobwas.6: 422 | 1s:215995/s 2s:279405/s 3s:302249/s 4s:312545/s 5s:318922/s 6s:323800/s 7s:326908/s 8s:329977/s 9s:330959/s 10s:331510/s 11s:331911/s 12s:332396/s 13s:332418/s 14s:332887/s 15s:333198/s 16s:333390/s 17s:333550/s 423 | ``` 424 | 425 | ### 3.quickws 流量测试数据如何 ? 426 | 427 | 在5800h的cpu上面, 同尺寸read buffer(4k), 对比默认用法,quickws在30s处理119GB数据,gorilla处理48GB数据。 428 | 429 | * quickws 430 | 431 | ``` 432 | quickws.windows.tcp.delay.4x: 433 | Destination: [127.0.0.1]:9000 434 | Interface lo address [127.0.0.1]:0 435 | Using interface lo to connect to [127.0.0.1]:9000 436 | Ramped up to 10000 connections. 437 | Total data sent: 119153.9 MiB (124941915494 bytes) 438 | Total data received: 119594.6 MiB (125404036361 bytes) 439 | Bandwidth per channel: 6.625⇅ Mbps (828.2 kBps) 440 | Aggregate bandwidth: 33439.980↓, 33316.752↑ Mbps 441 | Packet rate estimate: 3174704.8↓, 2930514.7↑ (9↓, 34↑ TCP MSS/op) 442 | Test duration: 30.001 s. 443 | ``` 444 | 445 | * gorilla 使用ReadMessage取数据 446 | 447 | ``` 448 | gorilla-linux-ReadMessage.tcp.delay: 449 | WARNING: Dumb terminal, expect unglorified output. 450 | Destination: [127.0.0.1]:9003 451 | Interface lo address [127.0.0.1]:0 452 | Using interface lo to connect to [127.0.0.1]:9003 453 | Ramped up to 10000 connections. 454 | Total data sent: 48678.1 MiB (51042707521 bytes) 455 | Total data received: 50406.2 MiB (52854715802 bytes) 456 | Bandwidth per channel: 2.771⇅ Mbps (346.3 kBps) 457 | Aggregate bandwidth: 14094.587↓, 13611.385↑ Mbps 458 | Packet rate estimate: 1399915.6↓, 1190593.2↑ (6↓, 45↑ TCP MSS/op) 459 | Test duration: 30 s. 460 | ``` 461 | 462 | ### 4.内存占用如何 ? 463 | 464 | quickws的特色之一是低内存占用。 465 | 466 | 1w连接的tps测试,1k payload 回写,初始内存占用约122MB, 在240s-260s之后大约86MB, 467 | 468 | ## 百万长链接测试 469 | 470 | ``` 471 | BenchType : BenchEcho 472 | Framework : quickws 473 | TPS : 108143 474 | EER : -118.52 475 | Min : 32.99us 476 | Avg : 92.26ms 477 | Max : 1.03s 478 | TP50 : 48.37ms 479 | TP75 : 53.88ms 480 | TP90 : 215.18ms 481 | TP95 : 430.07ms 482 | TP99 : 502.95ms 483 | Used : 18.49s 484 | Total : 2000000 485 | Success : 2000000 486 | Failed : 0 487 | Conns : 1000000 488 | Concurrency: 10000 489 | Payload : 1024 490 | CPU Min : -520020.80% 491 | CPU Avg : -912.44% 492 | CPU Max : 220653.13% 493 | MEM Min : 8.46G 494 | MEM Avg : 8.47G 495 | MEM Max : 8.47G 496 | ``` 497 | -------------------------------------------------------------------------------- /client_option_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "bytes" 19 | "crypto/tls" 20 | "encoding/binary" 21 | "io" 22 | "net" 23 | "net/http" 24 | "net/http/httptest" 25 | "strings" 26 | "sync" 27 | "sync/atomic" 28 | "testing" 29 | "time" 30 | 31 | "golang.org/x/net/proxy" 32 | ) 33 | 34 | func Test_ClientOption(t *testing.T) { 35 | t.Run("ClientOption.WithClientHTTPHeader", func(t *testing.T) { 36 | done := make(chan string, 1) 37 | run := int32(0) 38 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 | v := r.Header.Get("A") 40 | done <- v 41 | con, err := Upgrade(w, r) 42 | if err != nil { 43 | t.Error(err) 44 | return 45 | } 46 | 47 | defer con.Close() 48 | atomic.AddInt32(&run, 1) 49 | })) 50 | 51 | defer ts.Close() 52 | 53 | url := strings.ReplaceAll(ts.URL, "http", "ws") 54 | con, err := Dial(url, WithClientHTTPHeader(http.Header{ 55 | "A": []string{"A"}, 56 | }), WithClientCallback(&testDefaultCallback{})) 57 | if err != nil { 58 | t.Error(err) 59 | return 60 | } 61 | defer con.Close() 62 | 63 | select { 64 | case v := <-done: 65 | if v != "A" { 66 | t.Error("header fail") 67 | } 68 | case <-time.After(1000 * time.Millisecond): 69 | } 70 | if atomic.LoadInt32(&run) != 1 { 71 | t.Error("not run server:method fail") 72 | } 73 | }) 74 | 75 | t.Run("ClientOption.WithClientTLSConfig", func(t *testing.T) { 76 | done := make(chan string, 1) 77 | run := int32(0) 78 | ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 79 | v := r.Header.Get("A") 80 | atomic.AddInt32(&run, 1) 81 | done <- v 82 | con, err := Upgrade(w, r) 83 | if err != nil { 84 | t.Error(err) 85 | return 86 | } 87 | 88 | defer con.Close() 89 | })) 90 | 91 | defer ts.Close() 92 | 93 | url := strings.ReplaceAll(ts.URL, "http", "ws") 94 | con, err := Dial(url, 95 | WithClientTLSConfig(&tls.Config{InsecureSkipVerify: true}), 96 | WithClientHTTPHeader(http.Header{ 97 | "A": []string{"A"}, 98 | }), WithClientCallback(&testDefaultCallback{})) 99 | if err != nil { 100 | t.Error(err) 101 | return 102 | } 103 | defer con.Close() 104 | 105 | select { 106 | case v := <-done: 107 | if v != "A" { 108 | t.Error("header fail") 109 | } 110 | case <-time.After(1000 * time.Millisecond): 111 | } 112 | if atomic.LoadInt32(&run) != 1 { 113 | t.Error("not run server:method fail") 114 | } 115 | }) 116 | 117 | t.Run("Dial.WithClientDialTimeout", func(t *testing.T) { 118 | done := make(chan string, 1) 119 | run := int32(0) 120 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 121 | v := r.Header.Get("A") 122 | done <- v 123 | con, err := Upgrade(w, r) 124 | if err != nil { 125 | t.Error(err) 126 | return 127 | } 128 | 129 | defer con.Close() 130 | atomic.AddInt32(&run, 1) 131 | })) 132 | 133 | defer ts.Close() 134 | 135 | url := strings.ReplaceAll(ts.URL, "http", "ws") 136 | con, err := Dial(url, WithClientHTTPHeader(http.Header{ 137 | "A": []string{"A"}, 138 | }), WithClientCallback(&testDefaultCallback{}), WithClientDialTimeout(time.Second)) 139 | if err != nil { 140 | t.Error(err) 141 | return 142 | } 143 | defer con.Close() 144 | 145 | select { 146 | case v := <-done: 147 | if v != "A" { 148 | t.Error("header fail") 149 | } 150 | case <-time.After(1000 * time.Millisecond): 151 | } 152 | if atomic.LoadInt32(&run) != 1 { 153 | t.Error("not run server:method fail") 154 | } 155 | }) 156 | t.Run("ClientOption.WithClientDialTimeout", func(t *testing.T) {}) 157 | t.Run("6.1 Dial: WithClientBindHTTPHeader and echo Sec-Websocket-Protocol", func(t *testing.T) { 158 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 159 | _, err := Upgrade(w, r) 160 | if err != nil { 161 | t.Error(err) 162 | } 163 | })) 164 | 165 | defer ts.Close() 166 | 167 | url := strings.ReplaceAll(ts.URL, "http", "ws") 168 | h := make(http.Header) 169 | con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{ 170 | "Sec-WebSocket-Protocol": []string{"token"}, 171 | })) 172 | if err != nil { 173 | t.Error(err) 174 | } 175 | defer con.Close() 176 | 177 | if h["Sec-Websocket-Protocol"][0] != "token" { 178 | t.Error("header fail") 179 | } 180 | }) 181 | 182 | t.Run("6.2 DialConf: WithClientBindHTTPHeader and echo Sec-Websocket-Protocol", func(t *testing.T) { 183 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 184 | _, err := Upgrade(w, r) 185 | if err != nil { 186 | t.Error(err) 187 | } 188 | })) 189 | 190 | defer ts.Close() 191 | 192 | url := strings.ReplaceAll(ts.URL, "http", "ws") 193 | h := make(http.Header) 194 | con, err := DialConf(url, ClientOptionToConf(WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{ 195 | "Sec-WebSocket-Protocol": []string{"token"}, 196 | }))) 197 | if err != nil { 198 | t.Error(err) 199 | } 200 | defer con.Close() 201 | 202 | if h["Sec-Websocket-Protocol"][0] != "token" { 203 | t.Error("header fail") 204 | } 205 | }) 206 | 207 | t.Run("18 Dial: WithClientDialFunc.1", func(t *testing.T) { 208 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 209 | conn, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 210 | err := c.WriteMessage(o, b) 211 | if err != nil { 212 | t.Error(err) 213 | return 214 | } 215 | c.Close() 216 | })) 217 | if err != nil { 218 | t.Error(err) 219 | } 220 | 221 | conn.StartReadLoop() 222 | })) 223 | 224 | proxyAddr, err := net.Listen("tcp", "127.0.0.1:0") 225 | if err != nil { 226 | t.Error(err) 227 | } 228 | defer ts.Close() 229 | 230 | go func() { 231 | newConn, err := proxyAddr.Accept() 232 | if err != nil { 233 | t.Error(err) 234 | } 235 | 236 | err = newConn.SetDeadline(time.Now().Add(30 * time.Second)) 237 | if err != nil { 238 | t.Error(err) 239 | return 240 | } 241 | 242 | buf := make([]byte, 128) 243 | if _, err := io.ReadFull(newConn, buf[:3]); err != nil { 244 | t.Errorf("read failed: %v", err) 245 | return 246 | } 247 | 248 | // socks version 5, 1 authentication method, no auth 249 | if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { 250 | t.Errorf("read %x, want %x", buf[:len(want)], want) 251 | } 252 | 253 | // socks version 5, connect command, reserved, ipv4 address, port 80 254 | if _, err := newConn.Write([]byte{5, 0}); err != nil { 255 | t.Errorf("write failed: %v", err) 256 | return 257 | } 258 | 259 | // ver cmd rsv atyp dst.addr dst.port 260 | if _, err := io.ReadFull(newConn, buf[:10]); err != nil { 261 | t.Errorf("read failed: %v", err) 262 | return 263 | } 264 | if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { 265 | t.Errorf("read %x, want %x", buf[:len(want)], want) 266 | return 267 | } 268 | buf[1] = 0 269 | if _, err := newConn.Write(buf[:10]); err != nil { 270 | t.Errorf("write failed: %v", err) 271 | return 272 | } 273 | 274 | // 提取ip 275 | ip := net.IP(buf[4:8]) 276 | port := binary.BigEndian.Uint16(buf[8:10]) 277 | 278 | c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) 279 | if err != nil { 280 | t.Errorf("dial failed; %v", err) 281 | return 282 | } 283 | defer c2.Close() 284 | done := make(chan struct{}) 285 | go func() { 286 | _, err := io.Copy(newConn, c2) 287 | if err != nil { 288 | t.Error(err) 289 | return 290 | } 291 | 292 | close(done) 293 | }() 294 | _, err = io.Copy(c2, newConn) 295 | if err != nil { 296 | t.Error(err) 297 | return 298 | } 299 | <-done 300 | }() 301 | 302 | got := make([]byte, 0, 128) 303 | url := strings.ReplaceAll(ts.URL, "http", "ws") 304 | c, err := Dial(url, WithClientDialFunc(func() (Dialer, error) { 305 | return proxy.SOCKS5("tcp", proxyAddr.Addr().String(), nil, nil) 306 | }), WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 307 | got = append(got, b...) 308 | c.Close() 309 | })) 310 | if err != nil { 311 | t.Error(err) 312 | } 313 | 314 | data := []byte("hello world") 315 | err = c.WriteMessage(Binary, data) 316 | if err != nil { 317 | t.Error(err) 318 | return 319 | } 320 | _ = c.ReadLoop() 321 | 322 | t.Log("got", string(got), "want", string(data)) 323 | if !bytes.Equal(got, data) { 324 | t.Errorf("got %s, want %s", got, data) 325 | } 326 | }) 327 | 328 | t.Run("18 Dial: WithClientDialFunc.2", func(t *testing.T) { 329 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 330 | conn, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 331 | err := c.WriteMessage(o, b) 332 | if err != nil { 333 | t.Error(err) 334 | return 335 | } 336 | c.Close() 337 | })) 338 | if err != nil { 339 | t.Error(err) 340 | } 341 | 342 | conn.StartReadLoop() 343 | })) 344 | 345 | proxyAddr, err := net.Listen("tcp", "127.0.0.1:0") 346 | if err != nil { 347 | t.Error(err) 348 | } 349 | defer ts.Close() 350 | 351 | go func() { 352 | newConn, err := proxyAddr.Accept() 353 | if err != nil { 354 | t.Error(err) 355 | } 356 | 357 | err = newConn.SetDeadline(time.Now().Add(30 * time.Second)) 358 | if err != nil { 359 | t.Error(err) 360 | return 361 | } 362 | 363 | buf := make([]byte, 128) 364 | if _, err := io.ReadFull(newConn, buf[:3]); err != nil { 365 | t.Errorf("read failed: %v", err) 366 | return 367 | } 368 | 369 | // socks version 5, 1 authentication method, no auth 370 | if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { 371 | t.Errorf("read %x, want %x", buf[:len(want)], want) 372 | } 373 | 374 | // socks version 5, connect command, reserved, ipv4 address, port 80 375 | if _, err := newConn.Write([]byte{5, 0}); err != nil { 376 | t.Errorf("write failed: %v", err) 377 | return 378 | } 379 | 380 | // ver cmd rsv atyp dst.addr dst.port 381 | if _, err := io.ReadFull(newConn, buf[:10]); err != nil { 382 | t.Errorf("read failed: %v", err) 383 | return 384 | } 385 | if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { 386 | t.Errorf("read %x, want %x", buf[:len(want)], want) 387 | return 388 | } 389 | buf[1] = 0 390 | if _, err := newConn.Write(buf[:10]); err != nil { 391 | t.Errorf("write failed: %v", err) 392 | return 393 | } 394 | 395 | // 提取ip 396 | ip := net.IP(buf[4:8]) 397 | port := binary.BigEndian.Uint16(buf[8:10]) 398 | 399 | c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) 400 | if err != nil { 401 | t.Errorf("dial failed; %v", err) 402 | return 403 | } 404 | defer c2.Close() 405 | 406 | // done := make(chan struct{}) 407 | // newConn = &safeConn{Conn: newConn} 408 | // c2 = &safeConn{Conn: c2} 409 | // go func() { 410 | // _, err = io.Copy(newConn, c2) 411 | // if err != nil { 412 | // t.Error(err) 413 | // return 414 | // } 415 | // close(done) 416 | // }() 417 | // _, err = io.Copy(c2, newConn) 418 | // if err != nil { 419 | // t.Error(err) 420 | // return 421 | // } 422 | // <-done 423 | 424 | var ( 425 | newConnMu sync.Mutex 426 | c2Mu sync.Mutex 427 | wg sync.WaitGroup 428 | ) 429 | 430 | wg.Add(2) 431 | 432 | go func() { 433 | defer wg.Done() 434 | buf := make([]byte, 4096) 435 | for { 436 | n, err := c2.Read(buf) 437 | if err != nil { 438 | if err != io.EOF { 439 | t.Error(err) 440 | } 441 | break 442 | } 443 | newConnMu.Lock() 444 | _, err = newConn.Write(buf[:n]) 445 | newConnMu.Unlock() 446 | if err != nil { 447 | t.Error(err) 448 | break 449 | } 450 | } 451 | }() 452 | 453 | go func() { 454 | defer wg.Done() 455 | buf := make([]byte, 4096) 456 | for { 457 | n, err := newConn.Read(buf) 458 | if err != nil { 459 | if err != io.EOF { 460 | t.Error(err) 461 | } 462 | break 463 | } 464 | c2Mu.Lock() 465 | _, err = c2.Write(buf[:n]) 466 | c2Mu.Unlock() 467 | if err != nil { 468 | t.Error(err) 469 | break 470 | } 471 | } 472 | }() 473 | 474 | wg.Wait() 475 | }() 476 | 477 | got := make([]byte, 0, 128) 478 | url := strings.ReplaceAll(ts.URL, "http", "ws") 479 | c, err := DialConf(url, ClientOptionToConf(WithClientDialFunc(func() (Dialer, error) { 480 | return proxy.SOCKS5("tcp", proxyAddr.Addr().String(), nil, nil) 481 | }), WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 482 | got = append(got, b...) 483 | c.Close() 484 | }))) 485 | if err != nil { 486 | t.Error(err) 487 | } 488 | 489 | data := []byte("hello world") 490 | err = c.WriteMessage(Binary, data) 491 | if err != nil { 492 | t.Error(err) 493 | return 494 | } 495 | _ = c.ReadLoop() 496 | 497 | t.Log("got", string(got), "want", string(data)) 498 | if !bytes.Equal(got, data) { 499 | t.Errorf("got %s, want %s", got, data) 500 | } 501 | }) 502 | } 503 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "fmt" 18 | "net" 19 | "net/http" 20 | "net/http/httptest" 21 | "strings" 22 | "sync/atomic" 23 | "testing" 24 | "time" 25 | ) 26 | 27 | // 测试服务端握手失败的情况 28 | func Test_Server_HandshakeFail(t *testing.T) { 29 | // u := NewUpgrade() 30 | t.Run("local config:case:method fail", func(t *testing.T) { 31 | run := int32(0) 32 | done := make(chan bool, 1) 33 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 | _, err := Upgrade(w, r) 35 | if err == nil { 36 | t.Error("upgrade method fail") 37 | } 38 | atomic.AddInt32(&run, int32(1)) 39 | done <- true 40 | })) 41 | 42 | defer ts.Close() 43 | 44 | url := ts.URL 45 | req, err := http.NewRequest("POST", url, nil) 46 | if err != nil { 47 | t.Error(err) 48 | } 49 | _, err = http.DefaultClient.Do(req) 50 | if err != nil { 51 | t.Error(err) 52 | return 53 | } 54 | select { 55 | case <-done: 56 | case <-time.After(100 * time.Millisecond): 57 | } 58 | if atomic.LoadInt32(&run) != 1 { 59 | t.Error("not run server:method fail") 60 | } 61 | }) 62 | 63 | t.Run("global config:case:method fail", func(t *testing.T) { 64 | run := int32(0) 65 | upgrade := NewUpgrade() 66 | done := make(chan bool, 1) 67 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 68 | _, err := upgrade.Upgrade(w, r) 69 | if err == nil { 70 | t.Error("upgrade method fail") 71 | } 72 | atomic.AddInt32(&run, int32(1)) 73 | done <- true 74 | })) 75 | 76 | defer ts.Close() 77 | 78 | url := ts.URL 79 | req, err := http.NewRequest("POST", url, nil) 80 | if err != nil { 81 | t.Error(err) 82 | } 83 | _, err = http.DefaultClient.Do(req) 84 | if err != nil { 85 | t.Error(err) 86 | return 87 | } 88 | select { 89 | case <-done: 90 | case <-time.After(100 * time.Millisecond): 91 | } 92 | if atomic.LoadInt32(&run) != 1 { 93 | t.Error("not run server:method fail") 94 | } 95 | }) 96 | 97 | t.Run("local config:case:http proto fail", func(t *testing.T) { 98 | run := int32(0) 99 | done := make(chan bool, 1) 100 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 101 | _, err := Upgrade(w, r) 102 | if err == nil { 103 | t.Error("upgrade http proto fail") 104 | } 105 | atomic.AddInt32(&run, int32(1)) 106 | done <- true 107 | })) 108 | 109 | url := strings.ReplaceAll(ts.URL, "http://", "") 110 | defer ts.Close() 111 | c, err := net.Dial("tcp", url) 112 | if err != nil { 113 | t.Error(err) 114 | } 115 | _, err = c.Write([]byte("GET / HTTP/1.0\r\nHost: localhost:8080\r\n\r\n")) 116 | if err != nil { 117 | t.Error(err) 118 | return 119 | } 120 | c.Close() 121 | select { 122 | case <-done: 123 | case <-time.After(100 * time.Millisecond): 124 | } 125 | 126 | if atomic.LoadInt32(&run) != 1 { 127 | t.Error("not run server:http proto fail") 128 | } 129 | }) 130 | 131 | t.Run("global config:case:http proto fail", func(t *testing.T) { 132 | run := int32(0) 133 | upgrade := NewUpgrade() 134 | done := make(chan bool, 1) 135 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 136 | _, err := upgrade.Upgrade(w, r) 137 | if err == nil { 138 | t.Error("upgrade http proto fail") 139 | } 140 | atomic.AddInt32(&run, int32(1)) 141 | done <- true 142 | })) 143 | 144 | url := strings.ReplaceAll(ts.URL, "http://", "") 145 | defer ts.Close() 146 | c, err := net.Dial("tcp", url) 147 | if err != nil { 148 | t.Error(err) 149 | } 150 | _, err = c.Write([]byte("GET / HTTP/1.0\r\n\r\n")) 151 | if err != nil { 152 | t.Error(err) 153 | return 154 | } 155 | // c.Write([]byte("GET / HTTP/1.0\r\nHost: localhost:8080\r\n\r\n")) 156 | c.Close() 157 | 158 | select { 159 | case <-done: 160 | case <-time.After(100 * time.Millisecond): 161 | } 162 | if atomic.LoadInt32(&run) != 1 { 163 | t.Error("not run server:http proto fail") 164 | } 165 | }) 166 | 167 | t.Run("local config:case:host empty", func(t *testing.T) { 168 | run := int32(0) 169 | done := make(chan bool, 1) 170 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 171 | _, err := Upgrade(w, r) 172 | if err == nil { 173 | t.Error("upgrade host fail") 174 | } 175 | atomic.AddInt32(&run, int32(1)) 176 | done <- true 177 | })) 178 | 179 | defer ts.Close() 180 | 181 | url := strings.ReplaceAll(ts.URL, "http://", "") 182 | c, err := net.Dial("tcp", url) 183 | if err != nil { 184 | t.Error(err) 185 | } 186 | _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: \r\n\r\n")) 187 | if err != nil { 188 | t.Error(err) 189 | return 190 | } 191 | defer c.Close() 192 | select { 193 | case <-done: 194 | case <-time.After(100 * time.Millisecond): 195 | } 196 | if atomic.LoadInt32(&run) != 1 { 197 | t.Error("not run server:host empty") 198 | } 199 | }) 200 | 201 | t.Run("global config:case:upgrade fail", func(t *testing.T) { 202 | run := int32(0) 203 | upgrade := NewUpgrade() 204 | done := make(chan bool, 1) 205 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 206 | _, err := upgrade.Upgrade(w, r) 207 | if err == nil { 208 | t.Error("upgrade : upgrade field fail") 209 | } 210 | atomic.AddInt32(&run, int32(1)) 211 | done <- true 212 | })) 213 | 214 | url := strings.ReplaceAll(ts.URL, "http://", "") 215 | defer ts.Close() 216 | c, err := net.Dial("tcp", url) 217 | if err != nil { 218 | t.Error(err) 219 | } 220 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: xx\r\n\r\n", url)) 221 | _, err = c.Write(wbuf) 222 | if err != nil { 223 | t.Error(err) 224 | return 225 | } 226 | c.Close() 227 | 228 | select { 229 | case <-done: 230 | case <-time.After(100 * time.Millisecond): 231 | } 232 | if atomic.LoadInt32(&run) != 1 { 233 | t.Error("not run server:upgrade field fail") 234 | } 235 | }) 236 | 237 | t.Run("local config:case:upgrade fail", func(t *testing.T) { 238 | run := int32(0) 239 | done := make(chan bool, 1) 240 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 241 | _, err := Upgrade(w, r) 242 | if err == nil { 243 | t.Error("upgrade : upgrade field fail") 244 | } 245 | atomic.AddInt32(&run, int32(1)) 246 | done <- true 247 | })) 248 | 249 | url := strings.ReplaceAll(ts.URL, "http://", "") 250 | defer ts.Close() 251 | c, err := net.Dial("tcp", url) 252 | if err != nil { 253 | t.Error(err) 254 | } 255 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: xx\r\n\r\n", url)) 256 | _, err = c.Write(wbuf) 257 | if err != nil { 258 | t.Error(err) 259 | return 260 | } 261 | c.Close() 262 | 263 | select { 264 | case <-done: 265 | case <-time.After(100 * time.Millisecond): 266 | } 267 | if atomic.LoadInt32(&run) != 1 { 268 | t.Error("not run server:upgrade field fail") 269 | } 270 | }) 271 | 272 | t.Run("global config:case:Connection fail", func(t *testing.T) { 273 | run := int32(0) 274 | upgrade := NewUpgrade() 275 | done := make(chan bool, 1) 276 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 277 | _, err := upgrade.Upgrade(w, r) 278 | if err == nil { 279 | t.Error("upgrade : Connection field fail") 280 | } 281 | atomic.AddInt32(&run, int32(1)) 282 | done <- true 283 | })) 284 | 285 | url := strings.ReplaceAll(ts.URL, "http://", "") 286 | defer ts.Close() 287 | c, err := net.Dial("tcp", url) 288 | if err != nil { 289 | t.Error(err) 290 | } 291 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: xx\r\n\r\n", url)) 292 | _, err = c.Write(wbuf) 293 | if err != nil { 294 | t.Error(err) 295 | return 296 | } 297 | c.Close() 298 | 299 | select { 300 | case <-done: 301 | case <-time.After(100 * time.Millisecond): 302 | } 303 | if atomic.LoadInt32(&run) != 1 { 304 | t.Error("not run server:Connection field fail") 305 | } 306 | }) 307 | 308 | t.Run("local config:case:Connection fail", func(t *testing.T) { 309 | run := int32(0) 310 | done := make(chan bool, 1) 311 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 312 | _, err := Upgrade(w, r) 313 | if err == nil { 314 | t.Error("upgrade : Connection field fail") 315 | } 316 | atomic.AddInt32(&run, int32(1)) 317 | done <- true 318 | })) 319 | 320 | url := strings.ReplaceAll(ts.URL, "http://", "") 321 | defer ts.Close() 322 | c, err := net.Dial("tcp", url) 323 | if err != nil { 324 | t.Error(err) 325 | } 326 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: xx\r\n\r\n", url)) 327 | _, err = c.Write(wbuf) 328 | if err != nil { 329 | t.Error(err) 330 | return 331 | } 332 | c.Close() 333 | 334 | select { 335 | case <-done: 336 | case <-time.After(100 * time.Millisecond): 337 | } 338 | if atomic.LoadInt32(&run) != 1 { 339 | t.Error("not run server:Connection field fail") 340 | } 341 | }) 342 | 343 | t.Run("global config:case: Sec-WebSocket-Key fail", func(t *testing.T) { 344 | run := int32(0) 345 | upgrade := NewUpgrade() 346 | done := make(chan bool, 1) 347 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 348 | _, err := upgrade.Upgrade(w, r) 349 | if err == nil { 350 | t.Error("upgrade : Connection field fail") 351 | } 352 | atomic.AddInt32(&run, int32(1)) 353 | done <- true 354 | })) 355 | 356 | url := strings.ReplaceAll(ts.URL, "http://", "") 357 | defer ts.Close() 358 | c, err := net.Dial("tcp", url) 359 | if err != nil { 360 | t.Error(err) 361 | } 362 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", url)) 363 | _, err = c.Write(wbuf) 364 | if err != nil { 365 | t.Error(err) 366 | return 367 | } 368 | c.Close() 369 | 370 | select { 371 | case <-done: 372 | case <-time.After(100 * time.Millisecond): 373 | } 374 | if atomic.LoadInt32(&run) != 1 { 375 | t.Error("not run server:Connection field fail") 376 | } 377 | }) 378 | 379 | t.Run("local config:case: Sec-WebSocket-Key fail", func(t *testing.T) { 380 | run := int32(0) 381 | done := make(chan bool, 1) 382 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 383 | _, err := Upgrade(w, r) 384 | if err == nil { 385 | t.Error("upgrade : Connection field fail") 386 | } 387 | atomic.AddInt32(&run, int32(1)) 388 | done <- true 389 | })) 390 | 391 | url := strings.ReplaceAll(ts.URL, "http://", "") 392 | defer ts.Close() 393 | c, err := net.Dial("tcp", url) 394 | if err != nil { 395 | t.Error(err) 396 | } 397 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", url)) 398 | _, err = c.Write(wbuf) 399 | if err != nil { 400 | t.Error(err) 401 | return 402 | } 403 | c.Close() 404 | 405 | select { 406 | case <-done: 407 | case <-time.After(100 * time.Millisecond): 408 | } 409 | if atomic.LoadInt32(&run) != 1 { 410 | t.Error("not run server:Connection field fail") 411 | } 412 | }) 413 | 414 | t.Run("global config:case: Sec-WebSocket-Version fail", func(t *testing.T) { 415 | run := int32(0) 416 | upgrade := NewUpgrade() 417 | done := make(chan bool, 1) 418 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 419 | _, err := upgrade.Upgrade(w, r) 420 | if err == nil { 421 | t.Error("upgrade : Connection field fail") 422 | } 423 | atomic.AddInt32(&run, int32(1)) 424 | done <- true 425 | })) 426 | 427 | url := strings.ReplaceAll(ts.URL, "http://", "") 428 | defer ts.Close() 429 | c, err := net.Dial("tcp", url) 430 | if err != nil { 431 | t.Error(err) 432 | } 433 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: key\r\n\r\n", url)) 434 | _, err = c.Write(wbuf) 435 | if err != nil { 436 | t.Error(err) 437 | return 438 | } 439 | c.Close() 440 | 441 | select { 442 | case <-done: 443 | case <-time.After(100 * time.Millisecond): 444 | } 445 | if atomic.LoadInt32(&run) != 1 { 446 | t.Error("not run server:Connection field fail") 447 | } 448 | }) 449 | 450 | t.Run("local config:case: Sec-WebSocket-Version fail", func(t *testing.T) { 451 | run := int32(0) 452 | done := make(chan bool, 1) 453 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 454 | _, err := Upgrade(w, r) 455 | if err == nil { 456 | t.Error("upgrade : Connection field fail") 457 | } 458 | atomic.AddInt32(&run, int32(1)) 459 | done <- true 460 | })) 461 | 462 | url := strings.ReplaceAll(ts.URL, "http://", "") 463 | defer ts.Close() 464 | c, err := net.Dial("tcp", url) 465 | if err != nil { 466 | t.Error(err) 467 | } 468 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: key\r\n\r\n", url)) 469 | _, err = c.Write(wbuf) 470 | if err != nil { 471 | t.Error(err) 472 | return 473 | } 474 | c.Close() 475 | 476 | select { 477 | case <-done: 478 | case <-time.After(100 * time.Millisecond): 479 | } 480 | if atomic.LoadInt32(&run) != 1 { 481 | t.Error("not run server:Connection field fail") 482 | } 483 | }) 484 | } 485 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package quickws 16 | 17 | import ( 18 | "bufio" 19 | "bytes" 20 | "crypto/tls" 21 | "encoding/binary" 22 | "errors" 23 | "fmt" 24 | "io" 25 | "math/rand" 26 | "net" 27 | "sync" 28 | "sync/atomic" 29 | "time" 30 | "unsafe" 31 | 32 | "github.com/antlabs/wsutil/bufio2" 33 | "github.com/antlabs/wsutil/bytespool" 34 | "github.com/antlabs/wsutil/deflate" 35 | "github.com/antlabs/wsutil/enum" 36 | "github.com/antlabs/wsutil/fixedreader" 37 | "github.com/antlabs/wsutil/fixedwriter" 38 | "github.com/antlabs/wsutil/frame" 39 | "github.com/antlabs/wsutil/limitreader" 40 | "github.com/antlabs/wsutil/myonce" 41 | "github.com/antlabs/wsutil/opcode" 42 | ) 43 | 44 | const ( 45 | maxControlFrameSize = 125 46 | ) 47 | 48 | // var _ net.Conn = (*Conn)(nil) 49 | 50 | // 延迟写, 基于次数和时间 合并数据写入, 实验功能 51 | type delayWrite struct { 52 | delayBuf *bytes.Buffer // 延迟写的缓冲区 53 | delayTimeout *time.Timer // 延迟写的定时器 54 | delayErr error // TODO 原子操作 55 | delayNum int32 // 控制延迟写的数量 56 | } 57 | 58 | type Conn struct { 59 | fr fixedreader.FixedReader // 默认使用windows 60 | c net.Conn // net.Conn 61 | Callback // callback移至conn中 62 | br *bufio.Reader // read和fr同时只能使用一个 63 | *Config // config 可能是全局,也可能是局部初始化得来的 64 | pd deflate.PermessageDeflateConf // permessageDeflate局部配置 65 | once sync.Once // 清理资源的once 66 | readHeadArray [enum.MaxFrameHeaderSize]byte // 读取数据的头部 67 | fragmentFramePayload *[]byte // 存放分段帧的缓冲区 68 | bufioPayload *[]byte // bufio模式下的缓冲区, 默认为nil 69 | fragmentFrameHeader *frame.FrameHeader // 存放分段帧的头部 70 | wmu sync.Mutex // 写的锁 71 | *delayWrite // 只有在需要的时候才初始化, 修改为指针是为了在海量连接的时候减少内存占用 72 | deCtx *deflate.DeCompressContextTakeover // 解压缩上下文 73 | enCtx *deflate.CompressContextTakeover // 压缩上下文 74 | closed int32 // 0: open, 1: closed 75 | mu2 sync.Mutex 76 | onCloseOnce myonce.MyOnce // 保证只调用一次OnClose函数 77 | client bool // client(true) or server(flase) 78 | } 79 | 80 | func setNoDelay(c net.Conn, noDelay bool) error { 81 | if tcp, ok := c.(*net.TCPConn); ok { 82 | return tcp.SetNoDelay(noDelay) 83 | } 84 | 85 | if tlsTCP, ok := c.(*tls.Conn); ok { 86 | return setNoDelay(tlsTCP.NetConn(), noDelay) 87 | } 88 | return nil 89 | } 90 | 91 | func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) (wsCon *Conn, err error) { 92 | if err = setNoDelay(c, conf.tcpNoDelay); err != nil { 93 | return nil, err 94 | } 95 | 96 | wsCon = &Conn{ 97 | c: c, 98 | client: client, 99 | Config: conf, 100 | fr: fr, 101 | br: br, 102 | } 103 | 104 | return wsCon, err 105 | } 106 | 107 | // 返回标准库的net.Conn 108 | func (c *Conn) NetConn() net.Conn { 109 | return c.c 110 | } 111 | 112 | func (c *Conn) writeAndMaybeOnClose(err error) error { 113 | var sc *StatusCode 114 | defer func() { 115 | c.onCloseOnce.Do(&c.mu2, func() { 116 | c.Callback.OnClose(c, err) 117 | }) 118 | }() 119 | 120 | if errors.As(err, &sc) { 121 | if err := c.WriteTimeout(opcode.Close, sc.toBytes(), 2*time.Second); err != nil { 122 | return err 123 | } 124 | } 125 | return nil 126 | } 127 | 128 | func (c *Conn) writeErrAndOnClose(code StatusCode, userErr error) error { 129 | defer func() { 130 | c.onCloseOnce.Do(&c.mu2, func() { 131 | c.Callback.OnClose(c, userErr) 132 | }) 133 | }() 134 | if err := c.WriteTimeout(opcode.Close, code.toBytes(), 2*time.Second); err != nil { 135 | return err 136 | } 137 | 138 | return userErr 139 | } 140 | 141 | func (c *Conn) failRsv1(op opcode.Opcode) bool { 142 | // 解压缩没有开启 143 | if !c.pd.Decompression { 144 | return true 145 | } 146 | 147 | // 不是text和binary 148 | if op != opcode.Text && op != opcode.Binary { 149 | return true 150 | } 151 | 152 | return false 153 | } 154 | 155 | func (c *Conn) ReadLoop() (err error) { 156 | c.OnOpen(c) 157 | 158 | defer func() { 159 | // c.OnClose(c, err) 160 | c.Close() 161 | if c.fr.IsInit() { 162 | defer func() { 163 | if err1 := c.fr.Release(); err1 != nil { 164 | err = err1 165 | } 166 | c.fr.BufPtr() 167 | }() 168 | } 169 | }() 170 | 171 | if c.br != nil { 172 | newSize := int(1024 * c.bufioMultipleTimesPayloadSize) 173 | if newSize > 0 && c.br.Size() != newSize { 174 | // TODO sync.Pool管理 175 | (*bufio2.Reader2)(unsafe.Pointer(c.br)).ResetBuf(make([]byte, newSize)) 176 | } 177 | // bufio 模式才会使用payload 178 | c.bufioPayload = bytespool.GetBytes(1024 + enum.MaxFrameHeaderSize) 179 | } 180 | 181 | for { 182 | err = c.readMessage() 183 | if err != nil { 184 | return err 185 | } 186 | } 187 | } 188 | 189 | func (c *Conn) StartReadLoop() { 190 | go func() { 191 | _ = c.ReadLoop() 192 | }() 193 | } 194 | 195 | func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPayload *[]byte) (f frame.Frame2, err error) { 196 | if c.readTimeout > 0 { 197 | err = c.c.SetReadDeadline(time.Now().Add(c.readTimeout)) 198 | if err != nil { 199 | 200 | c.onCloseOnce.Do(&c.mu2, func() { 201 | c.Callback.OnClose(c, err) 202 | }) 203 | return 204 | } 205 | } 206 | 207 | if c.fr.IsInit() { 208 | f, err = frame.ReadFrameFromWindowsV2(&c.fr, headArray, c.windowsMultipleTimesPayloadSize, c.readMaxMessage) 209 | if err == frame.ErrTooLargePayload { 210 | err = TooBigMessage 211 | } 212 | } else { 213 | r := io.Reader(c.br) 214 | var lr io.Reader 215 | if c.readMaxMessage > 0 { 216 | lr = limitreader.NewLimitReader(c.br, c.readMaxMessage) 217 | } 218 | f, err = frame.ReadFrameFromReaderV3(r, lr, headArray, bufioPayload) 219 | } 220 | if err != nil { 221 | c.writeAndMaybeOnClose(err) 222 | return 223 | } 224 | 225 | if c.readTimeout > 0 { 226 | if err = c.c.SetReadDeadline(time.Time{}); err != nil { 227 | c.onCloseOnce.Do(&c.mu2, func() { 228 | c.Callback.OnClose(c, err) 229 | }) 230 | } 231 | } 232 | return 233 | } 234 | 235 | // 读取websocket frame.Frame的循环 236 | func (c *Conn) readMessage() (err error) { 237 | // 从网络读取数据 238 | f, err := c.readDataFromNet(&c.readHeadArray, c.bufioPayload) 239 | if err != nil { 240 | return err 241 | } 242 | 243 | op := f.Opcode 244 | if c.fragmentFrameHeader != nil { 245 | op = c.fragmentFrameHeader.Opcode 246 | } 247 | 248 | rsv1 := f.GetRsv1() 249 | // 检查Rsv1 rsv2 Rfd, errsv3 250 | if rsv1 && c.failRsv1(op) || f.GetRsv2() || f.GetRsv3() { 251 | err = fmt.Errorf("%w:Rsv1(%t) Rsv2(%t) rsv2(%t) compression:%t", ErrRsv123, rsv1, f.GetRsv2(), f.GetRsv3(), c.Compression) 252 | return c.writeErrAndOnClose(ProtocolError, err) 253 | } 254 | 255 | fin := f.GetFin() 256 | if c.fragmentFrameHeader != nil && !f.Opcode.IsControl() { 257 | if f.Opcode == 0 { 258 | *c.fragmentFramePayload = append(*c.fragmentFramePayload, *f.Payload...) 259 | 260 | // 分段的在这返回 261 | if fin { 262 | // 解压缩 263 | if c.fragmentFrameHeader.GetRsv1() && c.pd.Decompression { 264 | tempBuf, err := c.decode(c.fragmentFramePayload) 265 | if err != nil { 266 | return err 267 | } 268 | // 释放未解压缩的buffer到池里面 269 | bytespool.PutBytes(c.fragmentFramePayload) 270 | c.fragmentFramePayload = tempBuf 271 | } 272 | // 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是c.utf8Check修改成流式解析 273 | // TODO c.utf8Check 修改成流式解析 274 | if c.fragmentFrameHeader.Opcode == opcode.Text && !c.utf8Check(*c.fragmentFramePayload) { 275 | c.onCloseOnce.Do(&c.mu2, func() { 276 | c.Callback.OnClose(c, ErrTextNotUTF8) 277 | }) 278 | return ErrTextNotUTF8 279 | } 280 | 281 | c.Callback.OnMessage(c, c.fragmentFrameHeader.Opcode, *c.fragmentFramePayload) 282 | bytespool.PutBytes(c.fragmentFramePayload) 283 | c.fragmentFramePayload = nil 284 | c.fragmentFrameHeader = nil 285 | } 286 | return nil 287 | } 288 | 289 | c.writeErrAndOnClose(ProtocolError, ErrFrameOpcode) 290 | return ErrFrameOpcode 291 | } 292 | 293 | if f.Opcode == opcode.Text || f.Opcode == opcode.Binary { 294 | if !fin { 295 | prevFrame := f.FrameHeader 296 | // 第一次分段 297 | if c.fragmentFramePayload == nil { 298 | c.fragmentFramePayload = bytespool.GetBytes(len(*f.Payload)*2 + enum.MaxFrameHeaderSize) 299 | *c.fragmentFramePayload = (*c.fragmentFramePayload)[0:0] 300 | } 301 | 302 | newPayload := append(*c.fragmentFramePayload, *f.Payload...) 303 | if unsafe.SliceData(newPayload) != unsafe.SliceData(*c.fragmentFramePayload) { 304 | bytespool.PutBytes(c.fragmentFramePayload) 305 | } 306 | c.fragmentFramePayload = &newPayload 307 | f.Payload = nil 308 | 309 | // 让fragmentFrame的Payload指向readBuf, readBuf 原引用直接丢弃 310 | c.fragmentFrameHeader = &prevFrame 311 | return 312 | } 313 | 314 | decompression := false 315 | if rsv1 && c.pd.Decompression { 316 | // 不分段的解压缩 317 | f.Payload, err = c.decode(f.Payload) 318 | if err != nil { 319 | return err 320 | } 321 | decompression = true 322 | } 323 | 324 | if f.Opcode == opcode.Text { 325 | if !c.utf8Check(*f.Payload) { 326 | c.c.Close() 327 | c.onCloseOnce.Do(&c.mu2, func() { 328 | c.Callback.OnClose(c, ErrTextNotUTF8) 329 | }) 330 | return ErrTextNotUTF8 331 | } 332 | } 333 | 334 | c.Callback.OnMessage(c, f.Opcode, *f.Payload) 335 | if decompression { 336 | bytespool.PutBytes(f.Payload) 337 | } 338 | return 339 | } 340 | 341 | if f.Opcode == Close || f.Opcode == Ping || f.Opcode == Pong { 342 | // 对方发的控制消息太大 343 | if f.PayloadLen > maxControlFrameSize { 344 | c.writeErrAndOnClose(ProtocolError, ErrMaxControlFrameSize) 345 | return ErrMaxControlFrameSize 346 | } 347 | // Close, Ping, Pong 不能分片 348 | if !fin { 349 | c.writeErrAndOnClose(ProtocolError, ErrNOTBeFragmented) 350 | return ErrNOTBeFragmented 351 | } 352 | 353 | if f.Opcode == Close { 354 | if len(*f.Payload) == 0 { 355 | c.writeErrAndOnClose(NormalClosure, &CloseErrMsg{Code: NormalClosure}) 356 | return nil 357 | } 358 | 359 | if len(*f.Payload) < 2 { 360 | return c.writeErrAndOnClose(ProtocolError, ErrClosePayloadTooSmall) 361 | } 362 | 363 | if !c.utf8Check((*f.Payload)[2:]) { 364 | return c.writeErrAndOnClose(ProtocolError, ErrTextNotUTF8) 365 | } 366 | 367 | code := binary.BigEndian.Uint16(*f.Payload) 368 | if !validCode(code) { 369 | return c.writeErrAndOnClose(ProtocolError, ErrCloseValue) 370 | } 371 | 372 | // 回敬一个close包 373 | if err := c.WriteTimeout(Close, *f.Payload, 2*time.Second); err != nil { 374 | return err 375 | } 376 | 377 | err = bytesToCloseErrMsg(*f.Payload) 378 | c.onCloseOnce.Do(&c.mu2, func() { 379 | c.Callback.OnClose(c, err) 380 | }) 381 | return err 382 | } 383 | 384 | if f.Opcode == Ping { 385 | // 回一个pong包 386 | if c.replyPing { 387 | if err := c.WriteTimeout(Pong, *f.Payload, 2*time.Second); err != nil { 388 | c.onCloseOnce.Do(&c.mu2, func() { 389 | c.Callback.OnClose(c, err) 390 | }) 391 | return err 392 | } 393 | c.Callback.OnMessage(c, f.Opcode, *f.Payload) 394 | return 395 | } 396 | } 397 | 398 | if f.Opcode == Pong && c.ignorePong { 399 | return 400 | } 401 | 402 | c.Callback.OnMessage(c, f.Opcode, nil) 403 | return 404 | } 405 | // 检查Opcode 406 | c.writeErrAndOnClose(ProtocolError, ErrOpcode) 407 | return ErrOpcode 408 | } 409 | 410 | func (c *Conn) WriteMessage(op Opcode, writeBuf []byte) (err error) { 411 | if atomic.LoadInt32(&c.closed) == 1 { 412 | return ErrClosed 413 | } 414 | 415 | if op == opcode.Text { 416 | if !c.utf8Check(writeBuf) { 417 | return ErrTextNotUTF8 418 | } 419 | } 420 | 421 | rsv1 := c.pd.Compression && (op == opcode.Text || op == opcode.Binary) 422 | if rsv1 { 423 | writeBufPtr, err := c.encoode(&writeBuf) 424 | if err != nil { 425 | return err 426 | } 427 | 428 | defer bytespool.PutBytes(writeBufPtr) 429 | writeBuf = *writeBufPtr 430 | } 431 | 432 | // f.Opcode = op 433 | // f.PayloadLen = int64(len(writeBuf)) 434 | maskValue := uint32(0) 435 | if c.client { 436 | maskValue = rand.Uint32() 437 | } 438 | 439 | var fw fixedwriter.FixedWriter 440 | return frame.WriteFrame(&fw, c.c, writeBuf, true, rsv1, c.client, op, maskValue) 441 | } 442 | 443 | func (c *Conn) SetWriteDeadline(t time.Time) error { 444 | return c.c.SetWriteDeadline(t) 445 | } 446 | 447 | func (c *Conn) WriteTimeout(op Opcode, data []byte, t time.Duration) (err error) { 448 | if err = c.c.SetWriteDeadline(time.Now().Add(t)); err != nil { 449 | return 450 | } 451 | 452 | defer func() { _ = c.c.SetWriteDeadline(time.Time{}) }() 453 | return c.WriteMessage(op, data) 454 | } 455 | 456 | func (c *Conn) WriteCloseTimeout(sc StatusCode, t time.Duration) (err error) { 457 | buf := sc.toBytes() 458 | return c.WriteTimeout(opcode.Close, buf, t) 459 | } 460 | 461 | // data 不能超过125字节, rfc规定 462 | func (c *Conn) WritePing(data []byte) (err error) { 463 | return c.WriteControl(Ping, data[:]) 464 | } 465 | 466 | // data 不能超过125字节, rfc规定 467 | func (c *Conn) WritePong(data []byte) (err error) { 468 | return c.WriteControl(Pong, data[:]) 469 | } 470 | 471 | func (c *Conn) WriteControl(op Opcode, data []byte) (err error) { 472 | if len(data) > maxControlFrameSize { 473 | return ErrMaxControlFrameSize 474 | } 475 | return c.WriteMessage(op, data) 476 | } 477 | 478 | // 写分段数据, 目前主要是单元测试使用 479 | func (c *Conn) writeFragment(op Opcode, writeBuf []byte, maxFragment int /*单个段最大size*/) (err error) { 480 | if len(writeBuf) < maxFragment { 481 | return c.WriteMessage(op, writeBuf) 482 | } 483 | 484 | if op == opcode.Text { 485 | if !c.utf8Check(writeBuf) { 486 | return ErrTextNotUTF8 487 | } 488 | } 489 | 490 | rsv1 := c.pd.Compression && (op == opcode.Text || op == opcode.Binary) 491 | if rsv1 { 492 | writeBufPtr, err := c.encoode(&writeBuf) 493 | if err != nil { 494 | return err 495 | } 496 | defer bytespool.PutBytes(writeBufPtr) 497 | writeBuf = *writeBufPtr 498 | } 499 | 500 | // f.Opcode = op 501 | // f.PayloadLen = int64(len(writeBuf)) 502 | maskValue := uint32(0) 503 | if c.client { 504 | maskValue = rand.Uint32() 505 | } 506 | 507 | var fw fixedwriter.FixedWriter 508 | for len(writeBuf) > 0 { 509 | if len(writeBuf) > maxFragment { 510 | if err := frame.WriteFrame(&fw, c.c, writeBuf[:maxFragment], false, rsv1, c.client, op, maskValue); err != nil { 511 | return err 512 | } 513 | writeBuf = writeBuf[maxFragment:] 514 | op = Continuation 515 | continue 516 | } 517 | return frame.WriteFrame(&fw, c.c, writeBuf, true, rsv1, c.client, op, maskValue) 518 | } 519 | return nil 520 | } 521 | 522 | func (c *Conn) Close() (err error) { 523 | c.once.Do(func() { 524 | err = c.c.Close() 525 | c.wmu.Lock() 526 | if c.delayWrite != nil && c.delayTimeout != nil { 527 | c.delayTimeout.Stop() 528 | c.delayBuf = nil 529 | } 530 | c.wmu.Unlock() 531 | atomic.StoreInt32(&c.closed, 1) 532 | }) 533 | return 534 | } 535 | 536 | func (c *Conn) writerDelayBufSafe() { 537 | c.wmu.Lock() 538 | c.delayErr = c.writerDelayBufInner() 539 | c.wmu.Unlock() 540 | } 541 | 542 | func (c *Conn) writerDelayBufInner() (err error) { 543 | if c.delayBuf == nil || c.delayBuf.Len() == 0 || atomic.LoadInt32(&c.closed) == 1 { 544 | return nil 545 | } 546 | _, err = c.c.Write(c.delayBuf.Bytes()) 547 | if c.delayTimeout != nil { 548 | c.delayTimeout.Reset(c.maxDelayWriteDuration) 549 | } 550 | c.delayNum = 0 551 | c.delayBuf.Reset() 552 | return 553 | } 554 | 555 | // 对于流量场景这个版本推荐开启tcp delay 方法:WithClientTCPDelay() WithServerTCPDelay() 556 | 557 | // 该函数目前是研究性质的尝试 558 | // 延迟写消息, 对流量密集型的场景有用 或者开启tcp delay, 559 | // 1. 如果缓存的消息超过了多少条数 560 | // 2. 如果缓存的消费超过了多久的时间 561 | // 3. TODO: 最大缓存多少字节 562 | 563 | func (c *Conn) initDelayWrite() { 564 | if c.delayWrite == nil { 565 | c.wmu.Lock() 566 | if c.delayWrite == nil { 567 | c.delayWrite = &delayWrite{} 568 | } 569 | c.wmu.Unlock() 570 | } 571 | } 572 | func (c *Conn) isClosed() bool { 573 | return atomic.LoadInt32(&c.closed) == 1 574 | } 575 | 576 | func (c *Conn) WriteMessageDelay(op Opcode, writeBuf []byte) (err error) { 577 | if c.isClosed() { 578 | return ErrClosed 579 | } 580 | 581 | if op == opcode.Text { 582 | if !c.utf8Check(writeBuf) { 583 | return ErrTextNotUTF8 584 | } 585 | } 586 | 587 | // 初始化对应的资源 588 | c.initDelayWrite() 589 | rsv1 := c.pd.Compression && (op == opcode.Text || op == opcode.Binary) 590 | if rsv1 { 591 | writeBufPtr, err := c.encoode(&writeBuf) 592 | if err != nil { 593 | return err 594 | } 595 | defer bytespool.PutBytes(writeBufPtr) 596 | writeBuf = *writeBufPtr 597 | } 598 | 599 | c.wmu.Lock() 600 | if c.isClosed() { 601 | c.wmu.Unlock() 602 | return nil 603 | } 604 | // 初始化缓存 605 | if c.delayBuf == nil && c.delayWriteInitBufferSize > 0 { 606 | 607 | // TODO: sync.Pool管理下, 如果size是1k 2k 3k 608 | delayBuf := make([]byte, 0, c.delayWriteInitBufferSize) 609 | c.delayBuf = bytes.NewBuffer(delayBuf) 610 | } 611 | // 初始化定时器 612 | if c.delayTimeout == nil && c.maxDelayWriteDuration > 0 { 613 | c.delayTimeout = time.AfterFunc(c.maxDelayWriteDuration, c.writerDelayBufSafe) 614 | } 615 | c.wmu.Unlock() 616 | 617 | maskValue := uint32(0) 618 | if c.client { 619 | maskValue = rand.Uint32() 620 | } 621 | // 缓存的消息超过最大值, 则直接写入 622 | c.wmu.Lock() 623 | if c.isClosed() { 624 | c.wmu.Unlock() 625 | return 626 | } 627 | if c.delayNum+1 == c.maxDelayWriteNum { 628 | err = frame.WriteFrameToBytes(c.delayBuf, writeBuf, true, rsv1, c.client, op, maskValue) 629 | if err != nil { 630 | c.wmu.Unlock() 631 | return err 632 | } 633 | err = c.writerDelayBufInner() 634 | c.wmu.Unlock() 635 | return err 636 | } 637 | 638 | // 为了平衡生产者,消费者的速度,这里不使用协程 639 | if c.delayBuf != nil { 640 | err = frame.WriteFrameToBytes(c.delayBuf, writeBuf, true, rsv1, c.client, op, maskValue) 641 | } 642 | c.delayNum++ // 对记数计+1 643 | c.wmu.Unlock() 644 | return err 645 | } 646 | -------------------------------------------------------------------------------- /conn_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package quickws 15 | 16 | import ( 17 | "bytes" 18 | "crypto/md5" 19 | "fmt" 20 | "math/rand" 21 | "net/http" 22 | "net/http/httptest" 23 | "strings" 24 | "sync/atomic" 25 | "testing" 26 | "time" 27 | 28 | "github.com/antlabs/wsutil/fixedwriter" 29 | "github.com/antlabs/wsutil/frame" 30 | "github.com/antlabs/wsutil/opcode" 31 | ) 32 | 33 | var ( 34 | testBinaryMessage64kb = bytes.Repeat([]byte("1"), 65535) 35 | testTextMessage64kb = bytes.Repeat([]byte("中"), 65535/len("中")) 36 | testBinaryMessage10 = bytes.Repeat([]byte("1"), 10) 37 | ) 38 | 39 | type testMessageHandler struct { 40 | DefCallback 41 | t *testing.T 42 | need []byte 43 | callbed int32 44 | callbedChan chan bool 45 | server bool 46 | count int 47 | done chan struct{} 48 | output bool 49 | } 50 | 51 | func (t *testMessageHandler) OnMessage(c *Conn, op opcode.Opcode, msg []byte) { 52 | need := append([]byte(nil), t.need...) 53 | atomic.StoreInt32(&t.callbed, 1) 54 | t.callbedChan <- true 55 | if t.count == 0 { 56 | return 57 | } 58 | t.count-- 59 | 60 | message := "#client" 61 | if t.server { 62 | message = "#server" 63 | } 64 | // if t.output { 65 | // // fmt.Printf(">>>>>%p %s, %#v\n", &msg, message, msg) 66 | // } 67 | if len(msg) < 30 { 68 | if !bytes.Equal(msg, need) { 69 | t.t.Errorf(">>>>>%p %s, %#v\n", &msg, message, msg) 70 | } 71 | } else { 72 | 73 | md51 := md5.Sum(need) 74 | md52 := md5.Sum(msg) 75 | if !bytes.Equal(md51[:], md52[:]) { 76 | t.t.Errorf("md51 %x, md52 %x\n", md51, md52) 77 | } 78 | } 79 | err := c.WriteMessage(op, msg) 80 | if err != nil { 81 | t.t.Error(err) 82 | } 83 | // if !t.server { 84 | // // c.Close() 85 | // } 86 | } 87 | 88 | func (t *testMessageHandler) OnClose(c *Conn, err error) { 89 | message := "#client.OnClose" 90 | if t.server { 91 | message = "#server.OnClose" 92 | } 93 | 94 | fmt.Printf("OnClose: %s:%s\n", message, err) 95 | if t.done != nil { 96 | close(t.done) 97 | } 98 | } 99 | 100 | func newServrEcho(t *testing.T, data []byte, output bool) *httptest.Server { 101 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 102 | c, err := Upgrade(w, r, 103 | WithServerCallback(&testMessageHandler{t: t, need: data, server: true, count: -1, output: true, callbedChan: make(chan bool, 1)}), 104 | ) 105 | if err != nil { 106 | t.Error(err) 107 | return 108 | } 109 | 110 | _ = c.ReadLoop() 111 | })) 112 | 113 | ts.URL = "ws" + strings.TrimPrefix(ts.URL, "http") 114 | return ts 115 | } 116 | 117 | // 测试read message 118 | func Test_ReadMessage(t *testing.T) { 119 | t.Run("ReadMessage10", func(t *testing.T) { 120 | ts := newServrEcho(t, testBinaryMessage10, true) 121 | client := &testMessageHandler{t: t, need: append([]byte(nil), testBinaryMessage10...), count: 1, done: make(chan struct{}), output: true} 122 | client.callbedChan = make(chan bool, 1) 123 | c, err := Dial(ts.URL, WithClientCallback(client)) 124 | if err != nil { 125 | t.Error(err) 126 | return 127 | } 128 | c.StartReadLoop() 129 | 130 | tmp := append([]byte(nil), testBinaryMessage10...) 131 | 132 | err = c.WriteMessage(Binary, tmp) 133 | if err != nil { 134 | t.Error(err) 135 | return 136 | } 137 | // <-client.done 138 | select { 139 | case <-client.callbedChan: 140 | case <-time.After(time.Second / 3): 141 | } 142 | if atomic.LoadInt32(&client.callbed) != 1 { 143 | t.Error("not callbed") 144 | } 145 | }) 146 | 147 | t.Run("ReadMessage64K", func(t *testing.T) { 148 | ts := newServrEcho(t, testBinaryMessage64kb, false) 149 | client := &testMessageHandler{t: t, need: append([]byte(nil), testBinaryMessage64kb...), count: 1} 150 | client.callbedChan = make(chan bool, 1) 151 | c, err := Dial(ts.URL, WithClientCallback(client)) 152 | if err != nil { 153 | t.Error(err) 154 | return 155 | } 156 | c.StartReadLoop() 157 | 158 | tmp := append([]byte(nil), testBinaryMessage64kb...) 159 | err = c.WriteMessage(Binary, tmp) 160 | select { 161 | case <-client.callbedChan: 162 | case <-time.After(time.Second / 3): 163 | } 164 | if err != nil { 165 | t.Error(err) 166 | return 167 | } 168 | 169 | if atomic.LoadInt32(&client.callbed) != 1 { 170 | t.Error("not callbed") 171 | } 172 | }) 173 | 174 | t.Run("ReadMessage64K_Text", func(t *testing.T) { 175 | ts := newServrEcho(t, testTextMessage64kb, false) 176 | client := &testMessageHandler{t: t, need: append([]byte(nil), testTextMessage64kb...), count: 1} 177 | client.callbedChan = make(chan bool, 1) 178 | c, err := Dial(ts.URL, WithClientCallback(client)) 179 | if err != nil { 180 | t.Error(err) 181 | return 182 | } 183 | 184 | c.StartReadLoop() 185 | 186 | tmp := append([]byte(nil), testTextMessage64kb...) 187 | err = c.WriteMessage(Text, tmp) 188 | if err != nil { 189 | t.Error(err) 190 | return 191 | } 192 | select { 193 | case <-client.callbedChan: 194 | case <-time.After(time.Second / 3): 195 | } 196 | if atomic.LoadInt32(&client.callbed) != 1 { 197 | t.Errorf("not callbed:%d\n", client.callbed) 198 | } 199 | }) 200 | 201 | t.Run("ReadMessage_Fail_Rsv.1", func(t *testing.T) { 202 | run := int32(0) 203 | data := make(chan string, 1) 204 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 205 | c, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { 206 | err := c.WriteMessage(op, payload) 207 | if err != nil { 208 | t.Error(err) 209 | return 210 | } 211 | })) 212 | if err != nil { 213 | t.Error(err) 214 | } 215 | c.StartReadLoop() 216 | })) 217 | 218 | defer ts.Close() 219 | 220 | url := strings.ReplaceAll(ts.URL, "http", "ws") 221 | con, err := Dial(url, WithClientBufioParseMode(), WithClientCompression(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 222 | atomic.AddInt32(&run, int32(1)) 223 | data <- string(payload) 224 | })) 225 | if err != nil { 226 | t.Error(err) 227 | } 228 | defer con.Close() 229 | 230 | // err = con.WriteMessage(Binary, []byte("hello")) 231 | maskValue := rand.Uint32() 232 | var fw fixedwriter.FixedWriter 233 | err = frame.WriteFrame(&fw, con.c, []byte("hello"), true, true, con.client, Binary, maskValue) 234 | if err != nil { 235 | t.Error(err) 236 | } 237 | 238 | con.StartReadLoop() 239 | select { 240 | case d := <-data: 241 | if d != "hello" { 242 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 243 | } 244 | case <-time.After(1000 * time.Millisecond): 245 | } 246 | if atomic.LoadInt32(&run) > 0 { 247 | // 需要不运行server 248 | t.Errorf("need not run server") 249 | } 250 | }) 251 | 252 | t.Run("ReadMessage_Fail_Rsv.2", func(t *testing.T) { 253 | run := int32(0) 254 | data := make(chan string, 1) 255 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 256 | c, err := Upgrade(w, r, 257 | // WithServerDecompression(), 258 | WithServerDecompressAndCompress(), 259 | WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { 260 | err := c.WriteMessage(op, payload) 261 | if err != nil { 262 | t.Error(err) 263 | return 264 | } 265 | })) 266 | if err != nil { 267 | t.Error(err) 268 | } 269 | c.StartReadLoop() 270 | })) 271 | 272 | defer ts.Close() 273 | 274 | url := strings.ReplaceAll(ts.URL, "http", "ws") 275 | con, err := Dial(url, 276 | WithClientDecompressAndCompress(), 277 | WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 278 | atomic.AddInt32(&run, int32(1)) 279 | // data <- string(payload) 280 | })) 281 | if err != nil { 282 | t.Error(err) 283 | } 284 | defer con.Close() 285 | 286 | // err = con.WriteMessage(Binary, []byte("hello")) 287 | maskValue := rand.Uint32() 288 | var fw fixedwriter.FixedWriter 289 | err = frame.WriteFrame(&fw, con.c, []byte("hello"), true, true, con.client, Ping, maskValue) 290 | if err != nil { 291 | t.Error(err) 292 | } 293 | 294 | select { 295 | case <-data: 296 | case <-time.After(1000 * time.Millisecond): 297 | } 298 | if atomic.LoadInt32(&run) > 0 { 299 | // 需要不运行server 300 | t.Errorf("need not run server") 301 | } 302 | }) 303 | } 304 | 305 | // 测试分段frame 306 | func TestFragmentFrame(t *testing.T) { 307 | t.Run("FragmentFrame10", func(t *testing.T) { 308 | run := int32(0) 309 | data := make(chan string, 1) 310 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { 311 | err := c.WriteMessage(op, payload) 312 | if err != nil { 313 | t.Error(err) 314 | return 315 | } 316 | })) 317 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 318 | c, err := upgrade.Upgrade(w, r) 319 | if err != nil { 320 | t.Error(err) 321 | } 322 | c.StartReadLoop() 323 | })) 324 | 325 | defer ts.Close() 326 | 327 | url := strings.ReplaceAll(ts.URL, "http", "ws") 328 | con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 329 | atomic.AddInt32(&run, int32(1)) 330 | data <- string(payload) 331 | })) 332 | if err != nil { 333 | t.Error(err) 334 | } 335 | defer con.Close() 336 | 337 | err = con.writeFragment(Binary, []byte("hello"), 1) 338 | if err != nil { 339 | t.Error(err) 340 | return 341 | } 342 | con.StartReadLoop() 343 | select { 344 | case d := <-data: 345 | if d != "hello" { 346 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 347 | } 348 | case <-time.After(1000 * time.Millisecond): 349 | } 350 | if atomic.LoadInt32(&run) != 1 { 351 | t.Error("not run server:method fail") 352 | } 353 | }) 354 | 355 | t.Run("Ping-FragmentFrame-Fail", func(t *testing.T) { 356 | run := int32(0) 357 | data := make(chan string, 1) 358 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerOnCloseFunc(func(c *Conn, err error) { 359 | atomic.AddInt32(&run, int32(1)) 360 | data <- err.Error() 361 | })) 362 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 363 | c, err := upgrade.Upgrade(w, r) 364 | if err != nil { 365 | t.Error(err) 366 | } 367 | c.StartReadLoop() 368 | })) 369 | 370 | defer ts.Close() 371 | 372 | url := strings.ReplaceAll(ts.URL, "http", "ws") 373 | con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 374 | })) 375 | if err != nil { 376 | t.Error(err) 377 | return 378 | } 379 | defer con.Close() 380 | 381 | err = con.writeFragment(Ping, []byte("ho"), 1) 382 | if err != nil { 383 | t.Error(err) 384 | return 385 | } 386 | con.StartReadLoop() 387 | select { 388 | case <-data: 389 | case <-time.After(1000 * time.Millisecond): 390 | } 391 | if atomic.LoadInt32(&run) != 1 { 392 | t.Error("not run server:method fail") 393 | } 394 | }) 395 | 396 | t.Run("Text-FragmentFrame-Fail", func(t *testing.T) { 397 | run := int32(0) 398 | data := make(chan string, 1) 399 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerOnCloseFunc(func(c *Conn, err error) { 400 | atomic.AddInt32(&run, int32(1)) 401 | data <- err.Error() 402 | })) 403 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 404 | c, err := upgrade.Upgrade(w, r) 405 | if err != nil { 406 | t.Error(err) 407 | } 408 | c.StartReadLoop() 409 | })) 410 | 411 | defer ts.Close() 412 | 413 | url := strings.ReplaceAll(ts.URL, "http", "ws") 414 | con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 415 | })) 416 | if err != nil { 417 | t.Error(err) 418 | } 419 | defer con.Close() 420 | // con.writeFragment(Ping, []byte("hello"), 1) 421 | 422 | maskValue := rand.Uint32() 423 | var fw fixedwriter.FixedWriter 424 | err = frame.WriteFrame(&fw, con.c, []byte("h"), false, false, con.client, Text, maskValue) 425 | if err != nil { 426 | t.Error(err) 427 | } 428 | maskValue = rand.Uint32() 429 | err = frame.WriteFrame(&fw, con.c, []byte{}, true, false, con.client, Text, maskValue) 430 | if err != nil { 431 | t.Error(err) 432 | } 433 | con.StartReadLoop() 434 | select { 435 | case <-data: 436 | case <-time.After(1000 * time.Millisecond): 437 | } 438 | if atomic.LoadInt32(&run) != 1 { 439 | t.Error("not run server:method fail") 440 | } 441 | }) 442 | 443 | // 分段传递,并且压缩 444 | t.Run("FragmentFrame-Compression", func(t *testing.T) { 445 | run := int32(0) 446 | data := make(chan string, 1) 447 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerDecompression(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { 448 | err := c.WriteMessage(op, payload) 449 | if err != nil { 450 | t.Error(err) 451 | return 452 | } 453 | })) 454 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 455 | c, err := upgrade.Upgrade(w, r) 456 | if err != nil { 457 | t.Error(err) 458 | } 459 | c.StartReadLoop() 460 | })) 461 | 462 | defer ts.Close() 463 | 464 | url := strings.ReplaceAll(ts.URL, "http", "ws") 465 | con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientDecompressAndCompress(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 466 | atomic.AddInt32(&run, int32(1)) 467 | data <- string(payload) 468 | })) 469 | if err != nil { 470 | t.Error(err) 471 | } 472 | defer con.Close() 473 | 474 | err = con.writeFragment(Binary, []byte("hello"), 1) 475 | if err != nil { 476 | t.Error(err) 477 | return 478 | } 479 | con.StartReadLoop() 480 | select { 481 | case d := <-data: 482 | if d != "hello" { 483 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 484 | } 485 | case <-time.After(1000 * time.Millisecond): 486 | } 487 | if atomic.LoadInt32(&run) != 1 { 488 | t.Error("not run server:method fail") 489 | } 490 | }) 491 | 492 | t.Run("FragmentFrame-Small-Buffer", func(t *testing.T) { 493 | run := int32(0) 494 | data := make(chan string, 1) 495 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { 496 | err := c.WriteMessage(op, payload) 497 | if err != nil { 498 | t.Error(err) 499 | return 500 | } 501 | })) 502 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 503 | c, err := upgrade.Upgrade(w, r) 504 | if err != nil { 505 | t.Error(err) 506 | } 507 | c.StartReadLoop() 508 | })) 509 | 510 | defer ts.Close() 511 | 512 | url := strings.ReplaceAll(ts.URL, "http", "ws") 513 | con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientEnableUTF8Check(), 514 | WithClientDecompressAndCompress(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 515 | atomic.AddInt32(&run, int32(1)) 516 | data <- string(payload) 517 | })) 518 | if err != nil { 519 | t.Error(err) 520 | } 521 | defer con.Close() 522 | 523 | sendData := []byte("hell") 524 | err = con.writeFragment(Text, sendData, 5) 525 | if err != nil { 526 | t.Errorf("error:%v", err) 527 | } 528 | 529 | con.StartReadLoop() 530 | 531 | select { 532 | case d := <-data: 533 | if d != string(sendData) { 534 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 535 | } 536 | case <-time.After(1000 * time.Millisecond): 537 | } 538 | if atomic.LoadInt32(&run) != 1 { 539 | t.Error("not run server:method fail") 540 | } 541 | }) 542 | 543 | t.Run("FragmentFrame-Client-Not-UTF8", func(t *testing.T) { 544 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { 545 | err := c.WriteMessage(op, payload) 546 | if err != nil { 547 | t.Error(err) 548 | return 549 | } 550 | })) 551 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 552 | c, err := upgrade.Upgrade(w, r) 553 | if err != nil { 554 | t.Error(err) 555 | } 556 | c.StartReadLoop() 557 | })) 558 | 559 | defer ts.Close() 560 | 561 | url := strings.ReplaceAll(ts.URL, "http", "ws") 562 | con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientEnableUTF8Check(), 563 | WithClientDecompressAndCompress(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 564 | })) 565 | if err != nil { 566 | t.Error(err) 567 | } 568 | defer con.Close() 569 | 570 | // 这里必须要报错 571 | err = con.writeFragment(Text, []byte{128, 129, 130, 131}, 1) 572 | if err == nil { 573 | t.Error("not error") 574 | } 575 | }) 576 | 577 | t.Run("FragmentFrame-Server-Not-UTF8", func(t *testing.T) { 578 | run := int32(0) 579 | data := make(chan string, 1) 580 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerEnableUTF8Check(), WithServerOnCloseFunc(func(c *Conn, err error) { 581 | data <- err.Error() 582 | })) 583 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 584 | c, err := upgrade.Upgrade(w, r) 585 | if err != nil { 586 | t.Error(err) 587 | } 588 | c.StartReadLoop() 589 | })) 590 | 591 | defer ts.Close() 592 | 593 | url := strings.ReplaceAll(ts.URL, "http", "ws") 594 | con, err := Dial(url, WithClientDisableBufioClearHack(), 595 | WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 596 | })) 597 | if err != nil { 598 | t.Error(err) 599 | } 600 | defer con.Close() 601 | // 这里不能报错, 虽然使用了不合法的utf8的,但是没有开启检查 602 | err = con.writeFragment(Text, []byte{128, 129, 130, 131}, 1) 603 | if err != nil { 604 | t.Errorf("error :%v\n", err) 605 | return 606 | } 607 | con.StartReadLoop() 608 | select { 609 | case <-data: 610 | atomic.AddInt32(&run, 1) 611 | case <-time.After(500 * time.Millisecond): 612 | } 613 | 614 | if atomic.LoadInt32(&run) != 1 { 615 | t.Error("not run server:method fail") 616 | } 617 | }) 618 | } 619 | 620 | type testPingPongCloseHandler struct { 621 | DefCallback 622 | run int32 623 | data chan string 624 | } 625 | 626 | func (t *testPingPongCloseHandler) OnClose(c *Conn, err error) { 627 | fmt.Printf("%s\n", err.Error()) 628 | atomic.AddInt32(&t.run, 1) 629 | t.data <- "eof" 630 | } 631 | 632 | func Test_WriteControl(t *testing.T) { 633 | t.Run("WriteControl > maxControlFrameSize.message.fail", func(t *testing.T) { 634 | var shandler testPingPongCloseHandler 635 | shandler.data = make(chan string, 1) 636 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerCallback(&shandler)) 637 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 638 | c, err := upgrade.Upgrade(w, r) 639 | if err != nil { 640 | t.Error(err) 641 | } 642 | c.StartReadLoop() 643 | })) 644 | 645 | defer ts.Close() 646 | 647 | url := strings.ReplaceAll(ts.URL, "http", "ws") 648 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 649 | })) 650 | if err != nil { 651 | t.Error(err) 652 | } 653 | defer con.Close() 654 | 655 | err = con.WriteControl(Close, bytes.Repeat([]byte{1}, 126)) 656 | // 这里必须要报错 657 | if err == nil { 658 | t.Error("not error") 659 | } 660 | }) 661 | } 662 | 663 | func Test_API(t *testing.T) { 664 | t.Run("WriteTimeout", func(t *testing.T) { 665 | // 测试WriteTimeout的作用 666 | // 起个空服务,客户端写一个数据包,服务端不回包,让客户端触发ReadTimeout 667 | run := int32(0) 668 | runClose := make(chan struct{}, 1) 669 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 670 | c, err := Upgrade(w, r, 671 | WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { 672 | 673 | })) 674 | if err != nil { 675 | t.Error(err) 676 | } 677 | c.StartReadLoop() 678 | })) 679 | 680 | defer ts.Close() 681 | 682 | url := strings.ReplaceAll(ts.URL, "http", "ws") 683 | con, err := Dial(url, 684 | WithClientReadTimeout(20*time.Millisecond), 685 | WithClientCallbackFunc(func(c *Conn) {}, func(c *Conn, o Opcode, b []byte) {}, func(c *Conn, err error) { 686 | atomic.AddInt32(&run, int32(1)) 687 | runClose <- struct{}{} 688 | })) 689 | if err != nil { 690 | t.Error(err) 691 | } 692 | defer con.Close() 693 | 694 | err = con.WriteTimeout(Text, []byte("hello"), 10*time.Millisecond) 695 | if err != nil { 696 | t.Error(err) 697 | } 698 | con.StartReadLoop() 699 | select { 700 | case <-runClose: 701 | case <-time.After(1000 * time.Millisecond): 702 | t.Errorf("13-15.client: WriteMessageDelay-timeout-send: timeout \n") 703 | } 704 | if atomic.LoadInt32(&run) != 1 { 705 | t.Error("not run server:method fail") 706 | } 707 | }) 708 | 709 | t.Run("NetConn", func(t *testing.T) { 710 | var shandler testPingPongCloseHandler 711 | shandler.data = make(chan string, 1) 712 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerCallback(&shandler)) 713 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 714 | c, err := upgrade.Upgrade(w, r) 715 | if err != nil { 716 | t.Error(err) 717 | } 718 | if c.NetConn() != c.c { 719 | t.Error("server.not equal") 720 | } 721 | c.StartReadLoop() 722 | })) 723 | 724 | defer ts.Close() 725 | 726 | url := strings.ReplaceAll(ts.URL, "http", "ws") 727 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 728 | })) 729 | if err != nil { 730 | t.Error(err) 731 | } 732 | defer con.Close() 733 | if con.NetConn() != con.c { 734 | t.Error("client.not equal") 735 | } 736 | 737 | err = con.WriteControl(Close, bytes.Repeat([]byte{1}, 126)) 738 | // 这里必须要报错 739 | if err == nil { 740 | t.Error("not error") 741 | } 742 | }) 743 | } 744 | 745 | // 测试ping pong close control信息 746 | func TestPingPongClose(t *testing.T) { 747 | // 写一个超过maxControlFrameSize的消息 748 | t.Run("1.>maxControlFrameSize.fail", func(t *testing.T) { 749 | run := int32(0) 750 | var shandler testPingPongCloseHandler 751 | shandler.data = make(chan string, 1) 752 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerCallback(&shandler)) 753 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 754 | c, err := upgrade.Upgrade(w, r) 755 | if err != nil { 756 | t.Error(err) 757 | } 758 | c.StartReadLoop() 759 | })) 760 | 761 | defer ts.Close() 762 | 763 | url := strings.ReplaceAll(ts.URL, "http", "ws") 764 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 765 | atomic.AddInt32(&run, int32(1)) 766 | })) 767 | if err != nil { 768 | t.Error(err) 769 | } 770 | defer con.Close() 771 | 772 | err = con.WriteMessage(Close, bytes.Repeat([]byte("a"), maxControlFrameSize+3)) 773 | if err != nil { 774 | t.Error(err) 775 | return 776 | } 777 | con.StartReadLoop() 778 | select { 779 | case d := <-shandler.data: 780 | if d != "eof" { 781 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 782 | } 783 | case <-time.After(1000 * time.Millisecond): 784 | } 785 | if atomic.LoadInt32(&shandler.run) != 1 { 786 | t.Error("not run server:method fail") 787 | } 788 | }) 789 | 790 | t.Run("3.WriteCloseEmpty.fail", func(t *testing.T) { 791 | run := int32(0) 792 | var shandler testPingPongCloseHandler 793 | shandler.data = make(chan string, 1) 794 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerCallback(&shandler)) 795 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 796 | c, err := upgrade.Upgrade(w, r) 797 | if err != nil { 798 | t.Error(err) 799 | } 800 | c.StartReadLoop() 801 | })) 802 | 803 | defer ts.Close() 804 | 805 | url := strings.ReplaceAll(ts.URL, "http", "ws") 806 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 807 | atomic.AddInt32(&run, int32(1)) 808 | })) 809 | if err != nil { 810 | t.Error(err) 811 | } 812 | defer con.Close() 813 | 814 | err = con.WriteMessage(Close, nil) 815 | if err != nil { 816 | t.Error(err) 817 | return 818 | } 819 | con.StartReadLoop() 820 | select { 821 | case d := <-shandler.data: 822 | if d != "eof" { 823 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 824 | } 825 | case <-time.After(1000 * time.Millisecond): 826 | } 827 | if atomic.LoadInt32(&shandler.run) != 1 { 828 | t.Error("not run server:method fail") 829 | } 830 | }) 831 | 832 | t.Run("4.WriteClose Payload < 2.fail", func(t *testing.T) { 833 | run := int32(0) 834 | var shandler testPingPongCloseHandler 835 | shandler.data = make(chan string, 1) 836 | upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerCallback(&shandler)) 837 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 838 | c, err := upgrade.Upgrade(w, r) 839 | if err != nil { 840 | t.Error(err) 841 | } 842 | c.StartReadLoop() 843 | })) 844 | 845 | defer ts.Close() 846 | 847 | url := strings.ReplaceAll(ts.URL, "http", "ws") 848 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 849 | atomic.AddInt32(&run, int32(1)) 850 | })) 851 | if err != nil { 852 | t.Error(err) 853 | } 854 | defer con.Close() 855 | 856 | err = con.WriteMessage(Close, []byte{1}) 857 | if err != nil { 858 | t.Error(err) 859 | return 860 | } 861 | con.StartReadLoop() 862 | select { 863 | case d := <-shandler.data: 864 | if d != "eof" { 865 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 866 | } 867 | case <-time.After(1000 * time.Millisecond): 868 | } 869 | if atomic.LoadInt32(&shandler.run) != 1 { 870 | t.Error("not run server:method fail") 871 | } 872 | }) 873 | 874 | t.Run("5.WriteClose Payload > 2, utf8.check.fail", func(t *testing.T) { 875 | run := int32(0) 876 | var shandler testPingPongCloseHandler 877 | shandler.data = make(chan string, 1) 878 | upgrade := NewUpgrade(WithServerEnableUTF8Check(), WithServerBufioParseMode(), WithServerCallback(&shandler)) 879 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 880 | c, err := upgrade.Upgrade(w, r) 881 | if err != nil { 882 | t.Error(err) 883 | } 884 | c.StartReadLoop() 885 | })) 886 | 887 | defer ts.Close() 888 | 889 | url := strings.ReplaceAll(ts.URL, "http", "ws") 890 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 891 | atomic.AddInt32(&run, int32(1)) 892 | })) 893 | if err != nil { 894 | t.Error(err) 895 | } 896 | defer con.Close() 897 | 898 | err = con.WriteMessage(Close, []byte{128, 129, 130, 131}) 899 | if err != nil { 900 | t.Error(err) 901 | return 902 | } 903 | con.StartReadLoop() 904 | select { 905 | case d := <-shandler.data: 906 | if d != "eof" { 907 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 908 | } 909 | case <-time.After(1000 * time.Millisecond): 910 | } 911 | if atomic.LoadInt32(&shandler.run) != 1 { 912 | t.Error("not run server:method fail") 913 | } 914 | }) 915 | 916 | t.Run("6.WriteClose status code.fail", func(t *testing.T) { 917 | run := int32(0) 918 | var shandler testPingPongCloseHandler 919 | shandler.data = make(chan string, 1) 920 | upgrade := NewUpgrade(WithServerEnableUTF8Check(), WithServerBufioParseMode(), WithServerCallback(&shandler)) 921 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 922 | c, err := upgrade.Upgrade(w, r) 923 | if err != nil { 924 | t.Error(err) 925 | } 926 | c.StartReadLoop() 927 | })) 928 | 929 | defer ts.Close() 930 | 931 | url := strings.ReplaceAll(ts.URL, "http", "ws") 932 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 933 | atomic.AddInt32(&run, int32(1)) 934 | })) 935 | if err != nil { 936 | t.Error(err) 937 | } 938 | defer con.Close() 939 | 940 | var badSc StatusCode = 3 941 | err = con.WriteCloseTimeout(badSc, 10*time.Second) 942 | if err != nil { 943 | t.Error(err) 944 | return 945 | } 946 | con.StartReadLoop() 947 | select { 948 | case d := <-shandler.data: 949 | if d != "eof" { 950 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 951 | } 952 | case <-time.After(1000 * time.Millisecond): 953 | } 954 | if atomic.LoadInt32(&shandler.run) != 1 { 955 | t.Error("not run server:method fail") 956 | } 957 | }) 958 | t.Run("7.WriteClose", func(t *testing.T) { 959 | run := int32(0) 960 | statusCodes := []StatusCode{ 961 | NormalClosure, 962 | EndpointGoingAway, 963 | ProtocolError, 964 | DataCannotAccept, 965 | NotConsistentMessageType, 966 | TerminatingConnection, 967 | TooBigMessage, 968 | NoExtensions, 969 | ServerTerminating, 970 | 1004, 971 | 3000, 972 | } 973 | 974 | for _, st := range statusCodes { 975 | var shandler testPingPongCloseHandler 976 | shandler.data = make(chan string, 1) 977 | upgrade := NewUpgrade(WithServerEnableUTF8Check(), WithServerBufioParseMode(), WithServerCallback(&shandler)) 978 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 979 | c, err := upgrade.Upgrade(w, r) 980 | if err != nil { 981 | t.Error(err) 982 | } 983 | c.StartReadLoop() 984 | })) 985 | 986 | defer ts.Close() 987 | 988 | url := strings.ReplaceAll(ts.URL, "http", "ws") 989 | con, err := Dial(url, WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { 990 | atomic.AddInt32(&run, int32(1)) 991 | })) 992 | if err != nil { 993 | t.Error(err) 994 | } 995 | defer con.Close() 996 | 997 | err = con.WriteCloseTimeout(st, 10*time.Second) 998 | if err != nil { 999 | t.Error(err) 1000 | return 1001 | } 1002 | con.StartReadLoop() 1003 | select { 1004 | case d := <-shandler.data: 1005 | if d != "eof" { 1006 | t.Errorf("write message or read message fail:got:%s, need:hello\n", d) 1007 | } 1008 | case <-time.After(1000 * time.Millisecond): 1009 | } 1010 | if atomic.LoadInt32(&shandler.run) != 1 { 1011 | t.Error("not run server:method fail") 1012 | } 1013 | } 1014 | }) 1015 | 1016 | t.Run("8.fail-control", func(t *testing.T) { 1017 | run := int32(0) 1018 | data := make(chan string, 1) 1019 | upgrade := NewUpgrade(WithServerOnCloseFunc(func(c *Conn, err error) { 1020 | atomic.AddInt32(&run, 1) 1021 | data <- err.Error() 1022 | })) 1023 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1024 | c, err := upgrade.Upgrade(w, r) 1025 | if err != nil { 1026 | t.Error(err) 1027 | } 1028 | c.StartReadLoop() 1029 | })) 1030 | 1031 | defer ts.Close() 1032 | 1033 | url := strings.ReplaceAll(ts.URL, "http", "ws") 1034 | con, err := Dial(url, WithClientDisableBufioClearHack(), 1035 | WithClientEnableUTF8Check(), WithClientOnCloseFunc(func(c *Conn, err error) { 1036 | })) 1037 | if err != nil { 1038 | t.Error(err) 1039 | } 1040 | defer con.Close() 1041 | // 这里必须要报错 1042 | err = con.WriteMessage(4 /*这是rfc保留的frame*/, []byte("hello")) 1043 | if err != nil { 1044 | t.Error("not error") 1045 | } 1046 | con.StartReadLoop() 1047 | select { 1048 | case <-data: 1049 | case <-time.After(500 * time.Millisecond): 1050 | } 1051 | 1052 | if atomic.LoadInt32(&run) != 1 { 1053 | t.Error("not run server:method fail") 1054 | } 1055 | }) 1056 | } 1057 | --------------------------------------------------------------------------------