├── .gitignore ├── .gitlab-ci.yml ├── LICENSE ├── README.md ├── build.gradle ├── build_android.sh ├── build_desktop.sh ├── build_ios.sh ├── clangwrap.sh ├── cli.go ├── cli ├── client_test.go ├── common.go ├── echo_server.go ├── httpClient.go ├── logger.go ├── stunnelbidirection.go └── websocketbidirConnection.go ├── clib.patch ├── go.mod ├── go.sum ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── jitpack.yml ├── target.sh └── websocket ├── .circleci └── config.yml ├── .github └── release-drafter.yml ├── .gitignore ├── AUTHORS ├── LICENSE ├── README.md ├── client.go ├── client_server_test.go ├── client_test.go ├── compression.go ├── compression_test.go ├── conn.go ├── conn_broadcast_test.go ├── conn_test.go ├── doc.go ├── example_test.go ├── examples ├── autobahn │ ├── README.md │ ├── config │ │ └── fuzzingclient.json │ └── server.go ├── chat │ ├── README.md │ ├── client.go │ ├── home.html │ ├── hub.go │ └── main.go ├── command │ ├── README.md │ ├── home.html │ └── main.go ├── echo │ ├── README.md │ ├── client.go │ └── server.go └── filewatch │ ├── README.md │ └── main.go ├── go.mod ├── go.sum ├── join.go ├── join_test.go ├── json.go ├── json_test.go ├── mask.go ├── mask_safe.go ├── mask_test.go ├── prepared.go ├── prepared_test.go ├── proxy.go ├── server.go ├── server_test.go ├── tls_handshake.go ├── tls_handshake_116.go ├── util.go ├── util_test.go └── x_net_proxy.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | 3 | # Output of the go coverage tool, specifically when used with LiteIDE 4 | *.out 5 | /.idea/ -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: $CI_REGISTRY_IMAGE 2 | 3 | stages: 4 | - test 5 | - build 6 | 7 | format: 8 | stage: test 9 | before_script: 10 | - apk -U upgrade 11 | - apk --no-cache add curl 12 | - apk add go 13 | - export PATH=$PATH:~/go/bin 14 | script: 15 | - go fmt $(go list ./... | grep -v /vendor/) 16 | - go vet $(go list ./... | grep -v /vendor/) 17 | - go test -race $(go list ./... | grep -v /vendor/) 18 | 19 | BuildDesktop: 20 | stage: build 21 | before_script: 22 | - apk -U upgrade 23 | - apk --no-cache add curl 24 | - apk add go 25 | - export PATH=$PATH:~/go/bin 26 | script: 27 | - chmod +x build_desktop.sh 28 | - ./build_desktop.sh 29 | artifacts: 30 | paths: 31 | - ./build 32 | 33 | BuildAndroidLibrary: 34 | stage: build 35 | before_script: 36 | - apk -U upgrade 37 | - apk --no-cache add curl 38 | - apk add go 39 | - export PATH=$PATH:~/go/bin 40 | - extras ndk -n 21.3.6528147 41 | script: 42 | - ./build_android.sh 43 | artifacts: 44 | paths: 45 | - ./build 46 | 47 | BuildiOSFramework: 48 | stage: build 49 | tags: [ macos11qt6 ] 50 | script: 51 | - chmod +x build_ios.sh 52 | - ./build_ios.sh 53 | artifacts: 54 | paths: 55 | - ./build -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Windscribe Tunnel Proxy for client apps. 2 | This program forwards OpenVPN tcp traffic to WSTunnel or Stunnel server. 3 | 4 | ## Build 5 | 1. To build android library Run `build_android.sh`. (Require android sdk + ndk) 6 | 2. To build ios framework Run `build_ios.sh` (Requires xcode build tools) 7 | 3. To build binaries for desktop Run `build_desktop.sh` 8 | 9 | 10 | ## Download from jitpack (Android only) 11 | [![](https://jitpack.io/v/Windscribe/wstunnel.svg)](https://jitpack.io/#Windscribe/wstunnel) 12 | 13 | ## Use Library 14 | Import Library/Framework & Start proxy. 15 | ```val logFile = File(appContext.filesDir, PROXY_LOG).path 16 | initialise(BuildConfig.DEV, logFile) 17 | registerTunnelCallback(callback) 18 | if (isWSTunnel) { 19 | val remote = "wss://$ip:$port/$PROXY_TUNNEL_PROTOCOL/$PROXY_TUNNEL_ADDRESS/$WS_TUNNEL_PORT" 20 | startProxy(":$PROXY_TUNNEL_PORT", remote, 1, mtu) 21 | } else { 22 | val remote = "https://$ip:$port" 23 | startProxy(":$PROXY_TUNNEL_PORT", remote, 2, mtu) 24 | } 25 | ``` 26 | ## Start binary 27 | ```Flags: 28 | -d, --dev Turns on verbose logging. 29 | -h, --help help for root 30 | -l, --listenAddress string Local port for proxy > :65479 (default ":65479") 31 | -f, --logFilePath string Path to log file > file.log 32 | -m, --mtu int 1500 (default 1500) 33 | -r, --remoteAddress string Wstunnel > wss://$ip:$port/tcp/127.0.0.1/$WS_TUNNEL_PORT Stunnel > https://$ip:$port 34 | -t, --tunnelType int WStunnel > 1 , Stunnel > 2 (default 1) 35 | $ cli -l :65479 -r wss://$ip:$port/tcp/127.0.0.1/$WS_TUNNEL_PORT -t 1 -m 1500 -f file.log -d true 36 | $ cli -l :65479 -r https://$ip:$port -t 2 -m 1500 -f file.log -d true 37 | ``` 38 | 39 | ## Dependencies 40 | 1. Gorrila web socket for wstunnel [Link](https://github.com/gorilla/websocket) 41 | 2. Cobra for cli [Link](https://github.com/spf13/cobra) 42 | 3. Zap for logging [Link](https://github.com/uber-go/zap) 43 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'maven-publish' 2 | def LIB_GROUP_ID = 'com.windscribe' 3 | def LIB_ARTIFACT_ID = 'proxy' 4 | def LIB_VERSION = '1.0.0' 5 | def aarFile = layout.buildDirectory.file("proxy.aar") 6 | 7 | task build() { 8 | exec { 9 | executable "./build_android.sh" 10 | ignoreExitValue true 11 | } 12 | } 13 | afterEvaluate { 14 | publishing { 15 | publications { 16 | release(MavenPublication) { 17 | groupId LIB_GROUP_ID 18 | artifactId LIB_ARTIFACT_ID 19 | version LIB_VERSION 20 | artifact(aarFile) 21 | } 22 | } 23 | } 24 | } -------------------------------------------------------------------------------- /build_android.sh: -------------------------------------------------------------------------------- 1 | export PATH=$PATH:~/go/bin 2 | go mod tidy 3 | rm -r build/android 4 | mkdir -p build/android/arm64-v8a 5 | mkdir -p build/android/armeabi-v7a 6 | mkdir -p build/android/x86 7 | mkdir -p build/android/x86_64 8 | export CGO_ENABLED=1 9 | export CGO_CFLAGS="-fstack-protector-strong" 10 | 11 | if [[ "$(uname)" == "Darwin" ]]; then 12 | PLATFORM="darwin" 13 | elif [[ "$(uname)" == "Linux" ]]; then 14 | PLATFORM="linux" 15 | else 16 | PLATFORM="unknown" 17 | fi 18 | # shellcheck disable=SC2016 19 | buildCommand='go build -ldflags "-s -w" -buildmode=c-shared -o "$output_dir/libproxy.so" cli.go' 20 | echo "$buildCommand" 21 | 22 | # For ARM64 23 | output_dir="./build/android/arm64-v8a" 24 | TOOLCHAIN=("$ANDROID_NDK/toolchains/llvm/prebuilt/$PLATFORM-x86_64/bin/aarch64-linux-android21-clang") 25 | # shellcheck disable=SC2086 26 | GOOS=android GOARCH=arm64 CC="${TOOLCHAIN[0]}" output_dir="$output_dir" sh -c "$buildCommand" 27 | rm $output_dir/libproxy.h 28 | 29 | ## For ARMv7 30 | output_dir="./build/android/armeabi-v7a" 31 | TOOLCHAIN=("$ANDROID_NDK/toolchains/llvm/prebuilt/$PLATFORM-x86_64/bin/armv7a-linux-androideabi21-clang") 32 | # shellcheck disable=SC2086 33 | GOOS=android GOARCH=arm CC="${TOOLCHAIN[0]}" output_dir="$output_dir" sh -c "$buildCommand" 34 | rm $output_dir/libproxy.h 35 | 36 | ## For x86 37 | output_dir="./build/android/x86" 38 | TOOLCHAIN=("$ANDROID_NDK/toolchains/llvm/prebuilt/$PLATFORM-x86_64/bin/i686-linux-android21-clang") 39 | # shellcheck disable=SC2086 40 | GOOS=android GOARCH=386 CC="${TOOLCHAIN[0]}" output_dir="$output_dir" sh -c "$buildCommand" 41 | rm $output_dir/libproxy.h 42 | 43 | ## For x86_64 44 | output_dir="./build/android/x86_64" 45 | TOOLCHAIN=("$ANDROID_NDK/toolchains/llvm/prebuilt/$PLATFORM-x86_64/bin/x86_64-linux-android21-clang") 46 | # shellcheck disable=SC2086 47 | GOOS=android GOARCH=amd64 CC="${TOOLCHAIN[0]}" output_dir="$output_dir" sh -c "$buildCommand" 48 | rm $output_dir/libproxy.h 49 | 50 | echo 'Build successful...' -------------------------------------------------------------------------------- /build_desktop.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -r build 3 | platforms=("windows/arm64" "windows/amd64" "darwin/amd64" "linux/amd64" "linux/arm64") 4 | 5 | for platform in "${platforms[@]}" 6 | do 7 | platform_split=(${platform//\// }) 8 | GOOS=${platform_split[0]} 9 | GOARCH=${platform_split[1]} 10 | output_name='wstunnel-'$GOOS'-'$GOARCH 11 | if [ "$GOOS" = "windows" ]; then 12 | output_name+='.exe' 13 | fi 14 | echo "Building $output_name" 15 | mkdir -p build/desktop 16 | env GOOS="$GOOS" GOARCH="$GOARCH" go build -o build/desktop/$output_name -a -gcflags=all="-l -B" -ldflags="-w -s" 17 | if [ $? -ne 0 ]; then 18 | echo 'An error has occurred!' 19 | exit 1 20 | fi 21 | done -------------------------------------------------------------------------------- /build_ios.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Common variables 4 | export GOOS=ios 5 | export GOARCH=arm64 6 | export CGO_ENABLED=1 7 | 8 | build() { 9 | local sdk=$1 10 | local min_version=$2 11 | local platform_sdk=$3 12 | 13 | export SDK=$sdk 14 | CC="clang -arch arm64 -isysroot /Applications/Xcode.app/Contents/Developer/Platforms/${platform_sdk}.platform/Developer/SDKs/${platform_sdk}17.5.sdk -m${sdk}-version-min=${min_version} -fembed-bitcode" 15 | CGO_CFLAGS="" 16 | CGO_LDFLAGS="-framework CoreFoundation" 17 | export CC 18 | 19 | output_dir="./build/${sdk}/arm64" 20 | rm -rf "$output_dir" 21 | mkdir -p "$output_dir" 22 | go build -buildmode=c-archive -o "$output_dir/proxy.a" cli.go 23 | } 24 | 25 | # Build for Apple TVOS 26 | build "appletvos" "17.0" "AppleTVOS" 27 | echo "Apple TVOS framework at ./build/appletvos/arm64/proxy.a" 28 | 29 | # Build for Apple TV Simulator 30 | build "appletvsimulator" "17.0" "AppleTVSimulator" 31 | echo "Apple TVOS framework at ./build/appletvsimulator/arm64/proxy.a" 32 | 33 | # Build for iPhoneOS 34 | build "iphoneos" "12.0" "iPhoneOS" 35 | echo "iPhoneOS framework at ./build/iphoneos/arm64/proxy.a" 36 | 37 | # Build for iPhoneSimulator 38 | build "iphonesimulator" "12.0" "iPhoneSimulator" 39 | echo "iPhoneSimulator framework at ./build/iphonesimulator/arm64/proxy.a" 40 | 41 | # Create a combined headers directory 42 | combined_headers_dir="./build/combined_headers" 43 | rm -rf "$combined_headers_dir" 44 | mkdir -p "$combined_headers_dir" 45 | 46 | # Copy header files from all builds 47 | cp ./build/appletvos/arm64/*.h "$combined_headers_dir" 48 | # same for all platforms 49 | #cp ./build/iphoneos/arm64/*.h "$combined_headers_dir" 50 | #cp ./build/iphonesimulator/arm64/*.h "$combined_headers_dir" 51 | 52 | # Create .xcframework 53 | rm -rf ./build/Proxy.xcframework 54 | xcodebuild -create-xcframework \ 55 | -library ./build/appletvos/arm64/proxy.a -headers "$combined_headers_dir" \ 56 | -library ./build/appletvsimulator/arm64/proxy.a -headers "$combined_headers_dir" \ 57 | -library ./build/iphoneos/arm64/proxy.a -headers "$combined_headers_dir" \ 58 | -library ./build/iphonesimulator/arm64/proxy.a -headers "$combined_headers_dir" \ 59 | -output ./build/Proxy.xcframework 60 | 61 | echo "Combined framework at ./build/Proxy.xcframework" -------------------------------------------------------------------------------- /clangwrap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | CLANG=$(xcrun --sdk "$SDK" --find clang) 4 | 5 | exec "$CLANG" -target "$TARGET" -isysroot "$SDK_PATH" "$@" -------------------------------------------------------------------------------- /cli.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | //"C" 5 | "github.com/Windscribe/wstunnel/cli" 6 | "github.com/spf13/cobra" 7 | "os" 8 | //_ "runtime/cgo" 9 | ) 10 | 11 | var listenAddress string 12 | var remoteAddress string 13 | var tunnelType int 14 | var mtu int 15 | var extraTlsPadding bool 16 | var logFilePath string 17 | var dev = false 18 | 19 | var rootCmd = &cobra.Command{ 20 | Use: "root", 21 | Short: "Starts local proxy and connects to server.", 22 | Long: "Starts local proxy and sets up connection to the server. At minimum it requires remote server address and log file path.", 23 | Run: func(cmd *cobra.Command, args []string) { 24 | Initialise(dev, logFilePath) 25 | started := StartProxy(listenAddress, remoteAddress, tunnelType, mtu, extraTlsPadding) 26 | if started == false { 27 | os.Exit(0) 28 | } 29 | }, 30 | } 31 | 32 | func init() { 33 | rootCmd.PersistentFlags().StringVarP(&listenAddress, "listenAddress", "l", ":65479", "Local port for proxy > :65479") 34 | rootCmd.PersistentFlags().StringVarP(&remoteAddress, "remoteAddress", "r", "", "Wstunnel > wss://$ip:$port/tcp/127.0.0.1/$WS_TUNNEL_PORT Stunnel > https://$ip:$port") 35 | _ = rootCmd.MarkPersistentFlagRequired("remoteAddress") 36 | rootCmd.PersistentFlags().IntVarP(&tunnelType, "tunnelType", "t", 1, "WStunnel > 1 , Stunnel > 2") 37 | rootCmd.PersistentFlags().IntVarP(&mtu, "mtu", "m", 1500, "1500") 38 | rootCmd.PersistentFlags().BoolVarP(&extraTlsPadding, "extraTlsPadding", "p", false, "Add Extra TLS Padding to ClientHello packet.") 39 | rootCmd.PersistentFlags().StringVarP(&logFilePath, "logFilePath", "f", "", "Path to log file > file.log") 40 | _ = rootCmd.MarkPersistentFlagRequired("logFilePath") 41 | rootCmd.PersistentFlags().BoolVarP(&dev, "dev", "d", false, "Turns on verbose logging.") 42 | } 43 | 44 | func main() { 45 | _, err := rootCmd.ExecuteC() 46 | if err != nil { 47 | return 48 | } 49 | } 50 | 51 | //export Callback is used by http client to send events to host app 52 | var primaryListenerSocketFd int = -1 53 | 54 | //export Initialise 55 | func Initialise(development bool, logFilePath string) { 56 | cli.InitLogger(development, logFilePath) 57 | } 58 | 59 | //export StartProxy 60 | func StartProxy(listenAddress string, remoteAddress string, tunnelType int, mtu int, extraPadding bool) bool { 61 | cli.Logger.Infof("Starting proxy with listenAddress: %s remoteAddress %s tunnelType: %d mtu %d", listenAddress, remoteAddress, tunnelType, mtu) 62 | err := cli.NewHTTPClient(listenAddress, remoteAddress, tunnelType, mtu, func(fd int) { 63 | primaryListenerSocketFd = fd 64 | cli.Logger.Info("Socket ready to protect.") 65 | }, cli.Channel, extraPadding).Run() 66 | if err != nil { 67 | return false 68 | } 69 | return true 70 | } 71 | 72 | //export Stop 73 | func Stop() { 74 | cli.Logger.Info("Disconnect signal from host app.") 75 | cli.Channel <- "done" 76 | } 77 | 78 | //export GetPrimaryListenerSocketFd 79 | func GetPrimaryListenerSocketFd() int { 80 | return primaryListenerSocketFd 81 | } 82 | -------------------------------------------------------------------------------- /cli/client_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | var protocol = "ws://" 11 | var echoServerAddress = "localhost:8080" 12 | var path = "/ws" 13 | var webSocketServerAddress = fmt.Sprintf("%s%s%s", protocol, echoServerAddress, path) 14 | var tcpServerAddress = "localhost:1194" 15 | var dataToSend = []byte("Send me this message back.") 16 | 17 | func TestEndToEndConnection(t *testing.T) { 18 | InitLogger(true, "") 19 | //Ws server 20 | startServer(echoServerAddress, path) 21 | //Tcp server 22 | go func() { 23 | err := NewHTTPClient(tcpServerAddress, webSocketServerAddress, 1, 1600, func(fd int) { 24 | t.Log(fd) 25 | }, Channel, false).Run() 26 | if err != nil { 27 | t.Fail() 28 | return 29 | } 30 | }() 31 | time.Sleep(time.Millisecond * 100) 32 | //Client 1 33 | _, client1Err := mockClientConnection() 34 | if client1Err != nil { 35 | t.Fail() 36 | return 37 | } 38 | //Client 2 39 | _, client2Err := mockClientConnection() 40 | if client2Err != nil { 41 | t.Fail() 42 | return 43 | } 44 | //Exit 45 | time.Sleep(time.Millisecond * 100) 46 | Channel <- "done" 47 | //Client 3 48 | _, client3Err := mockClientConnection() 49 | if client3Err == nil { 50 | t.Fail() 51 | return 52 | } 53 | t.Log("Test is successful.") 54 | } 55 | 56 | func mockClientConnection() (string, error) { 57 | var conn, connErr = net.Dial("tcp", tcpServerAddress) 58 | if connErr != nil { 59 | return "", connErr 60 | } 61 | _, writeErr := conn.Write(dataToSend) 62 | if writeErr != nil { 63 | return "", writeErr 64 | } 65 | time.Sleep(time.Second * 1) 66 | data := make([]byte, 30) 67 | _, err := conn.Read(data) 68 | return string(data), err 69 | } 70 | -------------------------------------------------------------------------------- /cli/common.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | const ( 4 | // BufferSize is the size of the intermediate buffer for network packets 5 | BufferSize = 1024 6 | ) 7 | 8 | // Runner defines a basic interface with only a Run() function 9 | type Runner interface { 10 | Run() error 11 | } 12 | -------------------------------------------------------------------------------- /cli/echo_server.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "github.com/gorilla/websocket" 5 | "log" 6 | "net/http" 7 | ) 8 | 9 | // EchoServer simple web socket server for testing. 10 | type EchoServer struct { 11 | clients map[string]*websocket.Conn 12 | } 13 | 14 | var upgrader = websocket.Upgrader{ 15 | CheckOrigin: func(r *http.Request) bool { 16 | return true 17 | }, 18 | } 19 | 20 | func (server *EchoServer) handleRequest(w http.ResponseWriter, r *http.Request) { 21 | connection, err := upgrader.Upgrade(w, r, nil) 22 | if err != nil { 23 | return 24 | } 25 | clientId := connection.RemoteAddr().String() 26 | server.clients[clientId] = connection 27 | for { 28 | messageType, message, err := connection.ReadMessage() 29 | if err != nil || messageType == websocket.CloseMessage { 30 | break 31 | } 32 | go server.echoMessageBack(clientId, message) 33 | } 34 | connection.Close() 35 | delete(server.clients, clientId) 36 | } 37 | 38 | func startServer(address string, path string) *EchoServer { 39 | server := EchoServer{ 40 | clients: map[string]*websocket.Conn{}, 41 | } 42 | http.HandleFunc(path, server.handleRequest) 43 | go func() { 44 | err := http.ListenAndServe(address, nil) 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | }() 49 | return &server 50 | } 51 | 52 | func (server *EchoServer) echoMessageBack(clientId string, message []byte) { 53 | err := server.clients[clientId].WriteMessage(websocket.BinaryMessage, message) 54 | if err != nil { 55 | return 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /cli/httpClient.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/gorilla/websocket" 7 | tls "github.com/refraction-networking/utls" 8 | "math/rand" 9 | "net" 10 | "net/http" 11 | "net/url" 12 | "sync" 13 | "syscall" 14 | "time" 15 | ) 16 | 17 | //export WSTunnel wraps OpenVPN tcp traffic in to Websocket 18 | const WSTunnel = 1 19 | 20 | //export Stunnel wraps OpenVPN tcp traffic in to regular tcp. 21 | const Stunnel = 2 22 | 23 | //export Channel is used by host app to send events to http client. 24 | var Channel = make(chan string) 25 | 26 | // httpClient 27 | // sets up tcp server and remote connections. 28 | // ////////////////////////////////////////////////////////////////////////////// 29 | type httpClient struct { 30 | listenTCP string 31 | remoteServer string 32 | tunnelType int 33 | mtu int 34 | callback func(fd int) 35 | channel chan string 36 | extraPadding bool 37 | } 38 | 39 | func NewHTTPClient(listenTCP, remoteServer string, tunnelType int, mtu int, callback func(fd int), channel chan string, extraPadding bool) Runner { 40 | return &httpClient{ 41 | listenTCP: listenTCP, 42 | remoteServer: remoteServer, 43 | tunnelType: tunnelType, 44 | mtu: mtu, 45 | callback: callback, 46 | channel: channel, 47 | extraPadding: extraPadding, 48 | } 49 | } 50 | 51 | // Run stars tcp server and connect to remote server. 52 | func (h *httpClient) Run() error { 53 | tcpAdr, err := net.ResolveTCPAddr("tcp", h.listenTCP) 54 | if err != nil { 55 | Logger.Errorf("Error resolving tcp address: %s", err) 56 | return err 57 | } 58 | tcpConnection, err := net.ListenTCP("tcp", tcpAdr) 59 | if err != nil { 60 | return err 61 | } 62 | defer tcpConnection.Close() 63 | Logger.Infof("Listening on %s", h.listenTCP) 64 | doneMutex := sync.Mutex{} 65 | done := false 66 | isDone := func() bool { 67 | doneMutex.Lock() 68 | defer doneMutex.Unlock() 69 | return done 70 | } 71 | go func() { 72 | select { 73 | case msg := <-h.channel: 74 | if msg == "done" { 75 | doneMutex.Lock() 76 | defer doneMutex.Unlock() 77 | done = true 78 | _ = tcpConnection.Close() 79 | } 80 | } 81 | }() 82 | for !isDone() { 83 | tcpConn, err := tcpConnection.Accept() 84 | if err != nil { 85 | continue 86 | } 87 | Logger.Infof("New connection from %s", tcpConn.RemoteAddr().String()) 88 | if h.tunnelType == WSTunnel { 89 | handleWsTunnelConnection(h, tcpConn) 90 | } else if h.tunnelType == Stunnel { 91 | handleStunnelConnection(h, tcpConn) 92 | } else { 93 | Logger.Fatal("Invalid tunnel type specified.") 94 | } 95 | } 96 | return err 97 | } 98 | 99 | func handleStunnelConnection(h *httpClient, localConn net.Conn) { 100 | remoteConn, err := h.createRemoteConnection() 101 | if err != nil { 102 | Logger.Errorf("%s - Remote server connection > Error while dialing %s: %s", localConn.RemoteAddr(), h.remoteServer, err) 103 | _ = localConn.Close() 104 | return 105 | } 106 | err = remoteConn.HandshakeContext(context.Background()) 107 | if err != nil { 108 | _ = localConn.Close() 109 | Logger.Errorf("Error on handshake: %s", err) 110 | return 111 | } 112 | Logger.Info("Starting stunnel bi-direction connection.") 113 | b := NewStunnelBiDirection(localConn, remoteConn, h.mtu) 114 | go b.Run() 115 | } 116 | 117 | func (h *httpClient) createRemoteConnection() (*tls.UConn, error) { 118 | customNetDialer := h.createDialer() 119 | cfg := &tls.Config{ 120 | InsecureSkipVerify: true, 121 | } 122 | remoteUrl, err := url.Parse(h.remoteServer) 123 | if err != nil { 124 | return nil, err 125 | } 126 | netConn, err := customNetDialer.Dial("tcp", remoteUrl.Host) 127 | if err != nil { 128 | return nil, err 129 | } 130 | cfg.ServerName = remoteUrl.Hostname() 131 | 132 | remoteConn := tls.UClient(netConn, cfg, tls.HelloCustom) 133 | clientHelloSpec, err := tls.UTLSIdToSpec(tls.HelloRandomizedALPN) 134 | if err != nil { 135 | return nil, fmt.Errorf("uTlsConn.generateRandomizedSpec error: %+v", err) 136 | } 137 | 138 | if h.extraPadding { 139 | rand.Seed(time.Now().Unix()) 140 | alreadyHasPadding := false 141 | for _, ext := range clientHelloSpec.Extensions { 142 | if _, ok := ext.(*tls.UtlsPaddingExtension); ok { 143 | alreadyHasPadding = true 144 | ext.(*tls.UtlsPaddingExtension).PaddingLen = 2000 + rand.Intn(10000) 145 | ext.(*tls.UtlsPaddingExtension).WillPad = true 146 | ext.(*tls.UtlsPaddingExtension).GetPaddingLen = nil 147 | break 148 | } 149 | } 150 | if !alreadyHasPadding { 151 | clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &tls.UtlsPaddingExtension{PaddingLen: 2000 + rand.Intn(10000), WillPad: true, GetPaddingLen: nil}) 152 | } 153 | } 154 | 155 | err = remoteConn.ApplyPreset(&clientHelloSpec) 156 | if err != nil { 157 | return nil, fmt.Errorf("uTlsConn.ApplyPreset error: %+v", err) 158 | } 159 | 160 | return remoteConn, nil 161 | } 162 | 163 | func handleWsTunnelConnection(h *httpClient, tcpConn net.Conn) { 164 | wsConn, wsErr := h.createWsConnection(tcpConn.RemoteAddr().String()) 165 | if wsErr != nil || wsConn == nil { 166 | Logger.Errorf("%s - Ws connection > Error while dialing %s: %s", tcpConn.RemoteAddr(), h.remoteServer, wsErr) 167 | _ = tcpConn.Close() 168 | return 169 | } 170 | b := NewBidirConnection(tcpConn, wsConn, time.Second*10, h.mtu) 171 | go b.Run() 172 | } 173 | 174 | func (h *httpClient) toUrl(asString string) (string, error) { 175 | asURL, err := url.Parse(asString) 176 | if err != nil { 177 | return asString, err 178 | } 179 | return asURL.String(), nil 180 | } 181 | 182 | // createDialer creates custom dialer which provides access to socket fd 183 | func (h *httpClient) createDialer() *net.Dialer { 184 | customNetDialer := &net.Dialer{} 185 | // Access underlying socket fd before connecting to it. 186 | customNetDialer.Control = func(network, address string, c syscall.RawConn) error { 187 | return c.Control(func(fd uintptr) { 188 | Logger.Infof("Received socket fd %d", fd) 189 | i := int(fd) 190 | h.callback(i) 191 | }) 192 | } 193 | return customNetDialer 194 | } 195 | 196 | // createWsConnection creates a connection to websocket server. 197 | func (h *httpClient) createWsConnection(remoteAddr string) (wsConn *websocket.Conn, err error) { 198 | wsConnectUrl := h.remoteServer 199 | for { 200 | var wsURL string 201 | wsURL, err = h.toUrl(wsConnectUrl) 202 | if err != nil { 203 | return 204 | } 205 | Logger.Infof("%s - Connecting to %s", remoteAddr, wsURL) 206 | var httpResponse *http.Response 207 | dialer := *websocket.DefaultDialer 208 | dialer.TLSClientConfig = &tls.Config{ 209 | InsecureSkipVerify: true, 210 | } 211 | customNetDialer := h.createDialer() 212 | dialer.NetDial = func(network, addr string) (net.Conn, error) { 213 | return customNetDialer.Dial(network, addr) 214 | } 215 | wsConn, httpResponse, err = dialer.Dial(wsURL, nil) 216 | if wsConn != nil { 217 | Logger.Info("Successfully connected to remote server.") 218 | } else if err != nil { 219 | Logger.Errorf("Failed to connect to remote server.. %s", err) 220 | } 221 | if httpResponse != nil { 222 | switch httpResponse.StatusCode { 223 | case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect: 224 | wsConnectUrl = httpResponse.Header.Get("Location") 225 | Logger.Infof("%s - Redirect to %s", remoteAddr, wsConnectUrl) 226 | continue 227 | } 228 | } 229 | return 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /cli/logger.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "go.uber.org/zap" 5 | "go.uber.org/zap/zapcore" 6 | "log" 7 | "time" 8 | ) 9 | 10 | var Logger *zap.SugaredLogger 11 | 12 | // InitLogger initializes the logger. 13 | func InitLogger(development bool, logFilePath string) { 14 | cfg := zap.NewProductionConfig() 15 | outputPaths := []string{"stdout"} 16 | if logFilePath != "" { 17 | outputPaths = append(outputPaths, logFilePath) 18 | } 19 | cfg.OutputPaths = outputPaths 20 | 21 | cfg.Encoding = "json" 22 | cfg.EncoderConfig.EncodeDuration = zapcore.NanosDurationEncoder 23 | cfg.EncoderConfig.EncodeLevel = levelEncoder 24 | cfg.EncoderConfig.EncodeTime = syslogTimeEncoder 25 | cfg.EncoderConfig.MessageKey = "msg" 26 | cfg.EncoderConfig.CallerKey = "" 27 | cfg.EncoderConfig.NameKey = "mod" 28 | cfg.EncoderConfig.TimeKey = "tm" // Important: Set the TimeKey 29 | 30 | if !development { 31 | cfg.EncoderConfig.StacktraceKey = "" 32 | } 33 | 34 | zapLogger, err := cfg.Build(zap.AddCallerSkip(1)) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | 39 | Logger = zapLogger.With(zap.String("mod", "wstunnel")).Sugar() 40 | 41 | if logFilePath != "" { 42 | Logger.Info("Logging to stdout and file: ", zap.String("file", logFilePath)) 43 | } else { 44 | Logger.Info("Logging to stdout") 45 | } 46 | } 47 | 48 | func syslogTimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { 49 | enc.AppendString(t.Format("2006-01-02 15:04:05.000")) 50 | } 51 | 52 | func levelEncoder(level zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { 53 | enc.AppendString(level.String()) 54 | } 55 | -------------------------------------------------------------------------------- /cli/stunnelbidirection.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | tls "github.com/refraction-networking/utls" 5 | "net" 6 | "os" 7 | ) 8 | 9 | // StunnelBiDirection 10 | // creates an object to transfer data between the TCP clients and remote server in bidirectional way 11 | type StunnelBiDirection struct { 12 | localConn net.Conn 13 | remoteConn *tls.UConn 14 | mtu int 15 | } 16 | 17 | func NewStunnelBiDirection(localConn net.Conn, remoteConn *tls.UConn, mtu int) Runner { 18 | return &StunnelBiDirection{ 19 | localConn, remoteConn, mtu, 20 | } 21 | } 22 | 23 | func (s *StunnelBiDirection) Run() error { 24 | go s.sendTCPToStunnel() 25 | s.sendStunnelToTCP() 26 | return nil 27 | } 28 | 29 | // sendTCPToStunnel copies tcp traffic to remote server 30 | func (s *StunnelBiDirection) sendTCPToStunnel() { 31 | defer s.close() 32 | data := make([]byte, s.mtu) 33 | for { 34 | readSize, err := s.localConn.Read(data) 35 | if err != nil && !os.IsTimeout(err) { 36 | return 37 | } 38 | _, _ = s.remoteConn.Write(data[:readSize]) 39 | if err != nil { 40 | return 41 | } 42 | } 43 | } 44 | 45 | // sendStunnelToTCP copies remote server traffic to tcp connection. 46 | func (s *StunnelBiDirection) sendStunnelToTCP() { 47 | defer s.close() 48 | data := make([]byte, s.mtu) 49 | for { 50 | readSize, err := s.remoteConn.Read(data) 51 | if err != nil && !os.IsTimeout(err) { 52 | break 53 | } 54 | _, _ = s.localConn.Write(data[:readSize]) 55 | if err != nil { 56 | return 57 | } 58 | } 59 | } 60 | 61 | // close closes connections. 62 | func (s *StunnelBiDirection) close() { 63 | _ = s.remoteConn.Close() 64 | _ = s.localConn.Close() 65 | } 66 | -------------------------------------------------------------------------------- /cli/websocketbidirConnection.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "github.com/gorilla/websocket" 5 | "net" 6 | "os" 7 | "time" 8 | ) 9 | 10 | // WebSocketBiDirection 11 | // Creates an object to transfer data between the TCP clients and remote server in bidirectional way 12 | type WebSocketBiDirection struct { 13 | tcpConn net.Conn 14 | wsConn *websocket.Conn 15 | tcpReadTimeout time.Duration 16 | mtu int 17 | } 18 | 19 | func NewBidirConnection(tcpConn net.Conn, wsConn *websocket.Conn, tcpReadTimeout time.Duration, mtu int) Runner { 20 | return &WebSocketBiDirection{ 21 | tcpConn: tcpConn, 22 | wsConn: wsConn, 23 | tcpReadTimeout: tcpReadTimeout, 24 | mtu: mtu, 25 | } 26 | } 27 | 28 | // sendTCPToWS copies tcp traffic to web socket connection. 29 | func (b *WebSocketBiDirection) sendTCPToWS() { 30 | defer b.close() 31 | data := make([]byte, b.mtu) 32 | for { 33 | if b.tcpReadTimeout > 0 { 34 | _ = b.tcpConn.SetReadDeadline(time.Now().Add(b.tcpReadTimeout)) 35 | } 36 | readSize, err := b.tcpConn.Read(data) 37 | if err != nil && !os.IsTimeout(err) { 38 | return 39 | } 40 | 41 | if err := b.wsConn.WriteMessage(websocket.BinaryMessage, data[:readSize]); err != nil { 42 | return 43 | } 44 | } 45 | } 46 | 47 | // sendWSToTCP copies web socket traffic to tcp connection. 48 | func (b *WebSocketBiDirection) sendWSToTCP() { 49 | defer b.close() 50 | data := make([]byte, b.mtu) 51 | for { 52 | messageType, wsReader, err := b.wsConn.NextReader() 53 | if err != nil { 54 | return 55 | } 56 | if messageType != websocket.BinaryMessage { 57 | Logger.Infof("WSToTCP - Got wrong message type from WS: %s", messageType) 58 | return 59 | } 60 | 61 | for { 62 | readSize, err := wsReader.Read(data) 63 | if err != nil { 64 | break 65 | } 66 | 67 | if _, err := b.tcpConn.Write(data[:readSize]); err != nil { 68 | return 69 | } 70 | } 71 | } 72 | } 73 | 74 | func (b *WebSocketBiDirection) Run() error { 75 | go b.sendTCPToWS() 76 | b.sendWSToTCP() 77 | return nil 78 | } 79 | 80 | // close closes connections. 81 | func (b *WebSocketBiDirection) close() { 82 | _ = b.wsConn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second)) 83 | _ = b.wsConn.Close() 84 | _ = b.tcpConn.Close() 85 | } 86 | -------------------------------------------------------------------------------- /clib.patch: -------------------------------------------------------------------------------- 1 | Subject: [PATCH] clib 2 | --- 3 | Index: cli.go 4 | IDEA additional info: 5 | Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP 6 | <+>UTF-8 7 | =================================================================== 8 | diff --git a/cli.go b/cli.go 9 | --- a/cli.go (revision ac075a725b30b5d3a0a830bbc4615d11b2674fb6) 10 | +++ b/cli.go (date 1720546823021) 11 | @@ -1,10 +1,10 @@ 12 | package main 13 | 14 | import ( 15 | - //"C" 16 | + "C" 17 | "github.com/spf13/cobra" 18 | "os" 19 | - //_ "runtime/cgo" 20 | + _ "runtime/cgo" 21 | ) 22 | 23 | var listenAddress string 24 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/Windscribe/wstunnel 2 | 3 | go 1.18 4 | 5 | replace github.com/gorilla/websocket => ./websocket 6 | 7 | require ( 8 | github.com/gorilla/websocket v1.4.2 9 | github.com/refraction-networking/utls v1.3.2 10 | github.com/spf13/cobra v1.7.0 11 | go.uber.org/zap v1.23.0 12 | ) 13 | 14 | require ( 15 | github.com/andybalholm/brotli v1.0.4 // indirect 16 | github.com/gaukas/godicttls v0.0.3 // indirect 17 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 18 | github.com/klauspost/compress v1.15.15 // indirect 19 | github.com/spf13/pflag v1.0.5 // indirect 20 | go.uber.org/atomic v1.10.0 // indirect 21 | go.uber.org/multierr v1.8.0 // indirect 22 | golang.org/x/crypto v0.5.0 // indirect 23 | golang.org/x/sys v0.5.0 // indirect 24 | ) 25 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= 2 | github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= 3 | github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= 4 | github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= 9 | github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= 10 | github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= 11 | github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= 12 | github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= 13 | github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= 14 | github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= 15 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/refraction-networking/utls v1.1.5/go.mod h1:jRQxtYi7nkq1p28HF2lwOH5zQm9aC8rpK0O9lIIzGh8= 19 | github.com/refraction-networking/utls v1.3.2 h1:o+AkWB57mkcoW36ET7uJ002CpBWHu0KPxi6vzxvPnv8= 20 | github.com/refraction-networking/utls v1.3.2/go.mod h1:fmoaOww2bxzzEpIKOebIsnBvjQpqP7L2vcm/9KUfm/E= 21 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 22 | github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= 23 | github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= 24 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 25 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 26 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 27 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 28 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 29 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= 30 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 31 | go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= 32 | go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= 33 | go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= 34 | go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= 35 | go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= 36 | go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= 37 | go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= 38 | golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 39 | golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= 40 | golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= 41 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 42 | golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= 43 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 44 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 45 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 46 | golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 47 | golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= 48 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 49 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 50 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 51 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 52 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 53 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 54 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 55 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 56 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 57 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 58 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 59 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Windscribe/wstunnel/c191d6e13771317499277f23a5dad369b9c23605/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 4 | # Copyright © 2015-2021 the original authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | ############################################################################## 20 | # 21 | # Gradle start up script for POSIX generated by Gradle. 22 | # 23 | # Important for running: 24 | # 25 | # (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is 26 | # noncompliant, but you have some other compliant shell such as ksh or 27 | # bash, then to run this script, type that shell name before the whole 28 | # command line, like: 29 | # 30 | # ksh Gradle 31 | # 32 | # Busybox and similar reduced shells will NOT work, because this script 33 | # requires all of these POSIX shell features: 34 | # * functions; 35 | # * expansions «$var», «${var}», «${var:-default}», «${var+SET}», 36 | # «${var#prefix}», «${var%suffix}», and «$( cmd )»; 37 | # * compound commands having a testable exit status, especially «case»; 38 | # * various built-in commands including «command», «set», and «ulimit». 39 | # 40 | # Important for patching: 41 | # 42 | # (2) This script targets any POSIX shell, so it avoids extensions provided 43 | # by Bash, Ksh, etc; in particular arrays are avoided. 44 | # 45 | # The "traditional" practice of packing multiple parameters into a 46 | # space-separated string is a well documented source of bugs and security 47 | # problems, so this is (mostly) avoided, by progressively accumulating 48 | # options in "$@", and eventually passing that to Java. 49 | # 50 | # Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, 51 | # and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; 52 | # see the in-line comments for details. 53 | # 54 | # There are tweaks for specific operating systems such as AIX, CygWin, 55 | # Darwin, MinGW, and NonStop. 56 | # 57 | # (3) This script is generated from the Groovy template 58 | # https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt 59 | # within the Gradle project. 60 | # 61 | # You can find Gradle at https://github.com/gradle/gradle/. 62 | # 63 | ############################################################################## 64 | 65 | # Attempt to set APP_HOME 66 | 67 | # Resolve links: $0 may be a link 68 | app_path=$0 69 | 70 | # Need this for daisy-chained symlinks. 71 | while 72 | APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path 73 | [ -h "$app_path" ] 74 | do 75 | ls=$( ls -ld "$app_path" ) 76 | link=${ls#*' -> '} 77 | case $link in #( 78 | /*) app_path=$link ;; #( 79 | *) app_path=$APP_HOME$link ;; 80 | esac 81 | done 82 | 83 | APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit 84 | 85 | APP_NAME="Gradle" 86 | APP_BASE_NAME=${0##*/} 87 | 88 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 89 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' 90 | 91 | # Use the maximum available, or set MAX_FD != -1 to use that value. 92 | MAX_FD=maximum 93 | 94 | warn () { 95 | echo "$*" 96 | } >&2 97 | 98 | die () { 99 | echo 100 | echo "$*" 101 | echo 102 | exit 1 103 | } >&2 104 | 105 | # OS specific support (must be 'true' or 'false'). 106 | cygwin=false 107 | msys=false 108 | darwin=false 109 | nonstop=false 110 | case "$( uname )" in #( 111 | CYGWIN* ) cygwin=true ;; #( 112 | Darwin* ) darwin=true ;; #( 113 | MSYS* | MINGW* ) msys=true ;; #( 114 | NONSTOP* ) nonstop=true ;; 115 | esac 116 | 117 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 118 | 119 | 120 | # Determine the Java command to use to start the JVM. 121 | if [ -n "$JAVA_HOME" ] ; then 122 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 123 | # IBM's JDK on AIX uses strange locations for the executables 124 | JAVACMD=$JAVA_HOME/jre/sh/java 125 | else 126 | JAVACMD=$JAVA_HOME/bin/java 127 | fi 128 | if [ ! -x "$JAVACMD" ] ; then 129 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 130 | 131 | Please set the JAVA_HOME variable in your environment to match the 132 | location of your Java installation." 133 | fi 134 | else 135 | JAVACMD=java 136 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 137 | 138 | Please set the JAVA_HOME variable in your environment to match the 139 | location of your Java installation." 140 | fi 141 | 142 | # Increase the maximum file descriptors if we can. 143 | if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then 144 | case $MAX_FD in #( 145 | max*) 146 | MAX_FD=$( ulimit -H -n ) || 147 | warn "Could not query maximum file descriptor limit" 148 | esac 149 | case $MAX_FD in #( 150 | '' | soft) :;; #( 151 | *) 152 | ulimit -n "$MAX_FD" || 153 | warn "Could not set maximum file descriptor limit to $MAX_FD" 154 | esac 155 | fi 156 | 157 | # Collect all arguments for the java command, stacking in reverse order: 158 | # * args from the command line 159 | # * the main class name 160 | # * -classpath 161 | # * -D...appname settings 162 | # * --module-path (only if needed) 163 | # * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. 164 | 165 | # For Cygwin or MSYS, switch paths to Windows format before running java 166 | if "$cygwin" || "$msys" ; then 167 | APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) 168 | CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) 169 | 170 | JAVACMD=$( cygpath --unix "$JAVACMD" ) 171 | 172 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 173 | for arg do 174 | if 175 | case $arg in #( 176 | -*) false ;; # don't mess with options #( 177 | /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath 178 | [ -e "$t" ] ;; #( 179 | *) false ;; 180 | esac 181 | then 182 | arg=$( cygpath --path --ignore --mixed "$arg" ) 183 | fi 184 | # Roll the args list around exactly as many times as the number of 185 | # args, so each arg winds up back in the position where it started, but 186 | # possibly modified. 187 | # 188 | # NB: a `for` loop captures its iteration list before it begins, so 189 | # changing the positional parameters here affects neither the number of 190 | # iterations, nor the values presented in `arg`. 191 | shift # remove old arg 192 | set -- "$@" "$arg" # push replacement arg 193 | done 194 | fi 195 | 196 | # Collect all arguments for the java command; 197 | # * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of 198 | # shell script including quotes and variable substitutions, so put them in 199 | # double quotes to make sure that they get re-expanded; and 200 | # * put everything else in single quotes, so that it's not re-expanded. 201 | 202 | set -- \ 203 | "-Dorg.gradle.appname=$APP_BASE_NAME" \ 204 | -classpath "$CLASSPATH" \ 205 | org.gradle.wrapper.GradleWrapperMain \ 206 | "$@" 207 | 208 | # Stop when "xargs" is not available. 209 | if ! command -v xargs >/dev/null 2>&1 210 | then 211 | die "xargs is not available" 212 | fi 213 | 214 | # Use "xargs" to parse quoted args. 215 | # 216 | # With -n1 it outputs one arg per line, with the quotes and backslashes removed. 217 | # 218 | # In Bash we could simply go: 219 | # 220 | # readarray ARGS < <( xargs -n1 <<<"$var" ) && 221 | # set -- "${ARGS[@]}" "$@" 222 | # 223 | # but POSIX shell has neither arrays nor command substitution, so instead we 224 | # post-process each arg (as a line of input to sed) to backslash-escape any 225 | # character that might be a shell metacharacter, then use eval to reverse 226 | # that process (while maintaining the separation between arguments), and wrap 227 | # the whole thing up as a single "set" statement. 228 | # 229 | # This will of course break if any of these variables contains a newline or 230 | # an unmatched quote. 231 | # 232 | 233 | eval "set -- $( 234 | printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | 235 | xargs -n1 | 236 | sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | 237 | tr '\n' ' ' 238 | )" '"$@"' 239 | 240 | exec "$JAVACMD" "$@" 241 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%"=="" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%"=="" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 34 | 35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 37 | 38 | @rem Find java.exe 39 | if defined JAVA_HOME goto findJavaFromJavaHome 40 | 41 | set JAVA_EXE=java.exe 42 | %JAVA_EXE% -version >NUL 2>&1 43 | if %ERRORLEVEL% equ 0 goto execute 44 | 45 | echo. 46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 47 | echo. 48 | echo Please set the JAVA_HOME variable in your environment to match the 49 | echo location of your Java installation. 50 | 51 | goto fail 52 | 53 | :findJavaFromJavaHome 54 | set JAVA_HOME=%JAVA_HOME:"=% 55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 56 | 57 | if exist "%JAVA_EXE%" goto execute 58 | 59 | echo. 60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 61 | echo. 62 | echo Please set the JAVA_HOME variable in your environment to match the 63 | echo location of your Java installation. 64 | 65 | goto fail 66 | 67 | :execute 68 | @rem Setup the command line 69 | 70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 71 | 72 | 73 | @rem Execute Gradle 74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* 75 | 76 | :end 77 | @rem End local scope for the variables with windows NT shell 78 | if %ERRORLEVEL% equ 0 goto mainEnd 79 | 80 | :fail 81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 82 | rem the _cmd.exe /c_ return code! 83 | set EXIT_CODE=%ERRORLEVEL% 84 | if %EXIT_CODE% equ 0 set EXIT_CODE=1 85 | if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% 86 | exit /b %EXIT_CODE% 87 | 88 | :mainEnd 89 | if "%OS%"=="Windows_NT" endlocal 90 | 91 | :omega 92 | -------------------------------------------------------------------------------- /jitpack.yml: -------------------------------------------------------------------------------- 1 | before_install: 2 | - apt-get --quiet update --yes 3 | - apt-get --quiet install --yes wget tar unzip lib32stdc++6 lib32z1 4 | - wget --quiet --output-document=android-sdk.zip https://dl.google.com/android/repository/sdk-tools-linux-4333796.zip 5 | - unzip -d android-sdk-linux android-sdk.zip 6 | - echo y | android-sdk-linux/tools/bin/sdkmanager "platforms;android-33" >/dev/null 7 | - echo y | android-sdk-linux/tools/bin/sdkmanager "platform-tools" >/dev/null 8 | - echo y | android-sdk-linux/tools/bin/sdkmanager "build-tools;33.0.2" >/dev/null 9 | - echo y | android-sdk-linux/tools/bin/sdkmanager "ndk;21.3.6528147" >/dev/null 10 | - export ANDROID_HOME=$PWD/android-sdk-linux 11 | - export PATH=$PATH:$PWD/android-sdk-linux/platform-tools/ 12 | - chmod +x ./gradlew 13 | - set +o pipefail 14 | - yes | android-sdk-linux/tools/bin/sdkmanager --licenses 15 | - set -o pipefail 16 | install: 17 | - wget https://storage.googleapis.com/golang/go1.18.linux-amd64.tar.gz 18 | - tar -C ~ -xzvf go1.18.linux-amd64.tar.gz 19 | - export PATH="~/go/bin:$PATH" 20 | - ./build_android.sh 21 | - ./gradlew build publishToMavenLocal -------------------------------------------------------------------------------- /target.sh: -------------------------------------------------------------------------------- 1 | clangwrap.sh#!/bin/sh 2 | 3 | SDK_PATH=$(xcrun --sdk "$SDK" --show-sdk-path) 4 | export SDK_PATH 5 | 6 | if [ "$GOARCH" = "amd64" ]; then 7 | CARCH="x86_64" 8 | elif [ "$GOARCH" = "arm64" ]; then 9 | CARCH="arm64" 10 | fi 11 | 12 | if [ "$SDK" = "iphoneos" ]; then 13 | export TARGET="$CARCH-apple-ios$MIN_VERSION" 14 | elif [ "$SDK" = "iphonesimulator" ]; then 15 | export TARGET="$CARCH-apple-ios$MIN_VERSION-simulator" 16 | fi -------------------------------------------------------------------------------- /websocket/.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | jobs: 4 | "test": 5 | parameters: 6 | version: 7 | type: string 8 | default: "latest" 9 | golint: 10 | type: boolean 11 | default: true 12 | modules: 13 | type: boolean 14 | default: true 15 | goproxy: 16 | type: string 17 | default: "" 18 | docker: 19 | - image: "cimg/go:<< parameters.version >>" 20 | working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket 21 | environment: 22 | GO111MODULE: "on" 23 | GOPROXY: "<< parameters.goproxy >>" 24 | steps: 25 | - checkout 26 | - run: 27 | name: "Print the Go version" 28 | command: > 29 | go version 30 | - run: 31 | name: "Fetch dependencies" 32 | command: > 33 | if [[ << parameters.modules >> = true ]]; then 34 | go mod download 35 | export GO111MODULE=on 36 | else 37 | go get -v ./... 38 | fi 39 | # Only run gofmt, vet & lint against the latest Go version 40 | - run: 41 | name: "Run golint" 42 | command: > 43 | if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then 44 | go get -u golang.org/x/lint/golint 45 | golint ./... 46 | fi 47 | - run: 48 | name: "Run gofmt" 49 | command: > 50 | if [[ << parameters.version >> = "latest" ]]; then 51 | diff -u <(echo -n) <(gofmt -d -e .) 52 | fi 53 | - run: 54 | name: "Run go vet" 55 | command: > 56 | if [[ << parameters.version >> = "latest" ]]; then 57 | go vet -v ./... 58 | fi 59 | - run: 60 | name: "Run go test (+ race detector)" 61 | command: > 62 | go test -v -race ./... 63 | 64 | workflows: 65 | tests: 66 | jobs: 67 | - test: 68 | matrix: 69 | parameters: 70 | version: ["1.18", "1.17", "1.16"] 71 | -------------------------------------------------------------------------------- /websocket/.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | # Config for https://github.com/apps/release-drafter 2 | template: | 3 | 4 | 5 | 6 | ## CHANGELOG 7 | $CHANGES 8 | -------------------------------------------------------------------------------- /websocket/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | 24 | .idea/ 25 | *.iml 26 | -------------------------------------------------------------------------------- /websocket/AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of Gorilla WebSocket authors for copyright 2 | # purposes. 3 | # 4 | # Please keep the list sorted. 5 | 6 | Gary Burd 7 | Google LLC (https://opensource.google.com/) 8 | Joachim Bauch 9 | 10 | -------------------------------------------------------------------------------- /websocket/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 17 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 18 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 19 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 20 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 21 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | -------------------------------------------------------------------------------- /websocket/README.md: -------------------------------------------------------------------------------- 1 | # Gorilla WebSocket 2 | 3 | [![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) 4 | [![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) 5 | 6 | Gorilla WebSocket is a [Go](http://golang.org/) implementation of the 7 | [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. 8 | 9 | 10 | --- 11 | 12 | ⚠️ **[The Gorilla WebSocket Package is looking for a new maintainer](https://github.com/gorilla/websocket/issues/370)** 13 | 14 | --- 15 | 16 | ### Documentation 17 | 18 | * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) 19 | * [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) 20 | * [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) 21 | * [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) 22 | * [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) 23 | 24 | ### Status 25 | 26 | The Gorilla WebSocket package provides a complete and tested implementation of 27 | the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The 28 | package API is stable. 29 | 30 | ### Installation 31 | 32 | go get github.com/gorilla/websocket 33 | 34 | ### Protocol Compliance 35 | 36 | The Gorilla WebSocket package passes the server tests in the [Autobahn Test 37 | Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn 38 | subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). 39 | 40 | -------------------------------------------------------------------------------- /websocket/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "io/ioutil" 14 | "net" 15 | "net/http" 16 | "net/http/httptrace" 17 | "net/url" 18 | "strings" 19 | "time" 20 | ) 21 | import tls "github.com/refraction-networking/utls" 22 | 23 | // ErrBadHandshake is returned when the server response to opening handshake is 24 | // invalid. 25 | var ErrBadHandshake = errors.New("websocket: bad handshake") 26 | 27 | var errInvalidCompression = errors.New("websocket: invalid compression negotiation") 28 | 29 | // NewClient creates a new client connection using the given net connection. 30 | // The URL u specifies the host and request URI. Use requestHeader to specify 31 | // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies 32 | // (Cookie). Use the response.Header to get the selected subprotocol 33 | // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 34 | // 35 | // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 36 | // non-nil *http.Response so that callers can handle redirects, authentication, 37 | // etc. 38 | // 39 | // Deprecated: Use Dialer instead. 40 | func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { 41 | d := Dialer{ 42 | ReadBufferSize: readBufSize, 43 | WriteBufferSize: writeBufSize, 44 | NetDial: func(net, addr string) (net.Conn, error) { 45 | return netConn, nil 46 | }, 47 | } 48 | return d.Dial(u.String(), requestHeader) 49 | } 50 | 51 | // A Dialer contains options for connecting to WebSocket server. 52 | // 53 | // It is safe to call Dialer's methods concurrently. 54 | type Dialer struct { 55 | // NetDial specifies the dial function for creating TCP connections. If 56 | // NetDial is nil, net.Dial is used. 57 | NetDial func(network, addr string) (net.Conn, error) 58 | 59 | // NetDialContext specifies the dial function for creating TCP connections. If 60 | // NetDialContext is nil, NetDial is used. 61 | NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) 62 | 63 | // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If 64 | // NetDialTLSContext is nil, NetDialContext is used. 65 | // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and 66 | // TLSClientConfig is ignored. 67 | NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) 68 | 69 | // Proxy specifies a function to return a proxy for a given 70 | // Request. If the function returns a non-nil error, the 71 | // request is aborted with the provided error. 72 | // If Proxy is nil or returns a nil *URL, no proxy is used. 73 | Proxy func(*http.Request) (*url.URL, error) 74 | 75 | // TLSClientConfig specifies the TLS configuration to use with tls.Client. 76 | // If nil, the default configuration is used. 77 | // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake 78 | // is done there and TLSClientConfig is ignored. 79 | TLSClientConfig *tls.Config 80 | 81 | // HandshakeTimeout specifies the duration for the handshake to complete. 82 | HandshakeTimeout time.Duration 83 | 84 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 85 | // size is zero, then a useful default size is used. The I/O buffer sizes 86 | // do not limit the size of the messages that can be sent or received. 87 | ReadBufferSize, WriteBufferSize int 88 | 89 | // WriteBufferPool is a pool of buffers for write operations. If the value 90 | // is not set, then write buffers are allocated to the connection for the 91 | // lifetime of the connection. 92 | // 93 | // A pool is most useful when the application has a modest volume of writes 94 | // across a large number of connections. 95 | // 96 | // Applications should use a single pool for each unique value of 97 | // WriteBufferSize. 98 | WriteBufferPool BufferPool 99 | 100 | // Subprotocols specifies the client's requested subprotocols. 101 | Subprotocols []string 102 | 103 | // EnableCompression specifies if the client should attempt to negotiate 104 | // per message compression (RFC 7692). Setting this value to true does not 105 | // guarantee that compression will be supported. Currently only "no context 106 | // takeover" modes are supported. 107 | EnableCompression bool 108 | 109 | // Jar specifies the cookie jar. 110 | // If Jar is nil, cookies are not sent in requests and ignored 111 | // in responses. 112 | Jar http.CookieJar 113 | } 114 | 115 | // Dial creates a new client connection by calling DialContext with a background context. 116 | func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 117 | return d.DialContext(context.Background(), urlStr, requestHeader) 118 | } 119 | 120 | var errMalformedURL = errors.New("malformed ws or wss URL") 121 | 122 | func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { 123 | hostPort = u.Host 124 | hostNoPort = u.Host 125 | if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { 126 | hostNoPort = hostNoPort[:i] 127 | } else { 128 | switch u.Scheme { 129 | case "wss": 130 | hostPort += ":443" 131 | case "https": 132 | hostPort += ":443" 133 | default: 134 | hostPort += ":80" 135 | } 136 | } 137 | return hostPort, hostNoPort 138 | } 139 | 140 | // DefaultDialer is a dialer with all fields set to the default values. 141 | var DefaultDialer = &Dialer{ 142 | Proxy: http.ProxyFromEnvironment, 143 | HandshakeTimeout: 45 * time.Second, 144 | } 145 | 146 | // nilDialer is dialer to use when receiver is nil. 147 | var nilDialer = *DefaultDialer 148 | 149 | // DialContext creates a new client connection. Use requestHeader to specify the 150 | // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). 151 | // Use the response.Header to get the selected subprotocol 152 | // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 153 | // 154 | // The context will be used in the request and in the Dialer. 155 | // 156 | // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 157 | // non-nil *http.Response so that callers can handle redirects, authentication, 158 | // etcetera. The response body may not contain the entire response and does not 159 | // need to be closed by the application. 160 | func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 161 | if d == nil { 162 | d = &nilDialer 163 | } 164 | d.TLSClientConfig = &tls.Config{ 165 | InsecureSkipVerify: true, 166 | } 167 | 168 | challengeKey, err := generateChallengeKey() 169 | if err != nil { 170 | return nil, nil, err 171 | } 172 | 173 | u, err := url.Parse(urlStr) 174 | if err != nil { 175 | return nil, nil, err 176 | } 177 | 178 | switch u.Scheme { 179 | case "ws": 180 | u.Scheme = "http" 181 | case "wss": 182 | u.Scheme = "https" 183 | default: 184 | return nil, nil, errMalformedURL 185 | } 186 | 187 | if u.User != nil { 188 | // User name and password are not allowed in websocket URIs. 189 | return nil, nil, errMalformedURL 190 | } 191 | 192 | req := &http.Request{ 193 | Method: http.MethodGet, 194 | URL: u, 195 | Proto: "HTTP/1.1", 196 | ProtoMajor: 1, 197 | ProtoMinor: 1, 198 | Header: make(http.Header), 199 | Host: u.Host, 200 | } 201 | req = req.WithContext(ctx) 202 | 203 | // Set the cookies present in the cookie jar of the dialer 204 | if d.Jar != nil { 205 | for _, cookie := range d.Jar.Cookies(u) { 206 | req.AddCookie(cookie) 207 | } 208 | } 209 | 210 | // Set the request headers using the capitalization for names and values in 211 | // RFC examples. Although the capitalization shouldn't matter, there are 212 | // servers that depend on it. The Header.Set method is not used because the 213 | // method canonicalizes the header names. 214 | req.Header["Upgrade"] = []string{"websocket"} 215 | req.Header["Connection"] = []string{"Upgrade"} 216 | req.Header["Sec-WebSocket-Key"] = []string{challengeKey} 217 | req.Header["Sec-WebSocket-Version"] = []string{"13"} 218 | if len(d.Subprotocols) > 0 { 219 | req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} 220 | } 221 | for k, vs := range requestHeader { 222 | switch { 223 | case k == "Host": 224 | if len(vs) > 0 { 225 | req.Host = vs[0] 226 | } 227 | case k == "Upgrade" || 228 | k == "Connection" || 229 | k == "Sec-Websocket-Key" || 230 | k == "Sec-Websocket-Version" || 231 | k == "Sec-Websocket-Extensions" || 232 | (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): 233 | return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) 234 | case k == "Sec-Websocket-Protocol": 235 | req.Header["Sec-WebSocket-Protocol"] = vs 236 | default: 237 | req.Header[k] = vs 238 | } 239 | } 240 | 241 | if d.EnableCompression { 242 | req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} 243 | } 244 | 245 | if d.HandshakeTimeout != 0 { 246 | var cancel func() 247 | ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) 248 | defer cancel() 249 | } 250 | 251 | // Get network dial function. 252 | var netDial func(network, add string) (net.Conn, error) 253 | 254 | switch u.Scheme { 255 | case "http": 256 | if d.NetDialContext != nil { 257 | netDial = func(network, addr string) (net.Conn, error) { 258 | return d.NetDialContext(ctx, network, addr) 259 | } 260 | } else if d.NetDial != nil { 261 | netDial = d.NetDial 262 | } 263 | case "https": 264 | if d.NetDialTLSContext != nil { 265 | netDial = func(network, addr string) (net.Conn, error) { 266 | return d.NetDialTLSContext(ctx, network, addr) 267 | } 268 | } else if d.NetDialContext != nil { 269 | netDial = func(network, addr string) (net.Conn, error) { 270 | return d.NetDialContext(ctx, network, addr) 271 | } 272 | } else if d.NetDial != nil { 273 | netDial = d.NetDial 274 | } 275 | default: 276 | return nil, nil, errMalformedURL 277 | } 278 | 279 | if netDial == nil { 280 | netDialer := &net.Dialer{} 281 | netDial = func(network, addr string) (net.Conn, error) { 282 | return netDialer.DialContext(ctx, network, addr) 283 | } 284 | } 285 | 286 | // If needed, wrap the dial function to set the connection deadline. 287 | if deadline, ok := ctx.Deadline(); ok { 288 | forwardDial := netDial 289 | netDial = func(network, addr string) (net.Conn, error) { 290 | c, err := forwardDial(network, addr) 291 | if err != nil { 292 | return nil, err 293 | } 294 | err = c.SetDeadline(deadline) 295 | if err != nil { 296 | c.Close() 297 | return nil, err 298 | } 299 | return c, nil 300 | } 301 | } 302 | 303 | // If needed, wrap the dial function to connect through a proxy. 304 | if d.Proxy != nil { 305 | proxyURL, err := d.Proxy(req) 306 | if err != nil { 307 | return nil, nil, err 308 | } 309 | if proxyURL != nil { 310 | dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) 311 | if err != nil { 312 | return nil, nil, err 313 | } 314 | netDial = dialer.Dial 315 | } 316 | } 317 | 318 | hostPort, hostNoPort := hostPortNoPort(u) 319 | trace := httptrace.ContextClientTrace(ctx) 320 | if trace != nil && trace.GetConn != nil { 321 | trace.GetConn(hostPort) 322 | } 323 | 324 | netConn, err := netDial("tcp", hostPort) 325 | if err != nil { 326 | return nil, nil, err 327 | } 328 | if trace != nil && trace.GotConn != nil { 329 | trace.GotConn(httptrace.GotConnInfo{ 330 | Conn: netConn, 331 | }) 332 | } 333 | 334 | defer func() { 335 | if netConn != nil { 336 | netConn.Close() 337 | } 338 | }() 339 | 340 | if u.Scheme == "https" && d.NetDialTLSContext == nil { 341 | // If NetDialTLSContext is set, assume that the TLS handshake has already been done 342 | 343 | cfg := cloneTLSConfig(d.TLSClientConfig) 344 | if cfg.ServerName == "" { 345 | cfg.ServerName = hostNoPort 346 | } 347 | tlsConn := tls.UClient(netConn, cfg, tls.HelloRandomizedNoALPN) 348 | netConn = tlsConn 349 | 350 | if trace != nil && trace.TLSHandshakeStart != nil { 351 | trace.TLSHandshakeStart() 352 | } 353 | err := doHandshake(ctx, tlsConn, cfg) 354 | //if trace != nil && trace.TLSHandshakeDone != nil { 355 | // trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) 356 | //} 357 | 358 | if err != nil { 359 | return nil, nil, err 360 | } 361 | } 362 | 363 | conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) 364 | 365 | if err := req.Write(netConn); err != nil { 366 | return nil, nil, err 367 | } 368 | 369 | if trace != nil && trace.GotFirstResponseByte != nil { 370 | if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { 371 | trace.GotFirstResponseByte() 372 | } 373 | } 374 | 375 | resp, err := http.ReadResponse(conn.br, req) 376 | if err != nil { 377 | if d.TLSClientConfig != nil { 378 | for _, proto := range d.TLSClientConfig.NextProtos { 379 | if proto != "http/1.1" { 380 | return nil, nil, fmt.Errorf( 381 | "websocket: protocol %q was given but is not supported;"+ 382 | "sharing tls.Config with net/http Transport can cause this error: %w", 383 | proto, err, 384 | ) 385 | } 386 | } 387 | } 388 | return nil, nil, err 389 | } 390 | 391 | if d.Jar != nil { 392 | if rc := resp.Cookies(); len(rc) > 0 { 393 | d.Jar.SetCookies(u, rc) 394 | } 395 | } 396 | 397 | if resp.StatusCode != 101 || 398 | !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || 399 | !tokenListContainsValue(resp.Header, "Connection", "upgrade") || 400 | resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { 401 | // Before closing the network connection on return from this 402 | // function, slurp up some of the response to aid application 403 | // debugging. 404 | buf := make([]byte, 1024) 405 | n, _ := io.ReadFull(resp.Body, buf) 406 | resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) 407 | return nil, resp, ErrBadHandshake 408 | } 409 | 410 | for _, ext := range parseExtensions(resp.Header) { 411 | if ext[""] != "permessage-deflate" { 412 | continue 413 | } 414 | _, snct := ext["server_no_context_takeover"] 415 | _, cnct := ext["client_no_context_takeover"] 416 | if !snct || !cnct { 417 | return nil, resp, errInvalidCompression 418 | } 419 | conn.newCompressionWriter = compressNoContextTakeover 420 | conn.newDecompressionReader = decompressNoContextTakeover 421 | break 422 | } 423 | 424 | resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) 425 | conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") 426 | 427 | netConn.SetDeadline(time.Time{}) 428 | netConn = nil // to avoid close in defer. 429 | return conn, resp, nil 430 | } 431 | 432 | func cloneTLSConfig(cfg *tls.Config) *tls.Config { 433 | if cfg == nil { 434 | return &tls.Config{} 435 | } 436 | return cfg.Clone() 437 | } 438 | -------------------------------------------------------------------------------- /websocket/client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "net/url" 9 | "testing" 10 | ) 11 | 12 | var hostPortNoPortTests = []struct { 13 | u *url.URL 14 | hostPort, hostNoPort string 15 | }{ 16 | {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"}, 17 | {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"}, 18 | {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"}, 19 | {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"}, 20 | } 21 | 22 | func TestHostPortNoPort(t *testing.T) { 23 | for _, tt := range hostPortNoPortTests { 24 | hostPort, hostNoPort := hostPortNoPort(tt.u) 25 | if hostPort != tt.hostPort { 26 | t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort) 27 | } 28 | if hostNoPort != tt.hostNoPort { 29 | t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /websocket/compression.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "compress/flate" 9 | "errors" 10 | "io" 11 | "strings" 12 | "sync" 13 | ) 14 | 15 | const ( 16 | minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 17 | maxCompressionLevel = flate.BestCompression 18 | defaultCompressionLevel = 1 19 | ) 20 | 21 | var ( 22 | flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool 23 | flateReaderPool = sync.Pool{New: func() interface{} { 24 | return flate.NewReader(nil) 25 | }} 26 | ) 27 | 28 | func decompressNoContextTakeover(r io.Reader) io.ReadCloser { 29 | const tail = 30 | // Add four bytes as specified in RFC 31 | "\x00\x00\xff\xff" + 32 | // Add final block to squelch unexpected EOF error from flate reader. 33 | "\x01\x00\x00\xff\xff" 34 | 35 | fr, _ := flateReaderPool.Get().(io.ReadCloser) 36 | fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) 37 | return &flateReadWrapper{fr} 38 | } 39 | 40 | func isValidCompressionLevel(level int) bool { 41 | return minCompressionLevel <= level && level <= maxCompressionLevel 42 | } 43 | 44 | func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { 45 | p := &flateWriterPools[level-minCompressionLevel] 46 | tw := &truncWriter{w: w} 47 | fw, _ := p.Get().(*flate.Writer) 48 | if fw == nil { 49 | fw, _ = flate.NewWriter(tw, level) 50 | } else { 51 | fw.Reset(tw) 52 | } 53 | return &flateWriteWrapper{fw: fw, tw: tw, p: p} 54 | } 55 | 56 | // truncWriter is an io.Writer that writes all but the last four bytes of the 57 | // stream to another io.Writer. 58 | type truncWriter struct { 59 | w io.WriteCloser 60 | n int 61 | p [4]byte 62 | } 63 | 64 | func (w *truncWriter) Write(p []byte) (int, error) { 65 | n := 0 66 | 67 | // fill buffer first for simplicity. 68 | if w.n < len(w.p) { 69 | n = copy(w.p[w.n:], p) 70 | p = p[n:] 71 | w.n += n 72 | if len(p) == 0 { 73 | return n, nil 74 | } 75 | } 76 | 77 | m := len(p) 78 | if m > len(w.p) { 79 | m = len(w.p) 80 | } 81 | 82 | if nn, err := w.w.Write(w.p[:m]); err != nil { 83 | return n + nn, err 84 | } 85 | 86 | copy(w.p[:], w.p[m:]) 87 | copy(w.p[len(w.p)-m:], p[len(p)-m:]) 88 | nn, err := w.w.Write(p[:len(p)-m]) 89 | return n + nn, err 90 | } 91 | 92 | type flateWriteWrapper struct { 93 | fw *flate.Writer 94 | tw *truncWriter 95 | p *sync.Pool 96 | } 97 | 98 | func (w *flateWriteWrapper) Write(p []byte) (int, error) { 99 | if w.fw == nil { 100 | return 0, errWriteClosed 101 | } 102 | return w.fw.Write(p) 103 | } 104 | 105 | func (w *flateWriteWrapper) Close() error { 106 | if w.fw == nil { 107 | return errWriteClosed 108 | } 109 | err1 := w.fw.Flush() 110 | w.p.Put(w.fw) 111 | w.fw = nil 112 | if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { 113 | return errors.New("websocket: internal error, unexpected bytes at end of flate stream") 114 | } 115 | err2 := w.tw.w.Close() 116 | if err1 != nil { 117 | return err1 118 | } 119 | return err2 120 | } 121 | 122 | type flateReadWrapper struct { 123 | fr io.ReadCloser 124 | } 125 | 126 | func (r *flateReadWrapper) Read(p []byte) (int, error) { 127 | if r.fr == nil { 128 | return 0, io.ErrClosedPipe 129 | } 130 | n, err := r.fr.Read(p) 131 | if err == io.EOF { 132 | // Preemptively place the reader back in the pool. This helps with 133 | // scenarios where the application does not call NextReader() soon after 134 | // this final read. 135 | r.Close() 136 | } 137 | return n, err 138 | } 139 | 140 | func (r *flateReadWrapper) Close() error { 141 | if r.fr == nil { 142 | return io.ErrClosedPipe 143 | } 144 | err := r.fr.Close() 145 | flateReaderPool.Put(r.fr) 146 | r.fr = nil 147 | return err 148 | } 149 | -------------------------------------------------------------------------------- /websocket/compression_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "testing" 9 | ) 10 | 11 | type nopCloser struct{ io.Writer } 12 | 13 | func (nopCloser) Close() error { return nil } 14 | 15 | func TestTruncWriter(t *testing.T) { 16 | const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" 17 | for n := 1; n <= 10; n++ { 18 | var b bytes.Buffer 19 | w := &truncWriter{w: nopCloser{&b}} 20 | p := []byte(data) 21 | for len(p) > 0 { 22 | m := len(p) 23 | if m > n { 24 | m = n 25 | } 26 | w.Write(p[:m]) 27 | p = p[m:] 28 | } 29 | if b.String() != data[:len(data)-len(w.p)] { 30 | t.Errorf("%d: %q", n, b.String()) 31 | } 32 | } 33 | } 34 | 35 | func textMessages(num int) [][]byte { 36 | messages := make([][]byte, num) 37 | for i := 0; i < num; i++ { 38 | msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) 39 | messages[i] = []byte(msg) 40 | } 41 | return messages 42 | } 43 | 44 | func BenchmarkWriteNoCompression(b *testing.B) { 45 | w := ioutil.Discard 46 | c := newTestConn(nil, w, false) 47 | messages := textMessages(100) 48 | b.ResetTimer() 49 | for i := 0; i < b.N; i++ { 50 | c.WriteMessage(TextMessage, messages[i%len(messages)]) 51 | } 52 | b.ReportAllocs() 53 | } 54 | 55 | func BenchmarkWriteWithCompression(b *testing.B) { 56 | w := ioutil.Discard 57 | c := newTestConn(nil, w, false) 58 | messages := textMessages(100) 59 | c.enableWriteCompression = true 60 | c.newCompressionWriter = compressNoContextTakeover 61 | b.ResetTimer() 62 | for i := 0; i < b.N; i++ { 63 | c.WriteMessage(TextMessage, messages[i%len(messages)]) 64 | } 65 | b.ReportAllocs() 66 | } 67 | 68 | func TestValidCompressionLevel(t *testing.T) { 69 | c := newTestConn(nil, nil, false) 70 | for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { 71 | if err := c.SetCompressionLevel(level); err == nil { 72 | t.Errorf("no error for level %d", level) 73 | } 74 | } 75 | for _, level := range []int{minCompressionLevel, maxCompressionLevel} { 76 | if err := c.SetCompressionLevel(level); err != nil { 77 | t.Errorf("error for level %d", level) 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /websocket/conn_broadcast_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "io" 9 | "io/ioutil" 10 | "sync/atomic" 11 | "testing" 12 | ) 13 | 14 | // broadcastBench allows to run broadcast benchmarks. 15 | // In every broadcast benchmark we create many connections, then send the same 16 | // message into every connection and wait for all writes complete. This emulates 17 | // an application where many connections listen to the same data - i.e. PUB/SUB 18 | // scenarios with many subscribers in one channel. 19 | type broadcastBench struct { 20 | w io.Writer 21 | closeCh chan struct{} 22 | doneCh chan struct{} 23 | count int32 24 | conns []*broadcastConn 25 | compression bool 26 | usePrepared bool 27 | } 28 | 29 | type broadcastMessage struct { 30 | payload []byte 31 | prepared *PreparedMessage 32 | } 33 | 34 | type broadcastConn struct { 35 | conn *Conn 36 | msgCh chan *broadcastMessage 37 | } 38 | 39 | func newBroadcastConn(c *Conn) *broadcastConn { 40 | return &broadcastConn{ 41 | conn: c, 42 | msgCh: make(chan *broadcastMessage, 1), 43 | } 44 | } 45 | 46 | func newBroadcastBench(usePrepared, compression bool) *broadcastBench { 47 | bench := &broadcastBench{ 48 | w: ioutil.Discard, 49 | doneCh: make(chan struct{}), 50 | closeCh: make(chan struct{}), 51 | usePrepared: usePrepared, 52 | compression: compression, 53 | } 54 | bench.makeConns(10000) 55 | return bench 56 | } 57 | 58 | func (b *broadcastBench) makeConns(numConns int) { 59 | conns := make([]*broadcastConn, numConns) 60 | 61 | for i := 0; i < numConns; i++ { 62 | c := newTestConn(nil, b.w, true) 63 | if b.compression { 64 | c.enableWriteCompression = true 65 | c.newCompressionWriter = compressNoContextTakeover 66 | } 67 | conns[i] = newBroadcastConn(c) 68 | go func(c *broadcastConn) { 69 | for { 70 | select { 71 | case msg := <-c.msgCh: 72 | if msg.prepared != nil { 73 | c.conn.WritePreparedMessage(msg.prepared) 74 | } else { 75 | c.conn.WriteMessage(TextMessage, msg.payload) 76 | } 77 | val := atomic.AddInt32(&b.count, 1) 78 | if val%int32(numConns) == 0 { 79 | b.doneCh <- struct{}{} 80 | } 81 | case <-b.closeCh: 82 | return 83 | } 84 | } 85 | }(conns[i]) 86 | } 87 | b.conns = conns 88 | } 89 | 90 | func (b *broadcastBench) close() { 91 | close(b.closeCh) 92 | } 93 | 94 | func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { 95 | for _, c := range b.conns { 96 | c.msgCh <- msg 97 | } 98 | <-b.doneCh 99 | } 100 | 101 | func BenchmarkBroadcast(b *testing.B) { 102 | benchmarks := []struct { 103 | name string 104 | usePrepared bool 105 | compression bool 106 | }{ 107 | {"NoCompression", false, false}, 108 | {"Compression", false, true}, 109 | {"NoCompressionPrepared", true, false}, 110 | {"CompressionPrepared", true, true}, 111 | } 112 | payload := textMessages(1)[0] 113 | for _, bm := range benchmarks { 114 | b.Run(bm.name, func(b *testing.B) { 115 | bench := newBroadcastBench(bm.usePrepared, bm.compression) 116 | defer bench.close() 117 | b.ResetTimer() 118 | for i := 0; i < b.N; i++ { 119 | message := &broadcastMessage{ 120 | payload: payload, 121 | } 122 | if bench.usePrepared { 123 | pm, _ := NewPreparedMessage(TextMessage, message.payload) 124 | message.prepared = pm 125 | } 126 | bench.broadcastOnce(message) 127 | } 128 | b.ReportAllocs() 129 | }) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /websocket/conn_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "io/ioutil" 14 | "net" 15 | "reflect" 16 | "sync" 17 | "testing" 18 | "testing/iotest" 19 | "time" 20 | ) 21 | 22 | var _ net.Error = errWriteTimeout 23 | 24 | type fakeNetConn struct { 25 | io.Reader 26 | io.Writer 27 | } 28 | 29 | func (c fakeNetConn) Close() error { return nil } 30 | func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } 31 | func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } 32 | func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } 33 | func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } 34 | func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } 35 | 36 | type fakeAddr int 37 | 38 | var ( 39 | localAddr = fakeAddr(1) 40 | remoteAddr = fakeAddr(2) 41 | ) 42 | 43 | func (a fakeAddr) Network() string { 44 | return "net" 45 | } 46 | 47 | func (a fakeAddr) String() string { 48 | return "str" 49 | } 50 | 51 | // newTestConn creates a connnection backed by a fake network connection using 52 | // default values for buffering. 53 | func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { 54 | return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) 55 | } 56 | 57 | func TestFraming(t *testing.T) { 58 | frameSizes := []int{ 59 | 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 60 | // 65536, 65537 61 | } 62 | var readChunkers = []struct { 63 | name string 64 | f func(io.Reader) io.Reader 65 | }{ 66 | {"half", iotest.HalfReader}, 67 | {"one", iotest.OneByteReader}, 68 | {"asis", func(r io.Reader) io.Reader { return r }}, 69 | } 70 | writeBuf := make([]byte, 65537) 71 | for i := range writeBuf { 72 | writeBuf[i] = byte(i) 73 | } 74 | var writers = []struct { 75 | name string 76 | f func(w io.Writer, n int) (int, error) 77 | }{ 78 | {"iocopy", func(w io.Writer, n int) (int, error) { 79 | nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) 80 | return int(nn), err 81 | }}, 82 | {"write", func(w io.Writer, n int) (int, error) { 83 | return w.Write(writeBuf[:n]) 84 | }}, 85 | {"string", func(w io.Writer, n int) (int, error) { 86 | return io.WriteString(w, string(writeBuf[:n])) 87 | }}, 88 | } 89 | 90 | for _, compress := range []bool{false, true} { 91 | for _, isServer := range []bool{true, false} { 92 | for _, chunker := range readChunkers { 93 | 94 | var connBuf bytes.Buffer 95 | wc := newTestConn(nil, &connBuf, isServer) 96 | rc := newTestConn(chunker.f(&connBuf), nil, !isServer) 97 | if compress { 98 | wc.newCompressionWriter = compressNoContextTakeover 99 | rc.newDecompressionReader = decompressNoContextTakeover 100 | } 101 | for _, n := range frameSizes { 102 | for _, writer := range writers { 103 | name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) 104 | 105 | w, err := wc.NextWriter(TextMessage) 106 | if err != nil { 107 | t.Errorf("%s: wc.NextWriter() returned %v", name, err) 108 | continue 109 | } 110 | nn, err := writer.f(w, n) 111 | if err != nil || nn != n { 112 | t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) 113 | continue 114 | } 115 | err = w.Close() 116 | if err != nil { 117 | t.Errorf("%s: w.Close() returned %v", name, err) 118 | continue 119 | } 120 | 121 | opCode, r, err := rc.NextReader() 122 | if err != nil || opCode != TextMessage { 123 | t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) 124 | continue 125 | } 126 | 127 | t.Logf("frame size: %d", n) 128 | rbuf, err := ioutil.ReadAll(r) 129 | if err != nil { 130 | t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) 131 | continue 132 | } 133 | 134 | if len(rbuf) != n { 135 | t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) 136 | continue 137 | } 138 | 139 | for i, b := range rbuf { 140 | if byte(i) != b { 141 | t.Errorf("%s: bad byte at offset %d", name, i) 142 | break 143 | } 144 | } 145 | } 146 | } 147 | } 148 | } 149 | } 150 | } 151 | 152 | func TestControl(t *testing.T) { 153 | const message = "this is a ping/pong messsage" 154 | for _, isServer := range []bool{true, false} { 155 | for _, isWriteControl := range []bool{true, false} { 156 | name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) 157 | var connBuf bytes.Buffer 158 | wc := newTestConn(nil, &connBuf, isServer) 159 | rc := newTestConn(&connBuf, nil, !isServer) 160 | if isWriteControl { 161 | wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) 162 | } else { 163 | w, err := wc.NextWriter(PongMessage) 164 | if err != nil { 165 | t.Errorf("%s: wc.NextWriter() returned %v", name, err) 166 | continue 167 | } 168 | if _, err := w.Write([]byte(message)); err != nil { 169 | t.Errorf("%s: w.Write() returned %v", name, err) 170 | continue 171 | } 172 | if err := w.Close(); err != nil { 173 | t.Errorf("%s: w.Close() returned %v", name, err) 174 | continue 175 | } 176 | var actualMessage string 177 | rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) 178 | rc.NextReader() 179 | if actualMessage != message { 180 | t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) 181 | continue 182 | } 183 | } 184 | } 185 | } 186 | } 187 | 188 | // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. 189 | type simpleBufferPool struct { 190 | v interface{} 191 | } 192 | 193 | func (p *simpleBufferPool) Get() interface{} { 194 | v := p.v 195 | p.v = nil 196 | return v 197 | } 198 | 199 | func (p *simpleBufferPool) Put(v interface{}) { 200 | p.v = v 201 | } 202 | 203 | func TestWriteBufferPool(t *testing.T) { 204 | const message = "Now is the time for all good people to come to the aid of the party." 205 | 206 | var buf bytes.Buffer 207 | var pool simpleBufferPool 208 | rc := newTestConn(&buf, nil, false) 209 | 210 | // Specify writeBufferSize smaller than message size to ensure that pooling 211 | // works with fragmented messages. 212 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) 213 | 214 | if wc.writeBuf != nil { 215 | t.Fatal("writeBuf not nil after create") 216 | } 217 | 218 | // Part 1: test NextWriter/Write/Close 219 | 220 | w, err := wc.NextWriter(TextMessage) 221 | if err != nil { 222 | t.Fatalf("wc.NextWriter() returned %v", err) 223 | } 224 | 225 | if wc.writeBuf == nil { 226 | t.Fatal("writeBuf is nil after NextWriter") 227 | } 228 | 229 | writeBufAddr := &wc.writeBuf[0] 230 | 231 | if _, err := io.WriteString(w, message); err != nil { 232 | t.Fatalf("io.WriteString(w, message) returned %v", err) 233 | } 234 | 235 | if err := w.Close(); err != nil { 236 | t.Fatalf("w.Close() returned %v", err) 237 | } 238 | 239 | if wc.writeBuf != nil { 240 | t.Fatal("writeBuf not nil after w.Close()") 241 | } 242 | 243 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 244 | t.Fatal("writeBuf not returned to pool") 245 | } 246 | 247 | opCode, p, err := rc.ReadMessage() 248 | if opCode != TextMessage || err != nil { 249 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 250 | } 251 | 252 | if s := string(p); s != message { 253 | t.Fatalf("message is %s, want %s", s, message) 254 | } 255 | 256 | // Part 2: Test WriteMessage. 257 | 258 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 259 | t.Fatalf("wc.WriteMessage() returned %v", err) 260 | } 261 | 262 | if wc.writeBuf != nil { 263 | t.Fatal("writeBuf not nil after wc.WriteMessage()") 264 | } 265 | 266 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 267 | t.Fatal("writeBuf not returned to pool after WriteMessage") 268 | } 269 | 270 | opCode, p, err = rc.ReadMessage() 271 | if opCode != TextMessage || err != nil { 272 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 273 | } 274 | 275 | if s := string(p); s != message { 276 | t.Fatalf("message is %s, want %s", s, message) 277 | } 278 | } 279 | 280 | // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. 281 | func TestWriteBufferPoolSync(t *testing.T) { 282 | var buf bytes.Buffer 283 | var pool sync.Pool 284 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) 285 | rc := newTestConn(&buf, nil, false) 286 | 287 | const message = "Hello World!" 288 | for i := 0; i < 3; i++ { 289 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 290 | t.Fatalf("wc.WriteMessage() returned %v", err) 291 | } 292 | opCode, p, err := rc.ReadMessage() 293 | if opCode != TextMessage || err != nil { 294 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 295 | } 296 | if s := string(p); s != message { 297 | t.Fatalf("message is %s, want %s", s, message) 298 | } 299 | } 300 | } 301 | 302 | // errorWriter is an io.Writer than returns an error on all writes. 303 | type errorWriter struct{} 304 | 305 | func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } 306 | 307 | // TestWriteBufferPoolError ensures that buffer is returned to pool after error 308 | // on write. 309 | func TestWriteBufferPoolError(t *testing.T) { 310 | 311 | // Part 1: Test NextWriter/Write/Close 312 | 313 | var pool simpleBufferPool 314 | wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 315 | 316 | w, err := wc.NextWriter(TextMessage) 317 | if err != nil { 318 | t.Fatalf("wc.NextWriter() returned %v", err) 319 | } 320 | 321 | if wc.writeBuf == nil { 322 | t.Fatal("writeBuf is nil after NextWriter") 323 | } 324 | 325 | writeBufAddr := &wc.writeBuf[0] 326 | 327 | if _, err := io.WriteString(w, "Hello"); err != nil { 328 | t.Fatalf("io.WriteString(w, message) returned %v", err) 329 | } 330 | 331 | if err := w.Close(); err == nil { 332 | t.Fatalf("w.Close() did not return error") 333 | } 334 | 335 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 336 | t.Fatal("writeBuf not returned to pool") 337 | } 338 | 339 | // Part 2: Test WriteMessage 340 | 341 | wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 342 | 343 | if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { 344 | t.Fatalf("wc.WriteMessage did not return error") 345 | } 346 | 347 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 348 | t.Fatal("writeBuf not returned to pool") 349 | } 350 | } 351 | 352 | func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { 353 | const bufSize = 512 354 | 355 | expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} 356 | 357 | var b1, b2 bytes.Buffer 358 | wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 359 | rc := newTestConn(&b1, &b2, true) 360 | 361 | w, _ := wc.NextWriter(BinaryMessage) 362 | w.Write(make([]byte, bufSize+bufSize/2)) 363 | wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) 364 | w.Close() 365 | 366 | op, r, err := rc.NextReader() 367 | if op != BinaryMessage || err != nil { 368 | t.Fatalf("NextReader() returned %d, %v", op, err) 369 | } 370 | _, err = io.Copy(ioutil.Discard, r) 371 | if !reflect.DeepEqual(err, expectedErr) { 372 | t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) 373 | } 374 | _, _, err = rc.NextReader() 375 | if !reflect.DeepEqual(err, expectedErr) { 376 | t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) 377 | } 378 | } 379 | 380 | func TestEOFWithinFrame(t *testing.T) { 381 | const bufSize = 64 382 | 383 | for n := 0; ; n++ { 384 | var b bytes.Buffer 385 | wc := newTestConn(nil, &b, false) 386 | rc := newTestConn(&b, nil, true) 387 | 388 | w, _ := wc.NextWriter(BinaryMessage) 389 | w.Write(make([]byte, bufSize)) 390 | w.Close() 391 | 392 | if n >= b.Len() { 393 | break 394 | } 395 | b.Truncate(n) 396 | 397 | op, r, err := rc.NextReader() 398 | if err == errUnexpectedEOF { 399 | continue 400 | } 401 | if op != BinaryMessage || err != nil { 402 | t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) 403 | } 404 | _, err = io.Copy(ioutil.Discard, r) 405 | if err != errUnexpectedEOF { 406 | t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) 407 | } 408 | _, _, err = rc.NextReader() 409 | if err != errUnexpectedEOF { 410 | t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) 411 | } 412 | } 413 | } 414 | 415 | func TestEOFBeforeFinalFrame(t *testing.T) { 416 | const bufSize = 512 417 | 418 | var b1, b2 bytes.Buffer 419 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 420 | rc := newTestConn(&b1, &b2, true) 421 | 422 | w, _ := wc.NextWriter(BinaryMessage) 423 | w.Write(make([]byte, bufSize+bufSize/2)) 424 | 425 | op, r, err := rc.NextReader() 426 | if op != BinaryMessage || err != nil { 427 | t.Fatalf("NextReader() returned %d, %v", op, err) 428 | } 429 | _, err = io.Copy(ioutil.Discard, r) 430 | if err != errUnexpectedEOF { 431 | t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) 432 | } 433 | _, _, err = rc.NextReader() 434 | if err != errUnexpectedEOF { 435 | t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) 436 | } 437 | } 438 | 439 | func TestWriteAfterMessageWriterClose(t *testing.T) { 440 | wc := newTestConn(nil, &bytes.Buffer{}, false) 441 | w, _ := wc.NextWriter(BinaryMessage) 442 | io.WriteString(w, "hello") 443 | if err := w.Close(); err != nil { 444 | t.Fatalf("unxpected error closing message writer, %v", err) 445 | } 446 | 447 | if _, err := io.WriteString(w, "world"); err == nil { 448 | t.Fatalf("no error writing after close") 449 | } 450 | 451 | w, _ = wc.NextWriter(BinaryMessage) 452 | io.WriteString(w, "hello") 453 | 454 | // close w by getting next writer 455 | _, err := wc.NextWriter(BinaryMessage) 456 | if err != nil { 457 | t.Fatalf("unexpected error getting next writer, %v", err) 458 | } 459 | 460 | if _, err := io.WriteString(w, "world"); err == nil { 461 | t.Fatalf("no error writing after close") 462 | } 463 | } 464 | 465 | func TestReadLimit(t *testing.T) { 466 | t.Run("Test ReadLimit is enforced", func(t *testing.T) { 467 | const readLimit = 512 468 | message := make([]byte, readLimit+1) 469 | 470 | var b1, b2 bytes.Buffer 471 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) 472 | rc := newTestConn(&b1, &b2, true) 473 | rc.SetReadLimit(readLimit) 474 | 475 | // Send message at the limit with interleaved pong. 476 | w, _ := wc.NextWriter(BinaryMessage) 477 | w.Write(message[:readLimit-1]) 478 | wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) 479 | w.Write(message[:1]) 480 | w.Close() 481 | 482 | // Send message larger than the limit. 483 | wc.WriteMessage(BinaryMessage, message[:readLimit+1]) 484 | 485 | op, _, err := rc.NextReader() 486 | if op != BinaryMessage || err != nil { 487 | t.Fatalf("1: NextReader() returned %d, %v", op, err) 488 | } 489 | op, r, err := rc.NextReader() 490 | if op != BinaryMessage || err != nil { 491 | t.Fatalf("2: NextReader() returned %d, %v", op, err) 492 | } 493 | _, err = io.Copy(ioutil.Discard, r) 494 | if err != ErrReadLimit { 495 | t.Fatalf("io.Copy() returned %v", err) 496 | } 497 | }) 498 | 499 | t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { 500 | const readLimit = 1 501 | 502 | var b1, b2 bytes.Buffer 503 | rc := newTestConn(&b1, &b2, true) 504 | rc.SetReadLimit(readLimit) 505 | 506 | // First, send a non-final binary message 507 | b1.Write([]byte("\x02\x81")) 508 | 509 | // Mask key 510 | b1.Write([]byte("\x00\x00\x00\x00")) 511 | 512 | // First payload 513 | b1.Write([]byte("A")) 514 | 515 | // Next, send a negative-length, non-final continuation frame 516 | b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) 517 | 518 | // Mask key 519 | b1.Write([]byte("\x00\x00\x00\x00")) 520 | 521 | // Next, send a too long, final continuation frame 522 | b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) 523 | 524 | // Mask key 525 | b1.Write([]byte("\x00\x00\x00\x00")) 526 | 527 | // Too-long payload 528 | b1.Write([]byte("BCDEF")) 529 | 530 | op, r, err := rc.NextReader() 531 | if op != BinaryMessage || err != nil { 532 | t.Fatalf("1: NextReader() returned %d, %v", op, err) 533 | } 534 | 535 | var buf [10]byte 536 | var read int 537 | n, err := r.Read(buf[:]) 538 | if err != nil && err != ErrReadLimit { 539 | t.Fatalf("unexpected error testing read limit: %v", err) 540 | } 541 | read += n 542 | 543 | n, err = r.Read(buf[:]) 544 | if err != nil && err != ErrReadLimit { 545 | t.Fatalf("unexpected error testing read limit: %v", err) 546 | } 547 | read += n 548 | 549 | if err == nil && read > readLimit { 550 | t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) 551 | } 552 | }) 553 | } 554 | 555 | func TestAddrs(t *testing.T) { 556 | c := newTestConn(nil, nil, true) 557 | if c.LocalAddr() != localAddr { 558 | t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) 559 | } 560 | if c.RemoteAddr() != remoteAddr { 561 | t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) 562 | } 563 | } 564 | 565 | func TestDeprecatedUnderlyingConn(t *testing.T) { 566 | var b1, b2 bytes.Buffer 567 | fc := fakeNetConn{Reader: &b1, Writer: &b2} 568 | c := newConn(fc, true, 1024, 1024, nil, nil, nil) 569 | ul := c.UnderlyingConn() 570 | if ul != fc { 571 | t.Fatalf("Underlying conn is not what it should be.") 572 | } 573 | } 574 | 575 | func TestNetConn(t *testing.T) { 576 | var b1, b2 bytes.Buffer 577 | fc := fakeNetConn{Reader: &b1, Writer: &b2} 578 | c := newConn(fc, true, 1024, 1024, nil, nil, nil) 579 | ul := c.NetConn() 580 | if ul != fc { 581 | t.Fatalf("Underlying conn is not what it should be.") 582 | } 583 | } 584 | 585 | func TestBufioReadBytes(t *testing.T) { 586 | // Test calling bufio.ReadBytes for value longer than read buffer size. 587 | 588 | m := make([]byte, 512) 589 | m[len(m)-1] = '\n' 590 | 591 | var b1, b2 bytes.Buffer 592 | wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) 593 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) 594 | 595 | w, _ := wc.NextWriter(BinaryMessage) 596 | w.Write(m) 597 | w.Close() 598 | 599 | op, r, err := rc.NextReader() 600 | if op != BinaryMessage || err != nil { 601 | t.Fatalf("NextReader() returned %d, %v", op, err) 602 | } 603 | 604 | br := bufio.NewReader(r) 605 | p, err := br.ReadBytes('\n') 606 | if err != nil { 607 | t.Fatalf("ReadBytes() returned %v", err) 608 | } 609 | if len(p) != len(m) { 610 | t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) 611 | } 612 | } 613 | 614 | var closeErrorTests = []struct { 615 | err error 616 | codes []int 617 | ok bool 618 | }{ 619 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, 620 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, 621 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, 622 | {errors.New("hello"), []int{CloseNormalClosure}, false}, 623 | } 624 | 625 | func TestCloseError(t *testing.T) { 626 | for _, tt := range closeErrorTests { 627 | ok := IsCloseError(tt.err, tt.codes...) 628 | if ok != tt.ok { 629 | t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 630 | } 631 | } 632 | } 633 | 634 | var unexpectedCloseErrorTests = []struct { 635 | err error 636 | codes []int 637 | ok bool 638 | }{ 639 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, 640 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, 641 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, 642 | {errors.New("hello"), []int{CloseNormalClosure}, false}, 643 | } 644 | 645 | func TestUnexpectedCloseErrors(t *testing.T) { 646 | for _, tt := range unexpectedCloseErrorTests { 647 | ok := IsUnexpectedCloseError(tt.err, tt.codes...) 648 | if ok != tt.ok { 649 | t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 650 | } 651 | } 652 | } 653 | 654 | type blockingWriter struct { 655 | c1, c2 chan struct{} 656 | } 657 | 658 | func (w blockingWriter) Write(p []byte) (int, error) { 659 | // Allow main to continue 660 | close(w.c1) 661 | // Wait for panic in main 662 | <-w.c2 663 | return len(p), nil 664 | } 665 | 666 | func TestConcurrentWritePanic(t *testing.T) { 667 | w := blockingWriter{make(chan struct{}), make(chan struct{})} 668 | c := newTestConn(nil, w, false) 669 | go func() { 670 | c.WriteMessage(TextMessage, []byte{}) 671 | }() 672 | 673 | // wait for goroutine to block in write. 674 | <-w.c1 675 | 676 | defer func() { 677 | close(w.c2) 678 | if v := recover(); v != nil { 679 | return 680 | } 681 | }() 682 | 683 | c.WriteMessage(TextMessage, []byte{}) 684 | t.Fatal("should not get here") 685 | } 686 | 687 | type failingReader struct{} 688 | 689 | func (r failingReader) Read(p []byte) (int, error) { 690 | return 0, io.EOF 691 | } 692 | 693 | func TestFailedConnectionReadPanic(t *testing.T) { 694 | c := newTestConn(failingReader{}, nil, false) 695 | 696 | defer func() { 697 | if v := recover(); v != nil { 698 | return 699 | } 700 | }() 701 | 702 | for i := 0; i < 20000; i++ { 703 | c.ReadMessage() 704 | } 705 | t.Fatal("should not get here") 706 | } 707 | -------------------------------------------------------------------------------- /websocket/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package websocket implements the WebSocket protocol defined in RFC 6455. 6 | // 7 | // Overview 8 | // 9 | // The Conn type represents a WebSocket connection. A server application calls 10 | // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: 11 | // 12 | // var upgrader = websocket.Upgrader{ 13 | // ReadBufferSize: 1024, 14 | // WriteBufferSize: 1024, 15 | // } 16 | // 17 | // func handler(w http.ResponseWriter, r *http.Request) { 18 | // conn, err := upgrader.Upgrade(w, r, nil) 19 | // if err != nil { 20 | // log.Println(err) 21 | // return 22 | // } 23 | // ... Use conn to send and receive messages. 24 | // } 25 | // 26 | // Call the connection's WriteMessage and ReadMessage methods to send and 27 | // receive messages as a slice of bytes. This snippet of code shows how to echo 28 | // messages using these methods: 29 | // 30 | // for { 31 | // messageType, p, err := conn.ReadMessage() 32 | // if err != nil { 33 | // log.Println(err) 34 | // return 35 | // } 36 | // if err := conn.WriteMessage(messageType, p); err != nil { 37 | // log.Println(err) 38 | // return 39 | // } 40 | // } 41 | // 42 | // In above snippet of code, p is a []byte and messageType is an int with value 43 | // websocket.BinaryMessage or websocket.TextMessage. 44 | // 45 | // An application can also send and receive messages using the io.WriteCloser 46 | // and io.Reader interfaces. To send a message, call the connection NextWriter 47 | // method to get an io.WriteCloser, write the message to the writer and close 48 | // the writer when done. To receive a message, call the connection NextReader 49 | // method to get an io.Reader and read until io.EOF is returned. This snippet 50 | // shows how to echo messages using the NextWriter and NextReader methods: 51 | // 52 | // for { 53 | // messageType, r, err := conn.NextReader() 54 | // if err != nil { 55 | // return 56 | // } 57 | // w, err := conn.NextWriter(messageType) 58 | // if err != nil { 59 | // return err 60 | // } 61 | // if _, err := io.Copy(w, r); err != nil { 62 | // return err 63 | // } 64 | // if err := w.Close(); err != nil { 65 | // return err 66 | // } 67 | // } 68 | // 69 | // Data Messages 70 | // 71 | // The WebSocket protocol distinguishes between text and binary data messages. 72 | // Text messages are interpreted as UTF-8 encoded text. The interpretation of 73 | // binary messages is left to the application. 74 | // 75 | // This package uses the TextMessage and BinaryMessage integer constants to 76 | // identify the two data message types. The ReadMessage and NextReader methods 77 | // return the type of the received message. The messageType argument to the 78 | // WriteMessage and NextWriter methods specifies the type of a sent message. 79 | // 80 | // It is the application's responsibility to ensure that text messages are 81 | // valid UTF-8 encoded text. 82 | // 83 | // Control Messages 84 | // 85 | // The WebSocket protocol defines three types of control messages: close, ping 86 | // and pong. Call the connection WriteControl, WriteMessage or NextWriter 87 | // methods to send a control message to the peer. 88 | // 89 | // Connections handle received close messages by calling the handler function 90 | // set with the SetCloseHandler method and by returning a *CloseError from the 91 | // NextReader, ReadMessage or the message Read method. The default close 92 | // handler sends a close message to the peer. 93 | // 94 | // Connections handle received ping messages by calling the handler function 95 | // set with the SetPingHandler method. The default ping handler sends a pong 96 | // message to the peer. 97 | // 98 | // Connections handle received pong messages by calling the handler function 99 | // set with the SetPongHandler method. The default pong handler does nothing. 100 | // If an application sends ping messages, then the application should set a 101 | // pong handler to receive the corresponding pong. 102 | // 103 | // The control message handler functions are called from the NextReader, 104 | // ReadMessage and message reader Read methods. The default close and ping 105 | // handlers can block these methods for a short time when the handler writes to 106 | // the connection. 107 | // 108 | // The application must read the connection to process close, ping and pong 109 | // messages sent from the peer. If the application is not otherwise interested 110 | // in messages from the peer, then the application should start a goroutine to 111 | // read and discard messages from the peer. A simple example is: 112 | // 113 | // func readLoop(c *websocket.Conn) { 114 | // for { 115 | // if _, _, err := c.NextReader(); err != nil { 116 | // c.Close() 117 | // break 118 | // } 119 | // } 120 | // } 121 | // 122 | // Concurrency 123 | // 124 | // Connections support one concurrent reader and one concurrent writer. 125 | // 126 | // Applications are responsible for ensuring that no more than one goroutine 127 | // calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, 128 | // WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and 129 | // that no more than one goroutine calls the read methods (NextReader, 130 | // SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) 131 | // concurrently. 132 | // 133 | // The Close and WriteControl methods can be called concurrently with all other 134 | // methods. 135 | // 136 | // Origin Considerations 137 | // 138 | // Web browsers allow Javascript applications to open a WebSocket connection to 139 | // any host. It's up to the server to enforce an origin policy using the Origin 140 | // request header sent by the browser. 141 | // 142 | // The Upgrader calls the function specified in the CheckOrigin field to check 143 | // the origin. If the CheckOrigin function returns false, then the Upgrade 144 | // method fails the WebSocket handshake with HTTP status 403. 145 | // 146 | // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail 147 | // the handshake if the Origin request header is present and the Origin host is 148 | // not equal to the Host request header. 149 | // 150 | // The deprecated package-level Upgrade function does not perform origin 151 | // checking. The application is responsible for checking the Origin header 152 | // before calling the Upgrade function. 153 | // 154 | // Buffers 155 | // 156 | // Connections buffer network input and output to reduce the number 157 | // of system calls when reading or writing messages. 158 | // 159 | // Write buffers are also used for constructing WebSocket frames. See RFC 6455, 160 | // Section 5 for a discussion of message framing. A WebSocket frame header is 161 | // written to the network each time a write buffer is flushed to the network. 162 | // Decreasing the size of the write buffer can increase the amount of framing 163 | // overhead on the connection. 164 | // 165 | // The buffer sizes in bytes are specified by the ReadBufferSize and 166 | // WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default 167 | // size of 4096 when a buffer size field is set to zero. The Upgrader reuses 168 | // buffers created by the HTTP server when a buffer size field is set to zero. 169 | // The HTTP server buffers have a size of 4096 at the time of this writing. 170 | // 171 | // The buffer sizes do not limit the size of a message that can be read or 172 | // written by a connection. 173 | // 174 | // Buffers are held for the lifetime of the connection by default. If the 175 | // Dialer or Upgrader WriteBufferPool field is set, then a connection holds the 176 | // write buffer only when writing a message. 177 | // 178 | // Applications should tune the buffer sizes to balance memory use and 179 | // performance. Increasing the buffer size uses more memory, but can reduce the 180 | // number of system calls to read or write the network. In the case of writing, 181 | // increasing the buffer size can reduce the number of frame headers written to 182 | // the network. 183 | // 184 | // Some guidelines for setting buffer parameters are: 185 | // 186 | // Limit the buffer sizes to the maximum expected message size. Buffers larger 187 | // than the largest message do not provide any benefit. 188 | // 189 | // Depending on the distribution of message sizes, setting the buffer size to 190 | // a value less than the maximum expected message size can greatly reduce memory 191 | // use with a small impact on performance. Here's an example: If 99% of the 192 | // messages are smaller than 256 bytes and the maximum message size is 512 193 | // bytes, then a buffer size of 256 bytes will result in 1.01 more system calls 194 | // than a buffer size of 512 bytes. The memory savings is 50%. 195 | // 196 | // A write buffer pool is useful when the application has a modest number 197 | // writes over a large number of connections. when buffers are pooled, a larger 198 | // buffer size has a reduced impact on total memory use and has the benefit of 199 | // reducing system calls and frame overhead. 200 | // 201 | // Compression EXPERIMENTAL 202 | // 203 | // Per message compression extensions (RFC 7692) are experimentally supported 204 | // by this package in a limited capacity. Setting the EnableCompression option 205 | // to true in Dialer or Upgrader will attempt to negotiate per message deflate 206 | // support. 207 | // 208 | // var upgrader = websocket.Upgrader{ 209 | // EnableCompression: true, 210 | // } 211 | // 212 | // If compression was successfully negotiated with the connection's peer, any 213 | // message received in compressed form will be automatically decompressed. 214 | // All Read methods will return uncompressed bytes. 215 | // 216 | // Per message compression of messages written to a connection can be enabled 217 | // or disabled by calling the corresponding Conn method: 218 | // 219 | // conn.EnableWriteCompression(false) 220 | // 221 | // Currently this package does not support compression with "context takeover". 222 | // This means that messages must be compressed and decompressed in isolation, 223 | // without retaining sliding window or dictionary state across messages. For 224 | // more details refer to RFC 7692. 225 | // 226 | // Use of compression is experimental and may result in decreased performance. 227 | package websocket 228 | -------------------------------------------------------------------------------- /websocket/example_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket_test 6 | 7 | import ( 8 | "log" 9 | "net/http" 10 | "testing" 11 | 12 | "github.com/gorilla/websocket" 13 | ) 14 | 15 | var ( 16 | c *websocket.Conn 17 | req *http.Request 18 | ) 19 | 20 | // The websocket.IsUnexpectedCloseError function is useful for identifying 21 | // application and protocol errors. 22 | // 23 | // This server application works with a client application running in the 24 | // browser. The client application does not explicitly close the websocket. The 25 | // only expected close message from the client has the code 26 | // websocket.CloseGoingAway. All other close messages are likely the 27 | // result of an application or protocol error and are logged to aid debugging. 28 | func ExampleIsUnexpectedCloseError() { 29 | for { 30 | messageType, p, err := c.ReadMessage() 31 | if err != nil { 32 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { 33 | log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent")) 34 | } 35 | return 36 | } 37 | processMessage(messageType, p) 38 | } 39 | } 40 | 41 | func processMessage(mt int, p []byte) {} 42 | 43 | // TestX prevents godoc from showing this entire file in the example. Remove 44 | // this function when a second example is added. 45 | func TestX(t *testing.T) {} 46 | -------------------------------------------------------------------------------- /websocket/examples/autobahn/README.md: -------------------------------------------------------------------------------- 1 | # Test Server 2 | 3 | This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite). 4 | 5 | To test the server, run 6 | 7 | go run server.go 8 | 9 | and start the client test driver 10 | 11 | mkdir -p reports 12 | docker run -it --rm \ 13 | -v ${PWD}/config:/config \ 14 | -v ${PWD}/reports:/reports \ 15 | crossbario/autobahn-testsuite \ 16 | wstest -m fuzzingclient -s /config/fuzzingclient.json 17 | 18 | When the client completes, it writes a report to reports/index.html. 19 | -------------------------------------------------------------------------------- /websocket/examples/autobahn/config/fuzzingclient.json: -------------------------------------------------------------------------------- 1 | { 2 | "cases": ["*"], 3 | "exclude-cases": [], 4 | "exclude-agent-cases": {}, 5 | "outdir": "/reports", 6 | "options": {"failByDrop": false}, 7 | "servers": [ 8 | { 9 | "agent": "ReadAllWriteMessage", 10 | "url": "ws://host.docker.internal:9000/m" 11 | }, 12 | { 13 | "agent": "ReadAllWritePreparedMessage", 14 | "url": "ws://host.docker.internal:9000/p" 15 | }, 16 | { 17 | "agent": "CopyFull", 18 | "url": "ws://host.docker.internal:9000/f" 19 | }, 20 | { 21 | "agent": "ReadAllWrite", 22 | "url": "ws://host.docker.internal:9000/r" 23 | }, 24 | { 25 | "agent": "CopyWriterOnly", 26 | "url": "ws://host.docker.internal:9000/c" 27 | } 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /websocket/examples/autobahn/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Command server is a test server for the Autobahn WebSockets Test Suite. 6 | package main 7 | 8 | import ( 9 | "errors" 10 | "flag" 11 | "io" 12 | "log" 13 | "net/http" 14 | "time" 15 | "unicode/utf8" 16 | 17 | "github.com/gorilla/websocket" 18 | ) 19 | 20 | var upgrader = websocket.Upgrader{ 21 | ReadBufferSize: 4096, 22 | WriteBufferSize: 4096, 23 | EnableCompression: true, 24 | CheckOrigin: func(r *http.Request) bool { 25 | return true 26 | }, 27 | } 28 | 29 | // echoCopy echoes messages from the client using io.Copy. 30 | func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) { 31 | conn, err := upgrader.Upgrade(w, r, nil) 32 | if err != nil { 33 | log.Println("Upgrade:", err) 34 | return 35 | } 36 | defer conn.Close() 37 | for { 38 | mt, r, err := conn.NextReader() 39 | if err != nil { 40 | if err != io.EOF { 41 | log.Println("NextReader:", err) 42 | } 43 | return 44 | } 45 | if mt == websocket.TextMessage { 46 | r = &validator{r: r} 47 | } 48 | w, err := conn.NextWriter(mt) 49 | if err != nil { 50 | log.Println("NextWriter:", err) 51 | return 52 | } 53 | if mt == websocket.TextMessage { 54 | r = &validator{r: r} 55 | } 56 | if writerOnly { 57 | _, err = io.Copy(struct{ io.Writer }{w}, r) 58 | } else { 59 | _, err = io.Copy(w, r) 60 | } 61 | if err != nil { 62 | if err == errInvalidUTF8 { 63 | conn.WriteControl(websocket.CloseMessage, 64 | websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), 65 | time.Time{}) 66 | } 67 | log.Println("Copy:", err) 68 | return 69 | } 70 | err = w.Close() 71 | if err != nil { 72 | log.Println("Close:", err) 73 | return 74 | } 75 | } 76 | } 77 | 78 | func echoCopyWriterOnly(w http.ResponseWriter, r *http.Request) { 79 | echoCopy(w, r, true) 80 | } 81 | 82 | func echoCopyFull(w http.ResponseWriter, r *http.Request) { 83 | echoCopy(w, r, false) 84 | } 85 | 86 | // echoReadAll echoes messages from the client by reading the entire message 87 | // with ioutil.ReadAll. 88 | func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { 89 | conn, err := upgrader.Upgrade(w, r, nil) 90 | if err != nil { 91 | log.Println("Upgrade:", err) 92 | return 93 | } 94 | defer conn.Close() 95 | for { 96 | mt, b, err := conn.ReadMessage() 97 | if err != nil { 98 | if err != io.EOF { 99 | log.Println("NextReader:", err) 100 | } 101 | return 102 | } 103 | if mt == websocket.TextMessage { 104 | if !utf8.Valid(b) { 105 | conn.WriteControl(websocket.CloseMessage, 106 | websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), 107 | time.Time{}) 108 | log.Println("ReadAll: invalid utf8") 109 | } 110 | } 111 | if writeMessage { 112 | if !writePrepared { 113 | err = conn.WriteMessage(mt, b) 114 | if err != nil { 115 | log.Println("WriteMessage:", err) 116 | } 117 | } else { 118 | pm, err := websocket.NewPreparedMessage(mt, b) 119 | if err != nil { 120 | log.Println("NewPreparedMessage:", err) 121 | return 122 | } 123 | err = conn.WritePreparedMessage(pm) 124 | if err != nil { 125 | log.Println("WritePreparedMessage:", err) 126 | } 127 | } 128 | } else { 129 | w, err := conn.NextWriter(mt) 130 | if err != nil { 131 | log.Println("NextWriter:", err) 132 | return 133 | } 134 | if _, err := w.Write(b); err != nil { 135 | log.Println("Writer:", err) 136 | return 137 | } 138 | if err := w.Close(); err != nil { 139 | log.Println("Close:", err) 140 | return 141 | } 142 | } 143 | } 144 | } 145 | 146 | func echoReadAllWriter(w http.ResponseWriter, r *http.Request) { 147 | echoReadAll(w, r, false, false) 148 | } 149 | 150 | func echoReadAllWriteMessage(w http.ResponseWriter, r *http.Request) { 151 | echoReadAll(w, r, true, false) 152 | } 153 | 154 | func echoReadAllWritePreparedMessage(w http.ResponseWriter, r *http.Request) { 155 | echoReadAll(w, r, true, true) 156 | } 157 | 158 | func serveHome(w http.ResponseWriter, r *http.Request) { 159 | if r.URL.Path != "/" { 160 | http.Error(w, "Not found.", http.StatusNotFound) 161 | return 162 | } 163 | if r.Method != http.MethodGet { 164 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 165 | return 166 | } 167 | w.Header().Set("Content-Type", "text/html; charset=utf-8") 168 | io.WriteString(w, "Echo Server") 169 | } 170 | 171 | var addr = flag.String("addr", ":9000", "http service address") 172 | 173 | func main() { 174 | flag.Parse() 175 | http.HandleFunc("/", serveHome) 176 | http.HandleFunc("/c", echoCopyWriterOnly) 177 | http.HandleFunc("/f", echoCopyFull) 178 | http.HandleFunc("/r", echoReadAllWriter) 179 | http.HandleFunc("/m", echoReadAllWriteMessage) 180 | http.HandleFunc("/p", echoReadAllWritePreparedMessage) 181 | err := http.ListenAndServe(*addr, nil) 182 | if err != nil { 183 | log.Fatal("ListenAndServe: ", err) 184 | } 185 | } 186 | 187 | type validator struct { 188 | state int 189 | x rune 190 | r io.Reader 191 | } 192 | 193 | var errInvalidUTF8 = errors.New("invalid utf8") 194 | 195 | func (r *validator) Read(p []byte) (int, error) { 196 | n, err := r.r.Read(p) 197 | state := r.state 198 | x := r.x 199 | for _, b := range p[:n] { 200 | state, x = decode(state, x, b) 201 | if state == utf8Reject { 202 | break 203 | } 204 | } 205 | r.state = state 206 | r.x = x 207 | if state == utf8Reject || (err == io.EOF && state != utf8Accept) { 208 | return n, errInvalidUTF8 209 | } 210 | return n, err 211 | } 212 | 213 | // UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ 214 | // 215 | // Copyright (c) 2008-2009 Bjoern Hoehrmann 216 | // 217 | // Permission is hereby granted, free of charge, to any person obtaining a copy 218 | // of this software and associated documentation files (the "Software"), to 219 | // deal in the Software without restriction, including without limitation the 220 | // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 221 | // sell copies of the Software, and to permit persons to whom the Software is 222 | // furnished to do so, subject to the following conditions: 223 | // 224 | // The above copyright notice and this permission notice shall be included in 225 | // all copies or substantial portions of the Software. 226 | // 227 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 228 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 229 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 230 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 231 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 232 | // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 233 | // IN THE SOFTWARE. 234 | var utf8d = [...]byte{ 235 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1f 236 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3f 237 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5f 238 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7f 239 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9f 240 | 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // a0..bf 241 | 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // c0..df 242 | 0xa, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // e0..ef 243 | 0xb, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // f0..ff 244 | 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 245 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 246 | 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 247 | 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 248 | 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // s7..s8 249 | } 250 | 251 | const ( 252 | utf8Accept = 0 253 | utf8Reject = 1 254 | ) 255 | 256 | func decode(state int, x rune, b byte) (int, rune) { 257 | t := utf8d[b] 258 | if state != utf8Accept { 259 | x = rune(b&0x3f) | (x << 6) 260 | } else { 261 | x = rune((0xff >> t) & b) 262 | } 263 | state = int(utf8d[256+state*16+int(t)]) 264 | return state, x 265 | } 266 | -------------------------------------------------------------------------------- /websocket/examples/chat/README.md: -------------------------------------------------------------------------------- 1 | # Chat Example 2 | 3 | This application shows how to use the 4 | [websocket](https://github.com/gorilla/websocket) package to implement a simple 5 | web chat application. 6 | 7 | ## Running the example 8 | 9 | The example requires a working Go development environment. The [Getting 10 | Started](http://golang.org/doc/install) page describes how to install the 11 | development environment. 12 | 13 | Once you have Go up and running, you can download, build and run the example 14 | using the following commands. 15 | 16 | $ go get github.com/gorilla/websocket 17 | $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/chat` 18 | $ go run *.go 19 | 20 | To use the chat example, open http://localhost:8080/ in your browser. 21 | 22 | ## Server 23 | 24 | The server application defines two types, `Client` and `Hub`. The server 25 | creates an instance of the `Client` type for each websocket connection. A 26 | `Client` acts as an intermediary between the websocket connection and a single 27 | instance of the `Hub` type. The `Hub` maintains a set of registered clients and 28 | broadcasts messages to the clients. 29 | 30 | The application runs one goroutine for the `Hub` and two goroutines for each 31 | `Client`. The goroutines communicate with each other using channels. The `Hub` 32 | has channels for registering clients, unregistering clients and broadcasting 33 | messages. A `Client` has a buffered channel of outbound messages. One of the 34 | client's goroutines reads messages from this channel and writes the messages to 35 | the websocket. The other client goroutine reads messages from the websocket and 36 | sends them to the hub. 37 | 38 | ### Hub 39 | 40 | The code for the `Hub` type is in 41 | [hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). 42 | The application's `main` function starts the hub's `run` method as a goroutine. 43 | Clients send requests to the hub using the `register`, `unregister` and 44 | `broadcast` channels. 45 | 46 | The hub registers clients by adding the client pointer as a key in the 47 | `clients` map. The map value is always true. 48 | 49 | The unregister code is a little more complicated. In addition to deleting the 50 | client pointer from the `clients` map, the hub closes the clients's `send` 51 | channel to signal the client that no more messages will be sent to the client. 52 | 53 | The hub handles messages by looping over the registered clients and sending the 54 | message to the client's `send` channel. If the client's `send` buffer is full, 55 | then the hub assumes that the client is dead or stuck. In this case, the hub 56 | unregisters the client and closes the websocket. 57 | 58 | ### Client 59 | 60 | The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). 61 | 62 | The `serveWs` function is registered by the application's `main` function as 63 | an HTTP handler. The handler upgrades the HTTP connection to the WebSocket 64 | protocol, creates a client, registers the client with the hub and schedules the 65 | client to be unregistered using a defer statement. 66 | 67 | Next, the HTTP handler starts the client's `writePump` method as a goroutine. 68 | This method transfers messages from the client's send channel to the websocket 69 | connection. The writer method exits when the channel is closed by the hub or 70 | there's an error writing to the websocket connection. 71 | 72 | Finally, the HTTP handler calls the client's `readPump` method. This method 73 | transfers inbound messages from the websocket to the hub. 74 | 75 | WebSocket connections [support one concurrent reader and one concurrent 76 | writer](https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency). The 77 | application ensures that these concurrency requirements are met by executing 78 | all reads from the `readPump` goroutine and all writes from the `writePump` 79 | goroutine. 80 | 81 | To improve efficiency under high load, the `writePump` function coalesces 82 | pending chat messages in the `send` channel to a single WebSocket message. This 83 | reduces the number of system calls and the amount of data sent over the 84 | network. 85 | 86 | ## Frontend 87 | 88 | The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). 89 | 90 | On document load, the script checks for websocket functionality in the browser. 91 | If websocket functionality is available, then the script opens a connection to 92 | the server and registers a callback to handle messages from the server. The 93 | callback appends the message to the chat log using the appendLog function. 94 | 95 | To allow the user to manually scroll through the chat log without interruption 96 | from new messages, the `appendLog` function checks the scroll position before 97 | adding new content. If the chat log is scrolled to the bottom, then the 98 | function scrolls new content into view after adding the content. Otherwise, the 99 | scroll position is not changed. 100 | 101 | The form handler writes the user input to the websocket and clears the input 102 | field. 103 | -------------------------------------------------------------------------------- /websocket/examples/chat/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "log" 10 | "net/http" 11 | "time" 12 | 13 | "github.com/gorilla/websocket" 14 | ) 15 | 16 | const ( 17 | // Time allowed to write a message to the peer. 18 | writeWait = 10 * time.Second 19 | 20 | // Time allowed to read the next pong message from the peer. 21 | pongWait = 60 * time.Second 22 | 23 | // Send pings to peer with this period. Must be less than pongWait. 24 | pingPeriod = (pongWait * 9) / 10 25 | 26 | // Maximum message size allowed from peer. 27 | maxMessageSize = 512 28 | ) 29 | 30 | var ( 31 | newline = []byte{'\n'} 32 | space = []byte{' '} 33 | ) 34 | 35 | var upgrader = websocket.Upgrader{ 36 | ReadBufferSize: 1024, 37 | WriteBufferSize: 1024, 38 | } 39 | 40 | // Client is a middleman between the websocket connection and the hub. 41 | type Client struct { 42 | hub *Hub 43 | 44 | // The websocket connection. 45 | conn *websocket.Conn 46 | 47 | // Buffered channel of outbound messages. 48 | send chan []byte 49 | } 50 | 51 | // readPump pumps messages from the websocket connection to the hub. 52 | // 53 | // The application runs readPump in a per-connection goroutine. The application 54 | // ensures that there is at most one reader on a connection by executing all 55 | // reads from this goroutine. 56 | func (c *Client) readPump() { 57 | defer func() { 58 | c.hub.unregister <- c 59 | c.conn.Close() 60 | }() 61 | c.conn.SetReadLimit(maxMessageSize) 62 | c.conn.SetReadDeadline(time.Now().Add(pongWait)) 63 | c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 64 | for { 65 | _, message, err := c.conn.ReadMessage() 66 | if err != nil { 67 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 68 | log.Printf("error: %v", err) 69 | } 70 | break 71 | } 72 | message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) 73 | c.hub.broadcast <- message 74 | } 75 | } 76 | 77 | // writePump pumps messages from the hub to the websocket connection. 78 | // 79 | // A goroutine running writePump is started for each connection. The 80 | // application ensures that there is at most one writer to a connection by 81 | // executing all writes from this goroutine. 82 | func (c *Client) writePump() { 83 | ticker := time.NewTicker(pingPeriod) 84 | defer func() { 85 | ticker.Stop() 86 | c.conn.Close() 87 | }() 88 | for { 89 | select { 90 | case message, ok := <-c.send: 91 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 92 | if !ok { 93 | // The hub closed the channel. 94 | c.conn.WriteMessage(websocket.CloseMessage, []byte{}) 95 | return 96 | } 97 | 98 | w, err := c.conn.NextWriter(websocket.TextMessage) 99 | if err != nil { 100 | return 101 | } 102 | w.Write(message) 103 | 104 | // Add queued chat messages to the current websocket message. 105 | n := len(c.send) 106 | for i := 0; i < n; i++ { 107 | w.Write(newline) 108 | w.Write(<-c.send) 109 | } 110 | 111 | if err := w.Close(); err != nil { 112 | return 113 | } 114 | case <-ticker.C: 115 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 116 | if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 117 | return 118 | } 119 | } 120 | } 121 | } 122 | 123 | // serveWs handles websocket requests from the peer. 124 | func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { 125 | conn, err := upgrader.Upgrade(w, r, nil) 126 | if err != nil { 127 | log.Println(err) 128 | return 129 | } 130 | client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} 131 | client.hub.register <- client 132 | 133 | // Allow collection of memory referenced by the caller by doing all work in 134 | // new goroutines. 135 | go client.writePump() 136 | go client.readPump() 137 | } 138 | -------------------------------------------------------------------------------- /websocket/examples/chat/home.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Chat Example 5 | 53 | 90 | 91 | 92 |
93 |
94 | 95 | 96 |
97 | 98 | 99 | -------------------------------------------------------------------------------- /websocket/examples/chat/hub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | // Hub maintains the set of active clients and broadcasts messages to the 8 | // clients. 9 | type Hub struct { 10 | // Registered clients. 11 | clients map[*Client]bool 12 | 13 | // Inbound messages from the clients. 14 | broadcast chan []byte 15 | 16 | // Register requests from the clients. 17 | register chan *Client 18 | 19 | // Unregister requests from clients. 20 | unregister chan *Client 21 | } 22 | 23 | func newHub() *Hub { 24 | return &Hub{ 25 | broadcast: make(chan []byte), 26 | register: make(chan *Client), 27 | unregister: make(chan *Client), 28 | clients: make(map[*Client]bool), 29 | } 30 | } 31 | 32 | func (h *Hub) run() { 33 | for { 34 | select { 35 | case client := <-h.register: 36 | h.clients[client] = true 37 | case client := <-h.unregister: 38 | if _, ok := h.clients[client]; ok { 39 | delete(h.clients, client) 40 | close(client.send) 41 | } 42 | case message := <-h.broadcast: 43 | for client := range h.clients { 44 | select { 45 | case client.send <- message: 46 | default: 47 | close(client.send) 48 | delete(h.clients, client) 49 | } 50 | } 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /websocket/examples/chat/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "log" 10 | "net/http" 11 | ) 12 | 13 | var addr = flag.String("addr", ":8080", "http service address") 14 | 15 | func serveHome(w http.ResponseWriter, r *http.Request) { 16 | log.Println(r.URL) 17 | if r.URL.Path != "/" { 18 | http.Error(w, "Not found", http.StatusNotFound) 19 | return 20 | } 21 | if r.Method != http.MethodGet { 22 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 23 | return 24 | } 25 | http.ServeFile(w, r, "home.html") 26 | } 27 | 28 | func main() { 29 | flag.Parse() 30 | hub := newHub() 31 | go hub.run() 32 | http.HandleFunc("/", serveHome) 33 | http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { 34 | serveWs(hub, w, r) 35 | }) 36 | err := http.ListenAndServe(*addr, nil) 37 | if err != nil { 38 | log.Fatal("ListenAndServe: ", err) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /websocket/examples/command/README.md: -------------------------------------------------------------------------------- 1 | # Command example 2 | 3 | This example connects a websocket connection to stdin and stdout of a command. 4 | Received messages are written to stdin followed by a `\n`. Each line read from 5 | standard out is sent as a message to the client. 6 | 7 | $ go get github.com/gorilla/websocket 8 | $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/command` 9 | $ go run main.go 10 | # Open http://localhost:8080/ . 11 | 12 | Try the following commands. 13 | 14 | # Echo sent messages to the output area. 15 | $ go run main.go cat 16 | 17 | # Run a shell.Try sending "ls" and "cat main.go". 18 | $ go run main.go sh 19 | 20 | -------------------------------------------------------------------------------- /websocket/examples/command/home.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Command Example 5 | 53 | 94 | 95 | 96 |
97 |
98 | 99 | 100 |
101 | 102 | 103 | -------------------------------------------------------------------------------- /websocket/examples/command/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "bufio" 9 | "flag" 10 | "io" 11 | "log" 12 | "net/http" 13 | "os" 14 | "os/exec" 15 | "time" 16 | 17 | "github.com/gorilla/websocket" 18 | ) 19 | 20 | var ( 21 | addr = flag.String("addr", "127.0.0.1:8080", "http service address") 22 | cmdPath string 23 | ) 24 | 25 | const ( 26 | // Time allowed to write a message to the peer. 27 | writeWait = 10 * time.Second 28 | 29 | // Maximum message size allowed from peer. 30 | maxMessageSize = 8192 31 | 32 | // Time allowed to read the next pong message from the peer. 33 | pongWait = 60 * time.Second 34 | 35 | // Send pings to peer with this period. Must be less than pongWait. 36 | pingPeriod = (pongWait * 9) / 10 37 | 38 | // Time to wait before force close on connection. 39 | closeGracePeriod = 10 * time.Second 40 | ) 41 | 42 | func pumpStdin(ws *websocket.Conn, w io.Writer) { 43 | defer ws.Close() 44 | ws.SetReadLimit(maxMessageSize) 45 | ws.SetReadDeadline(time.Now().Add(pongWait)) 46 | ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 47 | for { 48 | _, message, err := ws.ReadMessage() 49 | if err != nil { 50 | break 51 | } 52 | message = append(message, '\n') 53 | if _, err := w.Write(message); err != nil { 54 | break 55 | } 56 | } 57 | } 58 | 59 | func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { 60 | defer func() { 61 | }() 62 | s := bufio.NewScanner(r) 63 | for s.Scan() { 64 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 65 | if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil { 66 | ws.Close() 67 | break 68 | } 69 | } 70 | if s.Err() != nil { 71 | log.Println("scan:", s.Err()) 72 | } 73 | close(done) 74 | 75 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 76 | ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 77 | time.Sleep(closeGracePeriod) 78 | ws.Close() 79 | } 80 | 81 | func ping(ws *websocket.Conn, done chan struct{}) { 82 | ticker := time.NewTicker(pingPeriod) 83 | defer ticker.Stop() 84 | for { 85 | select { 86 | case <-ticker.C: 87 | if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { 88 | log.Println("ping:", err) 89 | } 90 | case <-done: 91 | return 92 | } 93 | } 94 | } 95 | 96 | func internalError(ws *websocket.Conn, msg string, err error) { 97 | log.Println(msg, err) 98 | ws.WriteMessage(websocket.TextMessage, []byte("Internal server error.")) 99 | } 100 | 101 | var upgrader = websocket.Upgrader{} 102 | 103 | func serveWs(w http.ResponseWriter, r *http.Request) { 104 | ws, err := upgrader.Upgrade(w, r, nil) 105 | if err != nil { 106 | log.Println("upgrade:", err) 107 | return 108 | } 109 | 110 | defer ws.Close() 111 | 112 | outr, outw, err := os.Pipe() 113 | if err != nil { 114 | internalError(ws, "stdout:", err) 115 | return 116 | } 117 | defer outr.Close() 118 | defer outw.Close() 119 | 120 | inr, inw, err := os.Pipe() 121 | if err != nil { 122 | internalError(ws, "stdin:", err) 123 | return 124 | } 125 | defer inr.Close() 126 | defer inw.Close() 127 | 128 | proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{ 129 | Files: []*os.File{inr, outw, outw}, 130 | }) 131 | if err != nil { 132 | internalError(ws, "start:", err) 133 | return 134 | } 135 | 136 | inr.Close() 137 | outw.Close() 138 | 139 | stdoutDone := make(chan struct{}) 140 | go pumpStdout(ws, outr, stdoutDone) 141 | go ping(ws, stdoutDone) 142 | 143 | pumpStdin(ws, inw) 144 | 145 | // Some commands will exit when stdin is closed. 146 | inw.Close() 147 | 148 | // Other commands need a bonk on the head. 149 | if err := proc.Signal(os.Interrupt); err != nil { 150 | log.Println("inter:", err) 151 | } 152 | 153 | select { 154 | case <-stdoutDone: 155 | case <-time.After(time.Second): 156 | // A bigger bonk on the head. 157 | if err := proc.Signal(os.Kill); err != nil { 158 | log.Println("term:", err) 159 | } 160 | <-stdoutDone 161 | } 162 | 163 | if _, err := proc.Wait(); err != nil { 164 | log.Println("wait:", err) 165 | } 166 | } 167 | 168 | func serveHome(w http.ResponseWriter, r *http.Request) { 169 | if r.URL.Path != "/" { 170 | http.Error(w, "Not found", http.StatusNotFound) 171 | return 172 | } 173 | if r.Method != http.MethodGet { 174 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 175 | return 176 | } 177 | http.ServeFile(w, r, "home.html") 178 | } 179 | 180 | func main() { 181 | flag.Parse() 182 | if len(flag.Args()) < 1 { 183 | log.Fatal("must specify at least one argument") 184 | } 185 | var err error 186 | cmdPath, err = exec.LookPath(flag.Args()[0]) 187 | if err != nil { 188 | log.Fatal(err) 189 | } 190 | http.HandleFunc("/", serveHome) 191 | http.HandleFunc("/ws", serveWs) 192 | log.Fatal(http.ListenAndServe(*addr, nil)) 193 | } 194 | -------------------------------------------------------------------------------- /websocket/examples/echo/README.md: -------------------------------------------------------------------------------- 1 | # Client and server example 2 | 3 | This example shows a simple client and server. 4 | 5 | The server echoes messages sent to it. The client sends a message every second 6 | and prints all messages received. 7 | 8 | To run the example, start the server: 9 | 10 | $ go run server.go 11 | 12 | Next, start the client: 13 | 14 | $ go run client.go 15 | 16 | The server includes a simple web client. To use the client, open 17 | http://127.0.0.1:8080 in the browser and follow the instructions on the page. 18 | -------------------------------------------------------------------------------- /websocket/examples/echo/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | package main 9 | 10 | import ( 11 | "flag" 12 | "log" 13 | "net/url" 14 | "os" 15 | "os/signal" 16 | "time" 17 | 18 | "github.com/gorilla/websocket" 19 | ) 20 | 21 | var addr = flag.String("addr", "localhost:8080", "http service address") 22 | 23 | func main() { 24 | flag.Parse() 25 | log.SetFlags(0) 26 | 27 | interrupt := make(chan os.Signal, 1) 28 | signal.Notify(interrupt, os.Interrupt) 29 | 30 | u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"} 31 | log.Printf("connecting to %s", u.String()) 32 | 33 | c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) 34 | if err != nil { 35 | log.Fatal("dial:", err) 36 | } 37 | defer c.Close() 38 | 39 | done := make(chan struct{}) 40 | 41 | go func() { 42 | defer close(done) 43 | for { 44 | _, message, err := c.ReadMessage() 45 | if err != nil { 46 | log.Println("read:", err) 47 | return 48 | } 49 | log.Printf("recv: %s", message) 50 | } 51 | }() 52 | 53 | ticker := time.NewTicker(time.Second) 54 | defer ticker.Stop() 55 | 56 | for { 57 | select { 58 | case <-done: 59 | return 60 | case t := <-ticker.C: 61 | err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) 62 | if err != nil { 63 | log.Println("write:", err) 64 | return 65 | } 66 | case <-interrupt: 67 | log.Println("interrupt") 68 | 69 | // Cleanly close the connection by sending a close message and then 70 | // waiting (with timeout) for the server to close the connection. 71 | err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 72 | if err != nil { 73 | log.Println("write close:", err) 74 | return 75 | } 76 | select { 77 | case <-done: 78 | case <-time.After(time.Second): 79 | } 80 | return 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /websocket/examples/echo/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | package main 9 | 10 | import ( 11 | "flag" 12 | "html/template" 13 | "log" 14 | "net/http" 15 | 16 | "github.com/gorilla/websocket" 17 | ) 18 | 19 | var addr = flag.String("addr", "localhost:8080", "http service address") 20 | 21 | var upgrader = websocket.Upgrader{} // use default options 22 | 23 | func echo(w http.ResponseWriter, r *http.Request) { 24 | c, err := upgrader.Upgrade(w, r, nil) 25 | if err != nil { 26 | log.Print("upgrade:", err) 27 | return 28 | } 29 | defer c.Close() 30 | for { 31 | mt, message, err := c.ReadMessage() 32 | if err != nil { 33 | log.Println("read:", err) 34 | break 35 | } 36 | log.Printf("recv: %s", message) 37 | err = c.WriteMessage(mt, message) 38 | if err != nil { 39 | log.Println("write:", err) 40 | break 41 | } 42 | } 43 | } 44 | 45 | func home(w http.ResponseWriter, r *http.Request) { 46 | homeTemplate.Execute(w, "ws://"+r.Host+"/echo") 47 | } 48 | 49 | func main() { 50 | flag.Parse() 51 | log.SetFlags(0) 52 | http.HandleFunc("/echo", echo) 53 | http.HandleFunc("/", home) 54 | log.Fatal(http.ListenAndServe(*addr, nil)) 55 | } 56 | 57 | var homeTemplate = template.Must(template.New("").Parse(` 58 | 59 | 60 | 61 | 62 | 116 | 117 | 118 | 119 |
120 |

Click "Open" to create a connection to the server, 121 | "Send" to send a message to the server and "Close" to close the connection. 122 | You can change the message and send multiple times. 123 |

124 |

125 | 126 | 127 |

128 | 129 |

130 |
131 |
132 |
133 | 134 | 135 | `)) 136 | -------------------------------------------------------------------------------- /websocket/examples/filewatch/README.md: -------------------------------------------------------------------------------- 1 | # File Watch example. 2 | 3 | This example sends a file to the browser client for display whenever the file is modified. 4 | 5 | $ go get github.com/gorilla/websocket 6 | $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/filewatch` 7 | $ go run main.go 8 | # Open http://localhost:8080/ . 9 | # Modify the file to see it update in the browser. 10 | -------------------------------------------------------------------------------- /websocket/examples/filewatch/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "html/template" 10 | "io/ioutil" 11 | "log" 12 | "net/http" 13 | "os" 14 | "strconv" 15 | "time" 16 | 17 | "github.com/gorilla/websocket" 18 | ) 19 | 20 | const ( 21 | // Time allowed to write the file to the client. 22 | writeWait = 10 * time.Second 23 | 24 | // Time allowed to read the next pong message from the client. 25 | pongWait = 60 * time.Second 26 | 27 | // Send pings to client with this period. Must be less than pongWait. 28 | pingPeriod = (pongWait * 9) / 10 29 | 30 | // Poll file for changes with this period. 31 | filePeriod = 10 * time.Second 32 | ) 33 | 34 | var ( 35 | addr = flag.String("addr", ":8080", "http service address") 36 | homeTempl = template.Must(template.New("").Parse(homeHTML)) 37 | filename string 38 | upgrader = websocket.Upgrader{ 39 | ReadBufferSize: 1024, 40 | WriteBufferSize: 1024, 41 | } 42 | ) 43 | 44 | func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) { 45 | fi, err := os.Stat(filename) 46 | if err != nil { 47 | return nil, lastMod, err 48 | } 49 | if !fi.ModTime().After(lastMod) { 50 | return nil, lastMod, nil 51 | } 52 | p, err := ioutil.ReadFile(filename) 53 | if err != nil { 54 | return nil, fi.ModTime(), err 55 | } 56 | return p, fi.ModTime(), nil 57 | } 58 | 59 | func reader(ws *websocket.Conn) { 60 | defer ws.Close() 61 | ws.SetReadLimit(512) 62 | ws.SetReadDeadline(time.Now().Add(pongWait)) 63 | ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 64 | for { 65 | _, _, err := ws.ReadMessage() 66 | if err != nil { 67 | break 68 | } 69 | } 70 | } 71 | 72 | func writer(ws *websocket.Conn, lastMod time.Time) { 73 | lastError := "" 74 | pingTicker := time.NewTicker(pingPeriod) 75 | fileTicker := time.NewTicker(filePeriod) 76 | defer func() { 77 | pingTicker.Stop() 78 | fileTicker.Stop() 79 | ws.Close() 80 | }() 81 | for { 82 | select { 83 | case <-fileTicker.C: 84 | var p []byte 85 | var err error 86 | 87 | p, lastMod, err = readFileIfModified(lastMod) 88 | 89 | if err != nil { 90 | if s := err.Error(); s != lastError { 91 | lastError = s 92 | p = []byte(lastError) 93 | } 94 | } else { 95 | lastError = "" 96 | } 97 | 98 | if p != nil { 99 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 100 | if err := ws.WriteMessage(websocket.TextMessage, p); err != nil { 101 | return 102 | } 103 | } 104 | case <-pingTicker.C: 105 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 106 | if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { 107 | return 108 | } 109 | } 110 | } 111 | } 112 | 113 | func serveWs(w http.ResponseWriter, r *http.Request) { 114 | ws, err := upgrader.Upgrade(w, r, nil) 115 | if err != nil { 116 | if _, ok := err.(websocket.HandshakeError); !ok { 117 | log.Println(err) 118 | } 119 | return 120 | } 121 | 122 | var lastMod time.Time 123 | if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err == nil { 124 | lastMod = time.Unix(0, n) 125 | } 126 | 127 | go writer(ws, lastMod) 128 | reader(ws) 129 | } 130 | 131 | func serveHome(w http.ResponseWriter, r *http.Request) { 132 | if r.URL.Path != "/" { 133 | http.Error(w, "Not found", http.StatusNotFound) 134 | return 135 | } 136 | if r.Method != http.MethodGet { 137 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 138 | return 139 | } 140 | w.Header().Set("Content-Type", "text/html; charset=utf-8") 141 | p, lastMod, err := readFileIfModified(time.Time{}) 142 | if err != nil { 143 | p = []byte(err.Error()) 144 | lastMod = time.Unix(0, 0) 145 | } 146 | var v = struct { 147 | Host string 148 | Data string 149 | LastMod string 150 | }{ 151 | r.Host, 152 | string(p), 153 | strconv.FormatInt(lastMod.UnixNano(), 16), 154 | } 155 | homeTempl.Execute(w, &v) 156 | } 157 | 158 | func main() { 159 | flag.Parse() 160 | if flag.NArg() != 1 { 161 | log.Fatal("filename not specified") 162 | } 163 | filename = flag.Args()[0] 164 | http.HandleFunc("/", serveHome) 165 | http.HandleFunc("/ws", serveWs) 166 | if err := http.ListenAndServe(*addr, nil); err != nil { 167 | log.Fatal(err) 168 | } 169 | } 170 | 171 | const homeHTML = ` 172 | 173 | 174 | WebSocket Example 175 | 176 | 177 |
{{.Data}}
178 | 191 | 192 | 193 | ` 194 | -------------------------------------------------------------------------------- /websocket/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gorilla/websocket 2 | 3 | go 1.12 4 | 5 | require github.com/refraction-networking/utls v1.1.5 6 | -------------------------------------------------------------------------------- /websocket/go.sum: -------------------------------------------------------------------------------- 1 | github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= 2 | github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= 3 | github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= 4 | github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= 5 | github.com/refraction-networking/utls v1.1.5 h1:JtrojoNhbUQkBqEg05sP3gDgDj6hIEAAVKbI9lx4n6w= 6 | github.com/refraction-networking/utls v1.1.5/go.mod h1:jRQxtYi7nkq1p28HF2lwOH5zQm9aC8rpK0O9lIIzGh8= 7 | golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM= 8 | golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 9 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 10 | golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= 11 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 12 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 13 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 14 | golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= 15 | golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 16 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 17 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 18 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 19 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 20 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 21 | -------------------------------------------------------------------------------- /websocket/join.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "io" 9 | "strings" 10 | ) 11 | 12 | // JoinMessages concatenates received messages to create a single io.Reader. 13 | // The string term is appended to each message. The returned reader does not 14 | // support concurrent calls to the Read method. 15 | func JoinMessages(c *Conn, term string) io.Reader { 16 | return &joinReader{c: c, term: term} 17 | } 18 | 19 | type joinReader struct { 20 | c *Conn 21 | term string 22 | r io.Reader 23 | } 24 | 25 | func (r *joinReader) Read(p []byte) (int, error) { 26 | if r.r == nil { 27 | var err error 28 | _, r.r, err = r.c.NextReader() 29 | if err != nil { 30 | return 0, err 31 | } 32 | if r.term != "" { 33 | r.r = io.MultiReader(r.r, strings.NewReader(r.term)) 34 | } 35 | } 36 | n, err := r.r.Read(p) 37 | if err == io.EOF { 38 | err = nil 39 | r.r = nil 40 | } 41 | return n, err 42 | } 43 | -------------------------------------------------------------------------------- /websocket/join_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "io" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | func TestJoinMessages(t *testing.T) { 15 | messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} 16 | for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { 17 | for _, term := range []string{"", ","} { 18 | var connBuf bytes.Buffer 19 | wc := newTestConn(nil, &connBuf, true) 20 | rc := newTestConn(&connBuf, nil, false) 21 | for _, m := range messages { 22 | wc.WriteMessage(BinaryMessage, []byte(m)) 23 | } 24 | 25 | var result bytes.Buffer 26 | _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk)) 27 | if IsUnexpectedCloseError(err, CloseAbnormalClosure) { 28 | t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err) 29 | } 30 | want := strings.Join(messages, term) + term 31 | if result.String() != want { 32 | t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want) 33 | } 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /websocket/json.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "encoding/json" 9 | "io" 10 | ) 11 | 12 | // WriteJSON writes the JSON encoding of v as a message. 13 | // 14 | // Deprecated: Use c.WriteJSON instead. 15 | func WriteJSON(c *Conn, v interface{}) error { 16 | return c.WriteJSON(v) 17 | } 18 | 19 | // WriteJSON writes the JSON encoding of v as a message. 20 | // 21 | // See the documentation for encoding/json Marshal for details about the 22 | // conversion of Go values to JSON. 23 | func (c *Conn) WriteJSON(v interface{}) error { 24 | w, err := c.NextWriter(TextMessage) 25 | if err != nil { 26 | return err 27 | } 28 | err1 := json.NewEncoder(w).Encode(v) 29 | err2 := w.Close() 30 | if err1 != nil { 31 | return err1 32 | } 33 | return err2 34 | } 35 | 36 | // ReadJSON reads the next JSON-encoded message from the connection and stores 37 | // it in the value pointed to by v. 38 | // 39 | // Deprecated: Use c.ReadJSON instead. 40 | func ReadJSON(c *Conn, v interface{}) error { 41 | return c.ReadJSON(v) 42 | } 43 | 44 | // ReadJSON reads the next JSON-encoded message from the connection and stores 45 | // it in the value pointed to by v. 46 | // 47 | // See the documentation for the encoding/json Unmarshal function for details 48 | // about the conversion of JSON to a Go value. 49 | func (c *Conn) ReadJSON(v interface{}) error { 50 | _, r, err := c.NextReader() 51 | if err != nil { 52 | return err 53 | } 54 | err = json.NewDecoder(r).Decode(v) 55 | if err == io.EOF { 56 | // One value is expected in the message. 57 | err = io.ErrUnexpectedEOF 58 | } 59 | return err 60 | } 61 | -------------------------------------------------------------------------------- /websocket/json_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "encoding/json" 10 | "io" 11 | "reflect" 12 | "testing" 13 | ) 14 | 15 | func TestJSON(t *testing.T) { 16 | var buf bytes.Buffer 17 | wc := newTestConn(nil, &buf, true) 18 | rc := newTestConn(&buf, nil, false) 19 | 20 | var actual, expect struct { 21 | A int 22 | B string 23 | } 24 | expect.A = 1 25 | expect.B = "hello" 26 | 27 | if err := wc.WriteJSON(&expect); err != nil { 28 | t.Fatal("write", err) 29 | } 30 | 31 | if err := rc.ReadJSON(&actual); err != nil { 32 | t.Fatal("read", err) 33 | } 34 | 35 | if !reflect.DeepEqual(&actual, &expect) { 36 | t.Fatal("equal", actual, expect) 37 | } 38 | } 39 | 40 | func TestPartialJSONRead(t *testing.T) { 41 | var buf0, buf1 bytes.Buffer 42 | wc := newTestConn(nil, &buf0, true) 43 | rc := newTestConn(&buf0, &buf1, false) 44 | 45 | var v struct { 46 | A int 47 | B string 48 | } 49 | v.A = 1 50 | v.B = "hello" 51 | 52 | messageCount := 0 53 | 54 | // Partial JSON values. 55 | 56 | data, err := json.Marshal(v) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | for i := len(data) - 1; i >= 0; i-- { 61 | if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { 62 | t.Fatal(err) 63 | } 64 | messageCount++ 65 | } 66 | 67 | // Whitespace. 68 | 69 | if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { 70 | t.Fatal(err) 71 | } 72 | messageCount++ 73 | 74 | // Close. 75 | 76 | if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | for i := 0; i < messageCount; i++ { 81 | err := rc.ReadJSON(&v) 82 | if err != io.ErrUnexpectedEOF { 83 | t.Error("read", i, err) 84 | } 85 | } 86 | 87 | err = rc.ReadJSON(&v) 88 | if _, ok := err.(*CloseError); !ok { 89 | t.Error("final", err) 90 | } 91 | } 92 | 93 | func TestDeprecatedJSON(t *testing.T) { 94 | var buf bytes.Buffer 95 | wc := newTestConn(nil, &buf, true) 96 | rc := newTestConn(&buf, nil, false) 97 | 98 | var actual, expect struct { 99 | A int 100 | B string 101 | } 102 | expect.A = 1 103 | expect.B = "hello" 104 | 105 | if err := WriteJSON(wc, &expect); err != nil { 106 | t.Fatal("write", err) 107 | } 108 | 109 | if err := ReadJSON(rc, &actual); err != nil { 110 | t.Fatal("read", err) 111 | } 112 | 113 | if !reflect.DeepEqual(&actual, &expect) { 114 | t.Fatal("equal", actual, expect) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /websocket/mask.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 | // this source code is governed by a BSD-style license that can be found in the 3 | // LICENSE file. 4 | 5 | //go:build !appengine 6 | // +build !appengine 7 | 8 | package websocket 9 | 10 | import "unsafe" 11 | 12 | const wordSize = int(unsafe.Sizeof(uintptr(0))) 13 | 14 | func maskBytes(key [4]byte, pos int, b []byte) int { 15 | // Mask one byte at a time for small buffers. 16 | if len(b) < 2*wordSize { 17 | for i := range b { 18 | b[i] ^= key[pos&3] 19 | pos++ 20 | } 21 | return pos & 3 22 | } 23 | 24 | // Mask one byte at a time to word boundary. 25 | if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { 26 | n = wordSize - n 27 | for i := range b[:n] { 28 | b[i] ^= key[pos&3] 29 | pos++ 30 | } 31 | b = b[n:] 32 | } 33 | 34 | // Create aligned word size key. 35 | var k [wordSize]byte 36 | for i := range k { 37 | k[i] = key[(pos+i)&3] 38 | } 39 | kw := *(*uintptr)(unsafe.Pointer(&k)) 40 | 41 | // Mask one word at a time. 42 | n := (len(b) / wordSize) * wordSize 43 | for i := 0; i < n; i += wordSize { 44 | *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw 45 | } 46 | 47 | // Mask one byte at a time for remaining bytes. 48 | b = b[n:] 49 | for i := range b { 50 | b[i] ^= key[pos&3] 51 | pos++ 52 | } 53 | 54 | return pos & 3 55 | } 56 | -------------------------------------------------------------------------------- /websocket/mask_safe.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 | // this source code is governed by a BSD-style license that can be found in the 3 | // LICENSE file. 4 | 5 | //go:build appengine 6 | // +build appengine 7 | 8 | package websocket 9 | 10 | func maskBytes(key [4]byte, pos int, b []byte) int { 11 | for i := range b { 12 | b[i] ^= key[pos&3] 13 | pos++ 14 | } 15 | return pos & 3 16 | } 17 | -------------------------------------------------------------------------------- /websocket/mask_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 | // this source code is governed by a BSD-style license that can be found in the 3 | // LICENSE file. 4 | 5 | // !appengine 6 | 7 | package websocket 8 | 9 | import ( 10 | "fmt" 11 | "testing" 12 | ) 13 | 14 | func maskBytesByByte(key [4]byte, pos int, b []byte) int { 15 | for i := range b { 16 | b[i] ^= key[pos&3] 17 | pos++ 18 | } 19 | return pos & 3 20 | } 21 | 22 | func notzero(b []byte) int { 23 | for i := range b { 24 | if b[i] != 0 { 25 | return i 26 | } 27 | } 28 | return -1 29 | } 30 | 31 | func TestMaskBytes(t *testing.T) { 32 | key := [4]byte{1, 2, 3, 4} 33 | for size := 1; size <= 1024; size++ { 34 | for align := 0; align < wordSize; align++ { 35 | for pos := 0; pos < 4; pos++ { 36 | b := make([]byte, size+align)[align:] 37 | maskBytes(key, pos, b) 38 | maskBytesByByte(key, pos, b) 39 | if i := notzero(b); i >= 0 { 40 | t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) 41 | } 42 | } 43 | } 44 | } 45 | } 46 | 47 | func BenchmarkMaskBytes(b *testing.B) { 48 | for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { 49 | b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { 50 | for _, align := range []int{wordSize / 2} { 51 | b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { 52 | for _, fn := range []struct { 53 | name string 54 | fn func(key [4]byte, pos int, b []byte) int 55 | }{ 56 | {"byte", maskBytesByByte}, 57 | {"word", maskBytes}, 58 | } { 59 | b.Run(fn.name, func(b *testing.B) { 60 | key := newMaskKey() 61 | data := make([]byte, size+align)[align:] 62 | for i := 0; i < b.N; i++ { 63 | fn.fn(key, 0, data) 64 | } 65 | b.SetBytes(int64(len(data))) 66 | }) 67 | } 68 | }) 69 | } 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /websocket/prepared.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "net" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | // PreparedMessage caches on the wire representations of a message payload. 15 | // Use PreparedMessage to efficiently send a message payload to multiple 16 | // connections. PreparedMessage is especially useful when compression is used 17 | // because the CPU and memory expensive compression operation can be executed 18 | // once for a given set of compression options. 19 | type PreparedMessage struct { 20 | messageType int 21 | data []byte 22 | mu sync.Mutex 23 | frames map[prepareKey]*preparedFrame 24 | } 25 | 26 | // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. 27 | type prepareKey struct { 28 | isServer bool 29 | compress bool 30 | compressionLevel int 31 | } 32 | 33 | // preparedFrame contains data in wire representation. 34 | type preparedFrame struct { 35 | once sync.Once 36 | data []byte 37 | } 38 | 39 | // NewPreparedMessage returns an initialized PreparedMessage. You can then send 40 | // it to connection using WritePreparedMessage method. Valid wire 41 | // representation will be calculated lazily only once for a set of current 42 | // connection options. 43 | func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { 44 | pm := &PreparedMessage{ 45 | messageType: messageType, 46 | frames: make(map[prepareKey]*preparedFrame), 47 | data: data, 48 | } 49 | 50 | // Prepare a plain server frame. 51 | _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | // To protect against caller modifying the data argument, remember the data 57 | // copied to the plain server frame. 58 | pm.data = frameData[len(frameData)-len(data):] 59 | return pm, nil 60 | } 61 | 62 | func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { 63 | pm.mu.Lock() 64 | frame, ok := pm.frames[key] 65 | if !ok { 66 | frame = &preparedFrame{} 67 | pm.frames[key] = frame 68 | } 69 | pm.mu.Unlock() 70 | 71 | var err error 72 | frame.once.Do(func() { 73 | // Prepare a frame using a 'fake' connection. 74 | // TODO: Refactor code in conn.go to allow more direct construction of 75 | // the frame. 76 | mu := make(chan struct{}, 1) 77 | mu <- struct{}{} 78 | var nc prepareConn 79 | c := &Conn{ 80 | conn: &nc, 81 | mu: mu, 82 | isServer: key.isServer, 83 | compressionLevel: key.compressionLevel, 84 | enableWriteCompression: true, 85 | writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), 86 | } 87 | if key.compress { 88 | c.newCompressionWriter = compressNoContextTakeover 89 | } 90 | err = c.WriteMessage(pm.messageType, pm.data) 91 | frame.data = nc.buf.Bytes() 92 | }) 93 | return pm.messageType, frame.data, err 94 | } 95 | 96 | type prepareConn struct { 97 | buf bytes.Buffer 98 | net.Conn 99 | } 100 | 101 | func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } 102 | func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } 103 | -------------------------------------------------------------------------------- /websocket/prepared_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "compress/flate" 10 | "math/rand" 11 | "testing" 12 | ) 13 | 14 | var preparedMessageTests = []struct { 15 | messageType int 16 | isServer bool 17 | enableWriteCompression bool 18 | compressionLevel int 19 | }{ 20 | // Server 21 | {TextMessage, true, false, flate.BestSpeed}, 22 | {TextMessage, true, true, flate.BestSpeed}, 23 | {TextMessage, true, true, flate.BestCompression}, 24 | {PingMessage, true, false, flate.BestSpeed}, 25 | {PingMessage, true, true, flate.BestSpeed}, 26 | 27 | // Client 28 | {TextMessage, false, false, flate.BestSpeed}, 29 | {TextMessage, false, true, flate.BestSpeed}, 30 | {TextMessage, false, true, flate.BestCompression}, 31 | {PingMessage, false, false, flate.BestSpeed}, 32 | {PingMessage, false, true, flate.BestSpeed}, 33 | } 34 | 35 | func TestPreparedMessage(t *testing.T) { 36 | for _, tt := range preparedMessageTests { 37 | var data = []byte("this is a test") 38 | var buf bytes.Buffer 39 | c := newTestConn(nil, &buf, tt.isServer) 40 | if tt.enableWriteCompression { 41 | c.newCompressionWriter = compressNoContextTakeover 42 | } 43 | c.SetCompressionLevel(tt.compressionLevel) 44 | 45 | // Seed random number generator for consistent frame mask. 46 | rand.Seed(1234) 47 | 48 | if err := c.WriteMessage(tt.messageType, data); err != nil { 49 | t.Fatal(err) 50 | } 51 | want := buf.String() 52 | 53 | pm, err := NewPreparedMessage(tt.messageType, data) 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | 58 | // Scribble on data to ensure that NewPreparedMessage takes a snapshot. 59 | copy(data, "hello world") 60 | 61 | // Seed random number generator for consistent frame mask. 62 | rand.Seed(1234) 63 | 64 | buf.Reset() 65 | if err := c.WritePreparedMessage(pm); err != nil { 66 | t.Fatal(err) 67 | } 68 | got := buf.String() 69 | 70 | if got != want { 71 | t.Errorf("write message != prepared message for %+v", tt) 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /websocket/proxy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "encoding/base64" 10 | "errors" 11 | "net" 12 | "net/http" 13 | "net/url" 14 | "strings" 15 | ) 16 | 17 | type netDialerFunc func(network, addr string) (net.Conn, error) 18 | 19 | func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { 20 | return fn(network, addr) 21 | } 22 | 23 | func init() { 24 | proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { 25 | return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil 26 | }) 27 | } 28 | 29 | type httpProxyDialer struct { 30 | proxyURL *url.URL 31 | forwardDial func(network, addr string) (net.Conn, error) 32 | } 33 | 34 | func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { 35 | hostPort, _ := hostPortNoPort(hpd.proxyURL) 36 | conn, err := hpd.forwardDial(network, hostPort) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | connectHeader := make(http.Header) 42 | if user := hpd.proxyURL.User; user != nil { 43 | proxyUser := user.Username() 44 | if proxyPassword, passwordSet := user.Password(); passwordSet { 45 | credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) 46 | connectHeader.Set("Proxy-Authorization", "Basic "+credential) 47 | } 48 | } 49 | 50 | connectReq := &http.Request{ 51 | Method: http.MethodConnect, 52 | URL: &url.URL{Opaque: addr}, 53 | Host: addr, 54 | Header: connectHeader, 55 | } 56 | 57 | if err := connectReq.Write(conn); err != nil { 58 | conn.Close() 59 | return nil, err 60 | } 61 | 62 | // Read response. It's OK to use and discard buffered reader here becaue 63 | // the remote server does not speak until spoken to. 64 | br := bufio.NewReader(conn) 65 | resp, err := http.ReadResponse(br, connectReq) 66 | if err != nil { 67 | conn.Close() 68 | return nil, err 69 | } 70 | 71 | if resp.StatusCode != 200 { 72 | conn.Close() 73 | f := strings.SplitN(resp.Status, " ", 2) 74 | return nil, errors.New(f[1]) 75 | } 76 | return conn, nil 77 | } 78 | -------------------------------------------------------------------------------- /websocket/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "errors" 10 | "io" 11 | "net/http" 12 | "net/url" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | // HandshakeError describes an error with the handshake from the peer. 18 | type HandshakeError struct { 19 | message string 20 | } 21 | 22 | func (e HandshakeError) Error() string { return e.message } 23 | 24 | // Upgrader specifies parameters for upgrading an HTTP connection to a 25 | // WebSocket connection. 26 | // 27 | // It is safe to call Upgrader's methods concurrently. 28 | type Upgrader struct { 29 | // HandshakeTimeout specifies the duration for the handshake to complete. 30 | HandshakeTimeout time.Duration 31 | 32 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 33 | // size is zero, then buffers allocated by the HTTP server are used. The 34 | // I/O buffer sizes do not limit the size of the messages that can be sent 35 | // or received. 36 | ReadBufferSize, WriteBufferSize int 37 | 38 | // WriteBufferPool is a pool of buffers for write operations. If the value 39 | // is not set, then write buffers are allocated to the connection for the 40 | // lifetime of the connection. 41 | // 42 | // A pool is most useful when the application has a modest volume of writes 43 | // across a large number of connections. 44 | // 45 | // Applications should use a single pool for each unique value of 46 | // WriteBufferSize. 47 | WriteBufferPool BufferPool 48 | 49 | // Subprotocols specifies the server's supported protocols in order of 50 | // preference. If this field is not nil, then the Upgrade method negotiates a 51 | // subprotocol by selecting the first match in this list with a protocol 52 | // requested by the client. If there's no match, then no protocol is 53 | // negotiated (the Sec-Websocket-Protocol header is not included in the 54 | // handshake response). 55 | Subprotocols []string 56 | 57 | // Error specifies the function for generating HTTP error responses. If Error 58 | // is nil, then http.Error is used to generate the HTTP response. 59 | Error func(w http.ResponseWriter, r *http.Request, status int, reason error) 60 | 61 | // CheckOrigin returns true if the request Origin header is acceptable. If 62 | // CheckOrigin is nil, then a safe default is used: return false if the 63 | // Origin request header is present and the origin host is not equal to 64 | // request Host header. 65 | // 66 | // A CheckOrigin function should carefully validate the request origin to 67 | // prevent cross-site request forgery. 68 | CheckOrigin func(r *http.Request) bool 69 | 70 | // EnableCompression specify if the server should attempt to negotiate per 71 | // message compression (RFC 7692). Setting this value to true does not 72 | // guarantee that compression will be supported. Currently only "no context 73 | // takeover" modes are supported. 74 | EnableCompression bool 75 | } 76 | 77 | func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { 78 | err := HandshakeError{reason} 79 | if u.Error != nil { 80 | u.Error(w, r, status, err) 81 | } else { 82 | w.Header().Set("Sec-Websocket-Version", "13") 83 | http.Error(w, http.StatusText(status), status) 84 | } 85 | return nil, err 86 | } 87 | 88 | // checkSameOrigin returns true if the origin is not set or is equal to the request host. 89 | func checkSameOrigin(r *http.Request) bool { 90 | origin := r.Header["Origin"] 91 | if len(origin) == 0 { 92 | return true 93 | } 94 | u, err := url.Parse(origin[0]) 95 | if err != nil { 96 | return false 97 | } 98 | return equalASCIIFold(u.Host, r.Host) 99 | } 100 | 101 | func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { 102 | if u.Subprotocols != nil { 103 | clientProtocols := Subprotocols(r) 104 | for _, serverProtocol := range u.Subprotocols { 105 | for _, clientProtocol := range clientProtocols { 106 | if clientProtocol == serverProtocol { 107 | return clientProtocol 108 | } 109 | } 110 | } 111 | } else if responseHeader != nil { 112 | return responseHeader.Get("Sec-Websocket-Protocol") 113 | } 114 | return "" 115 | } 116 | 117 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 118 | // 119 | // The responseHeader is included in the response to the client's upgrade 120 | // request. Use the responseHeader to specify cookies (Set-Cookie). To specify 121 | // subprotocols supported by the server, set Upgrader.Subprotocols directly. 122 | // 123 | // If the upgrade fails, then Upgrade replies to the client with an HTTP error 124 | // response. 125 | func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { 126 | const badHandshake = "websocket: the client is not using the websocket protocol: " 127 | 128 | if !tokenListContainsValue(r.Header, "Connection", "upgrade") { 129 | return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") 130 | } 131 | 132 | if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { 133 | return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") 134 | } 135 | 136 | if r.Method != http.MethodGet { 137 | return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") 138 | } 139 | 140 | if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { 141 | return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") 142 | } 143 | 144 | if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { 145 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") 146 | } 147 | 148 | checkOrigin := u.CheckOrigin 149 | if checkOrigin == nil { 150 | checkOrigin = checkSameOrigin 151 | } 152 | if !checkOrigin(r) { 153 | return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") 154 | } 155 | 156 | challengeKey := r.Header.Get("Sec-Websocket-Key") 157 | if !isValidChallengeKey(challengeKey) { 158 | return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") 159 | } 160 | 161 | subprotocol := u.selectSubprotocol(r, responseHeader) 162 | 163 | // Negotiate PMCE 164 | var compress bool 165 | if u.EnableCompression { 166 | for _, ext := range parseExtensions(r.Header) { 167 | if ext[""] != "permessage-deflate" { 168 | continue 169 | } 170 | compress = true 171 | break 172 | } 173 | } 174 | 175 | h, ok := w.(http.Hijacker) 176 | if !ok { 177 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") 178 | } 179 | var brw *bufio.ReadWriter 180 | netConn, brw, err := h.Hijack() 181 | if err != nil { 182 | return u.returnError(w, r, http.StatusInternalServerError, err.Error()) 183 | } 184 | 185 | if brw.Reader.Buffered() > 0 { 186 | netConn.Close() 187 | return nil, errors.New("websocket: client sent data before handshake is complete") 188 | } 189 | 190 | var br *bufio.Reader 191 | if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { 192 | // Reuse hijacked buffered reader as connection reader. 193 | br = brw.Reader 194 | } 195 | 196 | buf := bufioWriterBuffer(netConn, brw.Writer) 197 | 198 | var writeBuf []byte 199 | if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { 200 | // Reuse hijacked write buffer as connection buffer. 201 | writeBuf = buf 202 | } 203 | 204 | c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) 205 | c.subprotocol = subprotocol 206 | 207 | if compress { 208 | c.newCompressionWriter = compressNoContextTakeover 209 | c.newDecompressionReader = decompressNoContextTakeover 210 | } 211 | 212 | // Use larger of hijacked buffer and connection write buffer for header. 213 | p := buf 214 | if len(c.writeBuf) > len(p) { 215 | p = c.writeBuf 216 | } 217 | p = p[:0] 218 | 219 | p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) 220 | p = append(p, computeAcceptKey(challengeKey)...) 221 | p = append(p, "\r\n"...) 222 | if c.subprotocol != "" { 223 | p = append(p, "Sec-WebSocket-Protocol: "...) 224 | p = append(p, c.subprotocol...) 225 | p = append(p, "\r\n"...) 226 | } 227 | if compress { 228 | p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) 229 | } 230 | for k, vs := range responseHeader { 231 | if k == "Sec-Websocket-Protocol" { 232 | continue 233 | } 234 | for _, v := range vs { 235 | p = append(p, k...) 236 | p = append(p, ": "...) 237 | for i := 0; i < len(v); i++ { 238 | b := v[i] 239 | if b <= 31 { 240 | // prevent response splitting. 241 | b = ' ' 242 | } 243 | p = append(p, b) 244 | } 245 | p = append(p, "\r\n"...) 246 | } 247 | } 248 | p = append(p, "\r\n"...) 249 | 250 | // Clear deadlines set by HTTP server. 251 | netConn.SetDeadline(time.Time{}) 252 | 253 | if u.HandshakeTimeout > 0 { 254 | netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) 255 | } 256 | if _, err = netConn.Write(p); err != nil { 257 | netConn.Close() 258 | return nil, err 259 | } 260 | if u.HandshakeTimeout > 0 { 261 | netConn.SetWriteDeadline(time.Time{}) 262 | } 263 | 264 | return c, nil 265 | } 266 | 267 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 268 | // 269 | // Deprecated: Use websocket.Upgrader instead. 270 | // 271 | // Upgrade does not perform origin checking. The application is responsible for 272 | // checking the Origin header before calling Upgrade. An example implementation 273 | // of the same origin policy check is: 274 | // 275 | // if req.Header.Get("Origin") != "http://"+req.Host { 276 | // http.Error(w, "Origin not allowed", http.StatusForbidden) 277 | // return 278 | // } 279 | // 280 | // If the endpoint supports subprotocols, then the application is responsible 281 | // for negotiating the protocol used on the connection. Use the Subprotocols() 282 | // function to get the subprotocols requested by the client. Use the 283 | // Sec-Websocket-Protocol response header to specify the subprotocol selected 284 | // by the application. 285 | // 286 | // The responseHeader is included in the response to the client's upgrade 287 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the 288 | // negotiated subprotocol (Sec-Websocket-Protocol). 289 | // 290 | // The connection buffers IO to the underlying network connection. The 291 | // readBufSize and writeBufSize parameters specify the size of the buffers to 292 | // use. Messages can be larger than the buffers. 293 | // 294 | // If the request is not a valid WebSocket handshake, then Upgrade returns an 295 | // error of type HandshakeError. Applications should handle this error by 296 | // replying to the client with an HTTP error response. 297 | func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { 298 | u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} 299 | u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { 300 | // don't return errors to maintain backwards compatibility 301 | } 302 | u.CheckOrigin = func(r *http.Request) bool { 303 | // allow all connections by default 304 | return true 305 | } 306 | return u.Upgrade(w, r, responseHeader) 307 | } 308 | 309 | // Subprotocols returns the subprotocols requested by the client in the 310 | // Sec-Websocket-Protocol header. 311 | func Subprotocols(r *http.Request) []string { 312 | h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) 313 | if h == "" { 314 | return nil 315 | } 316 | protocols := strings.Split(h, ",") 317 | for i := range protocols { 318 | protocols[i] = strings.TrimSpace(protocols[i]) 319 | } 320 | return protocols 321 | } 322 | 323 | // IsWebSocketUpgrade returns true if the client requested upgrade to the 324 | // WebSocket protocol. 325 | func IsWebSocketUpgrade(r *http.Request) bool { 326 | return tokenListContainsValue(r.Header, "Connection", "upgrade") && 327 | tokenListContainsValue(r.Header, "Upgrade", "websocket") 328 | } 329 | 330 | // bufioReaderSize size returns the size of a bufio.Reader. 331 | func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { 332 | // This code assumes that peek on a reset reader returns 333 | // bufio.Reader.buf[:0]. 334 | // TODO: Use bufio.Reader.Size() after Go 1.10 335 | br.Reset(originalReader) 336 | if p, err := br.Peek(0); err == nil { 337 | return cap(p) 338 | } 339 | return 0 340 | } 341 | 342 | // writeHook is an io.Writer that records the last slice passed to it vio 343 | // io.Writer.Write. 344 | type writeHook struct { 345 | p []byte 346 | } 347 | 348 | func (wh *writeHook) Write(p []byte) (int, error) { 349 | wh.p = p 350 | return len(p), nil 351 | } 352 | 353 | // bufioWriterBuffer grabs the buffer from a bufio.Writer. 354 | func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { 355 | // This code assumes that bufio.Writer.buf[:1] is passed to the 356 | // bufio.Writer's underlying writer. 357 | var wh writeHook 358 | bw.Reset(&wh) 359 | bw.WriteByte(0) 360 | bw.Flush() 361 | 362 | bw.Reset(originalWriter) 363 | 364 | return wh.p[:cap(wh.p)] 365 | } 366 | -------------------------------------------------------------------------------- /websocket/server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "net" 11 | "net/http" 12 | "reflect" 13 | "strings" 14 | "testing" 15 | ) 16 | 17 | var subprotocolTests = []struct { 18 | h string 19 | protocols []string 20 | }{ 21 | {"", nil}, 22 | {"foo", []string{"foo"}}, 23 | {"foo,bar", []string{"foo", "bar"}}, 24 | {"foo, bar", []string{"foo", "bar"}}, 25 | {" foo, bar", []string{"foo", "bar"}}, 26 | {" foo, bar ", []string{"foo", "bar"}}, 27 | } 28 | 29 | func TestSubprotocols(t *testing.T) { 30 | for _, st := range subprotocolTests { 31 | r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} 32 | protocols := Subprotocols(&r) 33 | if !reflect.DeepEqual(st.protocols, protocols) { 34 | t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols) 35 | } 36 | } 37 | } 38 | 39 | var isWebSocketUpgradeTests = []struct { 40 | ok bool 41 | h http.Header 42 | }{ 43 | {false, http.Header{"Upgrade": {"websocket"}}}, 44 | {false, http.Header{"Connection": {"upgrade"}}}, 45 | {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, 46 | } 47 | 48 | func TestIsWebSocketUpgrade(t *testing.T) { 49 | for _, tt := range isWebSocketUpgradeTests { 50 | ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) 51 | if tt.ok != ok { 52 | t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) 53 | } 54 | } 55 | } 56 | 57 | var checkSameOriginTests = []struct { 58 | ok bool 59 | r *http.Request 60 | }{ 61 | {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}}, 62 | {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, 63 | {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, 64 | } 65 | 66 | func TestCheckSameOrigin(t *testing.T) { 67 | for _, tt := range checkSameOriginTests { 68 | ok := checkSameOrigin(tt.r) 69 | if tt.ok != ok { 70 | t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok) 71 | } 72 | } 73 | } 74 | 75 | type reuseTestResponseWriter struct { 76 | brw *bufio.ReadWriter 77 | http.ResponseWriter 78 | } 79 | 80 | func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 81 | return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil 82 | } 83 | 84 | var bufioReuseTests = []struct { 85 | n int 86 | reuse bool 87 | }{ 88 | {4096, true}, 89 | {128, false}, 90 | } 91 | 92 | func TestBufioReuse(t *testing.T) { 93 | for i, tt := range bufioReuseTests { 94 | br := bufio.NewReaderSize(strings.NewReader(""), tt.n) 95 | bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) 96 | resp := &reuseTestResponseWriter{ 97 | brw: bufio.NewReadWriter(br, bw), 98 | } 99 | upgrader := Upgrader{} 100 | c, err := upgrader.Upgrade(resp, &http.Request{ 101 | Method: http.MethodGet, 102 | Header: http.Header{ 103 | "Upgrade": []string{"websocket"}, 104 | "Connection": []string{"upgrade"}, 105 | "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="}, 106 | "Sec-Websocket-Version": []string{"13"}, 107 | }}, nil) 108 | if err != nil { 109 | t.Fatal(err) 110 | } 111 | if reuse := c.br == br; reuse != tt.reuse { 112 | t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) 113 | } 114 | writeBuf := bufioWriterBuffer(c.NetConn(), bw) 115 | if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { 116 | t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /websocket/tls_handshake.go: -------------------------------------------------------------------------------- 1 | //go:build go1.17 2 | // +build go1.17 3 | 4 | package websocket 5 | 6 | import ( 7 | "context" 8 | ) 9 | import tls "github.com/refraction-networking/utls" 10 | 11 | func doHandshake(ctx context.Context, tlsConn *tls.UConn, cfg *tls.Config) error { 12 | if err := tlsConn.HandshakeContext(ctx); err != nil { 13 | return err 14 | } 15 | if !cfg.InsecureSkipVerify { 16 | if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { 17 | return err 18 | } 19 | } 20 | return nil 21 | } 22 | -------------------------------------------------------------------------------- /websocket/tls_handshake_116.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.17 2 | // +build !go1.17 3 | 4 | package websocket 5 | 6 | import ( 7 | "context" 8 | "crypto/tls" 9 | ) 10 | 11 | func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { 12 | if err := tlsConn.Handshake(); err != nil { 13 | return err 14 | } 15 | if !cfg.InsecureSkipVerify { 16 | if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { 17 | return err 18 | } 19 | } 20 | return nil 21 | } 22 | -------------------------------------------------------------------------------- /websocket/util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "crypto/rand" 9 | "crypto/sha1" 10 | "encoding/base64" 11 | "io" 12 | "net/http" 13 | "strings" 14 | "unicode/utf8" 15 | ) 16 | 17 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 18 | 19 | func computeAcceptKey(challengeKey string) string { 20 | h := sha1.New() 21 | h.Write([]byte(challengeKey)) 22 | h.Write(keyGUID) 23 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) 24 | } 25 | 26 | func generateChallengeKey() (string, error) { 27 | p := make([]byte, 16) 28 | if _, err := io.ReadFull(rand.Reader, p); err != nil { 29 | return "", err 30 | } 31 | return base64.StdEncoding.EncodeToString(p), nil 32 | } 33 | 34 | // Token octets per RFC 2616. 35 | var isTokenOctet = [256]bool{ 36 | '!': true, 37 | '#': true, 38 | '$': true, 39 | '%': true, 40 | '&': true, 41 | '\'': true, 42 | '*': true, 43 | '+': true, 44 | '-': true, 45 | '.': true, 46 | '0': true, 47 | '1': true, 48 | '2': true, 49 | '3': true, 50 | '4': true, 51 | '5': true, 52 | '6': true, 53 | '7': true, 54 | '8': true, 55 | '9': true, 56 | 'A': true, 57 | 'B': true, 58 | 'C': true, 59 | 'D': true, 60 | 'E': true, 61 | 'F': true, 62 | 'G': true, 63 | 'H': true, 64 | 'I': true, 65 | 'J': true, 66 | 'K': true, 67 | 'L': true, 68 | 'M': true, 69 | 'N': true, 70 | 'O': true, 71 | 'P': true, 72 | 'Q': true, 73 | 'R': true, 74 | 'S': true, 75 | 'T': true, 76 | 'U': true, 77 | 'W': true, 78 | 'V': true, 79 | 'X': true, 80 | 'Y': true, 81 | 'Z': true, 82 | '^': true, 83 | '_': true, 84 | '`': true, 85 | 'a': true, 86 | 'b': true, 87 | 'c': true, 88 | 'd': true, 89 | 'e': true, 90 | 'f': true, 91 | 'g': true, 92 | 'h': true, 93 | 'i': true, 94 | 'j': true, 95 | 'k': true, 96 | 'l': true, 97 | 'm': true, 98 | 'n': true, 99 | 'o': true, 100 | 'p': true, 101 | 'q': true, 102 | 'r': true, 103 | 's': true, 104 | 't': true, 105 | 'u': true, 106 | 'v': true, 107 | 'w': true, 108 | 'x': true, 109 | 'y': true, 110 | 'z': true, 111 | '|': true, 112 | '~': true, 113 | } 114 | 115 | // skipSpace returns a slice of the string s with all leading RFC 2616 linear 116 | // whitespace removed. 117 | func skipSpace(s string) (rest string) { 118 | i := 0 119 | for ; i < len(s); i++ { 120 | if b := s[i]; b != ' ' && b != '\t' { 121 | break 122 | } 123 | } 124 | return s[i:] 125 | } 126 | 127 | // nextToken returns the leading RFC 2616 token of s and the string following 128 | // the token. 129 | func nextToken(s string) (token, rest string) { 130 | i := 0 131 | for ; i < len(s); i++ { 132 | if !isTokenOctet[s[i]] { 133 | break 134 | } 135 | } 136 | return s[:i], s[i:] 137 | } 138 | 139 | // nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 140 | // and the string following the token or quoted string. 141 | func nextTokenOrQuoted(s string) (value string, rest string) { 142 | if !strings.HasPrefix(s, "\"") { 143 | return nextToken(s) 144 | } 145 | s = s[1:] 146 | for i := 0; i < len(s); i++ { 147 | switch s[i] { 148 | case '"': 149 | return s[:i], s[i+1:] 150 | case '\\': 151 | p := make([]byte, len(s)-1) 152 | j := copy(p, s[:i]) 153 | escape := true 154 | for i = i + 1; i < len(s); i++ { 155 | b := s[i] 156 | switch { 157 | case escape: 158 | escape = false 159 | p[j] = b 160 | j++ 161 | case b == '\\': 162 | escape = true 163 | case b == '"': 164 | return string(p[:j]), s[i+1:] 165 | default: 166 | p[j] = b 167 | j++ 168 | } 169 | } 170 | return "", "" 171 | } 172 | } 173 | return "", "" 174 | } 175 | 176 | // equalASCIIFold returns true if s is equal to t with ASCII case folding as 177 | // defined in RFC 4790. 178 | func equalASCIIFold(s, t string) bool { 179 | for s != "" && t != "" { 180 | sr, size := utf8.DecodeRuneInString(s) 181 | s = s[size:] 182 | tr, size := utf8.DecodeRuneInString(t) 183 | t = t[size:] 184 | if sr == tr { 185 | continue 186 | } 187 | if 'A' <= sr && sr <= 'Z' { 188 | sr = sr + 'a' - 'A' 189 | } 190 | if 'A' <= tr && tr <= 'Z' { 191 | tr = tr + 'a' - 'A' 192 | } 193 | if sr != tr { 194 | return false 195 | } 196 | } 197 | return s == t 198 | } 199 | 200 | // tokenListContainsValue returns true if the 1#token header with the given 201 | // name contains a token equal to value with ASCII case folding. 202 | func tokenListContainsValue(header http.Header, name string, value string) bool { 203 | headers: 204 | for _, s := range header[name] { 205 | for { 206 | var t string 207 | t, s = nextToken(skipSpace(s)) 208 | if t == "" { 209 | continue headers 210 | } 211 | s = skipSpace(s) 212 | if s != "" && s[0] != ',' { 213 | continue headers 214 | } 215 | if equalASCIIFold(t, value) { 216 | return true 217 | } 218 | if s == "" { 219 | continue headers 220 | } 221 | s = s[1:] 222 | } 223 | } 224 | return false 225 | } 226 | 227 | // parseExtensions parses WebSocket extensions from a header. 228 | func parseExtensions(header http.Header) []map[string]string { 229 | // From RFC 6455: 230 | // 231 | // Sec-WebSocket-Extensions = extension-list 232 | // extension-list = 1#extension 233 | // extension = extension-token *( ";" extension-param ) 234 | // extension-token = registered-token 235 | // registered-token = token 236 | // extension-param = token [ "=" (token | quoted-string) ] 237 | // ;When using the quoted-string syntax variant, the value 238 | // ;after quoted-string unescaping MUST conform to the 239 | // ;'token' ABNF. 240 | 241 | var result []map[string]string 242 | headers: 243 | for _, s := range header["Sec-Websocket-Extensions"] { 244 | for { 245 | var t string 246 | t, s = nextToken(skipSpace(s)) 247 | if t == "" { 248 | continue headers 249 | } 250 | ext := map[string]string{"": t} 251 | for { 252 | s = skipSpace(s) 253 | if !strings.HasPrefix(s, ";") { 254 | break 255 | } 256 | var k string 257 | k, s = nextToken(skipSpace(s[1:])) 258 | if k == "" { 259 | continue headers 260 | } 261 | s = skipSpace(s) 262 | var v string 263 | if strings.HasPrefix(s, "=") { 264 | v, s = nextTokenOrQuoted(skipSpace(s[1:])) 265 | s = skipSpace(s) 266 | } 267 | if s != "" && s[0] != ',' && s[0] != ';' { 268 | continue headers 269 | } 270 | ext[k] = v 271 | } 272 | if s != "" && s[0] != ',' { 273 | continue headers 274 | } 275 | result = append(result, ext) 276 | if s == "" { 277 | continue headers 278 | } 279 | s = s[1:] 280 | } 281 | } 282 | return result 283 | } 284 | 285 | // isValidChallengeKey checks if the argument meets RFC6455 specification. 286 | func isValidChallengeKey(s string) bool { 287 | // From RFC6455: 288 | // 289 | // A |Sec-WebSocket-Key| header field with a base64-encoded (see 290 | // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in 291 | // length. 292 | 293 | if s == "" { 294 | return false 295 | } 296 | decoded, err := base64.StdEncoding.DecodeString(s) 297 | return err == nil && len(decoded) == 16 298 | } 299 | -------------------------------------------------------------------------------- /websocket/util_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "net/http" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | var equalASCIIFoldTests = []struct { 14 | t, s string 15 | eq bool 16 | }{ 17 | {"WebSocket", "websocket", true}, 18 | {"websocket", "WebSocket", true}, 19 | {"Öyster", "öyster", false}, 20 | {"WebSocket", "WetSocket", false}, 21 | } 22 | 23 | func TestEqualASCIIFold(t *testing.T) { 24 | for _, tt := range equalASCIIFoldTests { 25 | eq := equalASCIIFold(tt.s, tt.t) 26 | if eq != tt.eq { 27 | t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq) 28 | } 29 | } 30 | } 31 | 32 | var tokenListContainsValueTests = []struct { 33 | value string 34 | ok bool 35 | }{ 36 | {"WebSocket", true}, 37 | {"WEBSOCKET", true}, 38 | {"websocket", true}, 39 | {"websockets", false}, 40 | {"x websocket", false}, 41 | {"websocket x", false}, 42 | {"other,websocket,more", true}, 43 | {"other, websocket, more", true}, 44 | } 45 | 46 | func TestTokenListContainsValue(t *testing.T) { 47 | for _, tt := range tokenListContainsValueTests { 48 | h := http.Header{"Upgrade": {tt.value}} 49 | ok := tokenListContainsValue(h, "Upgrade", "websocket") 50 | if ok != tt.ok { 51 | t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok) 52 | } 53 | } 54 | } 55 | 56 | var isValidChallengeKeyTests = []struct { 57 | key string 58 | ok bool 59 | }{ 60 | {"dGhlIHNhbXBsZSBub25jZQ==", true}, 61 | {"", false}, 62 | {"InvalidKey", false}, 63 | {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false}, 64 | } 65 | 66 | func TestIsValidChallengeKey(t *testing.T) { 67 | for _, tt := range isValidChallengeKeyTests { 68 | ok := isValidChallengeKey(tt.key) 69 | if ok != tt.ok { 70 | t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok) 71 | } 72 | } 73 | } 74 | 75 | var parseExtensionTests = []struct { 76 | value string 77 | extensions []map[string]string 78 | }{ 79 | {`foo`, []map[string]string{{"": "foo"}}}, 80 | {`foo, bar; baz=2`, []map[string]string{ 81 | {"": "foo"}, 82 | {"": "bar", "baz": "2"}}}, 83 | {`foo; bar="b,a;z"`, []map[string]string{ 84 | {"": "foo", "bar": "b,a;z"}}}, 85 | {`foo , bar; baz = 2`, []map[string]string{ 86 | {"": "foo"}, 87 | {"": "bar", "baz": "2"}}}, 88 | {`foo, bar; baz=2 junk`, []map[string]string{ 89 | {"": "foo"}}}, 90 | {`foo junk, bar; baz=2 junk`, nil}, 91 | {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ 92 | {"": "mux", "max-channels": "4", "flow-control": ""}, 93 | {"": "deflate-stream"}}}, 94 | {`permessage-foo; x="10"`, []map[string]string{ 95 | {"": "permessage-foo", "x": "10"}}}, 96 | {`permessage-foo; use_y, permessage-foo`, []map[string]string{ 97 | {"": "permessage-foo", "use_y": ""}, 98 | {"": "permessage-foo"}}}, 99 | {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ 100 | {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, 101 | {"": "permessage-deflate", "client_max_window_bits": ""}}}, 102 | {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{ 103 | {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"}, 104 | }}, 105 | } 106 | 107 | func TestParseExtensions(t *testing.T) { 108 | for _, tt := range parseExtensionTests { 109 | h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} 110 | extensions := parseExtensions(h) 111 | if !reflect.DeepEqual(extensions, tt.extensions) { 112 | t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions) 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /websocket/x_net_proxy.go: -------------------------------------------------------------------------------- 1 | // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. 2 | //go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy 3 | 4 | // Package proxy provides support for a variety of protocols to proxy network 5 | // data. 6 | // 7 | 8 | package websocket 9 | 10 | import ( 11 | "errors" 12 | "io" 13 | "net" 14 | "net/url" 15 | "os" 16 | "strconv" 17 | "strings" 18 | "sync" 19 | ) 20 | 21 | type proxy_direct struct{} 22 | 23 | // Direct is a direct proxy: one that makes network connections directly. 24 | var proxy_Direct = proxy_direct{} 25 | 26 | func (proxy_direct) Dial(network, addr string) (net.Conn, error) { 27 | return net.Dial(network, addr) 28 | } 29 | 30 | // A PerHost directs connections to a default Dialer unless the host name 31 | // requested matches one of a number of exceptions. 32 | type proxy_PerHost struct { 33 | def, bypass proxy_Dialer 34 | 35 | bypassNetworks []*net.IPNet 36 | bypassIPs []net.IP 37 | bypassZones []string 38 | bypassHosts []string 39 | } 40 | 41 | // NewPerHost returns a PerHost Dialer that directs connections to either 42 | // defaultDialer or bypass, depending on whether the connection matches one of 43 | // the configured rules. 44 | func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { 45 | return &proxy_PerHost{ 46 | def: defaultDialer, 47 | bypass: bypass, 48 | } 49 | } 50 | 51 | // Dial connects to the address addr on the given network through either 52 | // defaultDialer or bypass. 53 | func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { 54 | host, _, err := net.SplitHostPort(addr) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | return p.dialerForRequest(host).Dial(network, addr) 60 | } 61 | 62 | func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { 63 | if ip := net.ParseIP(host); ip != nil { 64 | for _, net := range p.bypassNetworks { 65 | if net.Contains(ip) { 66 | return p.bypass 67 | } 68 | } 69 | for _, bypassIP := range p.bypassIPs { 70 | if bypassIP.Equal(ip) { 71 | return p.bypass 72 | } 73 | } 74 | return p.def 75 | } 76 | 77 | for _, zone := range p.bypassZones { 78 | if strings.HasSuffix(host, zone) { 79 | return p.bypass 80 | } 81 | if host == zone[1:] { 82 | // For a zone ".example.com", we match "example.com" 83 | // too. 84 | return p.bypass 85 | } 86 | } 87 | for _, bypassHost := range p.bypassHosts { 88 | if bypassHost == host { 89 | return p.bypass 90 | } 91 | } 92 | return p.def 93 | } 94 | 95 | // AddFromString parses a string that contains comma-separated values 96 | // specifying hosts that should use the bypass proxy. Each value is either an 97 | // IP address, a CIDR range, a zone (*.example.com) or a host name 98 | // (localhost). A best effort is made to parse the string and errors are 99 | // ignored. 100 | func (p *proxy_PerHost) AddFromString(s string) { 101 | hosts := strings.Split(s, ",") 102 | for _, host := range hosts { 103 | host = strings.TrimSpace(host) 104 | if len(host) == 0 { 105 | continue 106 | } 107 | if strings.Contains(host, "/") { 108 | // We assume that it's a CIDR address like 127.0.0.0/8 109 | if _, net, err := net.ParseCIDR(host); err == nil { 110 | p.AddNetwork(net) 111 | } 112 | continue 113 | } 114 | if ip := net.ParseIP(host); ip != nil { 115 | p.AddIP(ip) 116 | continue 117 | } 118 | if strings.HasPrefix(host, "*.") { 119 | p.AddZone(host[1:]) 120 | continue 121 | } 122 | p.AddHost(host) 123 | } 124 | } 125 | 126 | // AddIP specifies an IP address that will use the bypass proxy. Note that 127 | // this will only take effect if a literal IP address is dialed. A connection 128 | // to a named host will never match an IP. 129 | func (p *proxy_PerHost) AddIP(ip net.IP) { 130 | p.bypassIPs = append(p.bypassIPs, ip) 131 | } 132 | 133 | // AddNetwork specifies an IP range that will use the bypass proxy. Note that 134 | // this will only take effect if a literal IP address is dialed. A connection 135 | // to a named host will never match. 136 | func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { 137 | p.bypassNetworks = append(p.bypassNetworks, net) 138 | } 139 | 140 | // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of 141 | // "example.com" matches "example.com" and all of its subdomains. 142 | func (p *proxy_PerHost) AddZone(zone string) { 143 | if strings.HasSuffix(zone, ".") { 144 | zone = zone[:len(zone)-1] 145 | } 146 | if !strings.HasPrefix(zone, ".") { 147 | zone = "." + zone 148 | } 149 | p.bypassZones = append(p.bypassZones, zone) 150 | } 151 | 152 | // AddHost specifies a host name that will use the bypass proxy. 153 | func (p *proxy_PerHost) AddHost(host string) { 154 | if strings.HasSuffix(host, ".") { 155 | host = host[:len(host)-1] 156 | } 157 | p.bypassHosts = append(p.bypassHosts, host) 158 | } 159 | 160 | // A Dialer is a means to establish a connection. 161 | type proxy_Dialer interface { 162 | // Dial connects to the given address via the proxy. 163 | Dial(network, addr string) (c net.Conn, err error) 164 | } 165 | 166 | // Auth contains authentication parameters that specific Dialers may require. 167 | type proxy_Auth struct { 168 | User, Password string 169 | } 170 | 171 | // FromEnvironment returns the dialer specified by the proxy related variables in 172 | // the environment. 173 | func proxy_FromEnvironment() proxy_Dialer { 174 | allProxy := proxy_allProxyEnv.Get() 175 | if len(allProxy) == 0 { 176 | return proxy_Direct 177 | } 178 | 179 | proxyURL, err := url.Parse(allProxy) 180 | if err != nil { 181 | return proxy_Direct 182 | } 183 | proxy, err := proxy_FromURL(proxyURL, proxy_Direct) 184 | if err != nil { 185 | return proxy_Direct 186 | } 187 | 188 | noProxy := proxy_noProxyEnv.Get() 189 | if len(noProxy) == 0 { 190 | return proxy 191 | } 192 | 193 | perHost := proxy_NewPerHost(proxy, proxy_Direct) 194 | perHost.AddFromString(noProxy) 195 | return perHost 196 | } 197 | 198 | // proxySchemes is a map from URL schemes to a function that creates a Dialer 199 | // from a URL with such a scheme. 200 | var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) 201 | 202 | // RegisterDialerType takes a URL scheme and a function to generate Dialers from 203 | // a URL with that scheme and a forwarding Dialer. Registered schemes are used 204 | // by FromURL. 205 | func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { 206 | if proxy_proxySchemes == nil { 207 | proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) 208 | } 209 | proxy_proxySchemes[scheme] = f 210 | } 211 | 212 | // FromURL returns a Dialer given a URL specification and an underlying 213 | // Dialer for it to make network requests. 214 | func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { 215 | var auth *proxy_Auth 216 | if u.User != nil { 217 | auth = new(proxy_Auth) 218 | auth.User = u.User.Username() 219 | if p, ok := u.User.Password(); ok { 220 | auth.Password = p 221 | } 222 | } 223 | 224 | switch u.Scheme { 225 | case "socks5": 226 | return proxy_SOCKS5("tcp", u.Host, auth, forward) 227 | } 228 | 229 | // If the scheme doesn't match any of the built-in schemes, see if it 230 | // was registered by another package. 231 | if proxy_proxySchemes != nil { 232 | if f, ok := proxy_proxySchemes[u.Scheme]; ok { 233 | return f(u, forward) 234 | } 235 | } 236 | 237 | return nil, errors.New("proxy: unknown scheme: " + u.Scheme) 238 | } 239 | 240 | var ( 241 | proxy_allProxyEnv = &proxy_envOnce{ 242 | names: []string{"ALL_PROXY", "all_proxy"}, 243 | } 244 | proxy_noProxyEnv = &proxy_envOnce{ 245 | names: []string{"NO_PROXY", "no_proxy"}, 246 | } 247 | ) 248 | 249 | // envOnce looks up an environment variable (optionally by multiple 250 | // names) once. It mitigates expensive lookups on some platforms 251 | // (e.g. Windows). 252 | // (Borrowed from net/http/transport.go) 253 | type proxy_envOnce struct { 254 | names []string 255 | once sync.Once 256 | val string 257 | } 258 | 259 | func (e *proxy_envOnce) Get() string { 260 | e.once.Do(e.init) 261 | return e.val 262 | } 263 | 264 | func (e *proxy_envOnce) init() { 265 | for _, n := range e.names { 266 | e.val = os.Getenv(n) 267 | if e.val != "" { 268 | return 269 | } 270 | } 271 | } 272 | 273 | // SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address 274 | // with an optional username and password. See RFC 1928 and RFC 1929. 275 | func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { 276 | s := &proxy_socks5{ 277 | network: network, 278 | addr: addr, 279 | forward: forward, 280 | } 281 | if auth != nil { 282 | s.user = auth.User 283 | s.password = auth.Password 284 | } 285 | 286 | return s, nil 287 | } 288 | 289 | type proxy_socks5 struct { 290 | user, password string 291 | network, addr string 292 | forward proxy_Dialer 293 | } 294 | 295 | const proxy_socks5Version = 5 296 | 297 | const ( 298 | proxy_socks5AuthNone = 0 299 | proxy_socks5AuthPassword = 2 300 | ) 301 | 302 | const proxy_socks5Connect = 1 303 | 304 | const ( 305 | proxy_socks5IP4 = 1 306 | proxy_socks5Domain = 3 307 | proxy_socks5IP6 = 4 308 | ) 309 | 310 | var proxy_socks5Errors = []string{ 311 | "", 312 | "general failure", 313 | "connection forbidden", 314 | "network unreachable", 315 | "host unreachable", 316 | "connection refused", 317 | "TTL expired", 318 | "command not supported", 319 | "address type not supported", 320 | } 321 | 322 | // Dial connects to the address addr on the given network via the SOCKS5 proxy. 323 | func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { 324 | switch network { 325 | case "tcp", "tcp6", "tcp4": 326 | default: 327 | return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) 328 | } 329 | 330 | conn, err := s.forward.Dial(s.network, s.addr) 331 | if err != nil { 332 | return nil, err 333 | } 334 | if err := s.connect(conn, addr); err != nil { 335 | conn.Close() 336 | return nil, err 337 | } 338 | return conn, nil 339 | } 340 | 341 | // connect takes an existing connection to a socks5 proxy server, 342 | // and commands the server to extend that connection to target, 343 | // which must be a canonical address with a host and port. 344 | func (s *proxy_socks5) connect(conn net.Conn, target string) error { 345 | host, portStr, err := net.SplitHostPort(target) 346 | if err != nil { 347 | return err 348 | } 349 | 350 | port, err := strconv.Atoi(portStr) 351 | if err != nil { 352 | return errors.New("proxy: failed to parse port number: " + portStr) 353 | } 354 | if port < 1 || port > 0xffff { 355 | return errors.New("proxy: port number out of range: " + portStr) 356 | } 357 | 358 | // the size here is just an estimate 359 | buf := make([]byte, 0, 6+len(host)) 360 | 361 | buf = append(buf, proxy_socks5Version) 362 | if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { 363 | buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) 364 | } else { 365 | buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) 366 | } 367 | 368 | if _, err := conn.Write(buf); err != nil { 369 | return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) 370 | } 371 | 372 | if _, err := io.ReadFull(conn, buf[:2]); err != nil { 373 | return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 374 | } 375 | if buf[0] != 5 { 376 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) 377 | } 378 | if buf[1] == 0xff { 379 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") 380 | } 381 | 382 | // See RFC 1929 383 | if buf[1] == proxy_socks5AuthPassword { 384 | buf = buf[:0] 385 | buf = append(buf, 1 /* password protocol version */) 386 | buf = append(buf, uint8(len(s.user))) 387 | buf = append(buf, s.user...) 388 | buf = append(buf, uint8(len(s.password))) 389 | buf = append(buf, s.password...) 390 | 391 | if _, err := conn.Write(buf); err != nil { 392 | return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) 393 | } 394 | 395 | if _, err := io.ReadFull(conn, buf[:2]); err != nil { 396 | return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 397 | } 398 | 399 | if buf[1] != 0 { 400 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") 401 | } 402 | } 403 | 404 | buf = buf[:0] 405 | buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) 406 | 407 | if ip := net.ParseIP(host); ip != nil { 408 | if ip4 := ip.To4(); ip4 != nil { 409 | buf = append(buf, proxy_socks5IP4) 410 | ip = ip4 411 | } else { 412 | buf = append(buf, proxy_socks5IP6) 413 | } 414 | buf = append(buf, ip...) 415 | } else { 416 | if len(host) > 255 { 417 | return errors.New("proxy: destination host name too long: " + host) 418 | } 419 | buf = append(buf, proxy_socks5Domain) 420 | buf = append(buf, byte(len(host))) 421 | buf = append(buf, host...) 422 | } 423 | buf = append(buf, byte(port>>8), byte(port)) 424 | 425 | if _, err := conn.Write(buf); err != nil { 426 | return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) 427 | } 428 | 429 | if _, err := io.ReadFull(conn, buf[:4]); err != nil { 430 | return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 431 | } 432 | 433 | failure := "unknown error" 434 | if int(buf[1]) < len(proxy_socks5Errors) { 435 | failure = proxy_socks5Errors[buf[1]] 436 | } 437 | 438 | if len(failure) > 0 { 439 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) 440 | } 441 | 442 | bytesToDiscard := 0 443 | switch buf[3] { 444 | case proxy_socks5IP4: 445 | bytesToDiscard = net.IPv4len 446 | case proxy_socks5IP6: 447 | bytesToDiscard = net.IPv6len 448 | case proxy_socks5Domain: 449 | _, err := io.ReadFull(conn, buf[:1]) 450 | if err != nil { 451 | return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 452 | } 453 | bytesToDiscard = int(buf[0]) 454 | default: 455 | return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) 456 | } 457 | 458 | if cap(buf) < bytesToDiscard { 459 | buf = make([]byte, bytesToDiscard) 460 | } else { 461 | buf = buf[:bytesToDiscard] 462 | } 463 | if _, err := io.ReadFull(conn, buf); err != nil { 464 | return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 465 | } 466 | 467 | // Also need to discard the port number 468 | if _, err := io.ReadFull(conn, buf[:2]); err != nil { 469 | return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 470 | } 471 | 472 | return nil 473 | } 474 | --------------------------------------------------------------------------------