├── autobahn ├── run_client.sh ├── config │ ├── fuzzingserver.json │ ├── fuzzingclient-io.json │ ├── fuzzingclient-elastic.json │ ├── fuzzingclient-onebyone.json │ ├── fuzzingclient-context-takeover-decompression.json │ ├── fuzzingclient-no-context-takeover-decompression.json │ ├── fuzzingclient-context-takeover-decompression-and-compression.json │ ├── fuzzingclient-no-context-takeover-decompression-and-compression.json │ ├── fuzzingclient.json │ └── fuzzingclient-all.json ├── run_server-elastic.sh ├── run_server.sh ├── run_server-io.sh ├── run_server-onebyone.sh ├── context-takeover-decompression.sh ├── no-context-takeover-decompression.sh ├── context-takeover-decompression-and-compression.sh ├── no-context-takeover-decompression-and-compression.sh ├── server │ ├── csr.pem │ ├── public.crt │ ├── privatekey.pem │ └── autobahn-server.go └── client │ └── autobahn-client.go ├── go.mod ├── .gitignore ├── upgrade_test.go ├── time_api.go ├── timer.go ├── server_options.go ├── opcode.go ├── Makefile ├── conn_write.go ├── go.sum ├── .github └── workflows │ └── go.yml ├── client_options.go ├── task_parse.go ├── select_task.go ├── utils.go ├── multi_event_loops_option.go ├── utils_test.go ├── callback.go ├── permessage_deflate.go ├── err.go ├── config.go ├── find_deadlock_test.sh ├── callback_test.go ├── event_loop.go ├── stat.go ├── status_codes.go ├── upgrade.go ├── server_handshake.go ├── multi_event_loops.go ├── client_test.go ├── client.go ├── common_options.go ├── README.md ├── client_option_test.go ├── LICENSE ├── conn_unix.go ├── server_test.go └── conn_core.go /autobahn/run_client.sh: -------------------------------------------------------------------------------- 1 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingserver -s /config/fuzzingserver.json 2 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingserver.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "ws://127.0.0.1:9005", 3 | "outdir": "./report/clients", 4 | "cases": [ 5 | "*" 6 | ], 7 | "exclude-cases": [], 8 | "exclude-agent-cases": {} 9 | } 10 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-io.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "non-tls-io", 6 | "url": "ws://127.0.0.1:9004/autobahn-io", 7 | "options": { 8 | "version": 18 9 | } 10 | } 11 | ], 12 | "cases": [ 13 | "*" 14 | ], 15 | "exclude-cases": [ 16 | "" 17 | ], 18 | "exclude-agent-cases": {} 19 | } -------------------------------------------------------------------------------- /autobahn/run_server-elastic.sh: -------------------------------------------------------------------------------- 1 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-elastic.json 2 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json 3 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-elastic.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "non-tls-elastic", 6 | "url": "ws://127.0.0.1:9004/autobahn-elastic", 7 | "options": { 8 | "version": 18 9 | } 10 | } 11 | ], 12 | "cases": [ 13 | "*" 14 | ], 15 | "exclude-cases": [ 16 | "" 17 | ], 18 | "exclude-agent-cases": {} 19 | } -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-onebyone.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "non-tls-onebyone", 6 | "url": "ws://127.0.0.1:9004/autobahn-onebyone", 7 | "options": { 8 | "version": 18 9 | } 10 | } 11 | ], 12 | "cases": [ 13 | "*" 14 | ], 15 | "exclude-cases": [ 16 | "" 17 | ], 18 | "exclude-agent-cases": {} 19 | } -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/antlabs/greatws 2 | 3 | go 1.24.1 4 | 5 | require ( 6 | github.com/antlabs/pulse v0.0.0-20250706072419-b71307af8032 7 | github.com/antlabs/task v0.0.0-20250706071410-2137462668b9 8 | github.com/antlabs/wsutil v0.1.10 9 | golang.org/x/sys v0.31.0 10 | ) 11 | 12 | require ( 13 | github.com/ebitengine/purego v0.8.4 // indirect 14 | github.com/klauspost/compress v1.17.8 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-context-takeover-decompression.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "context-takeover-decompression-no-tls", 6 | "url": "ws://localhost:9004/context-takeover-decompression", 7 | "options": { 8 | "version": 18 9 | } 10 | } 11 | ], 12 | "cases": [ 13 | "*" 14 | ], 15 | "exclude-cases": [ 16 | "" 17 | ], 18 | "exclude-agent-cases": {} 19 | } -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-no-context-takeover-decompression.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "no-context-takeover-decompression-no-tls", 6 | "url": "ws://localhost:9004/no-context-takeover-decompression", 7 | "options": { 8 | "version": 18 9 | } 10 | } 11 | ], 12 | "cases": [ 13 | "*" 14 | ], 15 | "exclude-cases": [ 16 | "" 17 | ], 18 | "exclude-agent-cases": {} 19 | } -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-context-takeover-decompression-and-compression.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "context-takeover-decompression-and-compression-no-tls", 6 | "url": "ws://localhost:9004/context-takeover-decompression-and-compression", 7 | "options": { 8 | "version": 18 9 | } 10 | } 11 | ], 12 | "cases": [ 13 | "*" 14 | ], 15 | "exclude-cases": [ 16 | "" 17 | ], 18 | "exclude-agent-cases": {} 19 | } -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-no-context-takeover-decompression-and-compression.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "no-context-takeover-decompression-and-compression-no-tls", 6 | "url": "ws://localhost:9004/no-context-takeover-decompression-and-compression", 7 | "options": { 8 | "version": 18 9 | } 10 | } 11 | ], 12 | "cases": [ 13 | "*" 14 | ], 15 | "exclude-cases": [ 16 | "" 17 | ], 18 | "exclude-agent-cases": {} 19 | } -------------------------------------------------------------------------------- /autobahn/run_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" = "Linux" ]; then 4 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-all.json 5 | else 6 | echo "not linux" 7 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-all.json 8 | fi 9 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json 10 | -------------------------------------------------------------------------------- /autobahn/run_server-io.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" = "Linux" ]; then 4 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-io.json 5 | else 6 | echo "not linux" 7 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-io.json 8 | fi 9 | 10 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json 11 | -------------------------------------------------------------------------------- /autobahn/run_server-onebyone.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" = "Linux" ]; then 4 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-onebyone.json 5 | else 6 | echo "not linux" 7 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-onebyone.json 8 | fi 9 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json 10 | -------------------------------------------------------------------------------- /autobahn/context-takeover-decompression.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" = "Linux" ]; then 4 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-context-takeover-decompression.json 5 | else 6 | echo "not linux" 7 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-context-takeover-decompression.json 8 | fi 9 | 10 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json -------------------------------------------------------------------------------- /autobahn/no-context-takeover-decompression.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" = "Linux" ]; then 4 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-no-context-takeover-decompression.json 5 | else 6 | echo "not linux" 7 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-no-context-takeover-decompression.json 8 | fi 9 | 10 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json -------------------------------------------------------------------------------- /autobahn/context-takeover-decompression-and-compression.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" = "Linux" ]; then 4 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-context-takeover-decompression-and-compression.json 5 | else 6 | echo "not linux" 7 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-context-takeover-decompression-and-compression.json 8 | fi 9 | 10 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json -------------------------------------------------------------------------------- /autobahn/no-context-takeover-decompression-and-compression.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname)" = "Linux" ]; then 4 | docker run -it --rm --net=host -v ${PWD}/config:/config -v /root/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-no-context-takeover-decompression-and-compression.json 5 | else 6 | echo "not linux" 7 | docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient-no-context-takeover-decompression-and-compression.json 8 | fi 9 | 10 | #docker run -it --rm --net=host -v ${PWD}/config:/config -v ${PWD}/report:/report crossbario/autobahn-testsuite wstest -m fuzzingclient -s /config/fuzzingclient.json -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | .idea 11 | /.DS_Store 12 | /.vscode 13 | 14 | # Test binary, built with `go test -c` 15 | *.test 16 | 17 | # Output of the go coverage tool, specifically when used with LiteIDE 18 | *.out 19 | 20 | # Dependency directories (remove the comment below to include it) 21 | # vendor/ 22 | 23 | # Go workspace file 24 | go.work 25 | 26 | autobahn-client-darwin-arm64 27 | autobahn-client-linux-amd64 28 | autobahn-server-darwin-arm64 29 | autobahn-server-darwin-arm64-arena 30 | autobahn-server-linux-amd64 31 | /autobahn/report -------------------------------------------------------------------------------- /upgrade_test.go: -------------------------------------------------------------------------------- 1 | package greatws 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | func TestUpgradeInner(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | conf *Config 13 | wantErr bool 14 | }{ 15 | { 16 | name: "Test with nil EventLoop", 17 | conf: &Config{}, 18 | wantErr: true, 19 | }, 20 | // Add more test cases here 21 | } 22 | 23 | for _, tt := range tests { 24 | t.Run(tt.name, func(t *testing.T) { 25 | req, err := http.NewRequest("GET", "/", nil) 26 | if err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | rr := httptest.NewRecorder() 31 | _, err = upgradeInner(rr, req, tt.conf, nil) 32 | 33 | if (err != nil) != tt.wantErr { 34 | t.Errorf("upgradeInner() error = %v, wantErr %v", err, tt.wantErr) 35 | } 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /time_api.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // 抽象出timer定时器接口, 16 | // 包含reset和stop两个方法 17 | package greatws 18 | 19 | import "time" 20 | 21 | type Timer interface { 22 | Reset(d time.Duration) bool 23 | Stop() bool 24 | } 25 | -------------------------------------------------------------------------------- /timer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import "time" 17 | 18 | func afterFunc(d time.Duration, cb func()) *time.Timer { 19 | if true { 20 | return time.AfterFunc(d, cb) 21 | } 22 | panic("unreachable") 23 | } 24 | -------------------------------------------------------------------------------- /server_options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | type ServerOption func(*ConnOption) 18 | 19 | type ConnOption struct { 20 | Config 21 | } 22 | 23 | // 2. 设置服务端支持的子协议 24 | func WithServerSubprotocols(subprotocols []string) ServerOption { 25 | return func(o *ConnOption) { 26 | o.subProtocols = subprotocols 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /opcode.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import "github.com/antlabs/wsutil/opcode" 18 | 19 | type Opcode = opcode.Opcode 20 | 21 | const ( 22 | Continuation = opcode.Continuation 23 | Text = opcode.Text 24 | Binary = opcode.Binary 25 | // 3 - 7保留 26 | _ // 3 27 | _ 28 | _ // 5 29 | _ 30 | _ // 7 31 | Close = opcode.Close 32 | Ping = opcode.Ping 33 | Pong = opcode.Pong 34 | ) 35 | -------------------------------------------------------------------------------- /autobahn/server/csr.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE REQUEST----- 2 | MIICyDCCAbACAQAwgYIxCzAJBgNVBAYTAjExMREwDwYDVQQIDAhzaGFuZ2hhaTER 3 | MA8GA1UEBwwIc2hhbmdoYWkxETAPBgNVBAoMCHNoYW5naGFpMREwDwYDVQQLDAhz 4 | aGFuZ2hhaTERMA8GA1UEAwwIc2hhbmdoYWkxFDASBgkqhkiG9w0BCQEWBXEuY29t 5 | MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtY9JJ7sFbZO1J2/L7vQg 6 | /u63/tTisUfc+Cxecn3NHnh+eaXzTYl08YFJWJX5HLFbIo24mV0UHzNWYihpZ+6y 7 | G6o37zbDJ/0k6QyRxSAi1CyI5K3LWJjrh3wlOiKXCqiwWYW61BvRloGN+gI6CD5S 8 | Kp3crGQE+IbXZ3W2bNmK4BzIJKfIcJekI0XarXcpFuyHxJWCfa6FGKCoOSx9GMK4 9 | w+moWNdM6h/IyX4LQr2B42+baDAm7uw22Eq14wvQNFDHorumlBl3JYZPOxOgKW0K 10 | Gphi8DexE+8JfT5zWnUoSonGc+tNYbOBPyGhkQgPVZ/wGUKC4Ee13JYIdsEIsZr7 11 | owIDAQABoAAwDQYJKoZIhvcNAQELBQADggEBAHzjf7/NrJVv3lKFguvyVr7ndZrt 12 | 7/dxaO5fGONbSqQDeKkikQbW4XMGgH9mQRBXhhb7k914sZW2l6a3zfQx0Qg/ocYD 13 | eW3Mw6wC0/WCuac62LewveJWZoe1aZr/emk35s8jhyochNw5WzmtLg0hoTAkd5VA 14 | 1uexFIlgbG6sktDiPjDwHwSZEvssFO4lPLvuqn887YkPH799qIPGKnnA6nFMixtX 15 | 7ndnHAjQOmALEO4COEKCC/Jl75ReqXqXCKIgrCzM0UKXWNTlFtJeTcx7t3fKLksP 16 | IvKXWZSjWFojHkXcgjg+sxN+xe/3BqfhBG1PleRgtybOiOlwjASBiZ95QM0= 17 | -----END CERTIFICATE REQUEST----- 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | # 1. server 3 | # mac, arm64 4 | GOOS=darwin GOARCH=arm64 go build -o autobahn-server-darwin-arm64.out ./autobahn/server/autobahn-server.go 5 | 6 | GOOS=darwin GOARCH=arm64 go build -race -o autobahn-server-darwin-arm64-race.out ./autobahn/server/autobahn-server.go 7 | # linux amd64 8 | CGO_ENABLE=0 GOOS=linux GOARCH=amd64 go build -o autobahn-server-linux-amd64.out ./autobahn/server/autobahn-server.go 9 | 10 | go build -race -o autobahn-server-linux-amd64-race.out ./autobahn/server/autobahn-server.go 11 | # windows amd64 12 | #GOOS=windows GOARCH=amd64 go build -o autobahn-server-windows-amd64.exe ./autobahn/server/autobahn-server.go 13 | 14 | # mac, arm64 15 | GOOS=darwin GOARCH=arm64 go build -o autobahn-client-darwin-arm64 ./autobahn/client/autobahn-client.go 16 | # linux amd64 17 | GOOS=linux GOARCH=amd64 go build -o autobahn-client-linux-amd64 ./autobahn/client/autobahn-client.go 18 | 19 | key: 20 | openssl genrsa 2048 > privatekey.pem 21 | openssl req -new -key privatekey.pem -out csr.pem 22 | openssl x509 -req -days 36500 -in csr.pem -signkey privatekey.pem -out public.crt 23 | -------------------------------------------------------------------------------- /conn_write.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build linux || darwin || netbsd || freebsd || openbsd || dragonfly 16 | // +build linux darwin netbsd freebsd openbsd dragonfly 17 | 18 | package greatws 19 | 20 | import "unsafe" 21 | 22 | type newConnWrite Conn 23 | 24 | func connToNewConn(c *Conn) *newConnWrite { 25 | return (*newConnWrite)(unsafe.Pointer(c)) 26 | } 27 | 28 | func (c *newConnWrite) Write(p []byte) (n int, err error) { 29 | c2 := (*Conn)(unsafe.Pointer(c)) 30 | return connWrite(c2, p) 31 | } 32 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/antlabs/pulse v0.0.0-20250706072419-b71307af8032 h1:/otfmYx78lTx9sizixAR1NlJs855LgUvPJaoNSoqKEw= 2 | github.com/antlabs/pulse v0.0.0-20250706072419-b71307af8032/go.mod h1:CR7Uwu/iaQvySMEXk44joshcjkxDS6hRFW8lp8Zq8rg= 3 | github.com/antlabs/task v0.0.0-20250706071410-2137462668b9 h1:fNu1p6qI8qGdbmMnxwmsCHPMs6pBm5m3UeCcMkjYSyA= 4 | github.com/antlabs/task v0.0.0-20250706071410-2137462668b9/go.mod h1:6YSWbWdHmW/4s9L0RkmiSCqAf0qmhmb119yzaSbJXww= 5 | github.com/antlabs/wsutil v0.1.10 h1:86p67dG8/iiQ+yZrHVl73OPHGnXfXopFSU0w84fLOdE= 6 | github.com/antlabs/wsutil v0.1.10/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A= 7 | github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= 8 | github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= 9 | github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= 10 | github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= 11 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 12 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 13 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "no-context-takeover-decompression-and-compression-no-tls", 6 | "url": "ws://localhost:9004/no-context-takeover-decompression-and-compression", 7 | "options": { 8 | "version": 18 9 | } 10 | }, 11 | { 12 | "agent": "no-context-takeover-decompression-no-tls", 13 | "url": "ws://localhost:9004/no-context-takeover-decompression", 14 | "options": { 15 | "version": 18 16 | } 17 | }, 18 | { 19 | "agent": "context-takeover-decompression-and-compression-no-tls", 20 | "url": "ws://localhost:9004/context-takeover-decompression-and-compression", 21 | "options": { 22 | "version": 18 23 | } 24 | }, 25 | { 26 | "agent": "context-takeover-decompression-no-tls", 27 | "url": "ws://localhost:9004/context-takeover-decompression", 28 | "options": { 29 | "version": 18 30 | } 31 | } 32 | ], 33 | "cases": [ 34 | "*" 35 | ], 36 | "exclude-cases": [ 37 | "" 38 | ], 39 | "exclude-agent-cases": {} 40 | } -------------------------------------------------------------------------------- /autobahn/server/public.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDhDCCAmwCCQCAT4Vkb5M9+jANBgkqhkiG9w0BAQUFADCBgjELMAkGA1UEBhMC 3 | MTExETAPBgNVBAgMCHNoYW5naGFpMREwDwYDVQQHDAhzaGFuZ2hhaTERMA8GA1UE 4 | CgwIc2hhbmdoYWkxETAPBgNVBAsMCHNoYW5naGFpMREwDwYDVQQDDAhzaGFuZ2hh 5 | aTEUMBIGCSqGSIb3DQEJARYFcS5jb20wIBcNMjMwOTEzMDYxNDIwWhgPMjEyMzA4 6 | MjAwNjE0MjBaMIGCMQswCQYDVQQGEwIxMTERMA8GA1UECAwIc2hhbmdoYWkxETAP 7 | BgNVBAcMCHNoYW5naGFpMREwDwYDVQQKDAhzaGFuZ2hhaTERMA8GA1UECwwIc2hh 8 | bmdoYWkxETAPBgNVBAMMCHNoYW5naGFpMRQwEgYJKoZIhvcNAQkBFgVxLmNvbTCC 9 | ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALWPSSe7BW2TtSdvy+70IP7u 10 | t/7U4rFH3PgsXnJ9zR54fnml802JdPGBSViV+RyxWyKNuJldFB8zVmIoaWfushuq 11 | N+82wyf9JOkMkcUgItQsiOSty1iY64d8JToilwqosFmFutQb0ZaBjfoCOgg+Uiqd 12 | 3KxkBPiG12d1tmzZiuAcyCSnyHCXpCNF2q13KRbsh8SVgn2uhRigqDksfRjCuMPp 13 | qFjXTOofyMl+C0K9geNvm2gwJu7sNthKteML0DRQx6K7ppQZdyWGTzsToCltChqY 14 | YvA3sRPvCX0+c1p1KEqJxnPrTWGzgT8hoZEID1Wf8BlCguBHtdyWCHbBCLGa+6MC 15 | AwEAATANBgkqhkiG9w0BAQUFAAOCAQEAjK9WICUbhrzwgsFmgmQSAenyUO41wHGt 16 | USc66EHaa62648pEYnGAl0Ow5Lns94vWO48c8sT7XpPgoDxI8oz/GrHV8y0XdiLn 17 | Re/Zho+P079WoYzZ7j2/CCxVaI6O9yCWdfoobLRoEH0Rqmb6ihKsgnIfNu4oj8OP 18 | kJ8ncZ3tVC5rZ4qVEZRGjJu9zMO6OhS98i2Op8lxqMIb+lt4+zoAbg2svIEXRWyJ 19 | XD2KeDCv/It95uow1ZZp8SRoORqtJvcmD+wqRkQUkSRvTHhEYadA8Sx4/GYTmEvJ 20 | 945EicoBtL2N2tK2yhq+DjB89ua24SX6tUiZES9gRRtI9KKp0hlslw== 21 | -----END CERTIFICATE----- 22 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | go: [ '1.20'] 14 | name: Go ${{ matrix.go }} sample 15 | 16 | steps: 17 | 18 | - name: Set up Go 1.20 19 | uses: actions/setup-go@v1 20 | with: 21 | go-version: 1.20 22 | id: go 23 | 24 | - name: Check out code into the Go module directory 25 | uses: actions/checkout@v1 26 | 27 | - name: Get dependencies 28 | run: | 29 | go get -v -t -d ./... 30 | if [ -f Gopkg.toml ]; then 31 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 32 | dep ensure 33 | fi 34 | - name: Build 35 | run: go build -v . 36 | 37 | # - name: Test-386 38 | # run: env GOARCH=386 go test -test.run=Test_Retry_sleep -v 39 | #run: env GOARCH=386 go test -v -coverprofile='coverage.out' -covermode=count ./... 40 | 41 | - name: Test-Race 42 | run: env GOARCH=amd64 go test -v -race ./... 43 | 44 | - name: Test-amd64 45 | run: env GOARCH=amd64 go test -v -coverprofile='coverage.out' -covermode=count ./... 46 | 47 | - name: Upload Coverage report 48 | uses: codecov/codecov-action@v1 49 | with: 50 | token: ${{secrets.CODECOV_TOKEN}} 51 | file: ./coverage.out 52 | -------------------------------------------------------------------------------- /autobahn/config/fuzzingclient-all.json: -------------------------------------------------------------------------------- 1 | { 2 | "outdir": "./report/", 3 | "servers": [ 4 | { 5 | "agent": "non-tls-io", 6 | "url": "ws://127.0.0.1:9004/autobahn-io", 7 | "options": { 8 | "version": 18 9 | } 10 | }, 11 | { 12 | "agent": "non-tls-onebyone", 13 | "url": "ws://127.0.0.1:9004/autobahn-onebyone", 14 | "options": { 15 | "version": 18 16 | } 17 | }, 18 | { 19 | "agent": "non-tls-elastic", 20 | "url": "ws://127.0.0.1:9004/autobahn-elastic", 21 | "options": { 22 | "version": 18 23 | } 24 | }, 25 | { 26 | "agent": "no-context-takeover-decompression-no-tls", 27 | "url": "ws://127.0.0.1:9004/no-context-takeover-decompression", 28 | "options": { 29 | "version": 18 30 | } 31 | }, 32 | { 33 | "agent": "no-context-takeover-decompression-and-compression-no-tls", 34 | "url": "ws://127.0.0.1:9004/no-context-takeover-decompression-and-compression", 35 | "options": { 36 | "version": 18 37 | } 38 | }, 39 | { 40 | "agent": "context-takeover-decompression-no-tls", 41 | "url": "ws://127.0.0.1:9004/context-takeover-decompression", 42 | "options": { 43 | "version": 18 44 | } 45 | }, 46 | { 47 | "agent": "context-takeover-decompression-and-compression-no-tls", 48 | "url": "ws://127.0.0.1:9004/context-takeover-decompression-and-compression", 49 | "options": { 50 | "version": 18 51 | } 52 | } 53 | ], 54 | "cases": [ 55 | "*" 56 | ], 57 | "exclude-cases": [ 58 | "" 59 | ], 60 | "exclude-agent-cases": {} 61 | } -------------------------------------------------------------------------------- /client_options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "crypto/tls" 19 | "net/http" 20 | "time" 21 | ) 22 | 23 | type ClientOption func(*DialOption) 24 | 25 | // 1.配置tls.config 26 | func WithClientTLSConfig(tls *tls.Config) ClientOption { 27 | return func(o *DialOption) { 28 | o.tlsConfig = tls 29 | } 30 | } 31 | 32 | // 2.配置http.Header 33 | func WithClientHTTPHeader(h http.Header) ClientOption { 34 | return func(o *DialOption) { 35 | o.Header = h 36 | } 37 | } 38 | 39 | // 3.配置握手时的timeout,tcp连接的timeout 40 | func WithClientDialTimeout(t time.Duration) ClientOption { 41 | return func(o *DialOption) { 42 | o.dialTimeout = t 43 | } 44 | } 45 | 46 | // 4.配置压缩 47 | func WithClientCompression() ClientOption { 48 | return func(o *DialOption) { 49 | o.Compression = true 50 | } 51 | } 52 | 53 | // 6.获取http header 54 | func WithClientBindHTTPHeader(h *http.Header) ClientOption { 55 | return func(o *DialOption) { 56 | o.bindClientHttpHeader = h 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /task_parse.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "runtime" 18 | "sync" 19 | ) 20 | 21 | type taskParse struct { 22 | allTaskParse []*taskParseNode 23 | } 24 | 25 | func newTaskParse() *taskParse { 26 | tp := &taskParse{ 27 | allTaskParse: make([]*taskParseNode, runtime.NumCPU()), 28 | } 29 | wg := sync.WaitGroup{} 30 | wg.Add(runtime.NumCPU()) 31 | 32 | tp.start(&wg) 33 | wg.Wait() 34 | return tp 35 | } 36 | 37 | func (t *taskParse) start(wg *sync.WaitGroup) { 38 | for i := 0; i < len(t.allTaskParse); i++ { 39 | t.allTaskParse[i] = &taskParseNode{ 40 | taskChan: make(chan func() bool, 1024), 41 | } 42 | go t.allTaskParse[i].run(wg) 43 | } 44 | } 45 | 46 | func (t *taskParse) addTask(fd int, f func() bool) { 47 | t.allTaskParse[fd%len(t.allTaskParse)].taskChan <- f 48 | } 49 | 50 | type taskParseNode struct { 51 | taskChan chan func() bool 52 | } 53 | 54 | func (tpn *taskParseNode) run(wg *sync.WaitGroup) { 55 | wg.Done() 56 | for f := range tpn.taskChan { 57 | if !f() { 58 | return 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /select_task.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "context" 18 | 19 | "github.com/antlabs/task/task/driver" 20 | ) 21 | 22 | type selectTask struct { 23 | taskDriverName string 24 | task driver.Tasker 25 | } 26 | type selectTasks []selectTask 27 | 28 | func newSelectTask(ctx context.Context, initCount, min, max int, c *driver.Conf) []selectTask { 29 | 30 | all := driver.GetAllRegister() 31 | rv := make([]selectTask, 0, len(all)) 32 | for _, val := range all { 33 | task := val.Driver.New(ctx, initCount, min, max, c) 34 | rv = append(rv, selectTask{ 35 | taskDriverName: val.Name, 36 | task: task, 37 | }) 38 | } 39 | return rv 40 | } 41 | 42 | func (s *selectTasks) newTask(taskName string) driver.TaskExecutor { 43 | for _, val := range *s { 44 | if val.taskDriverName == taskName { 45 | return val.task.NewExecutor() 46 | } 47 | } 48 | 49 | panic("greatws: no task driver found:" + taskName) 50 | } 51 | 52 | func (s *selectTasks) GetGoroutines() int { 53 | total := 0 54 | for _, val := range *s { 55 | total += val.task.GetGoroutines() 56 | } 57 | 58 | return total 59 | } 60 | -------------------------------------------------------------------------------- /autobahn/server/privatekey.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEAtY9JJ7sFbZO1J2/L7vQg/u63/tTisUfc+Cxecn3NHnh+eaXz 3 | TYl08YFJWJX5HLFbIo24mV0UHzNWYihpZ+6yG6o37zbDJ/0k6QyRxSAi1CyI5K3L 4 | WJjrh3wlOiKXCqiwWYW61BvRloGN+gI6CD5SKp3crGQE+IbXZ3W2bNmK4BzIJKfI 5 | cJekI0XarXcpFuyHxJWCfa6FGKCoOSx9GMK4w+moWNdM6h/IyX4LQr2B42+baDAm 6 | 7uw22Eq14wvQNFDHorumlBl3JYZPOxOgKW0KGphi8DexE+8JfT5zWnUoSonGc+tN 7 | YbOBPyGhkQgPVZ/wGUKC4Ee13JYIdsEIsZr7owIDAQABAoIBAQCTWrWe/1UKeCVA 8 | 2qWDTLQy9AB1XMaX56FZ8ni9J4kAv/62MI/lUDiPgcTLlvzV4sP6qVc3canRINNt 9 | WyshZUM83MwE5EdD/1qjosX0XX6nAXYhU0SEpagTEBkOs+AukHaAUd8uI13Zb1CR 10 | ppj+88WwPOtLJuo54waUO59RfMYP2S43Spk7Hxde+xFCCqjFgVRVCjAYHXPDc6PY 11 | g0pxS9vd55PtGnHOmoJ0sHDcYSP6D3pogJ6bqqKAitStW5pWFWjmS4EC1DkXXRm1 12 | 8jWb3kMYjx2XDpgPL8BTYual1AIJtfLPJLebgd5OB8loo2PQtXYI/7MY/PQ/Jj3O 13 | rsBOEh1BAoGBANwxWYHaLAFevjzxYv3IkoYpsaHyVAn8hud+AZuyMPs8owrE2ELM 14 | HX1LcTQznW5GcBOYwHbTNmq9swoSJpzOE8xdnh60s17/KMEkCpf1Wb10JXFDxbwD 15 | kzrVofJ1TGZpX0lOVM+BbufllxzVcDD4gYRQeUo4vpTON9w4O/IyE9/TAoGBANMV 16 | oyd1eDRvhYa0SrcNu448ujX94lw5wsHz73cKpoHjPnuEvMo3NT1RPMdxJUDtFBHi 17 | ROe8GZtrkYiaiWQFaa3oha0m7DZC9oYxFVEZiZjCC5wPb8uz3VNe6S7ZsQ9vwQjy 18 | nI1pPyHnIKfnVsa7VKjDEczGJUsVCUcWGSEGX+LxAoGBAMLgYklMX/nucgvZzzSw 19 | mQ1oRTABGmOkPXkPyjiT4knYhqv3PzcPE2JarJv4unJooLSXUm9Xyyd5MMXO/qF0 20 | uYz7pf+jCcUfqmAVl9KZcIz+CE1QH+age2Nsw2GkcrOIuq5URzdHZHKUfcMlG6Ab 21 | r/T8i/wmcHWedU3P7y4RKAnFAoGAYTsaSFbH4/9q1j2+HMvqlP5MGAq2dhz1JTok 22 | GAWD5Vizs1nVTKBZmcEN6iCoNFwAXqyHaOcwNHM8OlxU5QnJQB5XVQcUz3nQ7Mc3 23 | NoA47XCUwHIr4P0c+gZCCx6jfKTRmjmG+2x7dDZuyGi7hBdOS95vGJA9JXSvLVt0 24 | f83b5UECgYBm/4/YKSztVnOdz/ws2ch1MdAIag4w2ps9JrMI+Igdkq+LIdyWfrrF 25 | L9yFOKbjLR1TynulpEP4u7PEUH4g6mFriM/k+qLj9KukcNoVvnsfgz9sM+O3b54q 26 | tRlFyGQJefLNuGjs5ZbxN4PDYqzs0Y2RtOBEtmkF6rTdpSC5BUdQDQ== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "crypto/sha1" 19 | "encoding/base64" 20 | "math/rand" 21 | "sync" 22 | "time" 23 | "unsafe" 24 | ) 25 | 26 | var rng = rand.New(rand.NewSource(time.Now().UnixNano())) 27 | 28 | var mu sync.Mutex 29 | var uuid = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 30 | 31 | // StringToBytes 没有内存开销的转换 32 | // 33 | // func StringToBytes(s string) (b []byte) { 34 | // bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) 35 | // sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) 36 | // bh.Data = sh.Data 37 | // bh.Len = sh.Len 38 | // bh.Cap = sh.Len 39 | // return b 40 | // } 41 | func StringToBytes(s string) []byte { 42 | return unsafe.Slice(unsafe.StringData(s), len(s)) 43 | } 44 | 45 | // func BytesToString(b []byte) string { 46 | // return *(*string)(unsafe.Pointer(&b)) 47 | // } 48 | 49 | func secWebSocketAccept() string { 50 | // rfc规定是16字节 51 | var key [16]byte 52 | mu.Lock() 53 | rng.Read(key[:]) 54 | mu.Unlock() 55 | return base64.StdEncoding.EncodeToString(key[:]) 56 | } 57 | 58 | func secWebSocketAcceptVal(val string) string { 59 | s := sha1.New() 60 | s.Write(StringToBytes(val)) 61 | s.Write(uuid) 62 | r := s.Sum(nil) 63 | return base64.StdEncoding.EncodeToString(r) 64 | } 65 | -------------------------------------------------------------------------------- /multi_event_loops_option.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "log/slog" 18 | ) 19 | 20 | type EvOption func(e *MultiEventLoop) 21 | 22 | // 开启几个事件循环, 控制io go程数量 23 | func WithEventLoops(num int) EvOption { 24 | return func(e *MultiEventLoop) { 25 | e.numLoops = num 26 | } 27 | } 28 | 29 | // 最小业务goroutine数量, 控制业务go程数量 30 | // initCount: 初始化的协程数 31 | // min: 最小协程数 32 | // max: 最大协程数 33 | func WithBusinessGoNum(initCount, min, max int) EvOption { 34 | return func(e *MultiEventLoop) { 35 | if initCount <= 0 { 36 | initCount = defTaskInitCount 37 | } 38 | 39 | if min <= 0 { 40 | min = defTaskMin 41 | } 42 | 43 | if max <= 0 { 44 | max = defTaskMax 45 | } 46 | e.configTask.initCount = initCount 47 | e.configTask.min = min 48 | e.configTask.max = max 49 | } 50 | } 51 | 52 | // 设置business go程池 对流量压测友好的模式 53 | // func WithBusinessGoTrafficMode() EvOption { 54 | // return func(e *MultiEventLoop) { 55 | // e.taskMode = trafficMode 56 | // } 57 | // } 58 | 59 | // 设置日志级别 60 | func WithLogLevel(level slog.Level) EvOption { 61 | return func(e *MultiEventLoop) { 62 | e.level = level 63 | } 64 | } 65 | 66 | // 设置每个事件循环一次返回的最大事件数量 67 | func WithMaxEventNum(num int) EvOption { 68 | return func(e *MultiEventLoop) { 69 | e.maxEventNum = num 70 | } 71 | } 72 | 73 | // 暂时不可用 74 | // 是否使用io_uring, 支持linux系统,需要内核版本6.2.0以上(以后只会在>=6.2.0的版本上测试) 75 | // func WithIoUring() EvOption { 76 | // return func(e *MultiEventLoop) { 77 | // e.flag |= EVENT_IOURING 78 | // } 79 | // } 80 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "reflect" 19 | "testing" 20 | ) 21 | 22 | func TestStringToBytes(t *testing.T) { 23 | type args struct { 24 | s string 25 | } 26 | tests := []struct { 27 | name string 28 | args args 29 | wantB []byte 30 | }{ 31 | { 32 | name: "test1", 33 | args: args{s: "test1"}, 34 | wantB: []byte("test1"), 35 | }, 36 | { 37 | name: "test2", 38 | args: args{s: "test2"}, 39 | wantB: []byte("test2"), 40 | }, 41 | } 42 | for _, tt := range tests { 43 | t.Run(tt.name, func(t *testing.T) { 44 | if gotB := StringToBytes(tt.args.s); !reflect.DeepEqual(gotB, tt.wantB) { 45 | t.Errorf("StringToBytes() = %v, want %v", gotB, tt.wantB) 46 | } 47 | }) 48 | } 49 | } 50 | 51 | func Test_secWebSocketAccept(t *testing.T) { 52 | tests := []struct { 53 | name string 54 | want string 55 | }{ 56 | {name: ">0"}, 57 | } 58 | for _, tt := range tests { 59 | t.Run(tt.name, func(t *testing.T) { 60 | if got := secWebSocketAccept(); len(got) == 0 { 61 | t.Errorf("secWebSocketAccept() = %v, want %v", got, tt.want) 62 | } 63 | }) 64 | } 65 | } 66 | 67 | func Test_secWebSocketAcceptVal(t *testing.T) { 68 | type args struct { 69 | val string 70 | } 71 | tests := []struct { 72 | name string 73 | args args 74 | want string 75 | }{ 76 | {name: "test1", args: args{val: "test1"}}, 77 | } 78 | for _, tt := range tests { 79 | t.Run(tt.name, func(t *testing.T) { 80 | if got := secWebSocketAcceptVal(tt.args.val); len(got) == 0 { 81 | t.Errorf("secWebSocketAcceptVal() = %v, want %v", got, tt.want) 82 | } 83 | }) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /callback.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | type ( 17 | Callback interface { 18 | OnOpen(*Conn) 19 | OnMessage(*Conn, Opcode, []byte) 20 | OnClose(*Conn, error) 21 | } 22 | ) 23 | 24 | type ( 25 | OnOpenFunc func(*Conn) 26 | ) 27 | 28 | // 1. 默认的OnOpen, OnMessage, OnClose都是空函数 29 | type DefCallback struct{} 30 | 31 | func (defcallback *DefCallback) OnOpen(_ *Conn) { 32 | } 33 | 34 | func (defcallback *DefCallback) OnMessage(_ *Conn, _ Opcode, _ []byte) { 35 | } 36 | 37 | func (defcallback *DefCallback) OnClose(_ *Conn, _ error) { 38 | } 39 | 40 | // 2. 只设置OnMessage, 和OnClose互斥 41 | type OnMessageFunc func(*Conn, Opcode, []byte) 42 | 43 | func (o OnMessageFunc) OnOpen(_ *Conn) { 44 | } 45 | 46 | func (o OnMessageFunc) OnMessage(c *Conn, op Opcode, data []byte) { 47 | o(c, op, data) 48 | } 49 | 50 | func (o OnMessageFunc) OnClose(_ *Conn, _ error) { 51 | } 52 | 53 | // 3. 只设置OnClose, 和OnMessage互斥 54 | type OnCloseFunc func(*Conn, error) 55 | 56 | func (o OnCloseFunc) OnOpen(_ *Conn) { 57 | } 58 | 59 | func (o OnCloseFunc) OnMessage(_ *Conn, _ Opcode, _ []byte) { 60 | } 61 | 62 | func (o OnCloseFunc) OnClose(c *Conn, err error) { 63 | o(c, err) 64 | } 65 | 66 | // 4. 函数转换为接口 67 | type funcToCallback struct { 68 | onOpen func(*Conn) 69 | onMessage func(*Conn, Opcode, []byte) 70 | onClose func(*Conn, error) 71 | } 72 | 73 | func (f *funcToCallback) OnOpen(c *Conn) { 74 | if f.onOpen != nil { 75 | f.onOpen(c) 76 | } 77 | } 78 | 79 | func (f *funcToCallback) OnMessage(c *Conn, op Opcode, data []byte) { 80 | if f.onMessage != nil { 81 | f.onMessage(c, op, data) 82 | } 83 | } 84 | 85 | func (f *funcToCallback) OnClose(c *Conn, err error) { 86 | if f.onClose != nil { 87 | f.onClose(c, err) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /permessage_deflate.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "sync/atomic" 19 | "unsafe" 20 | 21 | "github.com/antlabs/wsutil/deflate" 22 | ) 23 | 24 | // 压缩的入口函数 25 | func (c *Conn) encoode(payload *[]byte) (encodePayload *[]byte, err error) { 26 | 27 | ct := (c.pd.ClientContextTakeover && c.client || !c.client && c.pd.ServerContextTakeover) && c.pd.Compression 28 | // 上下文接管 29 | bit := uint8(0) 30 | if c.client { 31 | bit = c.pd.ClientMaxWindowBits 32 | } else { 33 | bit = c.pd.ServerMaxWindowBits 34 | } 35 | if ct { 36 | // 这里的读取是单go程的。所以不用加锁 37 | if atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&c.enCtx))) == nil { 38 | 39 | c.mu.Lock() //加锁 40 | if atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&c.enCtx))) != nil { 41 | goto decode 42 | } 43 | enCtx, err := deflate.NewCompressContextTakeover(bit) 44 | if err != nil { 45 | c.mu.Unlock() 46 | return nil, err 47 | } 48 | atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&c.enCtx)), unsafe.Pointer(enCtx)) 49 | c.mu.Unlock() 50 | } 51 | 52 | } 53 | c.mu.Lock() 54 | decode: 55 | defer c.mu.Unlock() 56 | // 处理上下文接管和非上下文接管两种情况 57 | // bit 为啥放在参数里面传递, 因为非上下文接管的时候,也需要正确处理bit 58 | return c.enCtx.Compress(payload, bit) 59 | } 60 | 61 | // 解压缩入口函数 62 | // 解压目前只在一个go程里面按序列处理,所以不需要加锁 63 | func (c *Conn) decode(payload *[]byte) (decodePayload *[]byte, err error) { 64 | ct := (c.pd.ClientContextTakeover && c.client || !c.client && c.pd.ServerContextTakeover) && c.pd.Decompression 65 | // 上下文接管 66 | if ct { 67 | if c.deCtx == nil { 68 | 69 | bit := uint8(0) 70 | 71 | if c.client { 72 | bit = c.pd.ClientMaxWindowBits 73 | } else { 74 | bit = c.pd.ServerMaxWindowBits 75 | } 76 | c.deCtx, err = deflate.NewDecompressContextTakeover(bit) 77 | if err != nil { 78 | return nil, err 79 | } 80 | } 81 | 82 | } 83 | 84 | // 上下文接管, deCtx是nil 85 | // 非上下文接管, deCtx是非nil 86 | 87 | return c.deCtx.Decompress(payload, c.readMaxMessage) 88 | } 89 | -------------------------------------------------------------------------------- /err.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import "errors" 18 | 19 | var ( 20 | // conn已经被关闭 21 | ErrClosed = errors.New("closed") 22 | 23 | ErrWrongStatusCode = errors.New("Wrong status code") 24 | ErrUpgradeFieldValue = errors.New("The value of the upgrade field is not 'websocket'") 25 | ErrConnectionFieldValue = errors.New("The value of the connection field is not 'upgrade'") 26 | ErrSecWebSocketAccept = errors.New("The value of Sec-WebSocketAaccept field is invalid") 27 | 28 | ErrHostCannotBeEmpty = errors.New("Host cannot be empty") 29 | ErrSecWebSocketKey = errors.New("The value of SEC websocket key field is wrong") 30 | ErrSecWebSocketVersion = errors.New("The value of SEC websocket version field is wrong, not 13") 31 | 32 | ErrHTTPProtocolNotSupported = errors.New("HTTP protocol not supported") 33 | 34 | ErrOnlyGETSupported = errors.New("error:Only get methods are supported") 35 | ErrMaxControlFrameSize = errors.New("error:max control frame size > 125, need <= 125") 36 | ErrRsv123 = errors.New("error:rsv1 or rsv2 or rsv3 has a value") 37 | ErrOpcode = errors.New("error:wrong opcode") 38 | ErrNOTBeFragmented = errors.New("error:since control message MUST NOT be fragmented") 39 | ErrFrameOpcode = errors.New("error:since all data frames after the initial data frame must have opcode 0.") 40 | ErrTextNotUTF8 = errors.New("error:text is not utf8 data") 41 | ErrClosePayloadTooSmall = errors.New("error:close payload too small") 42 | ErrCloseValue = errors.New("error:close value is wrong") // close值不对 43 | ErrEmptyClose = errors.New("error:close value is empty") // close的值是空的 44 | ErrWriteClosed = errors.New("write close") 45 | ) 46 | 47 | var ( 48 | // 事件循环为空 49 | ErrEventLoopEmpty = errors.New("event loop is empty, Need to call WithServerMultiEventLoop or WithClientMultiEventLoop for configuration") 50 | // 事件循环没有启动 51 | ErrEventLoopNotStart = errors.New("event loop not start") 52 | ) 53 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "time" 19 | 20 | "github.com/antlabs/wsutil/deflate" 21 | ) 22 | 23 | // Config 配置 24 | // 有两种方式可以配置相关值 25 | // 1. NewUpgrade, 这通常在只初始化一次的时候使用 26 | // 2. greatws.Upgrade(), 这通常在每次请求的时候使用,每个语法的配置参数不一样 27 | // 这样可以方便的在两种方式中使用, 不需要担心配置参数会有并发修改的情况 28 | type Config struct { 29 | cb Callback // 静态配置 30 | deflate.PermessageDeflateConf // 静态配置, 从WithXXX函数中获取 31 | tcpNoDelay bool // TODO: 加下这个功能 32 | replyPing bool // 开启自动回复 33 | ignorePong bool // 忽略pong消息 34 | disableBufioClearHack bool // 关闭bufio的clear hack优化 35 | utf8Check func([]byte) bool // utf8检查 36 | readTimeout time.Duration // 加下这个功能 37 | windowsMultipleTimesPayloadSize float32 // 设置几倍(1024+14)的payload大小 38 | maxDelayWriteNum int32 // 最大延迟包的个数, 默认值为10 39 | delayWriteInitBufferSize int32 // 延迟写入的初始缓冲区大小, 默认值是8k 40 | maxDelayWriteDuration time.Duration // 最大延迟时间, 默认值是10ms 41 | subProtocols []string // 设置支持的子协议 42 | multiEventLoop *MultiEventLoop // 事件循环 43 | runInGoTask string // 运行业务OnMessage的策略, 现在greatws集成三种OnMessage运行模式,分别是io, task 44 | readMaxMessage int64 // 最大消息大小 45 | flowBackPressureRemoveRead bool // 流控背压机制,移除读事件 46 | } 47 | 48 | // func (c *Config) useIoUring() bool { 49 | // return c.multiEventLoop.flag == EVENT_IOURING 50 | // } 51 | 52 | // 默认设置 53 | func (c *Config) defaultSetting() { 54 | c.cb = &DefCallback{} 55 | c.maxDelayWriteNum = 10 56 | c.windowsMultipleTimesPayloadSize = 1.0 57 | c.delayWriteInitBufferSize = 8 * 1024 58 | c.maxDelayWriteDuration = 10 * time.Millisecond 59 | // c.runInGoStrategy = taskStrategyBind 60 | c.tcpNoDelay = true 61 | // 对于text消息,默认不检查text是utf8字符 62 | c.utf8Check = func(b []byte) bool { return true } 63 | c.runInGoTask = "elastic" //默认使用elastic模块 64 | } 65 | 66 | func (c *Config) defaultSettingAfter() { 67 | 68 | if c.multiEventLoop == nil { 69 | c.multiEventLoop = getDefaultMultiEventLoop() 70 | } 71 | c.multiEventLoop.Start() 72 | } 73 | -------------------------------------------------------------------------------- /find_deadlock_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置超时时间(秒) 4 | TIMEOUT=30 5 | 6 | # 颜色定义 7 | RED='\033[0;31m' 8 | GREEN='\033[0;32m' 9 | YELLOW='\033[1;33m' 10 | NC='\033[0m' # No Color 11 | 12 | echo "🔍 开始逐个运行测试函数,查找死锁问题..." 13 | echo "⏰ 设置超时时间: ${TIMEOUT}秒" 14 | echo "==========================" 15 | 16 | # 测试函数列表 17 | declare -a tests=( 18 | "TestStringToBytes" 19 | "Test_secWebSocketAccept" 20 | "Test_secWebSocketAcceptVal" 21 | "TestUpgradeInner" 22 | "Test_Server_HandshakeFail" 23 | "Test_DefaultCallback" 24 | "Test_ClientOption" 25 | "Test_CommonOption" 26 | "Test_Client_Dial_Check_Header" 27 | "Test_Conn" 28 | "Test_ReadMessage" 29 | "TestFragmentFrame" 30 | "Test_WriteControl" 31 | "Test_API" 32 | "TestPingPongClose" 33 | ) 34 | 35 | # 记录结果 36 | passed_tests=() 37 | failed_tests=() 38 | timeout_tests=() 39 | 40 | # 运行单个测试函数 41 | run_test() { 42 | local test_func=$1 43 | 44 | echo -e "\n📝 运行测试: ${YELLOW}${test_func}${NC}" 45 | 46 | # 使用 Go 的内置超时机制运行测试 47 | go test -v -timeout ${TIMEOUT}s -run "^${test_func}$" . 2>&1 & 48 | local test_pid=$! 49 | 50 | # 等待测试完成或超时 51 | local start_time=$(date +%s) 52 | while kill -0 $test_pid 2>/dev/null; do 53 | local current_time=$(date +%s) 54 | local elapsed=$((current_time - start_time)) 55 | 56 | if [ $elapsed -ge $TIMEOUT ]; then 57 | kill -9 $test_pid 2>/dev/null 58 | wait $test_pid 2>/dev/null 59 | echo -e "⏰ ${RED}TIMEOUT (可能死锁)${NC}: ${test_func}" 60 | timeout_tests+=("${test_func}") 61 | return 62 | fi 63 | 64 | sleep 1 65 | done 66 | 67 | # 获取测试退出状态 68 | wait $test_pid 69 | local exit_code=$? 70 | 71 | case $exit_code in 72 | 0) 73 | echo -e "✅ ${GREEN}PASSED${NC}: ${test_func}" 74 | passed_tests+=("${test_func}") 75 | ;; 76 | *) 77 | echo -e "❌ ${RED}FAILED${NC}: ${test_func} (退出码: ${exit_code})" 78 | failed_tests+=("${test_func}") 79 | ;; 80 | esac 81 | } 82 | 83 | # 逐个运行测试 84 | for test in "${tests[@]}"; do 85 | run_test "$test" 86 | 87 | # 每个测试之间稍微等待一下 88 | sleep 2 89 | done 90 | 91 | # 输出总结 92 | echo -e "\n==========================" 93 | echo -e "📊 ${YELLOW}测试结果总结${NC}" 94 | echo -e "==========================" 95 | 96 | echo -e "\n✅ ${GREEN}通过的测试 (${#passed_tests[@]})${NC}:" 97 | for test in "${passed_tests[@]}"; do 98 | echo " - $test" 99 | done 100 | 101 | echo -e "\n❌ ${RED}失败的测试 (${#failed_tests[@]})${NC}:" 102 | for test in "${failed_tests[@]}"; do 103 | echo " - $test" 104 | done 105 | 106 | echo -e "\n⏰ ${RED}超时的测试 (疑似死锁) (${#timeout_tests[@]})${NC}:" 107 | for test in "${timeout_tests[@]}"; do 108 | echo " - $test" 109 | done 110 | 111 | if [ ${#timeout_tests[@]} -gt 0 ]; then 112 | echo -e "\n🚨 ${RED}发现 ${#timeout_tests[@]} 个疑似死锁的测试函数!${NC}" 113 | echo -e "建议进一步分析这些函数的实现。" 114 | 115 | echo -e "\n🔍 可以单独深入分析超时的测试:" 116 | for test in "${timeout_tests[@]}"; do 117 | echo " go test -v -timeout ${TIMEOUT}s -run '^${test}$'" 118 | done 119 | else 120 | echo -e "\n🎉 ${GREEN}未发现明显的死锁问题!${NC}" 121 | fi 122 | 123 | echo -e "\n📝 如需更详细的分析,可以使用以下命令:" 124 | echo " go test -v -run '^函数名$' -timeout ${TIMEOUT}s" 125 | echo " go test -race -v -run '^函数名$' -timeout ${TIMEOUT}s # 检测竞态条件" -------------------------------------------------------------------------------- /callback_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "log/slog" 18 | "net/http" 19 | "net/http/httptest" 20 | "strings" 21 | "sync/atomic" 22 | "testing" 23 | "time" 24 | ) 25 | 26 | type testDefaultCallback struct { 27 | DefCallback 28 | } 29 | 30 | func Test_DefaultCallback(t *testing.T) { 31 | 32 | m := NewMultiEventLoopAndStartMust(WithEventLoops(1), WithLogLevel(slog.LevelDebug), WithBusinessGoNum(1, 1, 1)) 33 | t.Run("local: default callback", func(t *testing.T) { 34 | run := int32(0) 35 | done := make(chan bool, 1) 36 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 | c, err := Upgrade(w, r, WithServerCallback(&testDefaultCallback{}), WithServerMultiEventLoop(m)) 38 | if err != nil { 39 | t.Error(err) 40 | } 41 | defer c.Close() 42 | 43 | err = c.WriteMessage(Binary, []byte("hello")) 44 | if err != nil { 45 | t.Error(err) 46 | } 47 | 48 | atomic.AddInt32(&run, int32(1)) 49 | done <- true 50 | })) 51 | 52 | defer ts.Close() 53 | 54 | url := strings.ReplaceAll(ts.URL, "http", "ws") 55 | con, err := Dial(url, WithClientCallback(&testDefaultCallback{}), WithClientMultiEventLoop(m)) 56 | if err != nil { 57 | t.Error(err) 58 | } 59 | defer con.Close() 60 | 61 | err = con.WriteMessage(Binary, []byte("hello")) 62 | if err != nil { 63 | t.Errorf("WriteMessage fail:%v\n", err) 64 | return 65 | } 66 | select { 67 | case <-done: 68 | case <-time.After(1000 * time.Millisecond): 69 | } 70 | if atomic.LoadInt32(&run) != 1 { 71 | t.Error("not run server:method fail") 72 | } 73 | }) 74 | 75 | t.Run("global: default callback", func(t *testing.T) { 76 | run := int32(0) 77 | done := make(chan bool, 1) 78 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 79 | c, err := Upgrade(w, r, WithServerCallback(&testDefaultCallback{}), WithServerMultiEventLoop(m)) 80 | if err != nil { 81 | t.Error(err) 82 | } 83 | err = c.WriteMessage(Binary, []byte("hello")) 84 | if err != nil { 85 | t.Error(err) 86 | return 87 | } 88 | atomic.AddInt32(&run, int32(1)) 89 | done <- true 90 | })) 91 | 92 | defer ts.Close() 93 | 94 | url := strings.ReplaceAll(ts.URL, "http", "ws") 95 | con, err := Dial(url, WithClientCallback(&testDefaultCallback{}), WithClientMultiEventLoop(m)) 96 | if err != nil { 97 | t.Error(err) 98 | } 99 | defer con.Close() 100 | 101 | err = con.WriteMessage(Binary, []byte("hello")) 102 | if err != nil { 103 | t.Errorf("WriteMessage:%v\n", err) 104 | return 105 | } 106 | select { 107 | case <-done: 108 | case <-time.After(100 * time.Millisecond): 109 | } 110 | if atomic.LoadInt32(&run) != 1 { 111 | t.Error("not run server:method fail") 112 | } 113 | }) 114 | } 115 | -------------------------------------------------------------------------------- /event_loop.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "context" 18 | "errors" 19 | "sync/atomic" 20 | "time" 21 | 22 | "github.com/antlabs/pulse/core" 23 | "github.com/antlabs/task/task/driver" 24 | ) 25 | 26 | type evFlag int 27 | 28 | const ( 29 | EVENT_EPOLL evFlag = 1 << iota 30 | EVENT_IOURING 31 | ) 32 | 33 | type EventLoop struct { 34 | maxFd int // highest file descriptor currently registered 35 | setSize int // max number of file descriptors tracked 36 | core.PollingApi 37 | shutdown bool 38 | parent *MultiEventLoop 39 | localTask selectTasks 40 | } 41 | 42 | // 初始化函数 43 | func CreateEventLoop(setSize int, flag evFlag, parent *MultiEventLoop) (e *EventLoop, err error) { 44 | e = &EventLoop{ 45 | setSize: setSize, 46 | maxFd: -1, 47 | parent: parent, 48 | } 49 | 50 | var c driver.Conf 51 | c.Log = parent.Logger 52 | // 初始化任务池 53 | e.localTask = newSelectTask(parent.ctx, parent.configTask.initCount, parent.configTask.min, parent.configTask.max, &c) 54 | 55 | // TODO+ 56 | // e.localTask.taskConfig = e.parent.configTask.taskConfig 57 | // e.localTask.taskMode = e.parent.configTask.taskMode 58 | // e.localTask.init() 59 | e.PollingApi, err = core.Create(core.TriggerType(flag)) 60 | return e, err 61 | } 62 | 63 | // 柔性关闭所有的连接 64 | func (e *EventLoop) Shutdown(ctx context.Context) error { 65 | return nil 66 | } 67 | 68 | func (el *EventLoop) Loop() { 69 | for !el.shutdown { 70 | _, err := el.Poll(time.Duration(time.Second*100), func(fd int, state core.State, err error) { 71 | c := el.parent.safeConns.Get(fd) 72 | if err != nil { 73 | if errors.Is(err, core.EAGAIN) { 74 | return 75 | } 76 | if c != nil { 77 | c.Close() 78 | } 79 | el.parent.Error("apiPoll", "err", err.Error()) 80 | return 81 | } 82 | 83 | if c == nil { 84 | el.parent.Logger.Error("apiPoll c is nil", "fd", fd) 85 | return 86 | } 87 | 88 | if state.IsWrite() && c.needFlush() { 89 | c.flush() 90 | } 91 | 92 | if state.IsRead() { 93 | if err := c.processWebsocketFrame(); err != nil { 94 | c.Close() 95 | } 96 | } 97 | }) 98 | if err != nil { 99 | el.parent.Error("apiPoll", "err", err.Error()) 100 | return 101 | } 102 | } 103 | } 104 | 105 | // 获取一个连接 106 | func (m *EventLoop) getConn(fd int) *Conn { 107 | return m.parent.safeConns.Get(fd) 108 | } 109 | 110 | func (el *EventLoop) del(c *Conn) { 111 | fd := c.getFd() 112 | atomic.AddInt64(&el.parent.curConn, -1) 113 | el.parent.safeConns.Del(fd) 114 | // el.conns.Delete(fd) 115 | closeFd(fd) 116 | } 117 | 118 | func (el *EventLoop) delRead(c *Conn) error { 119 | return el.Del(c.getFd()) 120 | } 121 | 122 | func (el *EventLoop) addWrite(c *Conn) error { 123 | return el.AddWrite(c.getFd()) 124 | } 125 | 126 | func (el *EventLoop) GetApiName() string { 127 | return el.Name() 128 | } 129 | -------------------------------------------------------------------------------- /stat.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import "sync/atomic" 17 | 18 | // 统计信息 19 | type stat struct { 20 | readSyssall int64 // 读系统调用次数 21 | writeSyscall int64 // 写系统调用次数 22 | curConn int64 // 当前tcp连接数 23 | realloc int64 // 重新分配内存次数 24 | moveBytes uint64 // 移动字节数 25 | readEv int64 // 读事件次数 26 | writeEv int64 // 写事件次数 27 | pollEv int64 // poll事件次数, 包含读,写, 错误事件 28 | } 29 | 30 | // 对外接口,查询当前业务协程池个数 31 | func (m *MultiEventLoop) GetCurGoNum() (total int) { 32 | for _, v := range m.loops { 33 | // 本地任务数 34 | total += int(v.localTask.GetGoroutines()) 35 | } 36 | return 37 | } 38 | 39 | // 对外接口,查询业务协程池运行的当前业务数 40 | func (m *MultiEventLoop) GetCurTaskNum() (total int64) { 41 | for _, v := range m.loops { 42 | // 本地任务数 43 | total += int64(v.localTask.GetGoroutines()) 44 | } 45 | return 46 | } 47 | 48 | // 对外接口,查询移动字节数 49 | func (m *MultiEventLoop) GetMoveBytesNum() uint64 { 50 | return atomic.LoadUint64(&m.moveBytes) 51 | } 52 | 53 | // 对外接口,查询重新分配内存次数 54 | func (m *MultiEventLoop) GetReallocNum() int64 { 55 | return atomic.LoadInt64(&m.realloc) 56 | } 57 | 58 | // 对外接口,查询read syscall次数 59 | func (m *MultiEventLoop) GetReadSyscallNum() int64 { 60 | return atomic.LoadInt64(&m.readSyssall) 61 | } 62 | 63 | // 对外接口,查询write syscall次数 64 | func (m *MultiEventLoop) GetWriteSyscallNum() int64 { 65 | return atomic.LoadInt64(&m.writeSyscall) 66 | } 67 | 68 | // 对外接口,查询当前websocket连接数 69 | func (m *MultiEventLoop) GetCurConnNum() int64 { 70 | return atomic.LoadInt64(&m.curConn) 71 | } 72 | 73 | // 对外接口,查询poll read事件次数 74 | func (m *MultiEventLoop) GetReadEvNum() int64 { 75 | return atomic.LoadInt64(&m.readEv) 76 | } 77 | 78 | // 对外接口,查询poll write事件次数 79 | func (m *MultiEventLoop) GetWriteEvNum() int64 { 80 | return atomic.LoadInt64(&m.writeEv) 81 | } 82 | 83 | // 对外接口,查询poll 返回的事件总次数 84 | func (m *MultiEventLoop) GetPollEvNum() int64 { 85 | return atomic.LoadInt64(&m.pollEv) 86 | } 87 | 88 | // 对外接口,返回当前使用的api名字 89 | func (m *MultiEventLoop) GetApiName() string { 90 | if len(m.loops) == 0 { 91 | return "" 92 | } 93 | 94 | return m.loops[0].GetApiName() 95 | } 96 | 97 | // 对内接口 98 | func (m *MultiEventLoop) addRealloc() { 99 | atomic.AddInt64(&m.realloc, 1) 100 | } 101 | 102 | // 对内接口 103 | func (m *MultiEventLoop) addReadSyscall() { 104 | atomic.AddInt64(&m.readSyssall, 1) 105 | } 106 | 107 | // 对内接口 108 | func (m *MultiEventLoop) addWriteSyscall() { 109 | atomic.AddInt64(&m.writeSyscall, 1) 110 | } 111 | 112 | // 对内接口 113 | func (m *MultiEventLoop) addMoveBytes(n uint64) { 114 | atomic.AddUint64(&m.moveBytes, n) 115 | } 116 | 117 | // 对内接口 118 | func (m *MultiEventLoop) addReadEvNum() { 119 | atomic.AddInt64(&m.readEv, 1) 120 | } 121 | 122 | // 对内接口 123 | func (m *MultiEventLoop) addWriteEvNum() { 124 | atomic.AddInt64(&m.writeEv, 1) 125 | } 126 | 127 | // 对内接口 128 | func (m *MultiEventLoop) addPollEvNum() { 129 | atomic.AddInt64(&m.pollEv, 1) 130 | } 131 | -------------------------------------------------------------------------------- /autobahn/client/autobahn-client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | "runtime" 7 | "strconv" 8 | "sync" 9 | "time" 10 | 11 | "github.com/antlabs/greatws" 12 | ) 13 | 14 | // https://github.com/snapview/tokio-tungstenite/blob/master/examples/autobahn-client.rs 15 | 16 | const ( 17 | // host = "ws://192.168.128.44:9003" 18 | host = "ws://127.0.0.1:9005" 19 | agent = "greatws" 20 | ) 21 | 22 | type handler struct { 23 | m *greatws.MultiEventLoop 24 | } 25 | 26 | type echoHandler struct { 27 | wg *sync.WaitGroup 28 | done chan struct{} 29 | } 30 | 31 | func (e *echoHandler) OnOpen(c *greatws.Conn) { 32 | fmt.Printf("OnOpen::%p\n", c) 33 | } 34 | 35 | func (e *echoHandler) OnMessage(c *greatws.Conn, op greatws.Opcode, msg []byte) { 36 | // fmt.Printf("OnMessage: opcode:%s, msg.size:%d\n", op, len(msg)) 37 | if op == greatws.Text || op == greatws.Binary { 38 | // os.WriteFile("./debug.dat", msg, 0o644) 39 | // if err := c.WriteMessage(op, msg); err != nil { 40 | // fmt.Println("write fail:", err) 41 | // } 42 | if err := c.WriteTimeout(op, msg, 1*time.Minute); err != nil { 43 | fmt.Println("write fail:", err) 44 | } 45 | } 46 | } 47 | 48 | func (e *echoHandler) OnClose(c *greatws.Conn, err error) { 49 | fmt.Println("OnClose:", c, err) 50 | // defer e.wg.Done() 51 | close(e.done) 52 | } 53 | 54 | func (h *handler) getCaseCount() int { 55 | var count int 56 | done := make(chan bool, 1) 57 | c, err := greatws.Dial(fmt.Sprintf("%s/getCaseCount", host), greatws.WithClientMultiEventLoop(h.m), greatws.WithClientOnMessageFunc(func() greatws.OnMessageFunc { 58 | return func(c *greatws.Conn, op greatws.Opcode, msg []byte) { 59 | var err error 60 | count, err = strconv.Atoi(string(msg)) 61 | if err != nil { 62 | panic(err) 63 | } 64 | done <- true 65 | fmt.Printf("msg(%s)\n", msg) 66 | c.Close() 67 | } 68 | }())) 69 | if err != nil { 70 | panic(err) 71 | } 72 | defer c.Close() 73 | 74 | err = c.ReadLoop() 75 | <-done 76 | fmt.Printf("readloop rv:%s\n", err) 77 | return count 78 | } 79 | 80 | func (h *handler) runTest(caseNo int, wg *sync.WaitGroup) { 81 | done := make(chan struct{}) 82 | c, err := greatws.Dial(fmt.Sprintf("%s/runCase?case=%d&agent=%s", host, caseNo, agent), 83 | greatws.WithClientReplyPing(), 84 | greatws.WithClientEnableUTF8Check(), 85 | greatws.WithClientDecompressAndCompress(), 86 | greatws.WithClientContextTakeover(), 87 | greatws.WithClientMaxWindowsBits(10), 88 | greatws.WithClientCallback(&echoHandler{done: done, wg: wg}), 89 | greatws.WithClientMultiEventLoop(h.m), 90 | ) 91 | if err != nil { 92 | fmt.Println("Dial fail:", err) 93 | return 94 | } 95 | 96 | go func() { 97 | _ = c.ReadLoop() 98 | }() 99 | <-done 100 | } 101 | 102 | func (h *handler) updateReports() { 103 | c, err := greatws.Dial(fmt.Sprintf("%s/updateReports?agent=%s", host, agent), greatws.WithClientMultiEventLoop(h.m)) 104 | if err != nil { 105 | fmt.Println("Dial fail:", err) 106 | return 107 | } 108 | 109 | c.Close() 110 | } 111 | 112 | // 1.先通过接口获取case的总个数 113 | // 2.运行测试客户端client 114 | func main() { 115 | var h handler 116 | h.m = greatws.NewMultiEventLoopMust( 117 | greatws.WithEventLoops(runtime.NumCPU()/2), 118 | greatws.WithBusinessGoNum(50, 10, 10000), 119 | greatws.WithMaxEventNum(1000), 120 | greatws.WithLogLevel(slog.LevelError)) // epoll, kqueue 121 | 122 | h.m.Start() 123 | total := h.getCaseCount() 124 | var wg sync.WaitGroup 125 | // wg.Add(total) 126 | fmt.Println("total case:", total) 127 | for i := 1; i <= total; i++ { 128 | h.runTest(i, &wg) 129 | } 130 | // wg.Wait() 131 | h.updateReports() 132 | } 133 | -------------------------------------------------------------------------------- /status_codes.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "encoding/binary" 19 | "strconv" 20 | "strings" 21 | ) 22 | 23 | // https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 24 | // 这里记录了各种状态码的含义 25 | type StatusCode int16 26 | 27 | const ( 28 | // NormalClosure 正常关闭 29 | NormalClosure StatusCode = 1000 30 | // EndpointGoingAway 对端正在消失 31 | EndpointGoingAway StatusCode = 1001 32 | // ProtocolError 表示对端由于协议错误正在终止连接 33 | ProtocolError StatusCode = 1002 34 | // DataCannotAccept 收到一个不能接受的数据类型 35 | DataCannotAccept StatusCode = 1003 36 | // NotConsistentMessageType 表示对端正在终止连接, 消息类型不一致 37 | NotConsistentMessageType StatusCode = 1007 38 | // TerminatingConnection 表示对端正在终止连接, 没有好用的错误, 可以用这个错误码表示 39 | TerminatingConnection StatusCode = 1008 40 | // TooBigMessage 消息太大, 不能处理, 关闭连接 41 | TooBigMessage StatusCode = 1009 42 | // NoExtensions 只用于客户端, 服务端返回扩展消息 43 | NoExtensions StatusCode = 1010 44 | // ServerTerminating 服务端遇到意外情况, 中止请求 45 | ServerTerminating StatusCode = 1011 46 | ) 47 | 48 | func (s StatusCode) String() string { 49 | switch s { 50 | case NormalClosure: 51 | return "NormalClosure" 52 | case EndpointGoingAway: 53 | return "EndpointGoingAway" 54 | case ProtocolError: 55 | return "ProtocolError" 56 | case DataCannotAccept: 57 | return "DataCannotAccept" 58 | case NotConsistentMessageType: 59 | return "NotConsistentMessageType" 60 | case TerminatingConnection: 61 | return "TerminatingConnection" 62 | case TooBigMessage: 63 | return "TooBigMessage" 64 | case NoExtensions: 65 | return "NoExtensions" 66 | case ServerTerminating: 67 | return "ServerTerminating" 68 | } 69 | 70 | return "unknown" 71 | } 72 | 73 | func (s StatusCode) Error() string { 74 | return s.String() 75 | } 76 | 77 | func (s StatusCode) toBytes() (rv []byte) { 78 | rv = make([]byte, 2+len(s.String())) 79 | binary.BigEndian.PutUint16(rv, uint16(s)) 80 | copy(rv[2:], s.String()) 81 | return 82 | } 83 | 84 | type CloseErrMsg struct { 85 | Code StatusCode 86 | Msg string 87 | } 88 | 89 | func (c CloseErrMsg) Error() string { 90 | var out strings.Builder 91 | 92 | out.WriteString(" 0 { 101 | out.WriteString(c.Msg) 102 | } 103 | 104 | out.WriteString(">") 105 | return out.String() 106 | } 107 | 108 | func bytesToCloseErrMsg(payload []byte) *CloseErrMsg { 109 | var ce CloseErrMsg 110 | if len(payload) >= 2 { 111 | ce.Code = StatusCode(binary.BigEndian.Uint16(payload)) 112 | } 113 | 114 | if len(payload) >= 3 { 115 | ce.Msg = string(payload[3:]) 116 | } 117 | return &ce 118 | } 119 | 120 | func validCode(code uint16) bool { 121 | switch code { 122 | case 1004, 1005, 1006, 1015: 123 | return false 124 | } 125 | 126 | if code >= 1000 && code <= 1015 { 127 | return true 128 | } 129 | 130 | if code >= 3000 && code <= 4999 { 131 | return true 132 | } 133 | 134 | return false 135 | } 136 | -------------------------------------------------------------------------------- /upgrade.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "bytes" 19 | "errors" 20 | "net" 21 | "net/http" 22 | "syscall" 23 | "time" 24 | 25 | "github.com/antlabs/wsutil/bytespool" 26 | "github.com/antlabs/wsutil/deflate" 27 | ) 28 | 29 | type UpgradeServer struct { 30 | config Config 31 | } 32 | 33 | func NewUpgrade(opts ...ServerOption) *UpgradeServer { 34 | var conf ConnOption 35 | conf.defaultSetting() 36 | for _, o := range opts { 37 | o(&conf) 38 | } 39 | conf.defaultSettingAfter() 40 | return &UpgradeServer{config: conf.Config} 41 | } 42 | 43 | func (u *UpgradeServer) Upgrade(w http.ResponseWriter, r *http.Request) (c *Conn, err error) { 44 | return upgradeInner(w, r, &u.config, nil) 45 | } 46 | 47 | func (u *UpgradeServer) UpgradeLocalCallback(w http.ResponseWriter, r *http.Request, cb Callback) (c *Conn, err error) { 48 | return upgradeInner(w, r, &u.config, cb) 49 | } 50 | 51 | func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) { 52 | var conf ConnOption 53 | conf.defaultSetting() 54 | for _, o := range opts { 55 | o(&conf) 56 | } 57 | 58 | conf.defaultSettingAfter() 59 | return upgradeInner(w, r, &conf.Config, nil) 60 | } 61 | 62 | func getFdFromConn(c net.Conn) (newFd int, err error) { 63 | sc, ok := c.(interface { 64 | SyscallConn() (syscall.RawConn, error) 65 | }) 66 | if !ok { 67 | return 0, errors.New("RawConn Unsupported") 68 | } 69 | rc, err := sc.SyscallConn() 70 | if err != nil { 71 | return 0, errors.New("RawConn Unsupported") 72 | } 73 | 74 | err = rc.Control(func(fd uintptr) { 75 | newFd = int(fd) 76 | }) 77 | if err != nil { 78 | return 0, err 79 | } 80 | 81 | return duplicateSocket(int(newFd)) 82 | } 83 | 84 | func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config, cb Callback) (wsCon *Conn, err error) { 85 | if conf.multiEventLoop == nil { 86 | return nil, ErrEventLoopEmpty 87 | } 88 | 89 | if ecode, err := checkRequest(r); err != nil { 90 | http.Error(w, err.Error(), ecode) 91 | return nil, err 92 | } 93 | 94 | hi, ok := w.(http.Hijacker) 95 | if !ok { 96 | return nil, ErrNotFoundHijacker 97 | } 98 | 99 | // var read *bufio.Reader 100 | var conn net.Conn 101 | conn, _, err = hi.Hijack() 102 | if err != nil { 103 | return nil, err 104 | } 105 | if !conf.disableBufioClearHack { 106 | // bufio2.ClearReadWriter(rw) 107 | } 108 | 109 | // 是否打开解压缩 110 | // 外层接收压缩, 并且客户端发送扩展过来 111 | var pd deflate.PermessageDeflateConf 112 | if conf.Decompression { 113 | pd, err = deflate.GetConnPermessageDeflate(r.Header) 114 | if err != nil { 115 | return nil, err 116 | } 117 | } 118 | 119 | buf := bytespool.GetUpgradeRespBytes() 120 | 121 | tmpWriter := bytes.NewBuffer((*buf)[:0]) 122 | defer func() { 123 | bytespool.PutUpgradeRespBytes(buf) 124 | tmpWriter = nil 125 | }() 126 | resetPermessageDeflate(&pd, conf) 127 | if err = prepareWriteResponse(r, tmpWriter, conf, pd); err != nil { 128 | return 129 | } 130 | 131 | if _, err := conn.Write(tmpWriter.Bytes()); err != nil { 132 | return nil, err 133 | } 134 | 135 | if err = conn.SetDeadline(time.Time{}); err != nil { 136 | return nil, err 137 | } 138 | 139 | fd, err := getFdFromConn(conn) 140 | if err != nil { 141 | conn.Close() 142 | return nil, err 143 | } 144 | // 已经dup了一份fd,所以这里可以关闭 145 | if err = conn.Close(); err != nil { 146 | return nil, err 147 | } 148 | 149 | if wsCon, err = newConn(int64(fd), false, conf); err != nil { 150 | return nil, err 151 | } 152 | wsCon.pd = pd 153 | wsCon.Callback = cb 154 | if cb == nil { 155 | wsCon.Callback = conf.cb 156 | } 157 | wsCon.Callback.OnOpen(wsCon) 158 | if wsCon.Callback == nil { 159 | panic("callback is nil") 160 | } 161 | if err = conf.multiEventLoop.add(wsCon); err != nil { 162 | return nil, err 163 | } 164 | 165 | return wsCon, nil 166 | } 167 | 168 | func resetPermessageDeflate(pd *deflate.PermessageDeflateConf, conf *Config) { 169 | pd.Decompression = pd.Enable && conf.Decompression 170 | pd.Compression = pd.Enable && conf.Compression 171 | pd.ServerContextTakeover = pd.Enable && pd.ServerContextTakeover && conf.ServerContextTakeover 172 | pd.ClientContextTakeover = pd.Enable && pd.ClientContextTakeover && conf.ClientContextTakeover 173 | } 174 | -------------------------------------------------------------------------------- /server_handshake.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "errors" 19 | "fmt" 20 | "io" 21 | "net/http" 22 | "strings" 23 | 24 | "github.com/antlabs/wsutil/deflate" 25 | ) 26 | 27 | var ( 28 | ErrNotFoundHijacker = errors.New("not found Hijacker") 29 | bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ") 30 | bytesSecWebSocketExtensionsKey = []byte("Sec-WebSocket-Extensions: ") 31 | bytesCRLF = []byte("\r\n") 32 | bytesPutSecWebSocketProtocolKey = []byte("Sec-WebSocket-Protocol: ") 33 | strGetSecWebSocketProtocolKey = "Sec-WebSocket-Protocol" 34 | strWebSocketKey = "Sec-WebSocket-Key" 35 | ) 36 | 37 | func writeHeaderVal(w io.Writer, val []byte) (err error) { 38 | if _, err = w.Write(val); err != nil { 39 | return 40 | } 41 | 42 | if _, err = w.Write(bytesCRLF); err != nil { 43 | return 44 | } 45 | return 46 | } 47 | 48 | func subProtocol(subProtocol string, cnf *Config) string { 49 | if subProtocol == "" { 50 | return "" 51 | } 52 | 53 | subProtocols := strings.Split(subProtocol, ",") 54 | // 如果配置了subProtocols, 则检查客户端的subProtocols是否在配置的subProtocols中 55 | // 为什么要这么做,可以看下这个issue 56 | // https://github.com/antlabs/quickws/issues/12 57 | if len(cnf.subProtocols) > 0 { 58 | for _, clientSubProtocols := range subProtocols { 59 | clientSubProtocols = strings.TrimSpace(clientSubProtocols) 60 | for _, serverSubProtocols := range cnf.subProtocols { 61 | if clientSubProtocols == serverSubProtocols { 62 | return clientSubProtocols 63 | } 64 | } 65 | } 66 | } 67 | // echo Secf-WebSocket-Protocol 的值 68 | return subProtocol 69 | } 70 | 71 | // https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.2 72 | // 第5小点 73 | func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config, pd deflate.PermessageDeflateConf) (err error) { 74 | // 写入响应头 75 | // 写入Sec-WebSocket-Accept key 76 | if _, err = w.Write(bytesHeaderUpgrade); err != nil { 77 | return 78 | } 79 | 80 | v := secWebSocketAcceptVal(r.Header.Get(strWebSocketKey)) 81 | // 写入Sec-WebSocket-Accept vla 82 | if err = writeHeaderVal(w, StringToBytes(v)); err != nil { 83 | return err 84 | } 85 | 86 | // 给客户端回个信, 表示支持解压缩模式 87 | if pd.Decompression { 88 | if _, err = w.Write(bytesSecWebSocketExtensionsKey); err != nil { 89 | return err 90 | } 91 | if _, err = w.Write([]byte(deflate.GenSecWebSocketExtensions(pd))); err != nil { 92 | return err 93 | } 94 | if _, err = w.Write(bytesCRLF); err != nil { 95 | return err 96 | } 97 | } 98 | 99 | v = r.Header.Get(strGetSecWebSocketProtocolKey) 100 | v = subProtocol(v, cnf) 101 | if len(v) > 0 { 102 | if _, err = w.Write(bytesPutSecWebSocketProtocolKey); err != nil { 103 | return 104 | } 105 | 106 | if err = writeHeaderVal(w, StringToBytes(v)); err != nil { 107 | return err 108 | } 109 | } 110 | 111 | _, err = w.Write(bytesCRLF) 112 | return err 113 | } 114 | 115 | // https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.1 116 | // 按rfc标准, 先来一顿if else判断, 检查发的request是否满足标准 117 | func checkRequest(r *http.Request) (ecode int, err error) { 118 | // 不是get方法的 119 | if r.Method != http.MethodGet { 120 | // TODO错误消息 121 | return http.StatusMethodNotAllowed, fmt.Errorf("%w :%s", ErrOnlyGETSupported, r.Method) 122 | } 123 | // http版本低于1.1 124 | if !r.ProtoAtLeast(1, 1) { 125 | // TODO错误消息 126 | return http.StatusHTTPVersionNotSupported, ErrHTTPProtocolNotSupported 127 | } 128 | 129 | // 没有host字段的 130 | if r.Host == "" { 131 | return http.StatusBadRequest, ErrHostCannotBeEmpty 132 | } 133 | 134 | // Upgrade值不等于websocket的 135 | if upgrade := r.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") { 136 | return http.StatusBadRequest, ErrUpgradeFieldValue 137 | } 138 | 139 | // Connection值不是Upgrade 140 | if conn := r.Header.Get("Connection"); !strings.EqualFold(conn, "Upgrade") { 141 | return http.StatusBadRequest, ErrConnectionFieldValue 142 | } 143 | 144 | // Sec-WebSocket-Key解码之后是16字节长度 145 | // TODO后续优化 146 | if len(r.Header.Get("Sec-WebSocket-Key")) == 0 { 147 | return http.StatusBadRequest, ErrSecWebSocketKey 148 | } 149 | 150 | // Sec-WebSocket-Version的版本不是13的 151 | if r.Header.Get("Sec-WebSocket-Version") != "13" { 152 | return http.StatusUpgradeRequired, ErrSecWebSocketVersion 153 | } 154 | 155 | // TODO Sec-WebSocket-Extensions 156 | return 0, nil 157 | } 158 | -------------------------------------------------------------------------------- /multi_event_loops.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "context" 18 | "log/slog" 19 | "os" 20 | "runtime" 21 | "sync" 22 | "sync/atomic" 23 | "time" 24 | 25 | "github.com/antlabs/pulse/core" 26 | _ "github.com/antlabs/task/task" 27 | ) 28 | 29 | type taskConfig struct { 30 | initCount int // 初始化的协程数 31 | min int // 最小协程数 32 | max int // 最大协程数 33 | } 34 | 35 | type multiEventLoopOption struct { 36 | numLoops int //起多少个event loop 37 | 38 | // 为何不设计全局池, 现在的做法是 39 | // fd是绑定到某个事件循环上的, 40 | // 任务池是绑定到某个事件循环上的,所以这里的任务池也绑定到对应的localTask上 41 | // 如果设计全局任务池,那么概念就会很乱,容易出错,也会临界区竞争 42 | configTask taskConfig 43 | // taskMode taskMode 44 | level slog.Level //控制日志等级 45 | maxEventNum int //每次epoll/kqueue返回时,一次最多处理多少事件 46 | } 47 | 48 | // 默认MultiEventLoop 49 | var DefaultMultiEventLoop *MultiEventLoop 50 | 51 | var defaultOnce sync.Once 52 | 53 | func getDefaultMultiEventLoop() *MultiEventLoop { 54 | 55 | defaultOnce.Do(func() { 56 | DefaultMultiEventLoop = NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(256), WithLogLevel(slog.LevelError)) // epoll, kqueue 57 | }) 58 | return DefaultMultiEventLoop 59 | } 60 | 61 | type MultiEventLoop struct { 62 | multiEventLoopOption //配置选项 63 | 64 | safeConns core.SafeConns[Conn] 65 | 66 | loops []*EventLoop 67 | parseLoop *taskParse 68 | 69 | flag evFlag // 是否使用io_uring,目前没有使用 70 | 71 | stat // 统计信息 72 | *slog.Logger 73 | 74 | evLoopStart uint32 75 | 76 | ctx context.Context 77 | 78 | once sync.Once 79 | } 80 | 81 | var ( 82 | defMaxEventNum = 256 83 | defTaskMin = 50 84 | defTaskMax = 30000 85 | defTaskInitCount = 8 86 | defNumLoops = runtime.NumCPU() 87 | ) 88 | 89 | // 这个函数会被调用两次 90 | // 默认 1个event loop分发io事件, 多个parse loop解析websocket包 91 | func (m *MultiEventLoop) initDefaultSetting() { 92 | 93 | if m.level == 0 { 94 | m.level = slog.LevelError // 95 | } 96 | if m.numLoops == 0 { 97 | m.numLoops = max(defNumLoops, 1) 98 | } 99 | 100 | if m.maxEventNum == 0 { 101 | m.maxEventNum = defMaxEventNum 102 | } 103 | 104 | if m.configTask.min == 0 { 105 | m.configTask.min = defTaskMin 106 | } else { 107 | m.configTask.min = max(m.configTask.min/(m.numLoops), 1) 108 | } 109 | 110 | if m.configTask.max == 0 { 111 | m.configTask.max = defTaskMax 112 | } else { 113 | m.configTask.max = max(m.configTask.max/(m.numLoops), 1) 114 | } 115 | 116 | if m.configTask.initCount == 0 { 117 | m.configTask.initCount = defTaskInitCount 118 | } else { 119 | m.configTask.initCount = max(m.configTask.initCount/(m.numLoops), 1) 120 | } 121 | 122 | if m.flag == 0 { 123 | m.flag = EVENT_EPOLL 124 | } 125 | } 126 | 127 | func NewMultiEventLoopMust(opts ...EvOption) *MultiEventLoop { 128 | m, err := NewMultiEventLoop(opts...) 129 | if err != nil { 130 | panic(err) 131 | } 132 | 133 | return m 134 | } 135 | 136 | // 创建一个多路事件循环 137 | func NewMultiEventLoop(opts ...EvOption) (e *MultiEventLoop, err error) { 138 | m := &MultiEventLoop{} 139 | m.safeConns.Init(core.GetMaxFd()) 140 | m.initDefaultSetting() 141 | for _, o := range opts { 142 | o(m) 143 | } 144 | m.initDefaultSetting() 145 | m.Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: m.level})) 146 | 147 | m.ctx = context.Background() 148 | m.loops = make([]*EventLoop, m.numLoops) 149 | 150 | for i := 0; i < m.numLoops; i++ { 151 | m.loops[i], err = CreateEventLoop(m.maxEventNum, m.flag, m) 152 | if err != nil { 153 | return nil, err 154 | } 155 | } 156 | return m, nil 157 | } 158 | 159 | // 初始化一个多路事件循环,并且运行它 160 | func NewMultiEventLoopAndStartMust(opts ...EvOption) (m *MultiEventLoop) { 161 | m = NewMultiEventLoopMust(opts...) 162 | m.Start() 163 | return m 164 | } 165 | 166 | // 启动多路事件循环 167 | func (m *MultiEventLoop) Start() { 168 | 169 | m.once.Do(func() { 170 | for _, loop := range m.loops { 171 | go loop.Loop() 172 | } 173 | time.Sleep(time.Millisecond * 10) 174 | atomic.StoreUint32(&m.evLoopStart, 1) 175 | }) 176 | } 177 | 178 | func (m *MultiEventLoop) Free() { 179 | for _, m := range m.loops { 180 | m.Free() 181 | } 182 | } 183 | func (m *MultiEventLoop) isStart() bool { 184 | return atomic.LoadUint32(&m.evLoopStart) == 1 185 | } 186 | 187 | func (m *MultiEventLoop) getEventLoop(fd int) *EventLoop { 188 | return m.loops[fd%len(m.loops)] 189 | } 190 | 191 | // 添加一个连接到多路事件循环 192 | func (m *MultiEventLoop) add(c *Conn) error { 193 | fd := c.getFd() 194 | if fd == -1 { 195 | return nil 196 | } 197 | index := fd % len(m.loops) 198 | m.safeConns.Add(fd, c) 199 | // m.loops[index].conns.Store(fd, c) 200 | if err := m.loops[index].AddRead(c.getFd()); err != nil { 201 | m.loops[index].del(c) 202 | return err 203 | } 204 | atomic.AddInt64(&m.curConn, 1) 205 | return nil 206 | } 207 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "fmt" 18 | "net/http" 19 | "net/http/httptest" 20 | "strings" 21 | "sync/atomic" 22 | "testing" 23 | ) 24 | 25 | // 测试客户端Dial, 返回的http.Header 26 | func Test_Client_Dial_Check_Header(t *testing.T) { 27 | 28 | t.Run("Dial: valid resp: status code fail", func(t *testing.T) { 29 | done := make(chan bool, 1) 30 | run := int32(0) 31 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 32 | atomic.AddInt32(&run, int32(1)) 33 | done <- true 34 | })) 35 | 36 | defer ts.Close() 37 | 38 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 39 | _, err := Dial(rawURL) 40 | if err == nil { 41 | t.Fatal("should be error") 42 | } 43 | }) 44 | 45 | t.Run("DialConf: valid resp : status code fail", func(t *testing.T) { 46 | done := make(chan bool, 1) 47 | run := int32(0) 48 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 49 | atomic.AddInt32(&run, int32(1)) 50 | done <- true 51 | })) 52 | 53 | defer ts.Close() 54 | 55 | cnf := ClientOptionToConf() 56 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 57 | _, err := DialConf(rawURL, cnf) 58 | if err == nil { 59 | t.Fatal("should be error") 60 | } 61 | }) 62 | 63 | t.Run("Dial: valid resp: Upgrade field fail", func(t *testing.T) { 64 | done := make(chan bool, 1) 65 | run := int32(0) 66 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 67 | atomic.AddInt32(&run, int32(1)) 68 | w.WriteHeader(101) 69 | w.Header().Set("Upgrade", "xx") 70 | done <- true 71 | })) 72 | 73 | defer ts.Close() 74 | 75 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 76 | _, err := Dial(rawURL) 77 | if err == nil { 78 | t.Fatal("should be error") 79 | } 80 | }) 81 | 82 | t.Run("DialConf: valid resp: Upgrade field fail", func(t *testing.T) { 83 | done := make(chan bool, 1) 84 | run := int32(0) 85 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 86 | atomic.AddInt32(&run, int32(1)) 87 | w.WriteHeader(101) 88 | w.Header().Set("Upgrade", "xx") 89 | done <- true 90 | })) 91 | 92 | defer ts.Close() 93 | 94 | cnf := ClientOptionToConf() 95 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 96 | _, err := DialConf(rawURL, cnf) 97 | if err == nil { 98 | t.Fatal("should be error") 99 | } 100 | }) 101 | 102 | t.Run("Dial: valid resp: Connection fail", func(t *testing.T) { 103 | done := make(chan bool, 1) 104 | run := int32(0) 105 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 106 | atomic.AddInt32(&run, int32(1)) 107 | w.Header().Set("Upgrade", "websocket") 108 | w.Header().Set("Connection", "xx") 109 | w.WriteHeader(101) 110 | done <- true 111 | })) 112 | 113 | defer ts.Close() 114 | 115 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 116 | _, err := Dial(rawURL) 117 | if err == nil { 118 | t.Fatal("should be error") 119 | } 120 | }) 121 | 122 | t.Run("DialConf: valid resp: Connection fail", func(t *testing.T) { 123 | done := make(chan bool, 1) 124 | run := int32(0) 125 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 126 | atomic.AddInt32(&run, int32(1)) 127 | w.Header().Set("Upgrade", "websocket") 128 | w.Header().Set("Connection", "xx") 129 | w.WriteHeader(101) 130 | done <- true 131 | })) 132 | 133 | defer ts.Close() 134 | 135 | cnf := ClientOptionToConf() 136 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 137 | _, err := DialConf(rawURL, cnf) 138 | if err == nil { 139 | t.Fatal("should be error") 140 | } else { 141 | fmt.Printf("err: %v\n", err) 142 | } 143 | }) 144 | 145 | t.Run("Dial: valid resp: Sec-WebSocket-Accept fail", func(t *testing.T) { 146 | done := make(chan bool, 1) 147 | run := int32(0) 148 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 149 | atomic.AddInt32(&run, int32(1)) 150 | w.Header().Set("Upgrade", "websocket") 151 | w.Header().Set("Connection", "Upgrade") 152 | w.WriteHeader(101) 153 | done <- true 154 | })) 155 | 156 | defer ts.Close() 157 | 158 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 159 | _, err := Dial(rawURL) 160 | if err == nil { 161 | t.Fatal("should be error") 162 | } 163 | }) 164 | 165 | t.Run("DialConf: valid resp: Sec-WebSocket-Accept fail", func(t *testing.T) { 166 | done := make(chan bool, 1) 167 | run := int32(0) 168 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 169 | atomic.AddInt32(&run, int32(1)) 170 | w.Header().Set("Upgrade", "websocket") 171 | w.Header().Set("Connection", "Upgrade") 172 | w.WriteHeader(101) 173 | done <- true 174 | })) 175 | 176 | defer ts.Close() 177 | 178 | cnf := ClientOptionToConf() 179 | rawURL := strings.ReplaceAll(ts.URL, "http", "ws") 180 | _, err := DialConf(rawURL, cnf) 181 | if err == nil { 182 | t.Fatal("should be error") 183 | } else { 184 | fmt.Printf("err: %v\n", err) 185 | } 186 | }) 187 | } 188 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "bufio" 19 | "crypto/tls" 20 | "fmt" 21 | "net" 22 | "net/http" 23 | "net/url" 24 | "strings" 25 | "time" 26 | 27 | "github.com/antlabs/wsutil/bytespool" 28 | "github.com/antlabs/wsutil/deflate" 29 | "github.com/antlabs/wsutil/enum" 30 | "github.com/antlabs/wsutil/hostname" 31 | ) 32 | 33 | var ( 34 | defaultTimeout = time.Minute * 30 35 | ) 36 | 37 | type DialOption struct { 38 | Header http.Header 39 | u *url.URL 40 | tlsConfig *tls.Config 41 | dialTimeout time.Duration 42 | bindClientHttpHeader *http.Header // 握手成功之后, 客户端获取http.Header, 43 | Config 44 | } 45 | 46 | func ClientOptionToConf(opts ...ClientOption) *DialOption { 47 | var dial DialOption 48 | dial.defaultSetting() 49 | for _, o := range opts { 50 | o(&dial) 51 | } 52 | dial.defaultSettingAfter() 53 | return &dial 54 | } 55 | 56 | func DialConf(rawUrl string, conf *DialOption) (*Conn, error) { 57 | u, err := url.Parse(rawUrl) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | conf.u = u 63 | conf.dialTimeout = defaultTimeout 64 | if conf.Header == nil { 65 | conf.Header = make(http.Header) 66 | } 67 | 68 | return conf.Dial() 69 | } 70 | 71 | // https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 72 | // 又是一顿if else, 咬文嚼字 73 | func Dial(rawUrl string, opts ...ClientOption) (*Conn, error) { 74 | var dial DialOption 75 | u, err := url.Parse(rawUrl) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | dial.u = u 81 | dial.dialTimeout = defaultTimeout 82 | if dial.Header == nil { 83 | dial.Header = make(http.Header) 84 | } 85 | 86 | dial.defaultSetting() 87 | for _, o := range opts { 88 | o(&dial) 89 | } 90 | 91 | dial.defaultSettingAfter() 92 | return dial.Dial() 93 | } 94 | 95 | // 准备握手的数据 96 | func (d *DialOption) handshake() (*http.Request, string, error) { 97 | switch { 98 | case d.u.Scheme == "wss": 99 | d.u.Scheme = "https" 100 | case d.u.Scheme == "ws": 101 | d.u.Scheme = "http" 102 | default: 103 | return nil, "", fmt.Errorf("Unknown scheme, only supports ws:// or wss://: got %s", d.u.Scheme) 104 | } 105 | 106 | // 满足4.1 107 | // 第2点 GET约束http 1.1版本约束 108 | req, err := http.NewRequest("GET", d.u.String(), nil) 109 | if err != nil { 110 | return nil, "", err 111 | } 112 | // 第5点 113 | d.Header.Add("Upgrade", "websocket") 114 | // 第6点 115 | d.Header.Add("Connection", "Upgrade") 116 | // 第7点 117 | secWebSocket := secWebSocketAccept() 118 | d.Header.Add("Sec-WebSocket-Key", secWebSocket) 119 | // TODO 第8点 120 | // 第9点 121 | d.Header.Add("Sec-WebSocket-Version", "13") 122 | 123 | if d.Decompression && d.Compression { 124 | d.Header.Add("Sec-WebSocket-Extensions", deflate.GenSecWebSocketExtensions(d.PermessageDeflateConf)) 125 | } 126 | 127 | req.Header = d.Header 128 | return req, secWebSocket, nil 129 | } 130 | 131 | // 检查服务端响应的数据 132 | // 4.2.2.5 133 | func (d *DialOption) validateRsp(rsp *http.Response, secWebSocket string) error { 134 | if rsp.StatusCode != 101 { 135 | return fmt.Errorf("%w %d", ErrWrongStatusCode, rsp.StatusCode) 136 | } 137 | 138 | // 第2点 139 | if !strings.EqualFold(rsp.Header.Get("Upgrade"), "websocket") { 140 | return ErrUpgradeFieldValue 141 | } 142 | 143 | // 第3点 144 | if !strings.EqualFold(rsp.Header.Get("Connection"), "Upgrade") { 145 | return ErrConnectionFieldValue 146 | } 147 | 148 | // 第4点 149 | if !strings.EqualFold(rsp.Header.Get("Sec-WebSocket-Accept"), secWebSocketAcceptVal(secWebSocket)) { 150 | return ErrSecWebSocketAccept 151 | } 152 | 153 | // TODO 5点 154 | 155 | // TODO 6点 156 | return nil 157 | } 158 | 159 | // wss已经修改为https 160 | func (d *DialOption) tlsConn(c net.Conn) net.Conn { 161 | if d.u.Scheme == "https" { 162 | cfg := d.tlsConfig 163 | if cfg == nil { 164 | cfg = &tls.Config{} 165 | } else { 166 | cfg = cfg.Clone() 167 | } 168 | 169 | if cfg.ServerName == "" { 170 | host := d.u.Host 171 | if pos := strings.Index(host, ":"); pos != -1 { 172 | host = host[:pos] 173 | } 174 | cfg.ServerName = host 175 | } 176 | return tls.Client(c, cfg) 177 | } 178 | 179 | return c 180 | } 181 | 182 | func (d *DialOption) Dial() (wsCon *Conn, err error) { 183 | if d.Config.multiEventLoop == nil { 184 | return nil, ErrEventLoopEmpty 185 | } 186 | 187 | if !d.Config.multiEventLoop.isStart() { 188 | return nil, ErrEventLoopNotStart 189 | } 190 | req, secWebSocket, err := d.handshake() 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | hostName := hostname.GetHostName(d.u) 196 | var conn net.Conn 197 | conn, err = net.DialTimeout("tcp", hostName, d.dialTimeout) 198 | if err != nil { 199 | return nil, fmt.Errorf("net.Dial:%w", err) 200 | } 201 | 202 | err = conn.SetDeadline(time.Time{}) 203 | conn = d.tlsConn(conn) 204 | defer func() { 205 | if err != nil && conn != nil { 206 | conn.Close() 207 | conn = nil 208 | } 209 | }() 210 | 211 | if err = req.Write(conn); err != nil { 212 | return nil, fmt.Errorf("write req fail:%w", err) 213 | } 214 | 215 | br := bufio.NewReader(bufio.NewReader(conn)) 216 | rsp, err := http.ReadResponse(br, req) 217 | if err != nil { 218 | return nil, err 219 | } 220 | 221 | if d.bindClientHttpHeader != nil { 222 | *d.bindClientHttpHeader = rsp.Header.Clone() 223 | } 224 | 225 | pd, err := deflate.GetConnPermessageDeflate(rsp.Header) 226 | if err != nil { 227 | return nil, err 228 | } 229 | if d.Decompression { 230 | pd.Decompression = pd.Enable && d.Decompression 231 | } 232 | if d.Compression { 233 | pd.Compression = pd.Enable && d.Compression 234 | } 235 | 236 | if err = d.validateRsp(rsp, secWebSocket); err != nil { 237 | return 238 | } 239 | 240 | fd, err := getFdFromConn(conn) 241 | if err != nil { 242 | conn.Close() 243 | return nil, err 244 | } 245 | // 已经dup了一份fd,所以这里可以关闭 246 | if err = conn.Close(); err != nil { 247 | return nil, err 248 | } 249 | if wsCon, err = newConn(int64(fd), true, &d.Config); err != nil { 250 | return nil, err 251 | } 252 | wsCon.pd = pd 253 | wsCon.Callback = d.cb 254 | wsCon.OnOpen(wsCon) 255 | if br.Buffered() > 0 { 256 | b, err := br.Peek(br.Buffered()) 257 | if err != nil { 258 | return nil, err 259 | } 260 | 261 | wsCon.rbuf = bytespool.GetBytes(len(b) + enum.MaxFrameHeaderSize) 262 | 263 | copy(*wsCon.rbuf, b) 264 | wsCon.rw = len(b) 265 | if err = wsCon.processHeaderPayloadCallback(); err != nil { 266 | return nil, err 267 | } 268 | } 269 | if err = d.Config.multiEventLoop.add(wsCon); err != nil { 270 | return nil, err 271 | } 272 | return wsCon, nil 273 | } 274 | -------------------------------------------------------------------------------- /autobahn/server/autobahn-server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/tls" 5 | _ "embed" 6 | "flag" 7 | "fmt" 8 | "log" 9 | "log/slog" 10 | "net" 11 | "net/http" 12 | _ "net/http/pprof" 13 | "runtime" 14 | "time" 15 | 16 | "github.com/antlabs/greatws" 17 | ) 18 | 19 | var runInEventLoop = flag.Bool("run-in-event-loop", false, "run in event loop") 20 | 21 | //go:embed public.crt 22 | var certPEMBlock []byte 23 | 24 | //go:embed privatekey.pem 25 | var keyPEMBlock []byte 26 | 27 | type echoHandler struct{} 28 | 29 | func (e *echoHandler) OnOpen(c *greatws.Conn) { 30 | // err := c.WriteMessage(greatws.Binary, make([]byte, 1<<28)) 31 | // if err != nil { 32 | // fmt.Printf("%s\n", err) 33 | // } 34 | // fmt.Printf("OnOpen: %p\n", c) 35 | } 36 | 37 | func (e *echoHandler) OnMessage(c *greatws.Conn, op greatws.Opcode, msg []byte) { 38 | // fmt.Printf("OnMessage: %s, len(%d), op:%d\n", msg, len(msg), op) 39 | // if err := c.WriteTimeout(op, msg, 3*time.Second); err != nil { 40 | // fmt.Println("write fail:", err) 41 | // } 42 | if err := c.WriteMessage(op, msg); err != nil { 43 | slog.Error("write fail:", "err", err.Error()) 44 | } 45 | } 46 | 47 | func (e *echoHandler) OnClose(c *greatws.Conn, err error) { 48 | defer c.Close() 49 | errMsg := "" 50 | if err != nil { 51 | errMsg = err.Error() 52 | } 53 | slog.Error("OnClose:", "err", errMsg) 54 | } 55 | 56 | type handler struct { 57 | m *greatws.MultiEventLoop 58 | parseLoop *greatws.MultiEventLoop 59 | } 60 | 61 | // 运行在业务线程 62 | 63 | // 运行在io线程 64 | func (h *handler) echoRunInIo(w http.ResponseWriter, r *http.Request) { 65 | opts := []greatws.ServerOption{ 66 | greatws.WithServerReplyPing(), 67 | greatws.WithServerDecompression(), 68 | greatws.WithServerIgnorePong(), 69 | greatws.WithServerCallback(&echoHandler{}), 70 | greatws.WithServerEnableUTF8Check(), 71 | greatws.WithServerReadTimeout(5 * time.Second), 72 | greatws.WithServerMultiEventLoop(h.m), 73 | greatws.WithServerCallbackInEventLoop(), 74 | } 75 | 76 | if *runInEventLoop { 77 | opts = append(opts, greatws.WithServerCallbackInEventLoop()) 78 | } 79 | 80 | c, err := greatws.Upgrade(w, r, opts...) 81 | if err != nil { 82 | slog.Error("Upgrade fail:", "err", err.Error()) 83 | } 84 | _ = c 85 | } 86 | 87 | func (h *handler) echoRunOneByOne(w http.ResponseWriter, r *http.Request) { 88 | opts := []greatws.ServerOption{ 89 | greatws.WithServerReplyPing(), 90 | greatws.WithServerDecompression(), 91 | greatws.WithServerIgnorePong(), 92 | greatws.WithServerCallback(&echoHandler{}), 93 | greatws.WithServerEnableUTF8Check(), 94 | // greatws.WithServerReadTimeout(5 * time.Second), 95 | greatws.WithServerMultiEventLoop(h.m), 96 | greatws.WithServerOneByOneMode(), 97 | } 98 | 99 | if *runInEventLoop { 100 | opts = append(opts, greatws.WithServerCallbackInEventLoop()) 101 | } 102 | 103 | c, err := greatws.Upgrade(w, r, opts...) 104 | if err != nil { 105 | slog.Error("Upgrade fail:", "err", err.Error()) 106 | } 107 | _ = c 108 | } 109 | 110 | func (h *handler) echoRunElastic(w http.ResponseWriter, r *http.Request) { 111 | opts := []greatws.ServerOption{ 112 | greatws.WithServerReplyPing(), 113 | greatws.WithServerDecompression(), 114 | greatws.WithServerIgnorePong(), 115 | greatws.WithServerCallback(&echoHandler{}), 116 | greatws.WithServerEnableUTF8Check(), 117 | greatws.WithServerReadTimeout(5 * time.Second), 118 | greatws.WithServerMultiEventLoop(h.m), 119 | greatws.WithServerElasticMode(), 120 | } 121 | 122 | if *runInEventLoop { 123 | opts = append(opts, greatws.WithServerCallbackInEventLoop()) 124 | } 125 | 126 | c, err := greatws.Upgrade(w, r, opts...) 127 | if err != nil { 128 | slog.Error("Upgrade fail:", "err", err.Error()) 129 | } 130 | _ = c 131 | } 132 | 133 | // 1.测试不接管上下文,只解压 134 | func (h *handler) echoNoContextDecompression(w http.ResponseWriter, r *http.Request) { 135 | c, err := greatws.Upgrade(w, r, 136 | greatws.WithServerReplyPing(), 137 | greatws.WithServerDecompression(), 138 | greatws.WithServerIgnorePong(), 139 | greatws.WithServerCallback(&echoHandler{}), 140 | greatws.WithServerEnableUTF8Check(), 141 | greatws.WithServerMultiEventLoop(h.m), 142 | ) 143 | if err != nil { 144 | fmt.Println("Upgrade fail:", err) 145 | return 146 | } 147 | 148 | _ = c.ReadLoop() 149 | } 150 | 151 | // 2.测试不接管上下文,压缩和解压 152 | func (h *handler) echoNoContextDecompressionAndCompression(w http.ResponseWriter, r *http.Request) { 153 | c, err := greatws.Upgrade(w, r, 154 | greatws.WithServerReplyPing(), 155 | greatws.WithServerDecompressAndCompress(), 156 | greatws.WithServerIgnorePong(), 157 | greatws.WithServerCallback(&echoHandler{}), 158 | greatws.WithServerEnableUTF8Check(), 159 | greatws.WithServerMultiEventLoop(h.m), 160 | ) 161 | if err != nil { 162 | fmt.Println("Upgrade fail:", err) 163 | return 164 | } 165 | 166 | _ = c.ReadLoop() 167 | } 168 | 169 | // 3.测试接管上下文,解压 170 | func (h *handler) echoContextTakeoverDecompression(w http.ResponseWriter, r *http.Request) { 171 | c, err := greatws.Upgrade(w, r, 172 | greatws.WithServerReplyPing(), 173 | greatws.WithServerDecompression(), 174 | greatws.WithServerIgnorePong(), 175 | greatws.WithServerContextTakeover(), 176 | greatws.WithServerCallback(&echoHandler{}), 177 | greatws.WithServerEnableUTF8Check(), 178 | greatws.WithServerMultiEventLoop(h.m), 179 | ) 180 | if err != nil { 181 | fmt.Println("Upgrade fail:", err) 182 | return 183 | } 184 | 185 | _ = c.ReadLoop() 186 | } 187 | 188 | // 4.测试接管上下文,压缩/解压缩 189 | func (h *handler) echoContextTakeoverDecompressionAndCompression(w http.ResponseWriter, r *http.Request) { 190 | c, err := greatws.Upgrade(w, r, 191 | greatws.WithServerReplyPing(), 192 | greatws.WithServerDecompressAndCompress(), 193 | greatws.WithServerIgnorePong(), 194 | greatws.WithServerContextTakeover(), 195 | greatws.WithServerCallback(&echoHandler{}), 196 | greatws.WithServerEnableUTF8Check(), 197 | greatws.WithServerMultiEventLoop(h.m), 198 | ) 199 | if err != nil { 200 | fmt.Println("Upgrade fail:", err) 201 | return 202 | } 203 | 204 | _ = c.ReadLoop() 205 | } 206 | 207 | func main() { 208 | flag.Parse() 209 | 210 | var h handler 211 | runtime.SetBlockProfileRate(1) 212 | 213 | go func() { 214 | log.Println(http.ListenAndServe(":6060", nil)) 215 | }() 216 | 217 | // debug io-uring 218 | // h.m = greatws.NewMultiEventLoopMust(greatws.WithEventLoops(0), greatws.WithMaxEventNum(1000), greatws.WithIoUring(), greatws.WithLogLevel(slog.LevelDebug)) 219 | h.m = greatws.NewMultiEventLoopMust( 220 | greatws.WithEventLoops(runtime.NumCPU()/2), 221 | greatws.WithBusinessGoNum(50, 10, 10000), 222 | greatws.WithMaxEventNum(256), 223 | greatws.WithLogLevel(slog.LevelError)) // epoll, kqueue 224 | h.m.Start() 225 | 226 | parseLoopOpt := []greatws.EvOption{ 227 | greatws.WithBusinessGoNum(50, 10, 10000), 228 | greatws.WithMaxEventNum(1000), 229 | greatws.WithLogLevel(slog.LevelError), 230 | } 231 | 232 | h.parseLoop = greatws.NewMultiEventLoopMust(parseLoopOpt...) // epoll, kqueue 233 | h.parseLoop.Start() 234 | 235 | fmt.Printf("apiname:%s\n", h.m.GetApiName()) 236 | 237 | go func() { 238 | for { 239 | time.Sleep(time.Second) 240 | fmt.Printf("curConn:%d, curTask:%d\n", h.m.GetCurConnNum(), h.m.GetCurTaskNum()) 241 | } 242 | }() 243 | mux := &http.ServeMux{} 244 | mux.HandleFunc("/autobahn-io", h.echoRunInIo) 245 | mux.HandleFunc("/autobahn-onebyone", h.echoRunOneByOne) 246 | mux.HandleFunc("/autobahn-elastic", h.echoRunElastic) 247 | mux.HandleFunc("/no-context-takeover-decompression", h.echoNoContextDecompression) 248 | mux.HandleFunc("/no-context-takeover-decompression-and-compression", h.echoNoContextDecompressionAndCompression) 249 | mux.HandleFunc("/context-takeover-decompression", h.echoContextTakeoverDecompression) 250 | mux.HandleFunc("/context-takeover-decompression-and-compression", h.echoContextTakeoverDecompressionAndCompression) 251 | 252 | rawTCP, err := net.Listen("tcp", ":9004") 253 | if err != nil { 254 | fmt.Println("Listen fail:", err) 255 | return 256 | } 257 | 258 | go func() { 259 | log.Println("non-tls server exit:", http.Serve(rawTCP, mux)) 260 | }() 261 | 262 | cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) 263 | if err != nil { 264 | log.Fatalf("tls.X509KeyPair failed: %v", err) 265 | } 266 | tlsConfig := &tls.Config{ 267 | Certificates: []tls.Certificate{cert}, 268 | InsecureSkipVerify: true, 269 | } 270 | 271 | lnTLS, err := tls.Listen("tcp", "localhost:9005", tlsConfig) 272 | if err != nil { 273 | panic(err) 274 | } 275 | 276 | log.Println("tls server exit:", http.Serve(lnTLS, mux)) 277 | } 278 | -------------------------------------------------------------------------------- /common_options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "time" 18 | "unicode/utf8" 19 | ) 20 | 21 | // 0. CallbackFunc 22 | func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ClientOption { 23 | return func(o *DialOption) { 24 | o.cb = &funcToCallback{ 25 | onOpen: open, 26 | onMessage: m, 27 | onClose: c, 28 | } 29 | } 30 | } 31 | 32 | // 配置服务端回调函数 33 | func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ServerOption { 34 | return func(o *ConnOption) { 35 | o.cb = &funcToCallback{ 36 | onOpen: open, 37 | onMessage: m, 38 | onClose: c, 39 | } 40 | } 41 | } 42 | 43 | // 1. callback 44 | // 配置客户端callback 45 | func WithClientCallback(cb Callback) ClientOption { 46 | return func(o *DialOption) { 47 | o.cb = cb 48 | } 49 | } 50 | 51 | // 配置服务端回调函数 52 | func WithServerCallback(cb Callback) ServerOption { 53 | return func(o *ConnOption) { 54 | o.cb = cb 55 | } 56 | } 57 | 58 | // 2. 设置TCP_NODELAY 59 | // 设置客户端TCP_NODELAY 60 | func WithClientTCPDelay() ClientOption { 61 | return func(o *DialOption) { 62 | o.tcpNoDelay = false 63 | } 64 | } 65 | 66 | // 设置TCP_NODELAY 为false, 开启nagle算法 67 | // 设置服务端TCP_NODELAY 68 | func WithServerTCPDelay() ServerOption { 69 | return func(o *ConnOption) { 70 | o.tcpNoDelay = false 71 | } 72 | } 73 | 74 | // 3.关闭utf8检查 75 | func WithServerEnableUTF8Check() ServerOption { 76 | return func(o *ConnOption) { 77 | o.utf8Check = utf8.Valid 78 | } 79 | } 80 | 81 | func WithClientEnableUTF8Check() ClientOption { 82 | return func(o *DialOption) { 83 | o.utf8Check = utf8.Valid 84 | } 85 | } 86 | 87 | // 4.仅仅配置OnMessae函数 88 | // 仅仅配置OnMessae函数 89 | func WithServerOnMessageFunc(cb OnMessageFunc) ServerOption { 90 | return func(o *ConnOption) { 91 | o.cb = OnMessageFunc(cb) 92 | } 93 | } 94 | 95 | // 仅仅配置OnMessae函数 96 | func WithClientOnMessageFunc(cb OnMessageFunc) ClientOption { 97 | return func(o *DialOption) { 98 | o.cb = OnMessageFunc(cb) 99 | } 100 | } 101 | 102 | // 5. 103 | // 配置自动回应ping frame, 当收到ping, 回一个pong 104 | func WithServerReplyPing() ServerOption { 105 | return func(o *ConnOption) { 106 | o.replyPing = true 107 | } 108 | } 109 | 110 | // 配置自动回应ping frame, 当收到ping, 回一个pong 111 | func WithClientReplyPing() ClientOption { 112 | return func(o *DialOption) { 113 | o.replyPing = true 114 | } 115 | } 116 | 117 | // 6 配置忽略pong消息 118 | func WithClientIgnorePong() ClientOption { 119 | return func(o *DialOption) { 120 | o.ignorePong = true 121 | } 122 | } 123 | 124 | func WithServerIgnorePong() ServerOption { 125 | return func(o *ConnOption) { 126 | o.ignorePong = true 127 | } 128 | } 129 | 130 | // 7. 131 | // 设置几倍payload的缓冲区 132 | // 只有解析方式是窗口的时候才有效 133 | // 如果为1.0就是1024 + 14, 如果是2.0就是2048 + 14 134 | func WithServerWindowsMultipleTimesPayloadSize(mt float32) ServerOption { 135 | return func(o *ConnOption) { 136 | if mt < 1.0 { 137 | mt = 1.0 138 | } 139 | o.windowsMultipleTimesPayloadSize = mt 140 | } 141 | } 142 | 143 | func WithClientWindowsMultipleTimesPayloadSize(mt float32) ClientOption { 144 | return func(o *DialOption) { 145 | if mt < 1.0 { 146 | mt = 1.0 147 | } 148 | o.windowsMultipleTimesPayloadSize = mt 149 | } 150 | } 151 | 152 | // 10 配置解压缩 153 | func WithClientDecompression() ClientOption { 154 | return func(o *DialOption) { 155 | o.Decompression = true 156 | } 157 | } 158 | 159 | func WithServerDecompression() ServerOption { 160 | return func(o *ConnOption) { 161 | o.Decompression = true 162 | } 163 | } 164 | 165 | // 11 关闭bufio clear hack优化 166 | func WithServerDisableBufioClearHack() ServerOption { 167 | return func(o *ConnOption) { 168 | o.disableBufioClearHack = true 169 | } 170 | } 171 | 172 | func WithClientDisableBufioClearHack() ClientOption { 173 | return func(o *DialOption) { 174 | o.disableBufioClearHack = true 175 | } 176 | } 177 | 178 | // 13. 配置延迟发送 179 | // 配置延迟最大发送时间 180 | func WithServerMaxDelayWriteDuration(d time.Duration) ServerOption { 181 | return func(o *ConnOption) { 182 | o.maxDelayWriteDuration = d 183 | } 184 | } 185 | 186 | // 13. 配置延迟发送 187 | // 配置延迟最大发送时间 188 | func WithClientMaxDelayWriteDuration(d time.Duration) ClientOption { 189 | return func(o *DialOption) { 190 | o.maxDelayWriteDuration = d 191 | } 192 | } 193 | 194 | // 14.1 配置最大延迟个数.server 195 | func WithServerMaxDelayWriteNum(n int32) ServerOption { 196 | return func(o *ConnOption) { 197 | o.maxDelayWriteNum = n 198 | } 199 | } 200 | 201 | // 14.2 配置最大延迟个数.client 202 | func WithClientMaxDelayWriteNum(n int32) ClientOption { 203 | return func(o *DialOption) { 204 | o.maxDelayWriteNum = n 205 | } 206 | } 207 | 208 | // 15.1 配置延迟包的初始化buffer大小 209 | func WithServerDelayWriteInitBufferSize(n int32) ServerOption { 210 | return func(o *ConnOption) { 211 | o.delayWriteInitBufferSize = n 212 | } 213 | } 214 | 215 | // 15.2 配置延迟包的初始化buffer大小 216 | func WithClientDelayWriteInitBufferSize(n int32) ClientOption { 217 | return func(o *DialOption) { 218 | o.delayWriteInitBufferSize = n 219 | } 220 | } 221 | 222 | // 16. 配置读超时时间 223 | // 224 | // 16.1 .设置服务端读超时时间 225 | func WithServerReadTimeout(t time.Duration) ServerOption { 226 | return func(o *ConnOption) { 227 | o.readTimeout = t 228 | } 229 | } 230 | 231 | // 16.2 .设置客户端读超时时间 232 | func WithClientReadTimeout(t time.Duration) ClientOption { 233 | return func(o *DialOption) { 234 | o.readTimeout = t 235 | } 236 | } 237 | 238 | // 17。 只配置OnClose 239 | // 17.1 配置服务端OnClose 240 | func WithServerOnCloseFunc(onClose func(c *Conn, err error)) ServerOption { 241 | return func(o *ConnOption) { 242 | o.cb = OnCloseFunc(onClose) 243 | } 244 | } 245 | 246 | // 17.2 配置客户端OnClose 247 | func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption { 248 | return func(o *DialOption) { 249 | o.cb = OnCloseFunc(onClose) 250 | } 251 | } 252 | 253 | // 18.1 配置服务端Callback相关方法在io event loop中执行 254 | func WithServerCallbackInEventLoop() ServerOption { 255 | return func(o *ConnOption) { 256 | o.runInGoTask = "io" 257 | } 258 | } 259 | 260 | // 18.2 配置服务端Callback相关方法在io event loop中执行 261 | func WithClientCallbackInEventLoop() ClientOption { 262 | return func(o *DialOption) { 263 | o.runInGoTask = "io" 264 | } 265 | } 266 | 267 | // 默认模式 268 | // 19.1 配置服务端使用onebyone模式处理请求,从生命周期的开始到结束,这个Message只会被这个go程处理 269 | func WithServerOneByOneMode() ServerOption { 270 | return func(o *ConnOption) { 271 | o.runInGoTask = "onebyone" 272 | } 273 | } 274 | 275 | // 默认模式 276 | // 19.2 配置客户端使用onebyone模式处理请求,从生命周期的开始到结束,这个Message只会被这个go程处理 277 | func WithClientOneByOneMode() ClientOption { 278 | return func(o *DialOption) { 279 | o.runInGoTask = "onebyone" 280 | } 281 | } 282 | 283 | func WithServerElasticMode() ServerOption { 284 | return func(o *ConnOption) { 285 | o.runInGoTask = "elastic" 286 | } 287 | } 288 | 289 | func WithClientElasticMode() ClientOption { 290 | return func(o *DialOption) { 291 | o.runInGoTask = "elastic" 292 | } 293 | } 294 | 295 | // 20.1 配置自定义task, 需要确保传入的值是有效的,不然会panic 296 | func WithServerCustomTaskMode(taskName string) ServerOption { 297 | return func(o *ConnOption) { 298 | if len(taskName) > 0 { 299 | o.runInGoTask = taskName 300 | } 301 | } 302 | } 303 | 304 | // 20.2 配置自定义task, 需要确保传入的值是有效的,不然会panic 305 | func WithClientCustomTaskMode(taskName string) ClientOption { 306 | return func(o *DialOption) { 307 | if len(taskName) > 0 { 308 | o.runInGoTask = taskName 309 | } 310 | } 311 | } 312 | 313 | // 20.1 配置event 314 | func WithServerMultiEventLoop(m *MultiEventLoop) ServerOption { 315 | return func(o *ConnOption) { 316 | o.multiEventLoop = m 317 | } 318 | } 319 | 320 | // 20.2 配置event 321 | func WithClientMultiEventLoop(m *MultiEventLoop) ClientOption { 322 | return func(o *DialOption) { 323 | o.multiEventLoop = m 324 | } 325 | } 326 | 327 | // 21.1 配置压缩和解压缩 328 | func WithServerDecompressAndCompress() ServerOption { 329 | return func(o *ConnOption) { 330 | o.Compression = true 331 | o.Decompression = true 332 | } 333 | } 334 | 335 | // 21.2 配置压缩和解压缩 336 | func WithClientDecompressAndCompress() ClientOption { 337 | return func(o *DialOption) { 338 | o.Compression = true 339 | o.Decompression = true 340 | } 341 | } 342 | 343 | // 21.1 设置客户端支持上下文接管, 默认不支持上下文接管 344 | func WithClientContextTakeover() ClientOption { 345 | return func(o *DialOption) { 346 | o.ClientContextTakeover = true 347 | } 348 | } 349 | 350 | // 21.2 设置服务端支持上下文接管, 默认不支持上下文接管 351 | func WithServerContextTakeover() ServerOption { 352 | return func(o *ConnOption) { 353 | o.ServerContextTakeover = true 354 | } 355 | } 356 | 357 | // 21.1 设置客户端最大窗口位数,使用上下文接管时,这个参数才有效 358 | func WithClientMaxWindowsBits(bits uint8) ClientOption { 359 | return func(o *DialOption) { 360 | if bits < 8 || bits > 15 { 361 | return 362 | } 363 | o.ClientMaxWindowBits = bits 364 | } 365 | } 366 | 367 | // 22.2 设置服务端最大窗口位数, 使用上下文接管时,这个参数才有效 368 | func WithServerMaxWindowBits(bits uint8) ServerOption { 369 | return func(o *ConnOption) { 370 | if bits < 8 || bits > 15 { 371 | return 372 | } 373 | o.ServerMaxWindowBits = bits 374 | } 375 | } 376 | 377 | // 22.1 设置客户端最大可以读取的message的大小, 默认没有限制 378 | func WithClientReadMaxMessage(size int64) ClientOption { 379 | return func(o *DialOption) { 380 | o.readMaxMessage = size 381 | } 382 | } 383 | 384 | // 22.2 设置服务端最大可以读取的message的大小,默认没有限制 385 | func WithServerReadMaxMessage(size int64) ServerOption { 386 | return func(o *ConnOption) { 387 | o.readMaxMessage = size 388 | } 389 | } 390 | 391 | func WithFlowBackPressureRemoveRead() ServerOption { 392 | return func(o *ConnOption) { 393 | o.flowBackPressureRemoveRead = true 394 | } 395 | } 396 | 397 | func WithClientFlowBackPressureRemoveRead() ClientOption { 398 | return func(o *DialOption) { 399 | o.flowBackPressureRemoveRead = true 400 | } 401 | } 402 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # greatws 2 | 3 | 支持海量连接的websocket库,callback写法 4 | 5 | ![Go](https://github.com/antlabs/greatws/workflows/Go/badge.svg) 6 | [![codecov](https://codecov.io/gh/antlabs/greatws/branch/master/graph/badge.svg)](https://codecov.io/gh/antlabs/greatws) 7 | [![Go Report Card](https://goreportcard.com/badge/github.com/antlabs/greatws)](https://goreportcard.com/report/github.com/antlabs/greatws) 8 | 9 | ## 处理流程 10 | 11 | ![greatws.png](https://github.com/antlabs/images/blob/main/greatws/greatws.png?raw=true) 12 | 13 | # 特性 14 | 15 | * 支持 epoll/kqueue 16 | * 低内存占用 17 | * 高tps 18 | * 对websocket的兼容性较高,完整实现rfc6455, rfc7692 19 | 20 | # 暂不支持 21 | 22 | * ssl 23 | * windows 24 | * io-uring 25 | 26 | # 警告⚠️ 27 | 28 | 早期阶段,暂时不建议生产使用 29 | 30 | ## 内容 31 | 32 | * [安装](#Installation) 33 | * [例子](#example) 34 | * [net/http升级到websocket服务端](#net-http升级到websocket服务端) 35 | * [gin升级到websocket服务端](#gin升级到websocket服务端) 36 | * [客户端](#客户端) 37 | * [配置函数](#配置函数) 38 | * [客户端配置参数](#客户端配置) 39 | * [配置header](#配置header) 40 | * [配置握手时的超时时间](#配置握手时的超时时间) 41 | * [配置自动回复ping消息](#配置自动回复ping消息) 42 | * [配置客户端最大读取message](#配置客户端最大读message) 43 | * [配置客户端压缩和解压消息](#配置客户端压缩和解压消息) 44 | * [配置客户端上下文接管](#配置客户端上下文接管) 45 | * [服务配置参数](#服务端配置) 46 | * [配置服务自动回复ping消息](#配置服务自动回复ping消息) 47 | * [配置服务端最大读取message](#配置服务端最大读message) 48 | * [配置服务端解压消息](#配置服务端解压消息) 49 | * [配置服务端压缩和解压消息](#配置服务端压缩和解压消息) 50 | * [配置服务端上下文接管](#配置服务端上下文接管) 51 | 52 | # 例子-服务端 53 | 54 | ### net http升级到websocket服务端 55 | 56 | ```go 57 | 58 | package main 59 | 60 | import ( 61 | "fmt" 62 | 63 | "github.com/antlabs/greatws" 64 | ) 65 | 66 | type echoHandler struct{} 67 | 68 | func (e *echoHandler) OnOpen(c *greatws.Conn) { 69 | // fmt.Printf("OnOpen: %p\n", c) 70 | } 71 | 72 | func (e *echoHandler) OnMessage(c *greatws.Conn, op greatws.Opcode, msg []byte) { 73 | if err := c.WriteTimeout(op, msg, 3*time.Second); err != nil { 74 | fmt.Println("write fail:", err) 75 | } 76 | // if err := c.WriteMessage(op, msg); err != nil { 77 | // slog.Error("write fail:", err) 78 | // } 79 | } 80 | 81 | func (e *echoHandler) OnClose(c *greatws.Conn, err error) { 82 | errMsg := "" 83 | if err != nil { 84 | errMsg = err.Error() 85 | } 86 | slog.Error("OnClose:", errMsg) 87 | } 88 | 89 | type handler struct { 90 | m *greatws.MultiEventLoop 91 | } 92 | 93 | func (h *handler) echo(w http.ResponseWriter, r *http.Request) { 94 | c, err := greatws.Upgrade(w, r, 95 | greatws.WithServerReplyPing(), 96 | // greatws.WithServerDecompression(), 97 | greatws.WithServerIgnorePong(), 98 | greatws.WithServerCallback(&echoHandler{}), 99 | // greatws.WithServerEnableUTF8Check(), 100 | greatws.WithServerReadTimeout(5*time.Second), 101 | greatws.WithServerMultiEventLoop(h.m), 102 | ) 103 | if err != nil { 104 | slog.Error("Upgrade fail:", "err", err.Error()) 105 | } 106 | _ = c 107 | } 108 | 109 | func main() { 110 | 111 | var h handler 112 | 113 | h.m = greatws.NewMultiEventLoopMust(greatws.WithEventLoops(0), greatws.WithMaxEventNum(256), greatws.WithLogLevel(slog.LevelError)) // epoll, kqueue 114 | h.m.Start() 115 | fmt.Printf("apiname:%s\n", h.m.GetApiName()) 116 | 117 | mux := &http.ServeMux{} 118 | mux.HandleFunc("/autobahn", h.echo) 119 | 120 | rawTCP, err := net.Listen("tcp", ":9001") 121 | if err != nil { 122 | fmt.Println("Listen fail:", err) 123 | return 124 | } 125 | log.Println("non-tls server exit:", http.Serve(rawTCP, mux)) 126 | } 127 | ``` 128 | 129 | [返回](#内容) 130 | 131 | ### gin升级到websocket服务端 132 | 133 | ```go 134 | package main 135 | 136 | import ( 137 | "fmt" 138 | 139 | "github.com/antlabs/greatws" 140 | "github.com/gin-gonic/gin" 141 | ) 142 | 143 | type handler struct{ 144 | m *greatws.MultiEventLoop 145 | } 146 | 147 | func (h *handler) OnOpen(c *greatws.Conn) { 148 | fmt.Printf("服务端收到一个新的连接") 149 | } 150 | 151 | func (h *handler) OnMessage(c *greatws.Conn, op greatws.Opcode, msg []byte) { 152 | // 如果msg的生命周期不是在OnMessage中结束,需要拷贝一份 153 | // newMsg := make([]byte, len(msg)) 154 | // copy(newMsg, msg) 155 | 156 | fmt.Printf("收到客户端消息:%s\n", msg) 157 | c.WriteMessage(op, msg) 158 | // os.Stdout.Write(msg) 159 | } 160 | 161 | func (h *handler) OnClose(c *greatws.Conn, err error) { 162 | fmt.Printf("服务端连接关闭:%v\n", err) 163 | } 164 | 165 | func main() { 166 | r := gin.Default() 167 | var h handler 168 | h.m = greatws.NewMultiEventLoopMust(greatws.WithEventLoops(0), greatws.WithMaxEventNum(256), greatws.WithLogLevel(slog.LevelError)) // epoll, kqueue 169 | h.m.Start() 170 | 171 | r.GET("/", func(c *gin.Context) { 172 | con, err := greatws.Upgrade(c.Writer, c.Request, greatws.WithServerCallback(h.m), greatws.WithServerMultiEventLoop(h.m)) 173 | if err != nil { 174 | return 175 | } 176 | con.StartReadLoop() 177 | }) 178 | r.Run() 179 | } 180 | ``` 181 | 182 | [返回](#内容) 183 | 184 | ### 客户端 185 | 186 | ```go 187 | package main 188 | 189 | import ( 190 | "fmt" 191 | "time" 192 | 193 | "github.com/antlabs/greatws" 194 | ) 195 | 196 | var m *greatws.MultiEventLoop 197 | type handler struct{} 198 | 199 | func (h *handler) OnOpen(c *greatws.Conn) { 200 | fmt.Printf("客户端连接成功\n") 201 | } 202 | 203 | func (h *handler) OnMessage(c *greatws.Conn, op greatws.Opcode, msg []byte) { 204 | // 如果msg的生命周期不是在OnMessage中结束,需要拷贝一份 205 | // newMsg := make([]byte, len(msg)) 206 | // copy(newMsg, msg) 207 | 208 | fmt.Printf("收到服务端消息:%s\n", msg) 209 | c.WriteMessage(op, msg) 210 | time.Sleep(time.Second) 211 | } 212 | 213 | func (h *handler) OnClose(c *greatws.Conn, err error) { 214 | fmt.Printf("客户端端连接关闭:%v\n", err) 215 | } 216 | 217 | func main() { 218 | m = greatws.NewMultiEventLoopMust(greatws.WithEventLoops(0), greatws.WithMaxEventNum(256), greatws.WithLogLevel(slog.LevelError)) // epoll, kqueue 219 | m.Start() 220 | c, err := greatws.Dial("ws://127.0.0.1:8080/", greatws.WithClientCallback(&handler{}), greatws.WithServerMultiEventLoop(h.m)) 221 | if err != nil { 222 | fmt.Printf("连接失败:%v\n", err) 223 | return 224 | } 225 | 226 | c.WriteMessage(opcode.Text, []byte("hello")) 227 | time.Sleep(time.Hour) //demo里面等待下OnMessage 看下执行效果,因为greatws.Dial和WriteMessage都是非阻塞的函数调用,不会卡住主go程 228 | } 229 | ``` 230 | 231 | [返回](#内容) 232 | 233 | ## 配置函数 234 | 235 | ### 客户端配置参数 236 | 237 | #### 配置header 238 | 239 | ```go 240 | func main() { 241 | greatws.Dial("ws://127.0.0.1:12345/test", greatws.WithClientHTTPHeader(http.Header{ 242 | "h1": "v1", 243 | "h2":"v2", 244 | })) 245 | } 246 | ``` 247 | 248 | [返回](#内容) 249 | 250 | #### 配置握手时的超时时间 251 | 252 | ```go 253 | func main() { 254 | greatws.Dial("ws://127.0.0.1:12345/test", greatws.WithClientDialTimeout(2 * time.Second)) 255 | } 256 | ``` 257 | 258 | [返回](#内容) 259 | 260 | #### 配置自动回复ping消息 261 | 262 | ```go 263 | func main() { 264 | greatws.Dial("ws://127.0.0.1:12345/test", greatws.WithClientReplyPing()) 265 | } 266 | ``` 267 | 268 | [返回](#内容) 269 | 270 | #### 配置客户端最大读message 271 | 272 | ```go 273 | // 限制客户端最大服务返回返回的最大包是1024,如果超过这个大小报错 274 | greatws.Dial("ws://127.0.0.1:12345/test", greatws.WithClientReadMaxMessage(1024)) 275 | ``` 276 | 277 | [返回](#内容) 278 | 279 | #### 配置客户端压缩和解压消息 280 | 281 | ```go 282 | func main() { 283 | greatws.Dial("ws://127.0.0.1:12345/test", greatws.WithClientDecompressAndCompress()) 284 | } 285 | ``` 286 | 287 | [返回](#内容) 288 | 289 | #### 配置客户端上下文接管 290 | 291 | ```go 292 | func main() { 293 | greatws.Dial("ws://127.0.0.1:12345/test", greatws.WithClientContextTakeover()) 294 | } 295 | ``` 296 | 297 | [返回](#内容) 298 | 299 | ### 服务端配置参数 300 | 301 | #### 配置服务自动回复ping消息 302 | 303 | ```go 304 | func main() { 305 | c, err := greatws.Upgrade(w, r, greatws.WithServerReplyPing()) 306 | if err != nil { 307 | fmt.Println("Upgrade fail:", err) 308 | return 309 | } 310 | } 311 | ``` 312 | 313 | [返回](#内容) 314 | 315 | #### 配置服务端最大读message 316 | 317 | ```go 318 | func main() { 319 | // 配置服务端读取客户端最大的包是1024大小, 超过该值报错 320 | c, err := greatws.Upgrade(w, r, greatws.WithServerReadMaxMessage(1024)) 321 | if err != nil { 322 | fmt.Println("Upgrade fail:", err) 323 | return 324 | } 325 | } 326 | ``` 327 | 328 | [返回](#内容) 329 | 330 | #### 配置服务端解压消息 331 | 332 | ```go 333 | func main() { 334 | // 配置服务端读取客户端最大的包是1024大小, 超过该值报错 335 | c, err := greatws.Upgrade(w, r, greatws.WithServerDecompression()) 336 | if err != nil { 337 | fmt.Println("Upgrade fail:", err) 338 | return 339 | } 340 | } 341 | ``` 342 | 343 | [返回](#内容) 344 | 345 | #### 配置服务端压缩和解压消息 346 | 347 | ```go 348 | func main() { 349 | c, err := greatws.Upgrade(w, r, greatws.WithServerDecompressAndCompress()) 350 | if err != nil { 351 | fmt.Println("Upgrade fail:", err) 352 | return 353 | } 354 | } 355 | ``` 356 | 357 | [返回](#内容) 358 | 359 | #### 配置服务端上下文接管 360 | 361 | ```go 362 | func main() { 363 | // 配置服务端读取客户端最大的包是1024大小, 超过该值报错 364 | c, err := greatws.Upgrade(w, r, greatws.WithServerContextTakeover) 365 | if err != nil { 366 | fmt.Println("Upgrade fail:", err) 367 | return 368 | } 369 | } 370 | ``` 371 | 372 | [返回](#内容) 373 | 374 | ## 100w websocket长链接测试 375 | 376 | ### e5 洋垃圾机器 377 | 378 | * cpu=e5 2686(单路) 379 | * memory=32GB 380 | 381 | ``` 382 | BenchType : BenchEcho 383 | Framework : greatws 384 | TPS : 106014 385 | EER : 218.54 386 | Min : 49.26us 387 | Avg : 94.08ms 388 | Max : 954.33ms 389 | TP50 : 45.76ms 390 | TP75 : 52.27ms 391 | TP90 : 336.85ms 392 | TP95 : 427.07ms 393 | TP99 : 498.66ms 394 | Used : 18.87s 395 | Total : 2000000 396 | Success : 2000000 397 | Failed : 0 398 | Conns : 1000000 399 | Concurrency: 10000 400 | Payload : 1024 401 | CPU Min : 184.90% 402 | CPU Avg : 485.10% 403 | CPU Max : 588.31% 404 | MEM Min : 563.40M 405 | MEM Avg : 572.40M 406 | MEM Max : 594.48M 407 | ``` 408 | 409 | ### 5800h cpu 410 | 411 | * cpu=5800h 412 | * memory=64GB 413 | 414 | ``` 415 | BenchType : BenchEcho 416 | Framework : greatws 417 | TPS : 103544 418 | EER : 397.07 419 | Min : 26.51us 420 | Avg : 95.79ms 421 | Max : 1.34s 422 | TP50 : 58.26ms 423 | TP75 : 60.94ms 424 | TP90 : 62.50ms 425 | TP95 : 63.04ms 426 | TP99 : 63.47ms 427 | Used : 40.76s 428 | Total : 5000000 429 | Success : 4220634 430 | Failed : 779366 431 | Conns : 1000000 432 | Concurrency: 10000 433 | Payload : 1024 434 | CPU Min : 30.54% 435 | CPU Avg : 260.77% 436 | CPU Max : 335.88% 437 | MEM Min : 432.25M 438 | MEM Avg : 439.71M 439 | MEM Max : 449.62M 440 | ``` 441 | -------------------------------------------------------------------------------- /client_option_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "log/slog" 19 | "net/http" 20 | "net/http/httptest" 21 | "strings" 22 | "sync/atomic" 23 | "testing" 24 | "time" 25 | ) 26 | 27 | func Test_ClientOption(t *testing.T) { 28 | m := NewMultiEventLoopAndStartMust(WithEventLoops(1), WithLogLevel(slog.LevelDebug), WithBusinessGoNum(1, 1, 1)) 29 | t.Run("ClientOption.WithClientHTTPHeader", func(t *testing.T) { 30 | done := make(chan string, 1) 31 | run := int32(0) 32 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 33 | v := r.Header.Get("A") 34 | con, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 35 | if err != nil { 36 | t.Error(err) 37 | return 38 | } 39 | 40 | defer con.Close() 41 | atomic.AddInt32(&run, 1) 42 | done <- v 43 | })) 44 | 45 | defer ts.Close() 46 | 47 | url := strings.ReplaceAll(ts.URL, "http", "ws") 48 | con, err := Dial(url, WithClientHTTPHeader(http.Header{ 49 | "A": []string{"A"}, 50 | }), WithClientCallback(&testDefaultCallback{}), WithClientMultiEventLoop(m)) 51 | if err != nil { 52 | t.Error(err) 53 | return 54 | } 55 | defer con.Close() 56 | 57 | select { 58 | case v := <-done: 59 | if v != "A" { 60 | t.Error("header fail") 61 | } 62 | case <-time.After(1000 * time.Millisecond): 63 | } 64 | if atomic.LoadInt32(&run) != 1 { 65 | t.Error("not run server:method fail") 66 | } 67 | }) 68 | 69 | // TODO: 现在不支持tls 配置 70 | // t.Run("ClientOption.WithClientTLSConfig", func(t *testing.T) { 71 | // done := make(chan string, 1) 72 | // run := int32(0) 73 | // ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 74 | // v := r.Header.Get("A") 75 | // atomic.AddInt32(&run, 1) 76 | // done <- v 77 | // con, err := Upgrade(w, r) 78 | // if err != nil { 79 | // t.Error(err) 80 | // return 81 | // } 82 | 83 | // defer con.Close() 84 | // })) 85 | 86 | // defer ts.Close() 87 | 88 | // url := strings.ReplaceAll(ts.URL, "http", "ws") 89 | // con, err := Dial(url, 90 | // WithClientTLSConfig(&tls.Config{InsecureSkipVerify: true}), 91 | // WithClientHTTPHeader(http.Header{ 92 | // "A": []string{"A"}, 93 | // }), WithClientCallback(&testDefaultCallback{})) 94 | // if err != nil { 95 | // t.Error(err) 96 | // return 97 | // } 98 | // defer con.Close() 99 | 100 | // select { 101 | // case v := <-done: 102 | // if v != "A" { 103 | // t.Error("header fail") 104 | // } 105 | // case <-time.After(1000 * time.Millisecond): 106 | // } 107 | // if atomic.LoadInt32(&run) != 1 { 108 | // t.Error("not run server:method fail") 109 | // } 110 | // }) 111 | 112 | t.Run("6.1 Dial: WithClientBindHTTPHeader and echo Sec-Websocket-Protocol", func(t *testing.T) { 113 | m := NewMultiEventLoopAndStartMust(WithEventLoops(1), WithLogLevel(slog.LevelDebug), WithBusinessGoNum(1, 1, 1)) 114 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 116 | if err != nil { 117 | t.Error(err) 118 | } 119 | })) 120 | 121 | defer ts.Close() 122 | 123 | url := strings.ReplaceAll(ts.URL, "http", "ws") 124 | h := make(http.Header) 125 | con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientMultiEventLoop(m), WithClientHTTPHeader(http.Header{ 126 | "Sec-WebSocket-Protocol": []string{"token"}, 127 | })) 128 | if err != nil { 129 | t.Error(err) 130 | } 131 | defer con.Close() 132 | 133 | if h["Sec-Websocket-Protocol"][0] != "token" { 134 | t.Error("header fail") 135 | } 136 | }) 137 | 138 | t.Run("6.2 DialConf: WithClientBindHTTPHeader and echo Sec-Websocket-Protocol", func(t *testing.T) { 139 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 140 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 141 | if err != nil { 142 | t.Error(err) 143 | } 144 | })) 145 | 146 | defer ts.Close() 147 | 148 | url := strings.ReplaceAll(ts.URL, "http", "ws") 149 | h := make(http.Header) 150 | con, err := DialConf(url, ClientOptionToConf(WithClientBindHTTPHeader(&h), WithClientMultiEventLoop(m), WithClientHTTPHeader(http.Header{ 151 | "Sec-WebSocket-Protocol": []string{"token"}, 152 | }))) 153 | if err != nil { 154 | t.Error(err) 155 | } 156 | defer con.Close() 157 | 158 | if h["Sec-Websocket-Protocol"][0] != "token" { 159 | t.Error("header fail") 160 | } 161 | }) 162 | 163 | // t.Run("18 Dial: WithClientDialFunc.1", func(t *testing.T) { 164 | // ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 165 | // _, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 166 | // c.WriteMessage(o, b) 167 | // c.Close() 168 | // }), WithServerMultiEventLoop(m)) 169 | // if err != nil { 170 | // t.Error(err) 171 | // } 172 | 173 | // })) 174 | 175 | // proxyAddr, err := net.Listen("tcp", "127.0.0.1:0") 176 | // if err != nil { 177 | // t.Error(err) 178 | // } 179 | // defer ts.Close() 180 | 181 | // go func() { 182 | // newConn, err := proxyAddr.Accept() 183 | // if err != nil { 184 | // t.Error(err) 185 | // } 186 | 187 | // newConn.SetDeadline(time.Now().Add(30 * time.Second)) 188 | 189 | // buf := make([]byte, 128) 190 | // if _, err := io.ReadFull(newConn, buf[:3]); err != nil { 191 | // t.Errorf("read failed: %v", err) 192 | // return 193 | // } 194 | 195 | // // socks version 5, 1 authentication method, no auth 196 | // if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { 197 | // t.Errorf("read %x, want %x", buf[:len(want)], want) 198 | // } 199 | 200 | // // socks version 5, connect command, reserved, ipv4 address, port 80 201 | // if _, err := newConn.Write([]byte{5, 0}); err != nil { 202 | // t.Errorf("write failed: %v", err) 203 | // return 204 | // } 205 | 206 | // // ver cmd rsv atyp dst.addr dst.port 207 | // if _, err := io.ReadFull(newConn, buf[:10]); err != nil { 208 | // t.Errorf("read failed: %v", err) 209 | // return 210 | // } 211 | // if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { 212 | // t.Errorf("read %x, want %x", buf[:len(want)], want) 213 | // return 214 | // } 215 | // buf[1] = 0 216 | // if _, err := newConn.Write(buf[:10]); err != nil { 217 | // t.Errorf("write failed: %v", err) 218 | // return 219 | // } 220 | 221 | // // 提取ip 222 | // ip := net.IP(buf[4:8]) 223 | // port := binary.BigEndian.Uint16(buf[8:10]) 224 | 225 | // c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) 226 | // if err != nil { 227 | // t.Errorf("dial failed; %v", err) 228 | // return 229 | // } 230 | // defer c2.Close() 231 | // done := make(chan struct{}) 232 | // go func() { 233 | // io.Copy(newConn, c2) 234 | // close(done) 235 | // }() 236 | // io.Copy(c2, newConn) 237 | // <-done 238 | // }() 239 | 240 | // got := make([]byte, 0, 128) 241 | // url := strings.ReplaceAll(ts.URL, "http", "ws") 242 | // c, err := Dial(url, WithClientDialFunc(func() (Dialer, error) { 243 | // return proxy.SOCKS5("tcp", proxyAddr.Addr().String(), nil, nil) 244 | // }), WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 245 | // got = append(got, b...) 246 | // c.Close() 247 | // })) 248 | // if err != nil { 249 | // t.Error(err) 250 | // } 251 | 252 | // data := []byte("hello world") 253 | // c.WriteMessage(Binary, data) 254 | // c.ReadLoop() 255 | 256 | // t.Log("got", string(got), "want", string(data)) 257 | // if !bytes.Equal(got, data) { 258 | // t.Errorf("got %s, want %s", got, data) 259 | // } 260 | // }) 261 | 262 | // t.Run("18 Dial: WithClientDialFunc.2", func(t *testing.T) { 263 | // ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 264 | // _, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 265 | // c.WriteMessage(o, b) 266 | // c.Close() 267 | // })) 268 | // if err != nil { 269 | // t.Error(err) 270 | // } 271 | // })) 272 | 273 | // proxyAddr, err := net.Listen("tcp", "127.0.0.1:0") 274 | // if err != nil { 275 | // t.Error(err) 276 | // } 277 | // defer ts.Close() 278 | 279 | // go func() { 280 | // newConn, err := proxyAddr.Accept() 281 | // if err != nil { 282 | // t.Error(err) 283 | // } 284 | 285 | // newConn.SetDeadline(time.Now().Add(30 * time.Second)) 286 | 287 | // buf := make([]byte, 128) 288 | // if _, err := io.ReadFull(newConn, buf[:3]); err != nil { 289 | // t.Errorf("read failed: %v", err) 290 | // return 291 | // } 292 | 293 | // // socks version 5, 1 authentication method, no auth 294 | // if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { 295 | // t.Errorf("read %x, want %x", buf[:len(want)], want) 296 | // } 297 | 298 | // // socks version 5, connect command, reserved, ipv4 address, port 80 299 | // if _, err := newConn.Write([]byte{5, 0}); err != nil { 300 | // t.Errorf("write failed: %v", err) 301 | // return 302 | // } 303 | 304 | // // ver cmd rsv atyp dst.addr dst.port 305 | // if _, err := io.ReadFull(newConn, buf[:10]); err != nil { 306 | // t.Errorf("read failed: %v", err) 307 | // return 308 | // } 309 | // if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { 310 | // t.Errorf("read %x, want %x", buf[:len(want)], want) 311 | // return 312 | // } 313 | // buf[1] = 0 314 | // if _, err := newConn.Write(buf[:10]); err != nil { 315 | // t.Errorf("write failed: %v", err) 316 | // return 317 | // } 318 | 319 | // // 提取ip 320 | // ip := net.IP(buf[4:8]) 321 | // port := binary.BigEndian.Uint16(buf[8:10]) 322 | 323 | // c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) 324 | // if err != nil { 325 | // t.Errorf("dial failed; %v", err) 326 | // return 327 | // } 328 | // defer c2.Close() 329 | // done := make(chan struct{}) 330 | // go func() { 331 | // io.Copy(newConn, c2) 332 | // close(done) 333 | // }() 334 | // io.Copy(c2, newConn) 335 | // <-done 336 | // }() 337 | 338 | // got := make([]byte, 0, 128) 339 | // url := strings.ReplaceAll(ts.URL, "http", "ws") 340 | // c, err := DialConf(url, ClientOptionToConf(WithClientDialFunc(func() (Dialer, error) { 341 | // return proxy.SOCKS5("tcp", proxyAddr.Addr().String(), nil, nil) 342 | // }), WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { 343 | // got = append(got, b...) 344 | // c.Close() 345 | // }))) 346 | // if err != nil { 347 | // t.Error(err) 348 | // } 349 | 350 | // data := []byte("hello world") 351 | // c.WriteMessage(Binary, data) 352 | // c.ReadLoop() 353 | 354 | // t.Log("got", string(got), "want", string(data)) 355 | // if !bytes.Equal(got, data) { 356 | // t.Errorf("got %s, want %s", got, data) 357 | // } 358 | // }) 359 | } 360 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /conn_unix.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | //go:build linux || darwin || netbsd || freebsd || openbsd || dragonfly 16 | // +build linux darwin netbsd freebsd openbsd dragonfly 17 | 18 | package greatws 19 | 20 | import ( 21 | "errors" 22 | "fmt" 23 | "io" 24 | "log/slog" 25 | "net" 26 | "sync" 27 | "sync/atomic" 28 | "syscall" 29 | "time" 30 | "unsafe" 31 | 32 | "github.com/antlabs/pulse/core" 33 | "github.com/antlabs/task/task/driver" 34 | "github.com/antlabs/wsutil/bytespool" 35 | "github.com/antlabs/wsutil/deflate" 36 | "github.com/antlabs/wsutil/enum" 37 | "github.com/antlabs/wsutil/myonce" 38 | "golang.org/x/sys/unix" 39 | ) 40 | 41 | type writeState int32 42 | 43 | const ( 44 | writeDefault writeState = 1 << iota 45 | writeEagain 46 | writeSuccess 47 | ) 48 | 49 | func (s writeState) String() string { 50 | switch s { 51 | case writeDefault: 52 | return "default" 53 | case writeEagain: 54 | return "eagain" 55 | default: 56 | return "invalid" 57 | } 58 | } 59 | 60 | var ( 61 | ErrInvalidDeadline = errors.New("invalid deadline") 62 | // 读超时 63 | ErrReadTimeout = errors.New("read timeout") 64 | // 写超时 65 | ErrWriteTimeout = errors.New("write timeout") 66 | ) 67 | 68 | // Conn大小改变历史,增加上下文接管,从<160到184 69 | type Conn struct { 70 | conn 71 | 72 | Callback // callback移至conn中 73 | pd deflate.PermessageDeflateConf // 上下文接管的控制参数, 由于每个comm的配置都可能不一样,所以需要放在Conn里面 74 | mu sync.Mutex // 锁 75 | *Config // 配置 76 | deCtx *deflate.DeCompressContextTakeover // 解压缩上下文 77 | enCtx *deflate.CompressContextTakeover // 压缩上下文 78 | parent *EventLoop // event loop 79 | task driver.TaskExecutor // 任务,该任务会进协程池里面执行 80 | rtime *time.Timer // 控制读超时 81 | wtime *time.Timer // 控制写超时 82 | 83 | // mu2由 onCloseOnce使用, 这里使用新锁只是为了简化维护的难度 84 | // 也可以共用mu,区别 优点:节约内存,缺点:容易出现死锁和需要精心调试代码 85 | // 这里选择维护简单 86 | mu2 sync.Mutex 87 | onCloseOnce myonce.MyOnce // 保证只调用一次OnClose函数 88 | closed int32 // 是否关闭 89 | } 90 | 91 | func newConn(fd int64, client bool, conf *Config) (*Conn, error) { 92 | c := &Conn{ 93 | conn: conn{ 94 | fd: fd, 95 | client: client, 96 | }, 97 | // 初始化不分配内存,只有在需要的时候才分配 98 | Config: conf, 99 | parent: conf.multiEventLoop.getEventLoop(int(fd)), 100 | } 101 | 102 | c.task = c.parent.localTask.newTask(conf.runInGoTask) 103 | if conf.readTimeout > 0 { 104 | err := c.setReadDeadline(time.Now().Add(conf.readTimeout)) 105 | if err != nil { 106 | return nil, err 107 | } 108 | } 109 | return c, nil 110 | } 111 | 112 | // 这是一个空函数,兼容下quickws的接口 113 | func (c *Conn) StartReadLoop() { 114 | 115 | } 116 | 117 | // 这是一个空函数,兼容下quickws的接口 118 | func (c *Conn) ReadLoop() error { 119 | return nil 120 | } 121 | 122 | func duplicateSocket(socketFD int) (int, error) { 123 | return unix.Dup(socketFD) 124 | } 125 | 126 | // 没有加锁的版本,有外层已经有锁保护,所以不需要加锁 127 | func (c *Conn) closeWithoutLockOnClose(err error, onClose bool) { 128 | 129 | if c.isClosed() { 130 | return 131 | } 132 | 133 | if err != nil { 134 | err = io.EOF 135 | } 136 | fd := c.getFd() 137 | c.getLogger().Debug("close conn", slog.Int64("fd", int64(fd))) 138 | c.parent.del(c) 139 | atomic.StoreInt64(&c.fd, -1) 140 | atomic.StoreInt32(&c.closed, 1) 141 | 142 | // 这个必须要放在后面 143 | if onClose { 144 | c.onCloseOnce.Do(&c.mu2, func() { 145 | c.OnClose(c, err) 146 | }) 147 | if c.task != nil { 148 | c.task.Close(nil) 149 | } 150 | } 151 | 152 | } 153 | 154 | func (c *Conn) closeNoLock(err error) { 155 | 156 | c.closeWithoutLockOnClose(err, true) 157 | } 158 | 159 | func (c *Conn) closeWithLock(err error) { 160 | if c.isClosed() { 161 | return 162 | } 163 | 164 | c.mu.Lock() 165 | if c.isClosed() { 166 | c.mu.Unlock() 167 | return 168 | } 169 | 170 | if err == nil { 171 | err = io.EOF 172 | } 173 | c.closeWithoutLockOnClose(err, false) 174 | 175 | c.mu.Unlock() 176 | 177 | // 这个必须要放在后面, 不然会死锁,因为Close会调用用户的OnClose 178 | // 用户的OnClose也有可能调用Close, 所以使用flags来判断是否已经关闭 179 | if atomic.LoadInt32(&c.closed) == 1 { 180 | c.onCloseOnce.Do(&c.mu2, func() { 181 | c.OnClose(c, err) 182 | }) 183 | } 184 | } 185 | 186 | func (c *Conn) getPtr() int { 187 | return int(uintptr(unsafe.Pointer(c))) 188 | } 189 | 190 | // Conn Write入口, 原始定义是func (c *Conn) Write() (n int, err error) 191 | // 会引起误用,所以隐藏起来, 作为一个websocket库,直接暴露tcp的write接口也不合适 192 | func connWrite(c *Conn, b []byte) (n int, err error) { 193 | if c.isClosed() { 194 | return 0, ErrClosed 195 | } 196 | 197 | return c.write(b) 198 | } 199 | 200 | func (c *Conn) needFlush() bool { 201 | c.mu.Lock() 202 | defer c.mu.Unlock() 203 | return len(c.wbufList) > 0 204 | } 205 | 206 | func (c *Conn) flush() { 207 | if _, err := connWrite(c, nil); err != nil { 208 | slog.Error("failed to flush write buffer", "error", err) 209 | } 210 | } 211 | 212 | // writeToSocket 尝试将数据写入 socket,并处理中断与临时错误 213 | func (c *Conn) writeToSocket(data []byte) (int, error) { 214 | 215 | n, err := core.Write(c.getFd(), data) 216 | if err == nil { 217 | return n, nil 218 | } 219 | if err == syscall.EINTR { 220 | return 0, err // 被信号中断,直接返回 221 | } 222 | if err == syscall.EAGAIN { 223 | return 0, err // 资源暂时不可用 224 | } 225 | return 0, err // 其他错误直接返回 226 | 227 | } 228 | 229 | // appendToWbufList 将数据添加到写缓冲区列表 230 | // 先检查最后一个缓冲区是否有足够空间,如果有就直接append 231 | // 如果没有,将部分数据append到最后一个缓冲区,剩余部分创建新的readBufferSize大小的缓冲区 232 | func (c *Conn) appendToWbufList(data []byte, oldLen int) { 233 | if len(data) == 0 { 234 | return 235 | } 236 | 237 | // 如果wbufList为空,直接创建新的缓冲区 238 | if len(c.wbufList) == 0 { 239 | // 使用原始长度,对齐,提升复用率 240 | newBuf := bytespool.GetBytes(len(data) + oldLen) 241 | copy(*newBuf, data) 242 | *newBuf = (*newBuf)[:len(data)] 243 | c.wbufList = append(c.wbufList, newBuf) 244 | return 245 | } 246 | 247 | // 获取最后一个缓冲区 248 | lastBuf := c.wbufList[len(c.wbufList)-1] 249 | remainingSpace := cap(*lastBuf) - len(*lastBuf) 250 | 251 | // 如果最后一个缓冲区有足够空间,直接append 252 | if remainingSpace >= len(data) { 253 | *lastBuf = append(*lastBuf, data...) 254 | return 255 | } 256 | 257 | // 如果空间不够,先填满最后一个缓冲区 258 | if remainingSpace > 0 { 259 | *lastBuf = append(*lastBuf, data[:remainingSpace]...) 260 | data = data[remainingSpace:] // 剩余的数据 261 | } 262 | 263 | // 为剩余数据创建新的缓冲区(使用readBufferSize大小) 264 | for len(data) > 0 { 265 | newBuf := bytespool.GetBytes(len(data) + oldLen) 266 | copySize := len(data) 267 | if copySize > cap(*newBuf) { 268 | copySize = cap(*newBuf) 269 | } 270 | copy(*newBuf, data[:copySize]) 271 | *newBuf = (*newBuf)[:copySize] 272 | c.wbufList = append(c.wbufList, newBuf) 273 | data = data[copySize:] 274 | } 275 | } 276 | 277 | // handlePartialWrite 处理部分写入的情况,创建新缓冲区存储剩余数据 278 | func (c *Conn) handlePartialWrite(data *[]byte, n int, needAppend bool) error { 279 | if n < 0 { 280 | n = 0 281 | } 282 | 283 | // 如果已经全部写入,不需要创建新缓冲区 284 | if n >= len(*data) { 285 | return nil 286 | } 287 | 288 | remainingData := (*data)[n:] 289 | if needAppend { 290 | c.appendToWbufList(remainingData, len(*data)) 291 | } else { 292 | copy(*data, (*data)[n:]) 293 | *data = (*data)[:len(*data)-n] 294 | } 295 | 296 | // 部分写入成功,或者全部失败 297 | // 如果启用了流量背压机制且有部分写入,先删除读事件 298 | if c.Config.flowBackPressureRemoveRead { 299 | if delErr := c.eventLoop().delRead(c); delErr != nil { 300 | slog.Error("failed to delete read event", "error", delErr) 301 | } 302 | } else { 303 | if err := c.eventLoop().addWrite(c); err != nil { 304 | slog.Error("failed to add write event", "error", err) 305 | return err 306 | } 307 | } 308 | 309 | return nil 310 | } 311 | 312 | func (c *Conn) write(data []byte) (int, error) { 313 | 314 | if atomic.LoadInt64(&c.fd) == -1 { 315 | return 0, net.ErrClosed 316 | } 317 | 318 | if len(data) == 0 && len(c.wbufList) == 0 { 319 | return 0, nil 320 | } 321 | 322 | if len(c.wbufList) == 0 { 323 | n, err := c.writeToSocket(data) 324 | if errors.Is(err, core.EAGAIN) || errors.Is(err, core.EINTR) || err == nil { 325 | if n == len(data) { 326 | return n, nil 327 | } 328 | // 把剩余数据放到缓冲区 329 | if err := c.handlePartialWrite(&data, n, true); err != nil { 330 | c.closeNoLock(err) 331 | return 0, err 332 | } 333 | return len(data), nil 334 | } 335 | 336 | // 发生严重错误 337 | c.closeNoLock(err) 338 | return n, err 339 | } 340 | 341 | if len(data) > 0 { 342 | c.appendToWbufList(data, len(data)) 343 | } 344 | 345 | i := 0 346 | for i < len(c.wbufList) { 347 | wbuf := c.wbufList[i] 348 | n, err := c.writeToSocket(*wbuf) 349 | if errors.Is(err, core.EAGAIN) || errors.Is(err, core.EINTR) || err == nil /*写入成功,也有n != len(*wbuf)的情况*/ { 350 | if n == len(*wbuf) { 351 | bytespool.PutBytes(wbuf) 352 | c.wbufList[i] = nil 353 | i++ 354 | continue 355 | } 356 | // 移动剩余数据到缓冲区开始位置 357 | if err := c.handlePartialWrite(wbuf, n, false); err != nil { 358 | c.closeNoLock(err) 359 | return 0, err 360 | } 361 | 362 | // 移动未处理的缓冲区到列表开始位置 363 | copy(c.wbufList, c.wbufList[i:]) 364 | c.wbufList = c.wbufList[:len(c.wbufList)-i] 365 | return len(data), nil 366 | } 367 | 368 | c.closeNoLock(err) 369 | return n, err 370 | } 371 | 372 | // 所有数据都已写入 373 | c.wbufList = c.wbufList[:0] 374 | // 需要进的逻辑 375 | // 1.如果是垂直触发模式,并且启用了流量背压机制,重新添加读事件 376 | // 2.如果是水平触发模式也重新添加读事件,为了去掉写事件 377 | // 3.如果是垂直触发模式,并且启用了流量背压机制,则需要添加读事件 378 | 379 | // 不需要进的逻辑 380 | // 1.如果是垂直触发模式,并且没有启用流量背压机制,不需要重新添加事件, TODO 381 | 382 | if err := c.eventLoop().ResetRead(c.getFd()); err != nil { 383 | slog.Error("failed to reset read event", "error", err) 384 | } 385 | return len(data), nil 386 | } 387 | 388 | // kqueu/epoll模式下,读取数据 389 | // 该函数从缓冲区读取数据,并且解析出websocket frame 390 | // 有几种情况需要处理下 391 | // 1. 缓冲区空间不句够,需要扩容 392 | // 2. 缓冲区数据不够,并且一次性读取了多个frame 393 | func (c *Conn) processWebsocketFrame() (err error) { 394 | // 1. 处理frame header 395 | // if !c.useIoUring() { 396 | if c.rbuf == nil { 397 | c.rbuf = bytespool.GetBytes(int(float32(c.rh.PayloadLen)*c.windowsMultipleTimesPayloadSize) + enum.MaxFrameHeaderSize) 398 | } 399 | 400 | if c.readTimeout > 0 { 401 | // if err = c.setReadDeadline(time.Time{}); err != nil { 402 | // return err 403 | // } 404 | c.setReadDeadline(time.Now().Add(c.readTimeout)) 405 | } 406 | n := 0 407 | var success bool 408 | // 不使用io_uring的直接调用read获取buffer数据 409 | for i := 0; ; i++ { 410 | fd := atomic.LoadInt64(&c.fd) 411 | c.mu.Lock() 412 | n, err = unix.Read(int(fd), (*c.rbuf)[c.rw:]) 413 | c.mu.Unlock() 414 | c.multiEventLoop.addReadSyscall() 415 | // fmt.Printf("i = %d, n = %d, fd = %d, rbuf = %d, rw:%d, err = %v, %v, payload:%d\n", 416 | // i, n, c.fd, len((*c.rbuf)[c.rw:]), c.rw+n, err, time.Now(), c.rh.PayloadLen) 417 | if err != nil { 418 | // 信号中断,继续读 419 | if errors.Is(err, unix.EINTR) { 420 | continue 421 | } 422 | // 出错返回 423 | if !errors.Is(err, unix.EAGAIN) && !errors.Is(err, unix.EWOULDBLOCK) { 424 | goto fail 425 | } 426 | // 缓冲区没有数据,等待可读 427 | err = nil 428 | break 429 | } 430 | 431 | // 读到eof,直接关闭 432 | if n == 0 && len((*c.rbuf)[c.rw:]) > 0 { 433 | c.closeWithLock(io.EOF) 434 | c.onCloseOnce.Do(&c.mu2, func() { 435 | c.OnClose(c, io.EOF) 436 | }) 437 | err = io.EOF 438 | goto fail 439 | } 440 | 441 | if n > 0 { 442 | c.rw += n 443 | } 444 | 445 | if len((*c.rbuf)[c.rw:]) == 0 { 446 | // 说明缓存区已经满了。需要扩容 447 | // 并且如果使用epoll ET mode,需要继续读取,直到返回EAGAIN, 不然会丢失数据 448 | // 结合以上两种,缓存区满了就直接处理frame,解析出payload的长度,得到一个刚刚好的缓存区 449 | if _, err = c.readHeader(); err != nil { 450 | err = fmt.Errorf("read header err: %w", err) 451 | goto fail 452 | } 453 | if _, err = c.readPayloadAndCallback(); err != nil { 454 | err = fmt.Errorf("read header err: %w", err) 455 | goto fail 456 | } 457 | 458 | // TODO 459 | // if len((*c.rbuf)[c.rw:]) == 0 { 460 | // // 461 | // // panic(fmt.Sprintf("需要扩容:rw(%d):rr(%d):currState(%v)", c.rw, c.rr, c.curState.String())) 462 | // } 463 | continue 464 | } 465 | } 466 | 467 | for i := 0; ; i++ { 468 | success, err = c.readHeader() 469 | if err != nil { 470 | err = fmt.Errorf("read header err: %w", err) 471 | goto fail 472 | } 473 | 474 | if !success { 475 | goto success 476 | } 477 | success, err = c.readPayloadAndCallback() 478 | if err != nil { 479 | err = fmt.Errorf("read payload err: %w", err) 480 | goto fail 481 | } 482 | 483 | if !success { 484 | goto success 485 | } 486 | } 487 | 488 | success: 489 | fail: 490 | // 回收read buffer至内存池中 491 | if err != nil || c.rbuf != nil && c.rr == c.rw { 492 | c.rr, c.rw = 0, 0 493 | bytespool.PutBytes(c.rbuf) 494 | c.rbuf = nil 495 | } 496 | 497 | if err != nil { 498 | // 如果是status code类型,要回写符合rfc的close包 499 | c.writeAndMaybeOnClose(err) 500 | } 501 | return err 502 | } 503 | 504 | func (c *Conn) processHeaderPayloadCallback() (err error) { 505 | var success bool 506 | for i := 0; ; i++ { 507 | success, err = c.readHeader() 508 | if err != nil { 509 | err = fmt.Errorf("read header err: %w", err) 510 | goto fail 511 | } 512 | 513 | if !success { 514 | goto success 515 | } 516 | success, err = c.readPayloadAndCallback() 517 | if err != nil { 518 | err = fmt.Errorf("read payload err: %w", err) 519 | goto fail 520 | } 521 | 522 | if !success { 523 | goto success 524 | } 525 | } 526 | success: 527 | fail: 528 | // 回收read buffer至内存池中 529 | if err != nil || c.rbuf != nil && c.rr == c.rw { 530 | c.rr, c.rw = 0, 0 531 | bytespool.PutBytes(c.rbuf) 532 | c.rbuf = nil 533 | } 534 | return err 535 | } 536 | 537 | func (c *Conn) setDeadlineInner(t **time.Timer, tm time.Time, err error) error { 538 | if t == nil { 539 | return nil 540 | } 541 | c.mu.Lock() 542 | // c.getLogger().Error("Conn-lock", "addr", uintptr(unsafe.Pointer(c))) 543 | defer func() { 544 | // c.getLogger().Error("Conn-unlock", "addr", uintptr(unsafe.Pointer(c))) 545 | c.mu.Unlock() 546 | }() 547 | if tm.IsZero() { 548 | if *t != nil { 549 | // c.getLogger().Error("conn-reset", "addr", uintptr(unsafe.Pointer(c))) 550 | (*t).Stop() 551 | *t = nil 552 | } 553 | return nil 554 | } 555 | 556 | d := time.Until(tm) 557 | if d < 0 { 558 | return ErrInvalidDeadline 559 | } 560 | 561 | if *t == nil { 562 | *t = afterFunc(d, func() { 563 | c.closeWithLock(err) 564 | }) 565 | } else { 566 | (*t).Reset(d) 567 | } 568 | return nil 569 | } 570 | 571 | func (c *Conn) setReadDeadline(t time.Time) error { 572 | return c.setDeadlineInner(&c.rtime, t, ErrReadTimeout) 573 | } 574 | 575 | func (c *Conn) setWriteDeadline(t time.Time) error { 576 | return c.setDeadlineInner(&c.wtime, t, ErrWriteTimeout) 577 | 578 | } 579 | func closeFd(fd int) { 580 | unix.Close(int(fd)) 581 | } 582 | 583 | func (c *Conn) eventLoop() *EventLoop { 584 | return c.parent 585 | } 586 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | package greatws 15 | 16 | import ( 17 | "fmt" 18 | "log/slog" 19 | "net" 20 | "net/http" 21 | "net/http/httptest" 22 | "strings" 23 | "sync/atomic" 24 | "testing" 25 | "time" 26 | ) 27 | 28 | // 测试服务端握手失败的情况 29 | func Test_Server_HandshakeFail(t *testing.T) { 30 | // u := NewUpgrade() 31 | t.Run("local config:case:method fail", func(t *testing.T) { 32 | run := int32(0) 33 | done := make(chan bool, 1) 34 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 35 | m.Start() 36 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 38 | if err == nil { 39 | t.Error("upgrade method fail") 40 | } 41 | atomic.AddInt32(&run, int32(1)) 42 | done <- true 43 | })) 44 | 45 | defer ts.Close() 46 | 47 | url := ts.URL 48 | req, err := http.NewRequest("POST", url, nil) 49 | if err != nil { 50 | t.Error(err) 51 | } 52 | _, err = http.DefaultClient.Do(req) 53 | if err != nil { 54 | t.Error(err) 55 | return 56 | } 57 | select { 58 | case <-done: 59 | case <-time.After(100 * time.Millisecond): 60 | } 61 | if atomic.LoadInt32(&run) != 1 { 62 | t.Error("not run server:method fail") 63 | } 64 | }) 65 | 66 | t.Run("global config:case:method fail", func(t *testing.T) { 67 | run := int32(0) 68 | 69 | done := make(chan bool, 1) 70 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 71 | m.Start() 72 | upgrade := NewUpgrade(WithServerMultiEventLoop(m)) 73 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 74 | _, err := upgrade.Upgrade(w, r) 75 | if err == nil { 76 | t.Error("upgrade method fail") 77 | } 78 | atomic.AddInt32(&run, int32(1)) 79 | done <- true 80 | })) 81 | 82 | defer ts.Close() 83 | 84 | url := ts.URL 85 | req, err := http.NewRequest("POST", url, nil) 86 | if err != nil { 87 | t.Error(err) 88 | } 89 | _, err = http.DefaultClient.Do(req) 90 | if err != nil { 91 | t.Error(err) 92 | return 93 | } 94 | select { 95 | case <-done: 96 | case <-time.After(100 * time.Millisecond): 97 | } 98 | if atomic.LoadInt32(&run) != 1 { 99 | t.Error("not run server:method fail") 100 | } 101 | }) 102 | 103 | t.Run("local config:case:http proto fail", func(t *testing.T) { 104 | run := int32(0) 105 | done := make(chan bool, 1) 106 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 107 | m.Start() 108 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 109 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 110 | if err == nil { 111 | t.Error("upgrade http proto fail") 112 | } 113 | atomic.AddInt32(&run, int32(1)) 114 | done <- true 115 | })) 116 | 117 | url := strings.ReplaceAll(ts.URL, "http://", "") 118 | defer ts.Close() 119 | c, err := net.Dial("tcp", url) 120 | if err != nil { 121 | t.Error(err) 122 | } 123 | _, err = c.Write([]byte("GET / HTTP/1.0\r\nHost: localhost:8080\r\n\r\n")) 124 | if err != nil { 125 | t.Error(err) 126 | } 127 | c.Close() 128 | select { 129 | case <-done: 130 | case <-time.After(100 * time.Millisecond): 131 | } 132 | 133 | if atomic.LoadInt32(&run) != 1 { 134 | t.Error("not run server:http proto fail") 135 | } 136 | }) 137 | 138 | t.Run("global config:case:http proto fail", func(t *testing.T) { 139 | run := int32(0) 140 | done := make(chan bool, 1) 141 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 142 | m.Start() 143 | upgrade := NewUpgrade(WithServerMultiEventLoop(m)) 144 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 145 | _, err := upgrade.Upgrade(w, r) 146 | if err == nil { 147 | t.Error("upgrade http proto fail") 148 | } 149 | atomic.AddInt32(&run, int32(1)) 150 | done <- true 151 | })) 152 | 153 | url := strings.ReplaceAll(ts.URL, "http://", "") 154 | defer ts.Close() 155 | c, err := net.Dial("tcp", url) 156 | if err != nil { 157 | t.Error(err) 158 | } 159 | _, err = c.Write([]byte("GET / HTTP/1.0\r\n\r\n")) 160 | if err != nil { 161 | t.Error(err) 162 | return 163 | } 164 | // c.Write([]byte("GET / HTTP/1.0\r\nHost: localhost:8080\r\n\r\n")) 165 | c.Close() 166 | 167 | select { 168 | case <-done: 169 | case <-time.After(100 * time.Millisecond): 170 | } 171 | if atomic.LoadInt32(&run) != 1 { 172 | t.Error("not run server:http proto fail") 173 | } 174 | }) 175 | 176 | t.Run("local config:case:host empty", func(t *testing.T) { 177 | run := int32(0) 178 | done := make(chan bool, 1) 179 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 180 | m.Start() 181 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 182 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 183 | if err == nil { 184 | t.Error("upgrade host fail") 185 | } 186 | atomic.AddInt32(&run, int32(1)) 187 | done <- true 188 | })) 189 | 190 | defer ts.Close() 191 | 192 | url := strings.ReplaceAll(ts.URL, "http://", "") 193 | c, err := net.Dial("tcp", url) 194 | if err != nil { 195 | t.Error(err) 196 | } 197 | _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: \r\n\r\n")) 198 | if err != nil { 199 | t.Error(err) 200 | return 201 | } 202 | defer c.Close() 203 | select { 204 | case <-done: 205 | case <-time.After(100 * time.Millisecond): 206 | } 207 | if atomic.LoadInt32(&run) != 1 { 208 | t.Error("not run server:host empty") 209 | } 210 | }) 211 | 212 | t.Run("global config:case:upgrade fail", func(t *testing.T) { 213 | run := int32(0) 214 | done := make(chan bool, 1) 215 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 216 | m.Start() 217 | upgrade := NewUpgrade(WithServerMultiEventLoop(m)) 218 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 219 | _, err := upgrade.Upgrade(w, r) 220 | if err == nil { 221 | t.Error("upgrade : upgrade field fail") 222 | } 223 | atomic.AddInt32(&run, int32(1)) 224 | done <- true 225 | })) 226 | 227 | url := strings.ReplaceAll(ts.URL, "http://", "") 228 | defer ts.Close() 229 | c, err := net.Dial("tcp", url) 230 | if err != nil { 231 | t.Error(err) 232 | } 233 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: xx\r\n\r\n", url)) 234 | _, err = c.Write(wbuf) 235 | if err != nil { 236 | t.Error(err) 237 | return 238 | } 239 | c.Close() 240 | 241 | select { 242 | case <-done: 243 | case <-time.After(100 * time.Millisecond): 244 | } 245 | if atomic.LoadInt32(&run) != 1 { 246 | t.Error("not run server:upgrade field fail") 247 | } 248 | }) 249 | 250 | t.Run("local config:case:upgrade fail", func(t *testing.T) { 251 | run := int32(0) 252 | done := make(chan bool, 1) 253 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 254 | m.Start() 255 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 256 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 257 | if err == nil { 258 | t.Error("upgrade : upgrade field fail") 259 | } 260 | atomic.AddInt32(&run, int32(1)) 261 | done <- true 262 | })) 263 | 264 | url := strings.ReplaceAll(ts.URL, "http://", "") 265 | defer ts.Close() 266 | c, err := net.Dial("tcp", url) 267 | if err != nil { 268 | t.Error(err) 269 | } 270 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: xx\r\n\r\n", url)) 271 | _, err = c.Write(wbuf) 272 | if err != nil { 273 | t.Error(err) 274 | return 275 | } 276 | c.Close() 277 | 278 | select { 279 | case <-done: 280 | case <-time.After(100 * time.Millisecond): 281 | } 282 | if atomic.LoadInt32(&run) != 1 { 283 | t.Error("not run server:upgrade field fail") 284 | } 285 | }) 286 | 287 | t.Run("global config:case:Connection fail", func(t *testing.T) { 288 | run := int32(0) 289 | 290 | done := make(chan bool, 1) 291 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 292 | m.Start() 293 | upgrade := NewUpgrade(WithServerMultiEventLoop(m)) 294 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 295 | _, err := upgrade.Upgrade(w, r) 296 | if err == nil { 297 | t.Error("upgrade : Connection field fail") 298 | } 299 | atomic.AddInt32(&run, int32(1)) 300 | done <- true 301 | })) 302 | 303 | url := strings.ReplaceAll(ts.URL, "http://", "") 304 | defer ts.Close() 305 | c, err := net.Dial("tcp", url) 306 | if err != nil { 307 | t.Error(err) 308 | } 309 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: xx\r\n\r\n", url)) 310 | _, err = c.Write(wbuf) 311 | if err != nil { 312 | t.Error(err) 313 | return 314 | } 315 | c.Close() 316 | 317 | select { 318 | case <-done: 319 | case <-time.After(100 * time.Millisecond): 320 | } 321 | if atomic.LoadInt32(&run) != 1 { 322 | t.Error("not run server:Connection field fail") 323 | } 324 | }) 325 | 326 | t.Run("local config:case:Connection fail", func(t *testing.T) { 327 | run := int32(0) 328 | done := make(chan bool, 1) 329 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 330 | m.Start() 331 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 332 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 333 | if err == nil { 334 | t.Error("upgrade : Connection field fail") 335 | } 336 | atomic.AddInt32(&run, int32(1)) 337 | done <- true 338 | })) 339 | 340 | url := strings.ReplaceAll(ts.URL, "http://", "") 341 | defer ts.Close() 342 | c, err := net.Dial("tcp", url) 343 | if err != nil { 344 | t.Error(err) 345 | } 346 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: xx\r\n\r\n", url)) 347 | _, err = c.Write(wbuf) 348 | if err != nil { 349 | t.Error(err) 350 | return 351 | } 352 | c.Close() 353 | 354 | select { 355 | case <-done: 356 | case <-time.After(100 * time.Millisecond): 357 | } 358 | if atomic.LoadInt32(&run) != 1 { 359 | t.Error("not run server:Connection field fail") 360 | } 361 | }) 362 | 363 | t.Run("global config:case: Sec-WebSocket-Key fail", func(t *testing.T) { 364 | run := int32(0) 365 | 366 | done := make(chan bool, 1) 367 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 368 | m.Start() 369 | upgrade := NewUpgrade(WithServerMultiEventLoop(m)) 370 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 371 | _, err := upgrade.Upgrade(w, r) 372 | if err == nil { 373 | t.Error("upgrade : Connection field fail") 374 | } 375 | atomic.AddInt32(&run, int32(1)) 376 | done <- true 377 | })) 378 | 379 | url := strings.ReplaceAll(ts.URL, "http://", "") 380 | defer ts.Close() 381 | c, err := net.Dial("tcp", url) 382 | if err != nil { 383 | t.Error(err) 384 | } 385 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", url)) 386 | _, err = c.Write(wbuf) 387 | if err != nil { 388 | t.Error(err) 389 | return 390 | } 391 | c.Close() 392 | 393 | select { 394 | case <-done: 395 | case <-time.After(100 * time.Millisecond): 396 | } 397 | if atomic.LoadInt32(&run) != 1 { 398 | t.Error("not run server:Connection field fail") 399 | } 400 | }) 401 | 402 | t.Run("local config:case: Sec-WebSocket-Key fail", func(t *testing.T) { 403 | run := int32(0) 404 | done := make(chan bool, 1) 405 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 406 | m.Start() 407 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 408 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 409 | if err == nil { 410 | t.Error("upgrade : Connection field fail") 411 | } 412 | atomic.AddInt32(&run, int32(1)) 413 | done <- true 414 | })) 415 | 416 | url := strings.ReplaceAll(ts.URL, "http://", "") 417 | defer ts.Close() 418 | c, err := net.Dial("tcp", url) 419 | if err != nil { 420 | t.Error(err) 421 | } 422 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", url)) 423 | _, err = c.Write(wbuf) 424 | if err != nil { 425 | t.Error(err) 426 | } 427 | c.Close() 428 | 429 | select { 430 | case <-done: 431 | case <-time.After(100 * time.Millisecond): 432 | } 433 | if atomic.LoadInt32(&run) != 1 { 434 | t.Error("not run server:Connection field fail") 435 | } 436 | }) 437 | 438 | t.Run("global config:case: Sec-WebSocket-Version fail", func(t *testing.T) { 439 | run := int32(0) 440 | 441 | done := make(chan bool, 1) 442 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 443 | m.Start() 444 | upgrade := NewUpgrade(WithServerMultiEventLoop(m)) 445 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 446 | _, err := upgrade.Upgrade(w, r) 447 | if err == nil { 448 | t.Error("upgrade : Connection field fail") 449 | } 450 | atomic.AddInt32(&run, int32(1)) 451 | done <- true 452 | })) 453 | 454 | url := strings.ReplaceAll(ts.URL, "http://", "") 455 | defer ts.Close() 456 | c, err := net.Dial("tcp", url) 457 | if err != nil { 458 | t.Error(err) 459 | return 460 | } 461 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: key\r\n\r\n", url)) 462 | _, err = c.Write(wbuf) 463 | if err != nil { 464 | t.Error(err) 465 | return 466 | } 467 | c.Close() 468 | 469 | select { 470 | case <-done: 471 | case <-time.After(100 * time.Millisecond): 472 | } 473 | if atomic.LoadInt32(&run) != 1 { 474 | t.Error("not run server:Connection field fail") 475 | } 476 | }) 477 | 478 | t.Run("local config:case: Sec-WebSocket-Version fail", func(t *testing.T) { 479 | run := int32(0) 480 | done := make(chan bool, 1) 481 | m := NewMultiEventLoopMust(WithEventLoops(0), WithMaxEventNum(1000), WithLogLevel(slog.LevelError)) // epoll, kqueue 482 | m.Start() 483 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 484 | _, err := Upgrade(w, r, WithServerMultiEventLoop(m)) 485 | if err == nil { 486 | t.Error("upgrade : Connection field fail") 487 | } 488 | atomic.AddInt32(&run, int32(1)) 489 | done <- true 490 | })) 491 | 492 | url := strings.ReplaceAll(ts.URL, "http://", "") 493 | defer ts.Close() 494 | c, err := net.Dial("tcp", url) 495 | if err != nil { 496 | t.Error(err) 497 | } 498 | wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: key\r\n\r\n", url)) 499 | _, err = c.Write(wbuf) 500 | if err != nil { 501 | t.Error(err) 502 | } 503 | c.Close() 504 | 505 | select { 506 | case <-done: 507 | case <-time.After(100 * time.Millisecond): 508 | } 509 | if atomic.LoadInt32(&run) != 1 { 510 | t.Error("not run server:Connection field fail") 511 | } 512 | }) 513 | } 514 | -------------------------------------------------------------------------------- /conn_core.go: -------------------------------------------------------------------------------- 1 | // Copyright 2023-2024 antlabs. All rights reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package greatws 16 | 17 | import ( 18 | "encoding/binary" 19 | "errors" 20 | "fmt" 21 | "log/slog" 22 | "math/rand" 23 | "sync/atomic" 24 | "time" 25 | 26 | "github.com/antlabs/wsutil/bytespool" 27 | "github.com/antlabs/wsutil/enum" 28 | "github.com/antlabs/wsutil/errs" 29 | "github.com/antlabs/wsutil/fixedwriter" 30 | "github.com/antlabs/wsutil/frame" 31 | "github.com/antlabs/wsutil/mask" 32 | "github.com/antlabs/wsutil/opcode" 33 | ) 34 | 35 | const ( 36 | maxControlFrameSize = 125 37 | ) 38 | 39 | type frameState int8 40 | 41 | func (f frameState) String() string { 42 | switch f { 43 | case frameStateHeaderStart: 44 | return "frameStateHeaderStart" 45 | case frameStateHeaderPayloadAndMask: 46 | return "frameStateHeaderPayloadAndMask" 47 | case frameStatePayload: 48 | return "frameStatePayload" 49 | } 50 | return "" 51 | } 52 | 53 | const ( 54 | frameStateHeaderStart frameState = iota 55 | frameStateHeaderPayloadAndMask 56 | frameStatePayload 57 | ) 58 | 59 | // 内部的conn, 只包含fd, 读缓冲区, 写缓冲区, 状态机, 分段帧缓冲区 60 | // 这一层本来是和epoll/kqueue 等系统调用打交道的 61 | type conn struct { 62 | fd int64 // 文件描述符fd 63 | rbuf *[]byte // 读缓冲区 64 | rr int // rbuf读索引,rfc标准里面有超过4个字节的大包,所以索引只能用int类型 65 | rw int // rbuf写索引,rfc标准里面有超过4个字节的大包,所以索引只能用int类型 66 | wbufList []*[]byte // 写缓冲区, 当直接Write失败时,会将数据写入缓冲区 67 | lenAndMaskSize int // payload长度和掩码的长度 68 | rh frame.FrameHeader // frame头部 69 | fragmentFramePayload *[]byte // 存放分片帧的缓冲区, TODO: 这个可以优化下 把Test_DefaultCallback和 fragmentFrameHeader 放到一个结构体里面 70 | fragmentFrameHeader *frame.FrameHeader // 存放分段帧的头部 71 | lastPayloadLen int32 // 上一次读取的payload长度, TODO启用 72 | curState frameState // 保存当前状态机的状态 73 | client bool // 客户端为true,服务端为false 74 | } 75 | 76 | func (c *Conn) getLogger() *slog.Logger { 77 | return c.multiEventLoop.Logger 78 | } 79 | 80 | func (c *Conn) addTask(f func() bool) { 81 | if c.isClosed() { 82 | return 83 | } 84 | 85 | err := c.task.AddTask(&c.mu, f) 86 | if err != nil { 87 | c.getLogger().Error("addTask", "err", err.Error()) 88 | } 89 | 90 | } 91 | 92 | func (c *Conn) getFd() int { 93 | return int(atomic.LoadInt64(&c.fd)) 94 | } 95 | 96 | // 基于状态机解析frame 97 | func (c *Conn) readHeader() (sucess bool, err error) { 98 | state := c.curState 99 | // 开始解析frame 100 | if state == frameStateHeaderStart { 101 | // 小于最小的frame头部长度, 有空间就挪一挪 102 | if len(*c.rbuf)-c.rr < enum.MaxFrameHeaderSize { 103 | c.leftMove() 104 | } 105 | // fin rsv1 rsv2 rsv3 opcode 106 | if c.rw-c.rr < 2 { 107 | return false, nil 108 | } 109 | c.rh.Head = (*c.rbuf)[c.rr] 110 | 111 | // h.Fin = head[0]&(1<<7) > 0 112 | // h.Rsv1 = head[0]&(1<<6) > 0 113 | // h.Rsv2 = head[0]&(1<<5) > 0 114 | // h.Rsv3 = head[0]&(1<<4) > 0 115 | c.rh.Opcode = opcode.Opcode(c.rh.Head & 0xF) 116 | 117 | maskAndPayloadLen := (*c.rbuf)[c.rr+1] 118 | have := 0 119 | c.rh.Mask = maskAndPayloadLen&(1<<7) > 0 120 | 121 | if c.rh.Mask { 122 | have += 4 123 | } 124 | 125 | c.rh.PayloadLen = int64(maskAndPayloadLen & 0x7F) 126 | switch { 127 | // 长度 128 | case c.rh.PayloadLen >= 0 && c.rh.PayloadLen <= 125: 129 | case c.rh.PayloadLen == 126: 130 | // 2字节长度 131 | have += 2 132 | // size += 2 133 | case c.rh.PayloadLen == 127: 134 | // 8字节长度 135 | have += 8 136 | // size += 8 137 | default: 138 | // 预期之外的, 直接报错 139 | return sucess, errs.ErrFramePayloadLength 140 | } 141 | c.curState, state = frameStateHeaderPayloadAndMask, frameStateHeaderPayloadAndMask 142 | c.lenAndMaskSize = have 143 | c.rr += 2 144 | 145 | } 146 | 147 | if state == frameStateHeaderPayloadAndMask { 148 | if c.rw-c.rr < c.lenAndMaskSize { 149 | return 150 | } 151 | have := c.lenAndMaskSize 152 | head := (*c.rbuf)[c.rr : c.rr+have] 153 | switch c.rh.PayloadLen { 154 | case 126: 155 | c.rh.PayloadLen = int64(binary.BigEndian.Uint16(head[:2])) 156 | head = head[2:] 157 | case 127: 158 | c.rh.PayloadLen = int64(binary.BigEndian.Uint64(head[:8])) 159 | head = head[8:] 160 | } 161 | 162 | if c.readMaxMessage > 0 && c.rh.PayloadLen > c.readMaxMessage { 163 | return false, TooBigMessage 164 | } 165 | 166 | if c.rh.Mask { 167 | c.rh.MaskKey = binary.LittleEndian.Uint32(head[:4]) 168 | } 169 | c.curState = frameStatePayload 170 | c.rr += c.lenAndMaskSize 171 | return true, nil 172 | } 173 | 174 | return state == frameStatePayload, nil 175 | } 176 | 177 | func (c *Conn) failRsv1(op opcode.Opcode) bool { 178 | // 解压缩没有开启 179 | if !c.pd.Decompression { 180 | return true 181 | } 182 | 183 | // 不是text和binary 184 | if op != opcode.Text && op != opcode.Binary { 185 | return true 186 | } 187 | 188 | return false 189 | } 190 | 191 | func (c *Conn) leftMove() { 192 | if c.rr == 0 { 193 | return 194 | } 195 | // b.CountMove++ 196 | // b.MoveBytes += b.W - b.R 197 | n := copy(*c.rbuf, (*c.rbuf)[c.rr:c.rw]) 198 | c.rw -= c.rr 199 | c.rr = 0 200 | c.multiEventLoop.addMoveBytes(uint64(n)) 201 | } 202 | 203 | func (c *Conn) writeCap() int { 204 | return len((*c.rbuf)[c.rw:]) 205 | } 206 | 207 | // 需要考虑几种情况 208 | // 返回完整Payload逻辑 209 | // 1. 当前的rbuf长度不够,需要重新分配 210 | // 2. 当前的rbuf长度够,但是数据没有读完整 211 | // 返回分片Paylod逻辑 212 | // TODO 213 | func (c *Conn) readPayload() (f frame.Frame2, success bool, err error) { 214 | // 如果缓存区不够, 重新分配 215 | multipletimes := c.windowsMultipleTimesPayloadSize 216 | // 已读取未处理的数据 217 | readUnhandle := int64(c.rw - c.rr) 218 | // 情况 1,需要读的长度 > 剩余可用空间(未写的+已经被读取走的) 219 | if c.rh.PayloadLen-readUnhandle > int64(len((*c.rbuf)[c.rw:])+c.rr) { 220 | // 1.取得旧的buf 221 | oldBuf := c.rbuf 222 | // 2.获取新的buf 223 | newBuf := bytespool.GetBytes(int(float32(c.rh.PayloadLen)*multipletimes) + enum.MaxFrameHeaderSize) 224 | // 把旧的数据拷贝到新的buf里 225 | copy(*newBuf, (*oldBuf)[c.rr:c.rw]) 226 | c.rw -= c.rr 227 | c.rr = 0 228 | 229 | // 3.重置缓存区 230 | c.rbuf = newBuf 231 | // 4.将旧的buf放回池子里 232 | bytespool.PutBytes(oldBuf) 233 | c.multiEventLoop.addRealloc() 234 | 235 | // 情况 2。 空间是够的,需要挪一挪, 把已经读过的覆盖掉 236 | } else if c.rh.PayloadLen-readUnhandle > int64(c.writeCap()) { 237 | c.leftMove() 238 | } 239 | 240 | // 前面的reset已经保证了,buffer的大小是够的 241 | needRead := c.rh.PayloadLen - readUnhandle 242 | 243 | // fmt.Printf("needRead:%d:rr(%d):rw(%d):PayloadLen(%d), %v\n", needRead, c.rr, c.rw, c.rh.PayloadLen, c.rbuf) 244 | if needRead > 0 { 245 | return 246 | } 247 | c.lastPayloadLen = int32(c.rh.PayloadLen) 248 | // 普通frame 249 | newBuf := bytespool.GetBytes(int(c.rh.PayloadLen) + enum.MaxFrameHeaderSize) 250 | copy(*newBuf, (*c.rbuf)[c.rr:c.rr+int(c.rh.PayloadLen)]) 251 | newBuf2 := (*newBuf)[:c.rh.PayloadLen] //修改下len 252 | f.Payload = &newBuf2 253 | 254 | f.FrameHeader = c.rh 255 | c.rr += int(c.rh.PayloadLen) 256 | 257 | if len(*c.rbuf)-c.rw < enum.MaxFrameHeaderSize { 258 | c.leftMove() 259 | } 260 | 261 | return f, true, nil 262 | } 263 | 264 | func (c *Conn) processCallback(f frame.Frame2) (err error) { 265 | op := f.Opcode 266 | if c.fragmentFrameHeader != nil { 267 | op = c.fragmentFrameHeader.Opcode 268 | } 269 | 270 | rsv1 := f.GetRsv1() 271 | // 检查Rsv1 rsv2 Rfd, errsv3 272 | if rsv1 && c.failRsv1(op) || f.GetRsv2() || f.GetRsv3() { 273 | err = fmt.Errorf("%w:Rsv1(%t) Rsv2(%t) rsv2(%t) compression:%t", ErrRsv123, rsv1, f.GetRsv2(), f.GetRsv3(), c.pd.Compression) 274 | return c.writeErrAndOnClose(ProtocolError, err) 275 | } 276 | 277 | maskKey := c.rh.MaskKey 278 | needMask := c.rh.Mask 279 | 280 | fin := f.GetFin() 281 | // 分段的frame 282 | if c.fragmentFrameHeader != nil && !f.Opcode.IsControl() { 283 | if f.Opcode == 0 { 284 | // TODO 优化, 需要放到单独的业务go程, 目前为了保证时序性,先放到io go程里面 285 | if needMask { 286 | mask.Mask(*f.Payload, maskKey) 287 | } 288 | 289 | if c.fragmentFramePayload == nil { 290 | c.fragmentFramePayload = f.Payload 291 | } else { 292 | *c.fragmentFramePayload = append(*c.fragmentFramePayload, *f.Payload...) 293 | bytespool.PutBytes(f.Payload) 294 | } 295 | 296 | f.Payload = nil 297 | 298 | // 分段的在这返回 299 | if fin { 300 | // 解压缩 301 | fragmentFrameHeader := c.fragmentFrameHeader 302 | fragmentFramePayload := c.fragmentFramePayload 303 | decompression := c.pd.Decompression 304 | c.fragmentFrameHeader = nil 305 | c.fragmentFramePayload = nil 306 | 307 | // 进入业务协程执行 308 | c.addTask(func() (exit bool) { 309 | if fragmentFrameHeader.GetRsv1() && decompression { 310 | tempBuf, err := c.decode(fragmentFramePayload) 311 | if err != nil { 312 | // return err 313 | c.closeWithLock(err) 314 | return false 315 | } 316 | 317 | // 回收这块内存到pool里面 318 | bytespool.PutBytes(fragmentFramePayload) 319 | fragmentFramePayload = tempBuf 320 | } 321 | // 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是c.utf8Check修改成流式解析 322 | // TODO c.utf8Check 修改成流式解析 323 | if fragmentFrameHeader.Opcode == opcode.Text && !c.utf8Check(*fragmentFramePayload) { 324 | c.onCloseOnce.Do(&c.mu2, func() { 325 | c.Callback.OnClose(c, ErrTextNotUTF8) 326 | }) 327 | // return ErrTextNotUTF8 328 | c.closeWithLock(nil) 329 | return false 330 | } 331 | 332 | c.Callback.OnMessage(c, fragmentFrameHeader.Opcode, *fragmentFramePayload) 333 | bytespool.PutBytes(fragmentFramePayload) 334 | return false 335 | }) 336 | } 337 | return nil 338 | } 339 | 340 | c.writeErrAndOnClose(ProtocolError, ErrFrameOpcode) 341 | return ErrFrameOpcode 342 | } 343 | 344 | if f.Opcode == opcode.Text || f.Opcode == opcode.Binary { 345 | if !fin { 346 | prevFrame := f.FrameHeader 347 | // 第一次分段 348 | 349 | // TODO 放到单独的业务go程, 目前为了保证时序性,先放到io go程里面 350 | if needMask { 351 | mask.Mask(*f.Payload, maskKey) 352 | } 353 | if c.fragmentFramePayload == nil { 354 | // greatws和quickws,这时的f.Payload是单独分配出来的,所以转移下变量的所有权就行 355 | c.fragmentFramePayload = f.Payload 356 | f.Payload = nil 357 | } 358 | 359 | // 让fragmentFrame的Payload指向readBuf, readBuf 原引用直接丢弃 360 | c.fragmentFrameHeader = &prevFrame 361 | return 362 | } 363 | 364 | // var payloadPtr atomic.Pointer[[]byte] 365 | decompression := c.pd.Decompression 366 | payload := f.Payload 367 | f.Payload = nil 368 | // payloadPtr.Store(f.Payload) 369 | 370 | // text或者binary进入业务协程执行 371 | c.addTask(func() bool { 372 | return c.processCallbackData(f, payload, rsv1, decompression, needMask, maskKey) 373 | }) 374 | 375 | return 376 | } 377 | 378 | if f.Opcode == Close || f.Opcode == Ping || f.Opcode == Pong { 379 | 380 | // 消息体的内容比较小,直接在io go程里面处理 381 | if needMask { 382 | mask.Mask(*f.Payload, maskKey) 383 | } 384 | // 对方发的控制消息太大 385 | if f.PayloadLen > maxControlFrameSize { 386 | c.writeErrAndOnClose(ProtocolError, ErrMaxControlFrameSize) 387 | return ErrMaxControlFrameSize 388 | } 389 | // Close, Ping, Pong 不能分片 390 | if !fin { 391 | c.writeErrAndOnClose(ProtocolError, ErrNOTBeFragmented) 392 | return ErrNOTBeFragmented 393 | } 394 | 395 | if f.Opcode == Close { 396 | if len(*f.Payload) == 0 { 397 | c.writeErrAndOnClose(NormalClosure, &CloseErrMsg{Code: NormalClosure}) 398 | return nil 399 | } 400 | 401 | if len(*f.Payload) < 2 { 402 | return c.writeErrAndOnClose(ProtocolError, ErrClosePayloadTooSmall) 403 | } 404 | 405 | if !c.utf8Check((*f.Payload)[2:]) { 406 | return c.writeErrAndOnClose(ProtocolError, ErrTextNotUTF8) 407 | } 408 | 409 | code := binary.BigEndian.Uint16(*f.Payload) 410 | if !validCode(code) { 411 | return c.writeErrAndOnClose(ProtocolError, ErrCloseValue) 412 | } 413 | 414 | // 回敬一个close包 415 | if err := c.WriteTimeout(Close, *f.Payload, 2*time.Second); err != nil { 416 | return err 417 | } 418 | 419 | err = bytesToCloseErrMsg(*f.Payload) 420 | c.onCloseOnce.Do(&c.mu2, func() { 421 | c.Callback.OnClose(c, err) 422 | }) 423 | return err 424 | } 425 | 426 | if f.Opcode == Ping { 427 | // 回一个pong包 428 | if c.replyPing { 429 | if err := c.WriteTimeout(Pong, *f.Payload, 2*time.Second); err != nil { 430 | c.onCloseOnce.Do(&c.mu2, func() { 431 | c.Callback.OnClose(c, err) 432 | }) 433 | return err 434 | } 435 | // 进入业务协程执行 436 | payload := f.Payload 437 | // here 438 | c.addTask(func() bool { 439 | return c.processPing(f, payload) 440 | }) 441 | return 442 | } 443 | } 444 | 445 | if f.Opcode == Pong && c.ignorePong { 446 | return 447 | } 448 | 449 | // 进入业务协程执行 450 | c.addTask(func() bool { 451 | c.Callback.OnMessage(c, f.Opcode, nil) 452 | return false 453 | }) 454 | return 455 | } 456 | // 检查Opcode 457 | c.writeErrAndOnClose(ProtocolError, ErrOpcode) 458 | return ErrOpcode 459 | } 460 | 461 | func (c *Conn) processPing(f frame.Frame2, payload *[]byte) bool { 462 | c.Callback.OnMessage(c, f.Opcode, *payload) 463 | bytespool.PutBytes(payload) 464 | return false 465 | } 466 | 467 | // 如果是text或者binary的消息, 在这里调用OnMessage函数 468 | func (c *Conn) processCallbackData(f frame.Frame2, payload *[]byte, rsv1 bool, decompression bool, needMask bool, maskKey uint32) (ok bool) { 469 | var err error 470 | if needMask { 471 | mask.Mask(*payload, maskKey) 472 | } 473 | decodePayload := payload 474 | if rsv1 && decompression { 475 | // 不分段的解压缩 476 | decodePayload, err = c.decode(payload) 477 | if err != nil { 478 | c.closeWithLock(err) 479 | bytespool.PutBytes(payload) 480 | return false 481 | } 482 | defer bytespool.PutBytes(decodePayload) 483 | } 484 | 485 | if f.Opcode == opcode.Text { 486 | if !c.utf8Check(*decodePayload) { 487 | c.closeWithLock(nil) 488 | c.onCloseOnce.Do(&c.mu2, func() { 489 | c.Callback.OnClose(c, ErrTextNotUTF8) 490 | }) 491 | return false 492 | } 493 | } 494 | 495 | c.Callback.OnMessage(c, f.Opcode, *decodePayload) 496 | bytespool.PutBytes(payload) 497 | return false 498 | } 499 | 500 | func (c *Conn) writeAndMaybeOnClose(err error) error { 501 | var sc *StatusCode 502 | defer func() { 503 | c.onCloseOnce.Do(&c.mu2, func() { 504 | c.Callback.OnClose(c, err) 505 | }) 506 | }() 507 | 508 | if errors.As(err, &sc) { 509 | if err := c.WriteTimeout(opcode.Close, sc.toBytes(), 2*time.Second); err != nil { 510 | return err 511 | } 512 | } 513 | return nil 514 | } 515 | 516 | func (c *Conn) writeErrAndOnClose(code StatusCode, userErr error) error { 517 | defer func() { 518 | c.onCloseOnce.Do(&c.mu2, func() { 519 | c.Callback.OnClose(c, userErr) 520 | }) 521 | }() 522 | if err := c.WriteTimeout(opcode.Close, code.toBytes(), 2*time.Second); err != nil { 523 | return err 524 | } 525 | 526 | return userErr 527 | } 528 | 529 | func (c *Conn) readPayloadAndCallback() (sucess bool, err error) { 530 | if c.curState == frameStatePayload { 531 | f, success, err := c.readPayload() 532 | if err != nil { 533 | c.getLogger().Error("readPayloadAndCallback.read payload err", "err", err.Error()) 534 | return sucess, err 535 | } 536 | 537 | // fmt.Printf("read payload, success:%t, %v\n", success, f.Payload) 538 | if success { 539 | if err := c.processCallback(f); err != nil { 540 | c.closeWithLock(err) 541 | return false, err 542 | } 543 | c.curState = frameStateHeaderStart 544 | return true, err 545 | } 546 | } 547 | return false, nil 548 | } 549 | 550 | func (c *Conn) isClosed() bool { 551 | return atomic.LoadInt32(&c.closed) == 1 552 | } 553 | 554 | func (c *Conn) WriteMessage(op Opcode, writeBuf []byte) (err error) { 555 | if c.isClosed() { 556 | return ErrClosed 557 | } 558 | 559 | if op == opcode.Text { 560 | if !c.utf8Check(writeBuf) { 561 | return ErrTextNotUTF8 562 | } 563 | } 564 | 565 | rsv1 := c.pd.Compression && (op == opcode.Text || op == opcode.Binary) 566 | if rsv1 { 567 | writeBufPtr, err := c.encoode(&writeBuf) 568 | if err != nil { 569 | return err 570 | } 571 | 572 | defer bytespool.PutBytes(writeBufPtr) 573 | writeBuf = *writeBufPtr 574 | } 575 | 576 | maskValue := uint32(0) 577 | if c.client { 578 | maskValue = rand.Uint32() 579 | } 580 | 581 | var fw fixedwriter.FixedWriter 582 | 583 | c.mu.Lock() 584 | err = frame.WriteFrame(&fw, connToNewConn(c), writeBuf, true, rsv1, c.client, op, maskValue) 585 | c.mu.Unlock() 586 | 587 | return err 588 | } 589 | 590 | // 写分段数据, 目前主要是单元测试使用 591 | func (c *Conn) writeFragment(op Opcode, writeBuf []byte, maxFragment int /*单个段最大size*/) (err error) { 592 | if len(writeBuf) < maxFragment { 593 | return c.WriteMessage(op, writeBuf) 594 | } 595 | 596 | if op == opcode.Text { 597 | if !c.utf8Check(writeBuf) { 598 | return ErrTextNotUTF8 599 | } 600 | } 601 | 602 | rsv1 := c.pd.Compression && (op == opcode.Text || op == opcode.Binary) 603 | if rsv1 { 604 | writeBufPtr, err := c.encoode(&writeBuf) 605 | if err != nil { 606 | return err 607 | } 608 | defer bytespool.PutBytes(writeBufPtr) 609 | writeBuf = *writeBufPtr 610 | } 611 | 612 | // f.Opcode = op 613 | // f.PayloadLen = int64(len(writeBuf)) 614 | maskValue := uint32(0) 615 | if c.client { 616 | maskValue = rand.Uint32() 617 | } 618 | 619 | var fw fixedwriter.FixedWriter 620 | for len(writeBuf) > 0 { 621 | if len(writeBuf) > maxFragment { 622 | if err := frame.WriteFrame(&fw, connToNewConn(c), writeBuf[:maxFragment], false, rsv1, c.client, op, maskValue); err != nil { 623 | return err 624 | } 625 | writeBuf = writeBuf[maxFragment:] 626 | op = Continuation 627 | continue 628 | } 629 | return frame.WriteFrame(&fw, connToNewConn(c), writeBuf, true, rsv1, c.client, op, maskValue) 630 | } 631 | return nil 632 | } 633 | 634 | // TODO 635 | func (c *Conn) WriteTimeout(op Opcode, data []byte, t time.Duration) (err error) { 636 | if err = c.setWriteDeadline(time.Now().Add(t)); err != nil { 637 | return 638 | } 639 | 640 | defer func() { _ = c.setWriteDeadline(time.Time{}) }() 641 | return c.WriteMessage(op, data) 642 | } 643 | 644 | func (c *Conn) WriteControl(op Opcode, data []byte) (err error) { 645 | if len(data) > maxControlFrameSize { 646 | return ErrMaxControlFrameSize 647 | } 648 | return c.WriteMessage(op, data) 649 | } 650 | 651 | func (c *Conn) WriteCloseTimeout(sc StatusCode, t time.Duration) (err error) { 652 | buf := sc.toBytes() 653 | return c.WriteTimeout(opcode.Close, buf, t) 654 | } 655 | 656 | // data 不能超过125字节 657 | func (c *Conn) WritePing(data []byte) (err error) { 658 | return c.WriteControl(Ping, data[:]) 659 | } 660 | 661 | // data 不能超过125字节 662 | func (c *Conn) WritePong(data []byte) (err error) { 663 | return c.WriteControl(Pong, data[:]) 664 | } 665 | 666 | func (c *Conn) Close() error { 667 | if c == nil { 668 | return nil 669 | } 670 | 671 | c.closeWithLock(nil) 672 | return nil 673 | } 674 | --------------------------------------------------------------------------------