├── .circleci └── config.yml ├── .github └── release-drafter.yml ├── .gitignore ├── AUTHORS ├── LICENSE ├── README.md ├── client.go ├── client_proxy_server_test.go ├── client_server_test.go ├── client_test.go ├── compression.go ├── compression_test.go ├── conn.go ├── conn_broadcast_test.go ├── conn_test.go ├── doc.go ├── example_test.go ├── examples ├── autobahn │ ├── README.md │ ├── config │ │ └── fuzzingclient.json │ └── server.go ├── chat │ ├── README.md │ ├── client.go │ ├── home.html │ ├── hub.go │ └── main.go ├── command │ ├── README.md │ ├── home.html │ └── main.go ├── echo │ ├── README.md │ ├── client.go │ └── server.go └── filewatch │ ├── README.md │ └── main.go ├── go.mod ├── go.sum ├── join.go ├── join_test.go ├── json.go ├── json_test.go ├── mask.go ├── mask_safe.go ├── mask_test.go ├── prepared.go ├── prepared_test.go ├── proxy.go ├── server.go ├── server_test.go ├── util.go └── util_test.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | jobs: 4 | "test": 5 | parameters: 6 | version: 7 | type: string 8 | default: "latest" 9 | golint: 10 | type: boolean 11 | default: true 12 | modules: 13 | type: boolean 14 | default: true 15 | goproxy: 16 | type: string 17 | default: "" 18 | docker: 19 | - image: "cimg/go:<< parameters.version >>" 20 | working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket 21 | environment: 22 | GO111MODULE: "on" 23 | GOPROXY: "<< parameters.goproxy >>" 24 | steps: 25 | - checkout 26 | - run: 27 | name: "Print the Go version" 28 | command: > 29 | go version 30 | - run: 31 | name: "Fetch dependencies" 32 | command: > 33 | if [[ << parameters.modules >> = true ]]; then 34 | go mod download 35 | export GO111MODULE=on 36 | else 37 | go get -v ./... 38 | fi 39 | # Only run gofmt, vet & lint against the latest Go version 40 | - run: 41 | name: "Run golint" 42 | command: > 43 | if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then 44 | go get -u golang.org/x/lint/golint 45 | golint ./... 46 | fi 47 | - run: 48 | name: "Run gofmt" 49 | command: > 50 | if [[ << parameters.version >> = "latest" ]]; then 51 | diff -u <(echo -n) <(gofmt -d -e .) 52 | fi 53 | - run: 54 | name: "Run go vet" 55 | command: > 56 | if [[ << parameters.version >> = "latest" ]]; then 57 | go vet -v ./... 58 | fi 59 | - run: 60 | name: "Run go test (+ race detector)" 61 | command: > 62 | go test -v -race ./... 63 | 64 | workflows: 65 | tests: 66 | jobs: 67 | - test: 68 | matrix: 69 | parameters: 70 | version: ["1.22", "1.21", "1.20"] 71 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | # Config for https://github.com/apps/release-drafter 2 | template: | 3 | 4 | 5 | 6 | ## CHANGELOG 7 | $CHANGES 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | 24 | .idea/ 25 | *.iml 26 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of Gorilla WebSocket authors for copyright 2 | # purposes. 3 | # 4 | # Please keep the list sorted. 5 | 6 | Gary Burd 7 | Google LLC (https://opensource.google.com/) 8 | Joachim Bauch 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 17 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 18 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 19 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 20 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 21 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gorilla WebSocket 2 | 3 | [![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) 4 | [![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) 5 | 6 | Gorilla WebSocket is a [Go](http://golang.org/) implementation of the 7 | [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. 8 | 9 | 10 | ### Documentation 11 | 12 | * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) 13 | * [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat) 14 | * [Command example](https://github.com/gorilla/websocket/tree/main/examples/command) 15 | * [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo) 16 | * [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch) 17 | 18 | ### Status 19 | 20 | The Gorilla WebSocket package provides a complete and tested implementation of 21 | the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The 22 | package API is stable. 23 | 24 | ### Installation 25 | 26 | go get github.com/gorilla/websocket 27 | 28 | ### Protocol Compliance 29 | 30 | The Gorilla WebSocket package passes the server tests in the [Autobahn Test 31 | Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn 32 | subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn). 33 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "crypto/tls" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "net" 15 | "net/http" 16 | "net/http/httptrace" 17 | "net/url" 18 | "strings" 19 | "time" 20 | ) 21 | 22 | // ErrBadHandshake is returned when the server response to opening handshake is 23 | // invalid. 24 | var ErrBadHandshake = errors.New("websocket: bad handshake") 25 | 26 | var errInvalidCompression = errors.New("websocket: invalid compression negotiation") 27 | 28 | // NewClient creates a new client connection using the given net connection. 29 | // The URL u specifies the host and request URI. Use requestHeader to specify 30 | // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies 31 | // (Cookie). Use the response.Header to get the selected subprotocol 32 | // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 33 | // 34 | // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 35 | // non-nil *http.Response so that callers can handle redirects, authentication, 36 | // etc. 37 | // 38 | // Deprecated: Use Dialer instead. 39 | func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { 40 | d := Dialer{ 41 | ReadBufferSize: readBufSize, 42 | WriteBufferSize: writeBufSize, 43 | NetDial: func(net, addr string) (net.Conn, error) { 44 | return netConn, nil 45 | }, 46 | } 47 | return d.Dial(u.String(), requestHeader) 48 | } 49 | 50 | // A Dialer contains options for connecting to WebSocket server. 51 | // 52 | // It is safe to call Dialer's methods concurrently. 53 | type Dialer struct { 54 | // The following custom dial functions can be set to establish 55 | // connections to either the backend server or the proxy (if it 56 | // exists). The scheme of the dialed entity (either backend or 57 | // proxy) determines which custom dial function is selected: 58 | // either NetDialTLSContext for HTTPS or NetDialContext/NetDial 59 | // for HTTP. Since the "Proxy" function can determine the scheme 60 | // dynamically, it can make sense to set multiple custom dial 61 | // functions simultaneously. 62 | // 63 | // NetDial specifies the dial function for creating TCP connections. If 64 | // NetDial is nil, net.Dialer DialContext is used. 65 | // If "Proxy" field is also set, this function dials the proxy--not 66 | // the backend server. 67 | NetDial func(network, addr string) (net.Conn, error) 68 | 69 | // NetDialContext specifies the dial function for creating TCP connections. If 70 | // NetDialContext is nil, NetDial is used. 71 | // If "Proxy" field is also set, this function dials the proxy--not 72 | // the backend server. 73 | NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) 74 | 75 | // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If 76 | // NetDialTLSContext is nil, NetDialContext is used. 77 | // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and 78 | // TLSClientConfig is ignored. 79 | // If "Proxy" field is also set, this function dials the proxy (and performs 80 | // the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy 81 | // dialing case the TLSClientConfig could still be necessary for TLS to the backend server. 82 | NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) 83 | 84 | // Proxy specifies a function to return a proxy for a given 85 | // Request. If the function returns a non-nil error, the 86 | // request is aborted with the provided error. 87 | // If Proxy is nil or returns a nil *URL, no proxy is used. 88 | Proxy func(*http.Request) (*url.URL, error) 89 | 90 | // TLSClientConfig specifies the TLS configuration to use with tls.Client. 91 | // If nil, the default configuration is used. 92 | // If NetDialTLSContext is set, Dial assumes the TLS handshake 93 | // is done there and TLSClientConfig is ignored. 94 | TLSClientConfig *tls.Config 95 | 96 | // HandshakeTimeout specifies the duration for the handshake to complete. 97 | HandshakeTimeout time.Duration 98 | 99 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 100 | // size is zero, then a useful default size is used. The I/O buffer sizes 101 | // do not limit the size of the messages that can be sent or received. 102 | ReadBufferSize, WriteBufferSize int 103 | 104 | // WriteBufferPool is a pool of buffers for write operations. If the value 105 | // is not set, then write buffers are allocated to the connection for the 106 | // lifetime of the connection. 107 | // 108 | // A pool is most useful when the application has a modest volume of writes 109 | // across a large number of connections. 110 | // 111 | // Applications should use a single pool for each unique value of 112 | // WriteBufferSize. 113 | WriteBufferPool BufferPool 114 | 115 | // Subprotocols specifies the client's requested subprotocols. 116 | Subprotocols []string 117 | 118 | // EnableCompression specifies if the client should attempt to negotiate 119 | // per message compression (RFC 7692). Setting this value to true does not 120 | // guarantee that compression will be supported. Currently only "no context 121 | // takeover" modes are supported. 122 | EnableCompression bool 123 | 124 | // Jar specifies the cookie jar. 125 | // If Jar is nil, cookies are not sent in requests and ignored 126 | // in responses. 127 | Jar http.CookieJar 128 | } 129 | 130 | // Dial creates a new client connection by calling DialContext with a background context. 131 | func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 132 | return d.DialContext(context.Background(), urlStr, requestHeader) 133 | } 134 | 135 | var errMalformedURL = errors.New("malformed ws or wss URL") 136 | 137 | func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { 138 | hostPort = u.Host 139 | hostNoPort = u.Host 140 | if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { 141 | hostNoPort = hostNoPort[:i] 142 | } else { 143 | switch u.Scheme { 144 | case "wss": 145 | hostPort += ":443" 146 | case "https": 147 | hostPort += ":443" 148 | default: 149 | hostPort += ":80" 150 | } 151 | } 152 | return hostPort, hostNoPort 153 | } 154 | 155 | // DefaultDialer is a dialer with all fields set to the default values. 156 | var DefaultDialer = &Dialer{ 157 | Proxy: http.ProxyFromEnvironment, 158 | HandshakeTimeout: 45 * time.Second, 159 | } 160 | 161 | // nilDialer is dialer to use when receiver is nil. 162 | var nilDialer = *DefaultDialer 163 | 164 | // DialContext creates a new client connection. Use requestHeader to specify the 165 | // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). 166 | // Use the response.Header to get the selected subprotocol 167 | // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 168 | // 169 | // The context will be used in the request and in the Dialer. 170 | // 171 | // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 172 | // non-nil *http.Response so that callers can handle redirects, authentication, 173 | // etcetera. The response body may not contain the entire response and does not 174 | // need to be closed by the application. 175 | func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 176 | if d == nil { 177 | d = &nilDialer 178 | } 179 | 180 | challengeKey, err := generateChallengeKey() 181 | if err != nil { 182 | return nil, nil, err 183 | } 184 | 185 | u, err := url.Parse(urlStr) 186 | if err != nil { 187 | return nil, nil, err 188 | } 189 | 190 | switch u.Scheme { 191 | case "ws": 192 | u.Scheme = "http" 193 | case "wss": 194 | u.Scheme = "https" 195 | default: 196 | return nil, nil, errMalformedURL 197 | } 198 | 199 | if u.User != nil { 200 | // User name and password are not allowed in websocket URIs. 201 | return nil, nil, errMalformedURL 202 | } 203 | 204 | req := &http.Request{ 205 | Method: http.MethodGet, 206 | URL: u, 207 | Proto: "HTTP/1.1", 208 | ProtoMajor: 1, 209 | ProtoMinor: 1, 210 | Header: make(http.Header), 211 | Host: u.Host, 212 | } 213 | req = req.WithContext(ctx) 214 | 215 | // Set the cookies present in the cookie jar of the dialer 216 | if d.Jar != nil { 217 | for _, cookie := range d.Jar.Cookies(u) { 218 | req.AddCookie(cookie) 219 | } 220 | } 221 | 222 | // Set the request headers using the capitalization for names and values in 223 | // RFC examples. Although the capitalization shouldn't matter, there are 224 | // servers that depend on it. The Header.Set method is not used because the 225 | // method canonicalizes the header names. 226 | req.Header["Upgrade"] = []string{"websocket"} 227 | req.Header["Connection"] = []string{"Upgrade"} 228 | req.Header["Sec-WebSocket-Key"] = []string{challengeKey} 229 | req.Header["Sec-WebSocket-Version"] = []string{"13"} 230 | if len(d.Subprotocols) > 0 { 231 | req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} 232 | } 233 | for k, vs := range requestHeader { 234 | switch { 235 | case k == "Host": 236 | if len(vs) > 0 { 237 | req.Host = vs[0] 238 | } 239 | case k == "Upgrade" || 240 | k == "Connection" || 241 | k == "Sec-Websocket-Key" || 242 | k == "Sec-Websocket-Version" || 243 | k == "Sec-Websocket-Extensions" || 244 | (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): 245 | return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) 246 | case k == "Sec-Websocket-Protocol": 247 | req.Header["Sec-WebSocket-Protocol"] = vs 248 | default: 249 | req.Header[k] = vs 250 | } 251 | } 252 | 253 | if d.EnableCompression { 254 | req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} 255 | } 256 | 257 | if d.HandshakeTimeout != 0 { 258 | var cancel func() 259 | ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) 260 | defer cancel() 261 | } 262 | 263 | var proxyURL *url.URL 264 | if d.Proxy != nil { 265 | proxyURL, err = d.Proxy(req) 266 | if err != nil { 267 | return nil, nil, err 268 | } 269 | } 270 | netDial, err := d.netDialFn(ctx, proxyURL, u) 271 | if err != nil { 272 | return nil, nil, err 273 | } 274 | 275 | hostPort, hostNoPort := hostPortNoPort(u) 276 | trace := httptrace.ContextClientTrace(ctx) 277 | if trace != nil && trace.GetConn != nil { 278 | trace.GetConn(hostPort) 279 | } 280 | 281 | netConn, err := netDial(ctx, "tcp", hostPort) 282 | if err != nil { 283 | return nil, nil, err 284 | } 285 | if trace != nil && trace.GotConn != nil { 286 | trace.GotConn(httptrace.GotConnInfo{ 287 | Conn: netConn, 288 | }) 289 | } 290 | 291 | // Close the network connection when returning an error. The variable 292 | // netConn is set to nil before the success return at the end of the 293 | // function. 294 | defer func() { 295 | if netConn != nil { 296 | // It's safe to ignore the error from Close() because this code is 297 | // only executed when returning a more important error to the 298 | // application. 299 | _ = netConn.Close() 300 | } 301 | }() 302 | 303 | // Do TLS handshake over established connection if a proxy exists. 304 | if proxyURL != nil && u.Scheme == "https" { 305 | 306 | cfg := cloneTLSConfig(d.TLSClientConfig) 307 | if cfg.ServerName == "" { 308 | cfg.ServerName = hostNoPort 309 | } 310 | tlsConn := tls.Client(netConn, cfg) 311 | netConn = tlsConn 312 | 313 | if trace != nil && trace.TLSHandshakeStart != nil { 314 | trace.TLSHandshakeStart() 315 | } 316 | err := doHandshake(ctx, tlsConn, cfg) 317 | if trace != nil && trace.TLSHandshakeDone != nil { 318 | trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) 319 | } 320 | 321 | if err != nil { 322 | return nil, nil, err 323 | } 324 | } 325 | 326 | conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) 327 | 328 | if err := req.Write(netConn); err != nil { 329 | return nil, nil, err 330 | } 331 | 332 | if trace != nil && trace.GotFirstResponseByte != nil { 333 | if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { 334 | trace.GotFirstResponseByte() 335 | } 336 | } 337 | 338 | resp, err := http.ReadResponse(conn.br, req) 339 | if err != nil { 340 | if d.TLSClientConfig != nil { 341 | for _, proto := range d.TLSClientConfig.NextProtos { 342 | if proto != "http/1.1" { 343 | return nil, nil, fmt.Errorf( 344 | "websocket: protocol %q was given but is not supported;"+ 345 | "sharing tls.Config with net/http Transport can cause this error: %w", 346 | proto, err, 347 | ) 348 | } 349 | } 350 | } 351 | return nil, nil, err 352 | } 353 | 354 | if d.Jar != nil { 355 | if rc := resp.Cookies(); len(rc) > 0 { 356 | d.Jar.SetCookies(u, rc) 357 | } 358 | } 359 | 360 | if resp.StatusCode != 101 || 361 | !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || 362 | !tokenListContainsValue(resp.Header, "Connection", "upgrade") || 363 | resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { 364 | // Before closing the network connection on return from this 365 | // function, slurp up some of the response to aid application 366 | // debugging. 367 | buf := make([]byte, 1024) 368 | n, _ := io.ReadFull(resp.Body, buf) 369 | resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) 370 | return nil, resp, ErrBadHandshake 371 | } 372 | 373 | for _, ext := range parseExtensions(resp.Header) { 374 | if ext[""] != "permessage-deflate" { 375 | continue 376 | } 377 | _, snct := ext["server_no_context_takeover"] 378 | _, cnct := ext["client_no_context_takeover"] 379 | if !snct || !cnct { 380 | return nil, resp, errInvalidCompression 381 | } 382 | conn.newCompressionWriter = compressNoContextTakeover 383 | conn.newDecompressionReader = decompressNoContextTakeover 384 | break 385 | } 386 | 387 | resp.Body = io.NopCloser(bytes.NewReader([]byte{})) 388 | conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") 389 | 390 | if err := netConn.SetDeadline(time.Time{}); err != nil { 391 | return nil, resp, err 392 | } 393 | 394 | // Success! Set netConn to nil to stop the deferred function above from 395 | // closing the network connection. 396 | netConn = nil 397 | 398 | return conn, resp, nil 399 | } 400 | 401 | // Returns the dial function to establish the connection to either the backend 402 | // server or the proxy (if it exists). If the dialed entity is HTTPS, then the 403 | // returned dial function *also* performs the TLS handshake to the dialed entity. 404 | // NOTE: If a proxy exists, it is possible for a second TLS handshake to be 405 | // necessary over the established connection. 406 | func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *url.URL) (netDialerFunc, error) { 407 | var netDial netDialerFunc 408 | if proxyURL != nil { 409 | netDial = d.netDialFromURL(proxyURL) 410 | } else { 411 | netDial = d.netDialFromURL(backendURL) 412 | } 413 | // If needed, wrap the dial function to set the connection deadline. 414 | if deadline, ok := ctx.Deadline(); ok { 415 | netDial = netDialWithDeadline(netDial, deadline) 416 | } 417 | // Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth. 418 | if proxyURL != nil { 419 | return proxyFromURL(proxyURL, netDial) 420 | } 421 | return netDial, nil 422 | } 423 | 424 | // Returns function to create the connection depending on the Dialer's 425 | // custom dialing functions and the passed URL of entity connecting to. 426 | func (d *Dialer) netDialFromURL(u *url.URL) netDialerFunc { 427 | var netDial netDialerFunc 428 | switch { 429 | case d.NetDialContext != nil: 430 | netDial = d.NetDialContext 431 | case d.NetDial != nil: 432 | netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { 433 | return d.NetDial(net, addr) 434 | } 435 | default: 436 | netDial = (&net.Dialer{}).DialContext 437 | } 438 | // If dialed entity is HTTPS, then either use custom TLS dialing function (if exists) 439 | // or wrap the previously computed "netDial" to use TLS config for handshake. 440 | if u.Scheme == "https" { 441 | if d.NetDialTLSContext != nil { 442 | netDial = d.NetDialTLSContext 443 | } else { 444 | netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, u) 445 | } 446 | } 447 | return netDial 448 | } 449 | 450 | // Returns wrapped "netDial" function, performing TLS handshake after connecting. 451 | func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc { 452 | return func(ctx context.Context, unused, addr string) (net.Conn, error) { 453 | hostPort, hostNoPort := hostPortNoPort(u) 454 | trace := httptrace.ContextClientTrace(ctx) 455 | if trace != nil && trace.GetConn != nil { 456 | trace.GetConn(hostPort) 457 | } 458 | // Creates TCP connection to addr using passed "netDial" function. 459 | conn, err := netDial(ctx, "tcp", addr) 460 | if err != nil { 461 | return nil, err 462 | } 463 | cfg := cloneTLSConfig(tlsConfig) 464 | if cfg.ServerName == "" { 465 | cfg.ServerName = hostNoPort 466 | } 467 | tlsConn := tls.Client(conn, cfg) 468 | // Do the TLS handshake using TLSConfig over the wrapped connection. 469 | if trace != nil && trace.TLSHandshakeStart != nil { 470 | trace.TLSHandshakeStart() 471 | } 472 | err = doHandshake(ctx, tlsConn, cfg) 473 | if trace != nil && trace.TLSHandshakeDone != nil { 474 | trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) 475 | } 476 | if err != nil { 477 | tlsConn.Close() 478 | return nil, err 479 | } 480 | return tlsConn, nil 481 | } 482 | } 483 | 484 | // Returns wrapped "netDial" function, setting passed deadline. 485 | func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc { 486 | return func(ctx context.Context, network, addr string) (net.Conn, error) { 487 | c, err := netDial(ctx, network, addr) 488 | if err != nil { 489 | return nil, err 490 | } 491 | err = c.SetDeadline(deadline) 492 | if err != nil { 493 | c.Close() 494 | return nil, err 495 | } 496 | return c, nil 497 | } 498 | } 499 | 500 | func cloneTLSConfig(cfg *tls.Config) *tls.Config { 501 | if cfg == nil { 502 | return &tls.Config{} 503 | } 504 | return cfg.Clone() 505 | } 506 | 507 | func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { 508 | if err := tlsConn.HandshakeContext(ctx); err != nil { 509 | return err 510 | } 511 | if !cfg.InsecureSkipVerify { 512 | if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { 513 | return err 514 | } 515 | } 516 | return nil 517 | } 518 | -------------------------------------------------------------------------------- /client_proxy_server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "crypto/rand" 11 | "crypto/tls" 12 | "crypto/x509" 13 | "errors" 14 | "io" 15 | "net" 16 | "net/http" 17 | "net/http/httptest" 18 | "net/url" 19 | "strings" 20 | "sync/atomic" 21 | "testing" 22 | ) 23 | 24 | // These test cases use a websocket client (Dialer)/proxy/websocket server (Upgrader) 25 | // to validate the cases where a proxy is an intermediary between a websocket client 26 | // and server. The test cases usually 1) create a websocket server which echoes any 27 | // data received back to the client, 2) a basic duplex streaming proxy, and 3) a 28 | // websocket client which sends random data to the server through the proxy, 29 | // validating any subsequent data received is the same as the data sent. The various 30 | // permutations include the proxy and backend schemes (HTTP or HTTPS), as well as 31 | // the custom dial functions (e.g NetDialContext, NetDial) set on the Dialer. 32 | 33 | const ( 34 | subprotocolV1 = "subprotocol-version-1" 35 | subprotocolV2 = "subprotocol-version-2" 36 | ) 37 | 38 | // Permutation 1 39 | // 40 | // Backend: HTTP 41 | // Proxy: HTTP 42 | func TestHTTPProxyAndBackend(t *testing.T) { 43 | websocketTLS := false 44 | proxyTLS := false 45 | // Start the websocket server, which echoes data back to sender. 46 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 47 | defer websocketServer.Close() 48 | if err != nil { 49 | t.Fatalf("error starting websocket server: %v", err) 50 | } 51 | // Start the proxy server. 52 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 53 | defer proxyServer.Close() 54 | if err != nil { 55 | t.Fatalf("error starting proxy server: %v", err) 56 | } 57 | // Dial the websocket server through the proxy server. 58 | dialer := Dialer{ 59 | Proxy: http.ProxyURL(proxyServerURL), 60 | Subprotocols: []string{subprotocolV1}, 61 | } 62 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 63 | if err != nil { 64 | t.Fatalf("websocket dial error: %v", err) 65 | } 66 | // Send, receive, and validate random data over websocket connection. 67 | sendReceiveData(t, wsClient) 68 | // Validate the proxy server was called. 69 | if e, a := int64(1), proxyServer.numCalls(); e != a { 70 | t.Errorf("proxy not called") 71 | } 72 | } 73 | 74 | // Permutation 2 75 | // 76 | // Backend: HTTP 77 | // Proxy: HTTP 78 | // DialFn: NetDial (dials proxy) 79 | func TestHTTPProxyWithNetDial(t *testing.T) { 80 | websocketTLS := false 81 | proxyTLS := false 82 | // Start the websocket server, which echoes data back to sender. 83 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 84 | defer websocketServer.Close() 85 | if err != nil { 86 | t.Fatalf("error starting websocket server: %v", err) 87 | } 88 | // Start the proxy server. 89 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 90 | defer proxyServer.Close() 91 | if err != nil { 92 | t.Fatalf("error starting proxy server: %v", err) 93 | } 94 | // Dial the websocket server through the proxy server. 95 | var netDialCalled atomic.Int64 96 | dialer := Dialer{ 97 | NetDial: func(network, addr string) (net.Conn, error) { 98 | netDialCalled.Add(1) 99 | return (&net.Dialer{}).DialContext(context.Background(), network, addr) 100 | }, 101 | Proxy: http.ProxyURL(proxyServerURL), 102 | Subprotocols: []string{subprotocolV1}, 103 | } 104 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 105 | if err != nil { 106 | t.Fatalf("websocket dial error: %v", err) 107 | } 108 | // Send, receive, and validate random data over websocket connection. 109 | sendReceiveData(t, wsClient) 110 | if e, a := int64(1), netDialCalled.Load(); e != a { 111 | t.Errorf("netDial not called") 112 | } 113 | // Validate the proxy server was called. 114 | if e, a := int64(1), proxyServer.numCalls(); e != a { 115 | t.Errorf("proxy not called") 116 | } 117 | } 118 | 119 | // Permutation 3 120 | // 121 | // Backend: HTTP 122 | // Proxy: HTTP 123 | // DialFn: NetDialContext (dials proxy) 124 | func TestHTTPProxyWithNetDialContext(t *testing.T) { 125 | websocketTLS := false 126 | proxyTLS := false 127 | // Start the websocket server, which echoes data back to sender. 128 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 129 | defer websocketServer.Close() 130 | if err != nil { 131 | t.Fatalf("error starting websocket server: %v", err) 132 | } 133 | // Start the proxy server. 134 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 135 | defer proxyServer.Close() 136 | if err != nil { 137 | t.Fatalf("error starting proxy server: %v", err) 138 | } 139 | // Dial the websocket server through the proxy server. 140 | var netDialCalled atomic.Int64 141 | dialer := Dialer{ 142 | NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 143 | netDialCalled.Add(1) 144 | return (&net.Dialer{}).DialContext(ctx, network, addr) 145 | }, 146 | Proxy: http.ProxyURL(proxyServerURL), 147 | Subprotocols: []string{subprotocolV1}, 148 | } 149 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 150 | if err != nil { 151 | t.Fatalf("websocket dial error: %v", err) 152 | } 153 | // Send, receive, and validate random data over websocket connection. 154 | sendReceiveData(t, wsClient) 155 | if e, a := int64(1), netDialCalled.Load(); e != a { 156 | t.Errorf("netDial not called") 157 | } 158 | // Validate the proxy server was called. 159 | if e, a := int64(1), proxyServer.numCalls(); e != a { 160 | t.Errorf("proxy not called") 161 | } 162 | } 163 | 164 | // Permutation 4 165 | // 166 | // Backend: HTTPS 167 | // Proxy: HTTP 168 | // DialFn: NetDialTLSConfig (set but *ignored*) 169 | // TLS Config: set (used for backend TLS) 170 | func TestHTTPProxyWithHTTPSBackend(t *testing.T) { 171 | websocketTLS := true 172 | proxyTLS := false 173 | // Start the websocket server, which echoes data back to sender. 174 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 175 | defer websocketServer.Close() 176 | if err != nil { 177 | t.Fatalf("error starting websocket server: %v", err) 178 | } 179 | // Start the proxy server. 180 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 181 | defer proxyServer.Close() 182 | if err != nil { 183 | t.Fatalf("error starting proxy server: %v", err) 184 | } 185 | var netDialTLSCalled atomic.Int64 186 | dialer := Dialer{ 187 | Proxy: http.ProxyURL(proxyServerURL), 188 | // This function should be ignored, because an HTTP proxy exists 189 | // and the backend TLS handshake should use TLSClientConfig. 190 | NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 191 | netDialTLSCalled.Add(1) 192 | return (&net.Dialer{}).DialContext(ctx, network, addr) 193 | }, 194 | // Used for the backend server TLS handshake. 195 | TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), 196 | Subprotocols: []string{subprotocolV1}, 197 | } 198 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 199 | if err != nil { 200 | t.Fatalf("websocket dial error: %v", err) 201 | } 202 | // Send, receive, and validate random data over websocket connection. 203 | sendReceiveData(t, wsClient) 204 | if numTLSDials := netDialTLSCalled.Load(); numTLSDials > 0 { 205 | t.Errorf("NetDialTLS should have been ignored") 206 | } 207 | // Validate the proxy server was called. 208 | if e, a := int64(1), proxyServer.numCalls(); e != a { 209 | t.Errorf("proxy not called") 210 | } 211 | } 212 | 213 | // Permutation 5 214 | // 215 | // Backend: HTTPS 216 | // Proxy: HTTPS 217 | // TLS Config: set (used for both proxy and backend TLS) 218 | func TestHTTPSProxyAndBackend(t *testing.T) { 219 | websocketTLS := true 220 | proxyTLS := true 221 | // Start the websocket server, which echoes data back to sender. 222 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 223 | defer websocketServer.Close() 224 | if err != nil { 225 | t.Fatalf("error starting websocket server: %v", err) 226 | } 227 | // Start the proxy server. 228 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 229 | defer proxyServer.Close() 230 | if err != nil { 231 | t.Fatalf("error starting proxy server: %v", err) 232 | } 233 | dialer := Dialer{ 234 | Proxy: http.ProxyURL(proxyServerURL), 235 | TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), 236 | Subprotocols: []string{subprotocolV1}, 237 | } 238 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 239 | if err != nil { 240 | t.Fatalf("websocket dial error: %v", err) 241 | } 242 | // Send, receive, and validate random data over websocket connection. 243 | sendReceiveData(t, wsClient) 244 | // Validate the proxy server was called. 245 | if e, a := int64(1), proxyServer.numCalls(); e != a { 246 | t.Errorf("proxy not called") 247 | } 248 | } 249 | 250 | // Permutation 6 251 | // 252 | // Backend: HTTPS 253 | // Proxy: HTTPS 254 | // DialFn: NetDial (used to dial proxy) 255 | // TLS Config: set (used for both proxy and backend TLS) 256 | func TestHTTPSProxyUsingNetDial(t *testing.T) { 257 | websocketTLS := true 258 | proxyTLS := true 259 | // Start the websocket server, which echoes data back to sender. 260 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 261 | defer websocketServer.Close() 262 | if err != nil { 263 | t.Fatalf("error starting websocket server: %v", err) 264 | } 265 | // Start the proxy server. 266 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 267 | defer proxyServer.Close() 268 | if err != nil { 269 | t.Fatalf("error starting proxy server: %v", err) 270 | } 271 | var netDialCalled atomic.Int64 272 | dialer := Dialer{ 273 | NetDial: func(network, addr string) (net.Conn, error) { 274 | netDialCalled.Add(1) 275 | return (&net.Dialer{}).DialContext(context.Background(), network, addr) 276 | }, 277 | Proxy: http.ProxyURL(proxyServerURL), 278 | TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), 279 | Subprotocols: []string{subprotocolV1}, 280 | } 281 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 282 | if err != nil { 283 | t.Fatalf("websocket dial error: %v", err) 284 | } 285 | // Send, receive, and validate random data over websocket connection. 286 | sendReceiveData(t, wsClient) 287 | if e, a := int64(1), netDialCalled.Load(); e != a { 288 | t.Errorf("netDial not called") 289 | } 290 | // Validate the proxy server was called. 291 | if e, a := int64(1), proxyServer.numCalls(); e != a { 292 | t.Errorf("proxy not called") 293 | } 294 | } 295 | 296 | // Permutation 7 297 | // 298 | // Backend: HTTPS 299 | // Proxy: HTTPS 300 | // DialFn: NetDialContext (used to dial proxy) 301 | // TLS Config: set (used for both proxy and backend TLS) 302 | func TestHTTPSProxyUsingNetDialContext(t *testing.T) { 303 | websocketTLS := true 304 | proxyTLS := true 305 | // Start the websocket server, which echoes data back to sender. 306 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 307 | defer websocketServer.Close() 308 | if err != nil { 309 | t.Fatalf("error starting websocket server: %v", err) 310 | } 311 | // Start the proxy server. 312 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 313 | defer proxyServer.Close() 314 | if err != nil { 315 | t.Fatalf("error starting proxy server: %v", err) 316 | } 317 | var netDialCalled atomic.Int64 318 | dialer := Dialer{ 319 | NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 320 | netDialCalled.Add(1) 321 | return (&net.Dialer{}).DialContext(ctx, network, addr) 322 | }, 323 | Proxy: http.ProxyURL(proxyServerURL), 324 | TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), 325 | Subprotocols: []string{subprotocolV1}, 326 | } 327 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 328 | if err != nil { 329 | t.Fatalf("websocket dial error: %v", err) 330 | } 331 | // Send, receive, and validate random data over websocket connection. 332 | sendReceiveData(t, wsClient) 333 | if e, a := int64(1), netDialCalled.Load(); e != a { 334 | t.Errorf("netDial not called") 335 | } 336 | // Validate the proxy server was called. 337 | if e, a := int64(1), proxyServer.numCalls(); e != a { 338 | t.Errorf("proxy not called") 339 | } 340 | } 341 | 342 | // Permutation 8 343 | // 344 | // Backend: HTTPS 345 | // Proxy: HTTPS 346 | // DialFn: NetDialTLSContext (used for proxy TLS) 347 | // TLS Config: set (used for backend TLS) 348 | func TestHTTPSProxyUsingNetDialTLSContext(t *testing.T) { 349 | websocketTLS := true 350 | proxyTLS := true 351 | // Start the websocket server, which echoes data back to sender. 352 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 353 | defer websocketServer.Close() 354 | if err != nil { 355 | t.Fatalf("error starting websocket server: %v", err) 356 | } 357 | // Start the proxy server. 358 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 359 | defer proxyServer.Close() 360 | if err != nil { 361 | t.Fatalf("error starting proxy server: %v", err) 362 | } 363 | // Configure the proxy dialing function which dials the proxy and 364 | // performs the TLS handshake. 365 | var proxyDialCalled atomic.Int64 366 | proxyCerts := x509.NewCertPool() 367 | proxyCerts.AppendCertsFromPEM(proxyServerCert) 368 | proxyTLSConfig := &tls.Config{RootCAs: proxyCerts} 369 | proxyDial := func(ctx context.Context, network, addr string) (net.Conn, error) { 370 | proxyDialCalled.Add(1) 371 | return tls.Dial(network, addr, proxyTLSConfig) 372 | } 373 | // Configure the backend webscocket TLS configuration (handshake occurs 374 | // over the previously created proxy connection). 375 | websocketCerts := x509.NewCertPool() 376 | websocketCerts.AppendCertsFromPEM(websocketServerCert) 377 | websocketTLSConfig := &tls.Config{RootCAs: websocketCerts} 378 | dialer := Dialer{ 379 | Proxy: http.ProxyURL(proxyServerURL), 380 | // Dial and TLS handshake function to proxy. 381 | NetDialTLSContext: proxyDial, 382 | // Used for second TLS handshake to backend server over previously 383 | // established proxy connection. 384 | TLSClientConfig: websocketTLSConfig, 385 | Subprotocols: []string{subprotocolV1}, 386 | } 387 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 388 | if err != nil { 389 | t.Fatalf("websocket dial error: %v", err) 390 | } 391 | // Send, receive, and validate random data over websocket connection. 392 | sendReceiveData(t, wsClient) 393 | if e, a := int64(1), proxyDialCalled.Load(); e != a { 394 | t.Errorf("netDial not called") 395 | } 396 | // Validate the proxy server was called. 397 | if e, a := int64(1), proxyServer.numCalls(); e != a { 398 | t.Errorf("proxy not called") 399 | } 400 | } 401 | 402 | // Permutation 9 403 | // 404 | // Backend: HTTP 405 | // Proxy: HTTPS 406 | // TLS Config: set (used for proxy TLS) 407 | func TestHTTPSProxyHTTPBackend(t *testing.T) { 408 | websocketTLS := false 409 | proxyTLS := true 410 | // Start the websocket server, which echoes data back to sender. 411 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 412 | defer websocketServer.Close() 413 | if err != nil { 414 | t.Fatalf("error starting websocket server: %v", err) 415 | } 416 | // Start the proxy server. 417 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 418 | defer proxyServer.Close() 419 | if err != nil { 420 | t.Fatalf("error starting proxy server: %v", err) 421 | } 422 | dialer := Dialer{ 423 | Proxy: http.ProxyURL(proxyServerURL), 424 | TLSClientConfig: tlsConfig(websocketTLS, proxyTLS), 425 | Subprotocols: []string{subprotocolV1}, 426 | } 427 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 428 | if err != nil { 429 | t.Fatalf("websocket dial error: %v", err) 430 | } 431 | // Send, receive, and validate random data over websocket connection. 432 | sendReceiveData(t, wsClient) 433 | // Validate the proxy server was called. 434 | if e, a := int64(1), proxyServer.numCalls(); e != a { 435 | t.Errorf("proxy not called") 436 | } 437 | } 438 | 439 | // Permutation 10 440 | // 441 | // Backend: HTTP 442 | // Proxy: HTTPS 443 | // DialFn: NetDialTLSContext (used for proxy TLS) 444 | // TLS Config: set (ignored) 445 | func TestHTTPSProxyUsingNetDialTLSContextWithHTTPBackend(t *testing.T) { 446 | websocketTLS := false 447 | proxyTLS := true 448 | // Start the websocket server, which echoes data back to sender. 449 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 450 | defer websocketServer.Close() 451 | if err != nil { 452 | t.Fatalf("error starting websocket server: %v", err) 453 | } 454 | // Start the proxy server. 455 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 456 | defer proxyServer.Close() 457 | if err != nil { 458 | t.Fatalf("error starting proxy server: %v", err) 459 | } 460 | var proxyDialCalled atomic.Int64 461 | dialer := Dialer{ 462 | NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 463 | proxyDialCalled.Add(1) 464 | return tls.Dial(network, addr, tlsConfig(websocketTLS, proxyTLS)) 465 | }, 466 | Proxy: http.ProxyURL(proxyServerURL), 467 | TLSClientConfig: &tls.Config{}, // Misconfigured, but ignored. 468 | Subprotocols: []string{subprotocolV1}, 469 | } 470 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 471 | if err != nil { 472 | t.Fatalf("websocket dial error: %v", err) 473 | } 474 | // Send, receive, and validate random data over websocket connection. 475 | sendReceiveData(t, wsClient) 476 | if e, a := int64(1), proxyDialCalled.Load(); e != a { 477 | t.Errorf("netDial not called") 478 | } 479 | // Validate the proxy server was called. 480 | if e, a := int64(1), proxyServer.numCalls(); e != a { 481 | t.Errorf("proxy not called") 482 | } 483 | } 484 | 485 | func TestTLSValidationErrors(t *testing.T) { 486 | // Both websocket and proxy servers are started with TLS. 487 | websocketTLS := true 488 | proxyTLS := true 489 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 490 | defer websocketServer.Close() 491 | if err != nil { 492 | t.Fatalf("error starting websocket server: %v", err) 493 | } 494 | proxyServer, proxyServerURL, err := newProxyServer(proxyTLS) 495 | defer proxyServer.Close() 496 | if err != nil { 497 | t.Fatalf("error starting proxy server: %v", err) 498 | } 499 | // Dialer without proxy CA cert fails TLS verification. 500 | tlsError := "tls: failed to verify certificate" 501 | dialer := Dialer{ 502 | Proxy: http.ProxyURL(proxyServerURL), 503 | TLSClientConfig: tlsConfig(true, false), 504 | Subprotocols: []string{subprotocolV1}, 505 | } 506 | _, _, err = dialer.Dial(websocketURL.String(), nil) 507 | if err == nil { 508 | t.Errorf("expected proxy TLS verification error did not arrive") 509 | } else if !strings.Contains(err.Error(), tlsError) { 510 | t.Errorf("expected proxy TLS error (%s), got (%s)", err.Error(), tlsError) 511 | } 512 | // Validate the proxy handler was *NOT* called (because proxy 513 | // server TLS validation failed). 514 | if e, a := int64(0), proxyServer.numCalls(); e != a { 515 | t.Errorf("proxy should not have been called") 516 | } 517 | // Dialer without websocket CA cert fails TLS verification. 518 | dialer = Dialer{ 519 | Proxy: http.ProxyURL(proxyServerURL), 520 | TLSClientConfig: tlsConfig(false, true), 521 | Subprotocols: []string{subprotocolV1}, 522 | } 523 | _, _, err = dialer.Dial(websocketURL.String(), nil) 524 | if err == nil { 525 | t.Errorf("expected websocket TLS verification error did not arrive") 526 | } else if !strings.Contains(err.Error(), tlsError) { 527 | t.Errorf("expected websocket TLS error (%s), got (%s)", err.Error(), tlsError) 528 | } 529 | // Validate the proxy server *was* called (but subsequent 530 | // websocket server failed TLS validation). 531 | if e, a := int64(1), proxyServer.numCalls(); e != a { 532 | t.Errorf("proxy have been called") 533 | } 534 | } 535 | 536 | func TestProxyFnErrorIsPropagated(t *testing.T) { 537 | websocketServer, websocketURL, err := newWebsocketServer(false) 538 | defer websocketServer.Close() 539 | if err != nil { 540 | t.Fatalf("error starting websocket server: %v", err) 541 | } 542 | // Create a Dialer where Proxy function always returns an error. 543 | proxyURLError := errors.New("proxy URL generation error") 544 | dialer := Dialer{ 545 | Proxy: func(r *http.Request) (*url.URL, error) { 546 | return nil, proxyURLError 547 | }, 548 | Subprotocols: []string{subprotocolV1}, 549 | } 550 | // Proxy URL generation error should halt request and be propagated. 551 | _, _, err = dialer.Dial(websocketURL.String(), nil) 552 | if err == nil { 553 | t.Fatalf("expected websocket dial error, received none") 554 | } else if !errors.Is(proxyURLError, err) { 555 | t.Fatalf("expected error (%s), got (%s)", proxyURLError, err) 556 | } 557 | } 558 | 559 | func TestProxyFnNilMeansNoProxy(t *testing.T) { 560 | // Both websocket and proxy servers are started. 561 | websocketTLS := false 562 | proxyTLS := false 563 | websocketServer, websocketURL, err := newWebsocketServer(websocketTLS) 564 | defer websocketServer.Close() 565 | if err != nil { 566 | t.Fatalf("error starting websocket server: %v", err) 567 | } 568 | proxyServer, _, err := newProxyServer(proxyTLS) 569 | defer proxyServer.Close() 570 | if err != nil { 571 | t.Fatalf("error starting proxy server: %v", err) 572 | } 573 | // Dialer created with Proxy URL generation function returning nil 574 | // proxy URL, which continues with backend server connection without 575 | // proxying. 576 | dialer := Dialer{ 577 | Proxy: func(r *http.Request) (*url.URL, error) { 578 | return nil, nil 579 | }, 580 | Subprotocols: []string{subprotocolV1}, 581 | } 582 | wsClient, _, err := dialer.Dial(websocketURL.String(), nil) 583 | if err != nil { 584 | t.Fatalf("websocket dial error: %v", err) 585 | } 586 | sendReceiveData(t, wsClient) 587 | // Validate the proxy handler was *NOT* called (because proxy 588 | // URL generation returned nil). 589 | if e, a := int64(0), proxyServer.numCalls(); e != a { 590 | t.Errorf("proxy should not have been called") 591 | } 592 | } 593 | 594 | // "counter" interface can be implemented by a server to keep track 595 | // of the number of times a handler was called, as well as "Close". 596 | type counter interface { 597 | increment() 598 | numCalls() int64 599 | closer 600 | } 601 | 602 | type closer interface { 603 | Close() 604 | } 605 | 606 | // testServer implements "counter" interface. 607 | type testServer struct { 608 | server *httptest.Server 609 | numHandled atomic.Int64 610 | } 611 | 612 | func (ts *testServer) numCalls() int64 { 613 | return ts.numHandled.Load() 614 | } 615 | 616 | func (ts *testServer) increment() { 617 | ts.numHandled.Add(1) 618 | } 619 | 620 | func (ts *testServer) Close() { 621 | if ts.server != nil { 622 | ts.server.Close() 623 | } 624 | } 625 | 626 | // websocketEchoHandler upgrades the connection associated with the request, and 627 | // echoes binary messages read off the websocket connection back to the client. 628 | var websocketEchoHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 629 | upgrader := Upgrader{ 630 | CheckOrigin: func(r *http.Request) bool { 631 | return true // Accepting all requests 632 | }, 633 | Subprotocols: []string{ 634 | subprotocolV1, 635 | subprotocolV2, 636 | }, 637 | } 638 | wsConn, err := upgrader.Upgrade(w, req, nil) 639 | if err != nil { 640 | http.Error(w, err.Error(), http.StatusInternalServerError) 641 | } 642 | defer wsConn.Close() 643 | for { 644 | writer, err := wsConn.NextWriter(BinaryMessage) 645 | if err != nil { 646 | break 647 | } 648 | messageType, reader, err := wsConn.NextReader() 649 | if err != nil { 650 | break 651 | } 652 | if messageType != BinaryMessage { 653 | http.Error(w, "websocket reader not binary message type", 654 | http.StatusInternalServerError) 655 | } 656 | _, err = io.Copy(writer, reader) 657 | if err != nil { 658 | http.Error(w, "websocket server io copy error", 659 | http.StatusInternalServerError) 660 | } 661 | } 662 | }) 663 | 664 | // Returns a test backend websocket server as well as the URL pointing 665 | // to the server, or an error if one occurred. Sets up a TLS endpoint 666 | // on the server if the passed "tlsServer" is true. 667 | // func newWebsocketServer(tlsServer bool) (*httptest.Server, *url.URL, error) { 668 | func newWebsocketServer(tlsServer bool) (closer, *url.URL, error) { 669 | // Start the websocket server, which echoes data back to sender. 670 | websocketServer := httptest.NewUnstartedServer(websocketEchoHandler) 671 | if tlsServer { 672 | websocketKeyPair, err := tls.X509KeyPair(websocketServerCert, websocketServerKey) 673 | if err != nil { 674 | return nil, nil, err 675 | } 676 | websocketServer.TLS = &tls.Config{ 677 | Certificates: []tls.Certificate{websocketKeyPair}, 678 | } 679 | websocketServer.StartTLS() 680 | } else { 681 | websocketServer.Start() 682 | } 683 | websocketURL, err := url.Parse(websocketServer.URL) 684 | if err != nil { 685 | return nil, nil, err 686 | } 687 | if tlsServer { 688 | websocketURL.Scheme = "wss" 689 | } else { 690 | websocketURL.Scheme = "ws" 691 | } 692 | return websocketServer, websocketURL, nil 693 | } 694 | 695 | // proxyHandler creates a full duplex streaming connection between the client 696 | // (hijacking the http request connection), and an "upstream" dialed connection 697 | // to the "Host". Creates two goroutines to copy between connections in each direction. 698 | var proxyHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 699 | // Validate the CONNECT method. 700 | if req.Method != http.MethodConnect { 701 | http.Error(w, "method not allowed", http.StatusMethodNotAllowed) 702 | return 703 | } 704 | // Dial upstream server. 705 | upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host) 706 | if err != nil { 707 | http.Error(w, err.Error(), http.StatusInternalServerError) 708 | return 709 | } 710 | defer upstream.Close() 711 | // Return 200 OK to client. 712 | w.WriteHeader(http.StatusOK) 713 | // Hijack client connection. 714 | client, _, err := w.(http.Hijacker).Hijack() 715 | if err != nil { 716 | http.Error(w, err.Error(), http.StatusInternalServerError) 717 | return 718 | } 719 | defer client.Close() 720 | // Create duplex streaming between client and upstream connections. 721 | done := make(chan struct{}, 2) 722 | go func() { 723 | _, _ = io.Copy(upstream, client) 724 | done <- struct{}{} 725 | }() 726 | go func() { 727 | _, _ = io.Copy(client, upstream) 728 | done <- struct{}{} 729 | }() 730 | <-done 731 | }) 732 | 733 | // Returns a new test HTTP server, as well as the URL to that server, or 734 | // an error if one occurred. numProxyCalls keeps track of the number of 735 | // times the proxy handler was called with this server. 736 | func newProxyServer(tlsServer bool) (counter, *url.URL, error) { 737 | // Start the proxy server, keeping track of how many times the handler is called. 738 | ts := &testServer{} 739 | proxyServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 740 | ts.increment() 741 | proxyHandler.ServeHTTP(w, req) 742 | })) 743 | if tlsServer { 744 | proxyKeyPair, err := tls.X509KeyPair(proxyServerCert, proxyServerKey) 745 | if err != nil { 746 | return nil, nil, err 747 | } 748 | proxyServer.TLS = &tls.Config{ 749 | Certificates: []tls.Certificate{proxyKeyPair}, 750 | } 751 | proxyServer.StartTLS() 752 | } else { 753 | proxyServer.Start() 754 | } 755 | proxyURL, err := url.Parse(proxyServer.URL) 756 | if err != nil { 757 | return nil, nil, err 758 | } 759 | return ts, proxyURL, nil 760 | } 761 | 762 | // Returns the TLS config with the RootCAs cert pool set. If 763 | // neither websocket nor proxy server uses TLS, returns nil. 764 | func tlsConfig(websocketTLS bool, proxyTLS bool) *tls.Config { 765 | if !websocketTLS && !proxyTLS { 766 | return nil 767 | } 768 | certPool := x509.NewCertPool() 769 | tlsConfig := &tls.Config{ 770 | RootCAs: certPool, 771 | } 772 | if websocketTLS { 773 | tlsConfig.RootCAs.AppendCertsFromPEM(websocketServerCert) 774 | } 775 | if proxyTLS { 776 | tlsConfig.RootCAs.AppendCertsFromPEM(proxyServerCert) 777 | } 778 | return tlsConfig 779 | } 780 | 781 | // Sends, receives, and validates random data sent and received 782 | // over the passed websocket connection. 783 | const randomDataSize = 128 * 1024 784 | 785 | func sendReceiveData(t *testing.T, wsConn *Conn) { 786 | // Create the random data. 787 | randomData := make([]byte, randomDataSize) 788 | if _, err := rand.Read(randomData); err != nil { 789 | t.Errorf("unexpected error reading random data: %v", err) 790 | } 791 | // Send the random data. 792 | err := wsConn.WriteMessage(BinaryMessage, randomData) 793 | if err != nil { 794 | t.Errorf("websocket write error: %v", err) 795 | } 796 | // Read from the websocket connection, and validate the 797 | // read data is the same as the previously sent data. 798 | _, received, err := wsConn.ReadMessage() 799 | if !bytes.Equal(randomData, received) { 800 | t.Errorf("unexpected data received: %d bytes sent, %d bytes received", 801 | len(received), len(randomData)) 802 | } 803 | } 804 | 805 | // proxyServerCert was generated from crypto/tls/generate_cert.go with the following command: 806 | // 807 | // go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h 808 | // 809 | // proxyServerCert is a self-signed. 810 | var proxyServerCert = []byte(`-----BEGIN CERTIFICATE----- 811 | MIIDGTCCAgGgAwIBAgIRALL5AZcefF4kkYV1SEG6YrMwDQYJKoZIhvcNAQELBQAw 812 | EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 813 | MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEP 814 | ADCCAQoCggEBALQ/FHcyVwdFHxARbbD2KBtDUT7Eni+8ioNdjtGcmtXqBv45EC1C 815 | JOqqGJTroFGJ6Q9kQIZ9FqH5IJR2fOOJD9kOTueG4Vt1JY1rj1Kbpjefu8XleZ5L 816 | SBwIWVnN/lEsEbuKmj7N2gLt5AH3zMZiBI1mg1u9Z5ZZHYbCiTpBrwsq6cTlvR9g 817 | dyo1YkM5hRESCzsrL0aUByoo0qRMD8ZsgANJwgsiO0/M6idbxDwv1BnGwGmRYvOE 818 | Hxpy3v0Jg7GJYrvnpnifJTs4nw91N5X9pXxR7FFzi/6HTYDWRljvTb0w6XciKYAz 819 | bWZ0+cJr5F7wB7ovlbm7HrQIR7z7EIIu2d8CAwEAAaNoMGYwDgYDVR0PAQH/BAQD 820 | AgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wLgYDVR0R 821 | BCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZI 822 | hvcNAQELBQADggEBAFPPWopNEJtIA2VFAQcqN6uJK+JVFOnjGRoCrM6Xgzdm0wxY 823 | XCGjsxY5dl+V7KzdGqu858rCaq5osEBqypBpYAnS9C38VyCDA1vPS1PsN8SYv48z 824 | DyBwj+7R2qar0ADBhnhWxvYO9M72lN/wuCqFKYMeFSnJdQLv3AsrrHe9lYqOa36s 825 | 8wxSwVTFTYXBzljPEnSaaJMPqFD8JXaZK1ryJPkO5OsCNQNGtatNiWAf3DcmwHAT 826 | MGYMzP0u4nw47aRz9shB8w+taPKHx2BVwE1m/yp3nHVioOjXqA1fwRQVGclCJSH1 827 | D2iq3hWVHRENgjTjANBPICLo9AZ4JfN6PH19mnU= 828 | -----END CERTIFICATE-----`) 829 | 830 | // proxyServerKey is the private key for proxyServerCert. 831 | var proxyServerKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 832 | MIIEogIBAAKCAQEAtD8UdzJXB0UfEBFtsPYoG0NRPsSeL7yKg12O0Zya1eoG/jkQ 833 | LUIk6qoYlOugUYnpD2RAhn0WofkglHZ844kP2Q5O54bhW3UljWuPUpumN5+7xeV5 834 | nktIHAhZWc3+USwRu4qaPs3aAu3kAffMxmIEjWaDW71nllkdhsKJOkGvCyrpxOW9 835 | H2B3KjViQzmFERILOysvRpQHKijSpEwPxmyAA0nCCyI7T8zqJ1vEPC/UGcbAaZFi 836 | 84QfGnLe/QmDsYliu+emeJ8lOzifD3U3lf2lfFHsUXOL/odNgNZGWO9NvTDpdyIp 837 | gDNtZnT5wmvkXvAHui+VubsetAhHvPsQgi7Z3wIDAQABAoIBAGmw93IxjYCQ0ncc 838 | kSKMJNZfsdtJdaxuNRZ0nNNirhQzR2h403iGaZlEpmdkhzxozsWcto1l+gh+SdFk 839 | bTUK4MUZM8FlgO2dEqkLYh5BcMT7ICMZvSfJ4v21E5eqR68XVUqQKoQbNvQyxFk3 840 | EddeEGdNrkb0GDK8DKlBlzAW5ep4gjG85wSTjR+J+muUv3R0BgLBFSuQnIDM/IMB 841 | LWqsja/QbtB7yppe7jL5u8UCFdZG8BBKT9fcvFIu5PRLO3MO0uOI7LTc8+W1Xm23 842 | uv+j3SY0+v+6POjK0UlJFFi/wkSPTFIfrQO1qFBkTDQHhQ6q/7GnILYYOiGbIRg2 843 | NNuP52ECgYEAzXEoy50wSYh8xfFaBuxbm3ruuG2W49jgop7ZfoFrPWwOQKAZS441 844 | VIwV4+e5IcA6KkuYbtGSdTYqK1SMkgnUyD/VevwAqH5TJoEIGu0pDuKGwVuwqioZ 845 | frCIAV5GllKyUJ55VZNbRr2vY2fCsWbaCSCHETn6C16DNuTCe5C0JBECgYEA4JqY 846 | 5GpNbMG8fOt4H7hU0Fbm2yd6SHJcQ3/9iimef7xG6ajxsYrIhg1ft+3IPHMjVI0+ 847 | 9brwHDnWg4bOOx/VO4VJBt6Dm/F33bndnZRkuIjfSNpLM51P+EnRdaFVHOJHwKqx 848 | uF69kihifCAG7YATgCveeXImzBUSyZUz9UrETu8CgYARNBimdFNG1RcdvEg9rC0/ 849 | p9u1tfecvNySwZqU7WF9kz7eSonTueTdX521qAHowaAdSpdJMGODTTXaywm6cPhQ 850 | jIfj9JZZhbqQzt1O4+08Qdvm9TamCUB5S28YLjza+bHU7nBaqixKkDfPqzCyilpX 851 | yVGGL8SwjwmN3zop/sQXAQKBgC0JMsESQ6YcDsRpnrOVjYQc+LtW5iEitTdfsaID 852 | iGGKihmOI7B66IxgoCHMTws39wycKdSyADVYr5e97xpR3rrJlgQHmBIrz+Iow7Q2 853 | LiAGaec8xjl6QK/DdXmFuQBKqyKJ14rljFODP4QuE9WJid94bGqjpf3j99ltznZP 854 | 4J8HAoGAJb4eb4lu4UGwifDzqfAPzLGCoi0fE1/hSx34lfuLcc1G+LEu9YDKoOVJ 855 | 9suOh0b5K/bfEy9KrVMBBriduvdaERSD8S3pkIQaitIz0B029AbE4FLFf9lKQpP2 856 | KR8NJEkK99Vh/tew6jAMll70xFrE7aF8VLXJVE7w4sQzuvHxl9Q= 857 | -----END RSA PRIVATE KEY----- 858 | `) 859 | 860 | // websocketServerCert is self-signed. 861 | var websocketServerCert = []byte(`-----BEGIN CERTIFICATE----- 862 | MIIDOTCCAiGgAwIBAgIQYSN1VY/favsLUo+B7gJ5tTANBgkqhkiG9w0BAQsFADAS 863 | MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw 864 | MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A 865 | MIIBCgKCAQEApBlintjkL1fO1Sk2pzNvl862CtTwU7/Jy6EZqWzI17wEbPn4sbSD 866 | bHhfDlPl2nmw3hVkc6LNK+eqzm2GX/ai4tgMiaH7kyyNit1K3g7y7GISMf9poWIa 867 | POJhid2wmhKHbEtHECSdQ5c/jEN1UVzB4go5LO7MEEVo9kyQ+yBqS6gISyFmfaT4 868 | qOsPJBir33bBpptSend1JSXaRTXqRa1p+oudw2ILa4U7KfuKK3emp21m5/HYAuSf 869 | CV4WqqDoDiBPMpsQ0kPEPugWZKFeF3qanmqFFvptYx+zJbOznWYY2D3idWsvcg6q 870 | VLPEB19oXaVBV0HXPFtObm5m1jCpl8FI1wIDAQABo4GIMIGFMA4GA1UdDwEB/wQE 871 | AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud 872 | DgQWBBQcSkjqA9rgos1daegNj49BpRCA0jAuBgNVHREEJzAlggtleGFtcGxlLmNv 873 | bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAnk9i 874 | 9rogNTi9B1pn+Fbk3WALKdEjv/uyePsTnwdyvswVbeYbQweU9TrhYT2+eXbMA5kY 875 | 7TaQm46idRqxCKMgc3Ip3DADJdm8cJX9p2ExU4fKdkPc1KD/J+4QHHx1W2Ml5S2o 876 | foOo6j1F0UdZP/rBj0UumEZp32qW+4DhVV/QQjUB8J0gaDC7yZBMdyMIeClR0RqE 877 | YfZdCJbQHqtTwBXN+imQUHPGmksYkRDpFRvw/4crpcMIE04mVVd99nOpFCQnK61t 878 | 9US1y17VW1lYpkqlCS+rkcAtor4Z5naSf9/oLGCxEAwyW0pwHGO6MXtMxvB/JD20 879 | hJdlz1I7wlSfF4MiRQ== 880 | -----END CERTIFICATE-----`) 881 | 882 | // websocketServerKey is the private key for websocketServerCert. 883 | var websocketServerKey = []byte(`-----BEGIN PRIVATE KEY----- 884 | MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCkGWKe2OQvV87V 885 | KTanM2+XzrYK1PBTv8nLoRmpbMjXvARs+fixtINseF8OU+XaebDeFWRzos0r56rO 886 | bYZf9qLi2AyJofuTLI2K3UreDvLsYhIx/2mhYho84mGJ3bCaEodsS0cQJJ1Dlz+M 887 | Q3VRXMHiCjks7swQRWj2TJD7IGpLqAhLIWZ9pPio6w8kGKvfdsGmm1J6d3UlJdpF 888 | NepFrWn6i53DYgtrhTsp+4ord6anbWbn8dgC5J8JXhaqoOgOIE8ymxDSQ8Q+6BZk 889 | oV4XepqeaoUW+m1jH7Mls7OdZhjYPeJ1ay9yDqpUs8QHX2hdpUFXQdc8W05ubmbW 890 | MKmXwUjXAgMBAAECggEAE6BkTDgH//rnkP/Ej/Y17Zkv6qxnMLe/4evwZB7PsrBu 891 | cxOUAWUOpvA1UO215bh87+2XvcDbUISnyC1kpKDyAGGeC5llER2DXE11VokWgtvZ 892 | Q0OXavw5w83A+WVGFFdiUmXP0l10CxEm7OwQjFz6D21GQ1qC65tG9NZZghTxbFTe 893 | iZKqgWqyHsaAWLOuDQbj1FTEBMFrY8f9RbclSh0luPZnzGc4BVI/t34jKPZBpH2N 894 | NCkr8aB7MMHGhrNZFHAu/KAvq8UBrDTX+O8ERMwcwQWB4nne2+GOTN0MdcAUc72i 895 | GryzIa8TgO+TpQOYoZ4NPnzFrsa+m3G2Tug3vbt62QKBgQDOPfM4/5/x/h/ggxQn 896 | aRvEOC+8ldeqEOS1VTGiuDKJMWXrNkG+d+AsxfNP4k0QVNrpEAZSYcf0gnS9Odcl 897 | luEsi/yPZDDnPg/cS+Z3336VKsggly7BWFs1Ct/9I+ZfSCl88TkVpIfeCBC34XEb 898 | 0mFUq/RdLqXj/mVLbBfr+H8cEwKBgQDLsJUm8lkWFAPJ8UMto8xeUMGk44VukYwx 899 | +oI6KhplFntiI0C1Dd9wrxyCjySlJcc0NFt6IPN84d7pI9LQSbiKXQ1jMvsBzd4G 900 | EMtG8SHpIY/mMU+KzWLHYVFS0FA4PvXXvPRNLOXas7hbALZdLshVKd7aDlkQAb5C 901 | KWFHeIFwrQKBgA8r5Xl67HQrwoKMge4IQF+l1nUj/LJo/boNI1KaBDWtaZbs7dcq 902 | EFaa1TQ6LHsYEuZ0JFLpGIF3G0lUOOxt9fCF97VApIxON3J4LuMAkNo+RGyJUoos 903 | isETJLkFbAv0TgD/6bga21fM9hXgwqZOSpSk9ZvpM5DbBO6QbA4SwJ77AoGAX7h1 904 | /z14XAW/2hDE7xfAnLn6plA9jj5b0cjVlhvfF44/IVlLuUnxrPS9wyUdpXZhbMkG 905 | DBicFB3ZMVqiYTuju3ILLojwqGJkahlOTeJXe0VIaHbX2HS4bNXw76fxat07jsy/ 906 | Sd1Fj0dR5YIqMRQhFNR+Y57Gf90x2cm0a2/X9GkCgYANawYx9bNfcX0HMVG7vktK 907 | 6/80omnoBM0JUxA+V7DxS8kr9Cj2Y/kcS+VHb4yyoSkDgnsSdnCr1ZTctcj828MJ 908 | 8AUwskAtEjPkHRXEgRRnEl2oJGD1TT5iwBNnuPAQDXwzkGCRYBnlfZNbILbOoSUz 909 | m+VDcqT5XzcRADa/TLlEXA== 910 | -----END PRIVATE KEY----- 911 | `) 912 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "net/url" 9 | "testing" 10 | ) 11 | 12 | var hostPortNoPortTests = []struct { 13 | u *url.URL 14 | hostPort, hostNoPort string 15 | }{ 16 | {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"}, 17 | {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"}, 18 | {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"}, 19 | {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"}, 20 | } 21 | 22 | func TestHostPortNoPort(t *testing.T) { 23 | for _, tt := range hostPortNoPortTests { 24 | hostPort, hostNoPort := hostPortNoPort(tt.u) 25 | if hostPort != tt.hostPort { 26 | t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort) 27 | } 28 | if hostNoPort != tt.hostNoPort { 29 | t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /compression.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "compress/flate" 9 | "errors" 10 | "io" 11 | "strings" 12 | "sync" 13 | ) 14 | 15 | const ( 16 | minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 17 | maxCompressionLevel = flate.BestCompression 18 | defaultCompressionLevel = 1 19 | ) 20 | 21 | var ( 22 | flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool 23 | flateReaderPool = sync.Pool{New: func() interface{} { 24 | return flate.NewReader(nil) 25 | }} 26 | ) 27 | 28 | func decompressNoContextTakeover(r io.Reader) io.ReadCloser { 29 | const tail = 30 | // Add four bytes as specified in RFC 31 | "\x00\x00\xff\xff" + 32 | // Add final block to squelch unexpected EOF error from flate reader. 33 | "\x01\x00\x00\xff\xff" 34 | 35 | fr, _ := flateReaderPool.Get().(io.ReadCloser) 36 | mr := io.MultiReader(r, strings.NewReader(tail)) 37 | if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { 38 | // Reset never fails, but handle error in case that changes. 39 | fr = flate.NewReader(mr) 40 | } 41 | return &flateReadWrapper{fr} 42 | } 43 | 44 | func isValidCompressionLevel(level int) bool { 45 | return minCompressionLevel <= level && level <= maxCompressionLevel 46 | } 47 | 48 | func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { 49 | p := &flateWriterPools[level-minCompressionLevel] 50 | tw := &truncWriter{w: w} 51 | fw, _ := p.Get().(*flate.Writer) 52 | if fw == nil { 53 | fw, _ = flate.NewWriter(tw, level) 54 | } else { 55 | fw.Reset(tw) 56 | } 57 | return &flateWriteWrapper{fw: fw, tw: tw, p: p} 58 | } 59 | 60 | // truncWriter is an io.Writer that writes all but the last four bytes of the 61 | // stream to another io.Writer. 62 | type truncWriter struct { 63 | w io.WriteCloser 64 | n int 65 | p [4]byte 66 | } 67 | 68 | func (w *truncWriter) Write(p []byte) (int, error) { 69 | n := 0 70 | 71 | // fill buffer first for simplicity. 72 | if w.n < len(w.p) { 73 | n = copy(w.p[w.n:], p) 74 | p = p[n:] 75 | w.n += n 76 | if len(p) == 0 { 77 | return n, nil 78 | } 79 | } 80 | 81 | m := len(p) 82 | if m > len(w.p) { 83 | m = len(w.p) 84 | } 85 | 86 | if nn, err := w.w.Write(w.p[:m]); err != nil { 87 | return n + nn, err 88 | } 89 | 90 | copy(w.p[:], w.p[m:]) 91 | copy(w.p[len(w.p)-m:], p[len(p)-m:]) 92 | nn, err := w.w.Write(p[:len(p)-m]) 93 | return n + nn, err 94 | } 95 | 96 | type flateWriteWrapper struct { 97 | fw *flate.Writer 98 | tw *truncWriter 99 | p *sync.Pool 100 | } 101 | 102 | func (w *flateWriteWrapper) Write(p []byte) (int, error) { 103 | if w.fw == nil { 104 | return 0, errWriteClosed 105 | } 106 | return w.fw.Write(p) 107 | } 108 | 109 | func (w *flateWriteWrapper) Close() error { 110 | if w.fw == nil { 111 | return errWriteClosed 112 | } 113 | err1 := w.fw.Flush() 114 | w.p.Put(w.fw) 115 | w.fw = nil 116 | if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { 117 | return errors.New("websocket: internal error, unexpected bytes at end of flate stream") 118 | } 119 | err2 := w.tw.w.Close() 120 | if err1 != nil { 121 | return err1 122 | } 123 | return err2 124 | } 125 | 126 | type flateReadWrapper struct { 127 | fr io.ReadCloser 128 | } 129 | 130 | func (r *flateReadWrapper) Read(p []byte) (int, error) { 131 | if r.fr == nil { 132 | return 0, io.ErrClosedPipe 133 | } 134 | n, err := r.fr.Read(p) 135 | if err == io.EOF { 136 | // Preemptively place the reader back in the pool. This helps with 137 | // scenarios where the application does not call NextReader() soon after 138 | // this final read. 139 | r.Close() 140 | } 141 | return n, err 142 | } 143 | 144 | func (r *flateReadWrapper) Close() error { 145 | if r.fr == nil { 146 | return io.ErrClosedPipe 147 | } 148 | err := r.fr.Close() 149 | flateReaderPool.Put(r.fr) 150 | r.fr = nil 151 | return err 152 | } 153 | -------------------------------------------------------------------------------- /compression_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "testing" 8 | ) 9 | 10 | type nopCloser struct{ io.Writer } 11 | 12 | func (nopCloser) Close() error { return nil } 13 | 14 | func TestTruncWriter(t *testing.T) { 15 | const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" 16 | for n := 1; n <= 10; n++ { 17 | var b bytes.Buffer 18 | w := &truncWriter{w: nopCloser{&b}} 19 | p := []byte(data) 20 | for len(p) > 0 { 21 | m := len(p) 22 | if m > n { 23 | m = n 24 | } 25 | _, _ = w.Write(p[:m]) 26 | p = p[m:] 27 | } 28 | if b.String() != data[:len(data)-len(w.p)] { 29 | t.Errorf("%d: %q", n, b.String()) 30 | } 31 | } 32 | } 33 | 34 | func textMessages(num int) [][]byte { 35 | messages := make([][]byte, num) 36 | for i := 0; i < num; i++ { 37 | msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) 38 | messages[i] = []byte(msg) 39 | } 40 | return messages 41 | } 42 | 43 | func BenchmarkWriteNoCompression(b *testing.B) { 44 | w := io.Discard 45 | c := newTestConn(nil, w, false) 46 | messages := textMessages(100) 47 | b.ResetTimer() 48 | for i := 0; i < b.N; i++ { 49 | _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) 50 | } 51 | b.ReportAllocs() 52 | } 53 | 54 | func BenchmarkWriteWithCompression(b *testing.B) { 55 | w := io.Discard 56 | c := newTestConn(nil, w, false) 57 | messages := textMessages(100) 58 | c.enableWriteCompression = true 59 | c.newCompressionWriter = compressNoContextTakeover 60 | b.ResetTimer() 61 | for i := 0; i < b.N; i++ { 62 | _ = c.WriteMessage(TextMessage, messages[i%len(messages)]) 63 | } 64 | b.ReportAllocs() 65 | } 66 | 67 | func TestValidCompressionLevel(t *testing.T) { 68 | c := newTestConn(nil, nil, false) 69 | for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { 70 | if err := c.SetCompressionLevel(level); err == nil { 71 | t.Errorf("no error for level %d", level) 72 | } 73 | } 74 | for _, level := range []int{minCompressionLevel, maxCompressionLevel} { 75 | if err := c.SetCompressionLevel(level); err != nil { 76 | t.Errorf("error for level %d", level) 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /conn_broadcast_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "io" 9 | "sync/atomic" 10 | "testing" 11 | ) 12 | 13 | // broadcastBench allows to run broadcast benchmarks. 14 | // In every broadcast benchmark we create many connections, then send the same 15 | // message into every connection and wait for all writes complete. This emulates 16 | // an application where many connections listen to the same data - i.e. PUB/SUB 17 | // scenarios with many subscribers in one channel. 18 | type broadcastBench struct { 19 | w io.Writer 20 | closeCh chan struct{} 21 | doneCh chan struct{} 22 | count int32 23 | conns []*broadcastConn 24 | compression bool 25 | usePrepared bool 26 | } 27 | 28 | type broadcastMessage struct { 29 | payload []byte 30 | prepared *PreparedMessage 31 | } 32 | 33 | type broadcastConn struct { 34 | conn *Conn 35 | msgCh chan *broadcastMessage 36 | } 37 | 38 | func newBroadcastConn(c *Conn) *broadcastConn { 39 | return &broadcastConn{ 40 | conn: c, 41 | msgCh: make(chan *broadcastMessage, 1), 42 | } 43 | } 44 | 45 | func newBroadcastBench(usePrepared, compression bool) *broadcastBench { 46 | bench := &broadcastBench{ 47 | w: io.Discard, 48 | doneCh: make(chan struct{}), 49 | closeCh: make(chan struct{}), 50 | usePrepared: usePrepared, 51 | compression: compression, 52 | } 53 | bench.makeConns(10000) 54 | return bench 55 | } 56 | 57 | func (b *broadcastBench) makeConns(numConns int) { 58 | conns := make([]*broadcastConn, numConns) 59 | 60 | for i := 0; i < numConns; i++ { 61 | c := newTestConn(nil, b.w, true) 62 | if b.compression { 63 | c.enableWriteCompression = true 64 | c.newCompressionWriter = compressNoContextTakeover 65 | } 66 | conns[i] = newBroadcastConn(c) 67 | go func(c *broadcastConn) { 68 | for { 69 | select { 70 | case msg := <-c.msgCh: 71 | if msg.prepared != nil { 72 | _ = c.conn.WritePreparedMessage(msg.prepared) 73 | } else { 74 | _ = c.conn.WriteMessage(TextMessage, msg.payload) 75 | } 76 | val := atomic.AddInt32(&b.count, 1) 77 | if val%int32(numConns) == 0 { 78 | b.doneCh <- struct{}{} 79 | } 80 | case <-b.closeCh: 81 | return 82 | } 83 | } 84 | }(conns[i]) 85 | } 86 | b.conns = conns 87 | } 88 | 89 | func (b *broadcastBench) close() { 90 | close(b.closeCh) 91 | } 92 | 93 | func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { 94 | for _, c := range b.conns { 95 | c.msgCh <- msg 96 | } 97 | <-b.doneCh 98 | } 99 | 100 | func BenchmarkBroadcast(b *testing.B) { 101 | benchmarks := []struct { 102 | name string 103 | usePrepared bool 104 | compression bool 105 | }{ 106 | {"NoCompression", false, false}, 107 | {"Compression", false, true}, 108 | {"NoCompressionPrepared", true, false}, 109 | {"CompressionPrepared", true, true}, 110 | } 111 | payload := textMessages(1)[0] 112 | for _, bm := range benchmarks { 113 | b.Run(bm.name, func(b *testing.B) { 114 | bench := newBroadcastBench(bm.usePrepared, bm.compression) 115 | defer bench.close() 116 | b.ResetTimer() 117 | for i := 0; i < b.N; i++ { 118 | message := &broadcastMessage{ 119 | payload: payload, 120 | } 121 | if bench.usePrepared { 122 | pm, _ := NewPreparedMessage(TextMessage, message.payload) 123 | message.prepared = pm 124 | } 125 | bench.broadcastOnce(message) 126 | } 127 | b.ReportAllocs() 128 | }) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /conn_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "net" 14 | "reflect" 15 | "sync" 16 | "testing" 17 | "testing/iotest" 18 | "time" 19 | ) 20 | 21 | var _ net.Error = errWriteTimeout 22 | 23 | type fakeNetConn struct { 24 | io.Reader 25 | io.Writer 26 | } 27 | 28 | func (c fakeNetConn) Close() error { return nil } 29 | func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } 30 | func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } 31 | func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } 32 | func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } 33 | func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } 34 | 35 | type fakeAddr int 36 | 37 | var ( 38 | localAddr = fakeAddr(1) 39 | remoteAddr = fakeAddr(2) 40 | ) 41 | 42 | func (a fakeAddr) Network() string { 43 | return "net" 44 | } 45 | 46 | func (a fakeAddr) String() string { 47 | return "str" 48 | } 49 | 50 | // newTestConn creates a connection backed by a fake network connection using 51 | // default values for buffering. 52 | func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { 53 | return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) 54 | } 55 | 56 | func TestFraming(t *testing.T) { 57 | frameSizes := []int{ 58 | 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 59 | // 65536, 65537 60 | } 61 | var readChunkers = []struct { 62 | name string 63 | f func(io.Reader) io.Reader 64 | }{ 65 | {"half", iotest.HalfReader}, 66 | {"one", iotest.OneByteReader}, 67 | {"asis", func(r io.Reader) io.Reader { return r }}, 68 | } 69 | writeBuf := make([]byte, 65537) 70 | for i := range writeBuf { 71 | writeBuf[i] = byte(i) 72 | } 73 | var writers = []struct { 74 | name string 75 | f func(w io.Writer, n int) (int, error) 76 | }{ 77 | {"iocopy", func(w io.Writer, n int) (int, error) { 78 | nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) 79 | return int(nn), err 80 | }}, 81 | {"write", func(w io.Writer, n int) (int, error) { 82 | return w.Write(writeBuf[:n]) 83 | }}, 84 | {"string", func(w io.Writer, n int) (int, error) { 85 | return io.WriteString(w, string(writeBuf[:n])) 86 | }}, 87 | } 88 | 89 | for _, compress := range []bool{false, true} { 90 | for _, isServer := range []bool{true, false} { 91 | for _, chunker := range readChunkers { 92 | 93 | var connBuf bytes.Buffer 94 | wc := newTestConn(nil, &connBuf, isServer) 95 | rc := newTestConn(chunker.f(&connBuf), nil, !isServer) 96 | if compress { 97 | wc.newCompressionWriter = compressNoContextTakeover 98 | rc.newDecompressionReader = decompressNoContextTakeover 99 | } 100 | for _, n := range frameSizes { 101 | for _, writer := range writers { 102 | name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) 103 | 104 | w, err := wc.NextWriter(TextMessage) 105 | if err != nil { 106 | t.Errorf("%s: wc.NextWriter() returned %v", name, err) 107 | continue 108 | } 109 | nn, err := writer.f(w, n) 110 | if err != nil || nn != n { 111 | t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) 112 | continue 113 | } 114 | err = w.Close() 115 | if err != nil { 116 | t.Errorf("%s: w.Close() returned %v", name, err) 117 | continue 118 | } 119 | 120 | opCode, r, err := rc.NextReader() 121 | if err != nil || opCode != TextMessage { 122 | t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) 123 | continue 124 | } 125 | 126 | t.Logf("frame size: %d", n) 127 | rbuf, err := io.ReadAll(r) 128 | if err != nil { 129 | t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) 130 | continue 131 | } 132 | 133 | if len(rbuf) != n { 134 | t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) 135 | continue 136 | } 137 | 138 | for i, b := range rbuf { 139 | if byte(i) != b { 140 | t.Errorf("%s: bad byte at offset %d", name, i) 141 | break 142 | } 143 | } 144 | } 145 | } 146 | } 147 | } 148 | } 149 | } 150 | 151 | func TestWriteControlDeadline(t *testing.T) { 152 | t.Parallel() 153 | message := []byte("hello") 154 | var connBuf bytes.Buffer 155 | c := newTestConn(nil, &connBuf, true) 156 | if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil { 157 | t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err) 158 | } 159 | if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil { 160 | t.Errorf("WriteControl(..., future deadline) = %v, want nil", err) 161 | } 162 | if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil { 163 | t.Errorf("WriteControl(..., past deadline) = nil, want timeout error") 164 | } 165 | } 166 | 167 | func TestConcurrencyWriteControl(t *testing.T) { 168 | const message = "this is a ping/pong messsage" 169 | loop := 10 170 | workers := 10 171 | for i := 0; i < loop; i++ { 172 | var connBuf bytes.Buffer 173 | 174 | wg := sync.WaitGroup{} 175 | wc := newTestConn(nil, &connBuf, true) 176 | 177 | for i := 0; i < workers; i++ { 178 | wg.Add(1) 179 | go func() { 180 | defer wg.Done() 181 | if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil { 182 | t.Errorf("concurrently wc.WriteControl() returned %v", err) 183 | } 184 | }() 185 | } 186 | 187 | wg.Wait() 188 | wc.Close() 189 | } 190 | } 191 | 192 | func TestControl(t *testing.T) { 193 | const message = "this is a ping/pong message" 194 | for _, isServer := range []bool{true, false} { 195 | for _, isWriteControl := range []bool{true, false} { 196 | name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) 197 | var connBuf bytes.Buffer 198 | wc := newTestConn(nil, &connBuf, isServer) 199 | rc := newTestConn(&connBuf, nil, !isServer) 200 | if isWriteControl { 201 | _ = wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) 202 | } else { 203 | w, err := wc.NextWriter(PongMessage) 204 | if err != nil { 205 | t.Errorf("%s: wc.NextWriter() returned %v", name, err) 206 | continue 207 | } 208 | if _, err := w.Write([]byte(message)); err != nil { 209 | t.Errorf("%s: w.Write() returned %v", name, err) 210 | continue 211 | } 212 | if err := w.Close(); err != nil { 213 | t.Errorf("%s: w.Close() returned %v", name, err) 214 | continue 215 | } 216 | var actualMessage string 217 | rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) 218 | _, _, _ = rc.NextReader() 219 | if actualMessage != message { 220 | t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) 221 | continue 222 | } 223 | } 224 | } 225 | } 226 | } 227 | 228 | // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. 229 | type simpleBufferPool struct { 230 | v interface{} 231 | } 232 | 233 | func (p *simpleBufferPool) Get() interface{} { 234 | v := p.v 235 | p.v = nil 236 | return v 237 | } 238 | 239 | func (p *simpleBufferPool) Put(v interface{}) { 240 | p.v = v 241 | } 242 | 243 | func TestWriteBufferPool(t *testing.T) { 244 | const message = "Now is the time for all good people to come to the aid of the party." 245 | 246 | var buf bytes.Buffer 247 | var pool simpleBufferPool 248 | rc := newTestConn(&buf, nil, false) 249 | 250 | // Specify writeBufferSize smaller than message size to ensure that pooling 251 | // works with fragmented messages. 252 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) 253 | 254 | if wc.writeBuf != nil { 255 | t.Fatal("writeBuf not nil after create") 256 | } 257 | 258 | // Part 1: test NextWriter/Write/Close 259 | 260 | w, err := wc.NextWriter(TextMessage) 261 | if err != nil { 262 | t.Fatalf("wc.NextWriter() returned %v", err) 263 | } 264 | 265 | if wc.writeBuf == nil { 266 | t.Fatal("writeBuf is nil after NextWriter") 267 | } 268 | 269 | writeBufAddr := &wc.writeBuf[0] 270 | 271 | if _, err := io.WriteString(w, message); err != nil { 272 | t.Fatalf("io.WriteString(w, message) returned %v", err) 273 | } 274 | 275 | if err := w.Close(); err != nil { 276 | t.Fatalf("w.Close() returned %v", err) 277 | } 278 | 279 | if wc.writeBuf != nil { 280 | t.Fatal("writeBuf not nil after w.Close()") 281 | } 282 | 283 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 284 | t.Fatal("writeBuf not returned to pool") 285 | } 286 | 287 | opCode, p, err := rc.ReadMessage() 288 | if opCode != TextMessage || err != nil { 289 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 290 | } 291 | 292 | if s := string(p); s != message { 293 | t.Fatalf("message is %s, want %s", s, message) 294 | } 295 | 296 | // Part 2: Test WriteMessage. 297 | 298 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 299 | t.Fatalf("wc.WriteMessage() returned %v", err) 300 | } 301 | 302 | if wc.writeBuf != nil { 303 | t.Fatal("writeBuf not nil after wc.WriteMessage()") 304 | } 305 | 306 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 307 | t.Fatal("writeBuf not returned to pool after WriteMessage") 308 | } 309 | 310 | opCode, p, err = rc.ReadMessage() 311 | if opCode != TextMessage || err != nil { 312 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 313 | } 314 | 315 | if s := string(p); s != message { 316 | t.Fatalf("message is %s, want %s", s, message) 317 | } 318 | } 319 | 320 | // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. 321 | func TestWriteBufferPoolSync(t *testing.T) { 322 | var buf bytes.Buffer 323 | var pool sync.Pool 324 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) 325 | rc := newTestConn(&buf, nil, false) 326 | 327 | const message = "Hello World!" 328 | for i := 0; i < 3; i++ { 329 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { 330 | t.Fatalf("wc.WriteMessage() returned %v", err) 331 | } 332 | opCode, p, err := rc.ReadMessage() 333 | if opCode != TextMessage || err != nil { 334 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) 335 | } 336 | if s := string(p); s != message { 337 | t.Fatalf("message is %s, want %s", s, message) 338 | } 339 | } 340 | } 341 | 342 | // errorWriter is an io.Writer than returns an error on all writes. 343 | type errorWriter struct{} 344 | 345 | func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } 346 | 347 | // TestWriteBufferPoolError ensures that buffer is returned to pool after error 348 | // on write. 349 | func TestWriteBufferPoolError(t *testing.T) { 350 | 351 | // Part 1: Test NextWriter/Write/Close 352 | 353 | var pool simpleBufferPool 354 | wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 355 | 356 | w, err := wc.NextWriter(TextMessage) 357 | if err != nil { 358 | t.Fatalf("wc.NextWriter() returned %v", err) 359 | } 360 | 361 | if wc.writeBuf == nil { 362 | t.Fatal("writeBuf is nil after NextWriter") 363 | } 364 | 365 | writeBufAddr := &wc.writeBuf[0] 366 | 367 | if _, err := io.WriteString(w, "Hello"); err != nil { 368 | t.Fatalf("io.WriteString(w, message) returned %v", err) 369 | } 370 | 371 | if err := w.Close(); err == nil { 372 | t.Fatalf("w.Close() did not return error") 373 | } 374 | 375 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 376 | t.Fatal("writeBuf not returned to pool") 377 | } 378 | 379 | // Part 2: Test WriteMessage 380 | 381 | wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) 382 | 383 | if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { 384 | t.Fatalf("wc.WriteMessage did not return error") 385 | } 386 | 387 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { 388 | t.Fatal("writeBuf not returned to pool") 389 | } 390 | } 391 | 392 | func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { 393 | const bufSize = 512 394 | 395 | expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} 396 | 397 | var b1, b2 bytes.Buffer 398 | wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 399 | rc := newTestConn(&b1, &b2, true) 400 | 401 | w, _ := wc.NextWriter(BinaryMessage) 402 | _, _ = w.Write(make([]byte, bufSize+bufSize/2)) 403 | _ = wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) 404 | w.Close() 405 | 406 | op, r, err := rc.NextReader() 407 | if op != BinaryMessage || err != nil { 408 | t.Fatalf("NextReader() returned %d, %v", op, err) 409 | } 410 | _, err = io.Copy(io.Discard, r) 411 | if !reflect.DeepEqual(err, expectedErr) { 412 | t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) 413 | } 414 | _, _, err = rc.NextReader() 415 | if !reflect.DeepEqual(err, expectedErr) { 416 | t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) 417 | } 418 | } 419 | 420 | func TestEOFWithinFrame(t *testing.T) { 421 | const bufSize = 64 422 | 423 | for n := 0; ; n++ { 424 | var b bytes.Buffer 425 | wc := newTestConn(nil, &b, false) 426 | rc := newTestConn(&b, nil, true) 427 | 428 | w, _ := wc.NextWriter(BinaryMessage) 429 | _, _ = w.Write(make([]byte, bufSize)) 430 | w.Close() 431 | 432 | if n >= b.Len() { 433 | break 434 | } 435 | b.Truncate(n) 436 | 437 | op, r, err := rc.NextReader() 438 | if err == errUnexpectedEOF { 439 | continue 440 | } 441 | if op != BinaryMessage || err != nil { 442 | t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) 443 | } 444 | _, err = io.Copy(io.Discard, r) 445 | if err != errUnexpectedEOF { 446 | t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) 447 | } 448 | _, _, err = rc.NextReader() 449 | if err != errUnexpectedEOF { 450 | t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) 451 | } 452 | } 453 | } 454 | 455 | func TestEOFBeforeFinalFrame(t *testing.T) { 456 | const bufSize = 512 457 | 458 | var b1, b2 bytes.Buffer 459 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) 460 | rc := newTestConn(&b1, &b2, true) 461 | 462 | w, _ := wc.NextWriter(BinaryMessage) 463 | _, _ = w.Write(make([]byte, bufSize+bufSize/2)) 464 | 465 | op, r, err := rc.NextReader() 466 | if op != BinaryMessage || err != nil { 467 | t.Fatalf("NextReader() returned %d, %v", op, err) 468 | } 469 | _, err = io.Copy(io.Discard, r) 470 | if err != errUnexpectedEOF { 471 | t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) 472 | } 473 | _, _, err = rc.NextReader() 474 | if err != errUnexpectedEOF { 475 | t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) 476 | } 477 | } 478 | 479 | func TestWriteAfterMessageWriterClose(t *testing.T) { 480 | wc := newTestConn(nil, &bytes.Buffer{}, false) 481 | w, _ := wc.NextWriter(BinaryMessage) 482 | _, _ = io.WriteString(w, "hello") 483 | if err := w.Close(); err != nil { 484 | t.Fatalf("unexpected error closing message writer, %v", err) 485 | } 486 | 487 | if _, err := io.WriteString(w, "world"); err == nil { 488 | t.Fatalf("no error writing after close") 489 | } 490 | 491 | w, _ = wc.NextWriter(BinaryMessage) 492 | _, _ = io.WriteString(w, "hello") 493 | 494 | // close w by getting next writer 495 | _, err := wc.NextWriter(BinaryMessage) 496 | if err != nil { 497 | t.Fatalf("unexpected error getting next writer, %v", err) 498 | } 499 | 500 | if _, err := io.WriteString(w, "world"); err == nil { 501 | t.Fatalf("no error writing after close") 502 | } 503 | } 504 | 505 | func TestReadLimit(t *testing.T) { 506 | t.Run("Test ReadLimit is enforced", func(t *testing.T) { 507 | const readLimit = 512 508 | message := make([]byte, readLimit+1) 509 | 510 | var b1, b2 bytes.Buffer 511 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) 512 | rc := newTestConn(&b1, &b2, true) 513 | rc.SetReadLimit(readLimit) 514 | 515 | // Send message at the limit with interleaved pong. 516 | w, _ := wc.NextWriter(BinaryMessage) 517 | _, _ = w.Write(message[:readLimit-1]) 518 | _ = wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) 519 | _, _ = w.Write(message[:1]) 520 | w.Close() 521 | 522 | // Send message larger than the limit. 523 | _ = wc.WriteMessage(BinaryMessage, message[:readLimit+1]) 524 | 525 | op, _, err := rc.NextReader() 526 | if op != BinaryMessage || err != nil { 527 | t.Fatalf("1: NextReader() returned %d, %v", op, err) 528 | } 529 | op, r, err := rc.NextReader() 530 | if op != BinaryMessage || err != nil { 531 | t.Fatalf("2: NextReader() returned %d, %v", op, err) 532 | } 533 | _, err = io.Copy(io.Discard, r) 534 | if err != ErrReadLimit { 535 | t.Fatalf("io.Copy() returned %v", err) 536 | } 537 | }) 538 | 539 | t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { 540 | const readLimit = 1 541 | 542 | var b1, b2 bytes.Buffer 543 | rc := newTestConn(&b1, &b2, true) 544 | rc.SetReadLimit(readLimit) 545 | 546 | // First, send a non-final binary message 547 | b1.Write([]byte("\x02\x81")) 548 | 549 | // Mask key 550 | b1.Write([]byte("\x00\x00\x00\x00")) 551 | 552 | // First payload 553 | b1.Write([]byte("A")) 554 | 555 | // Next, send a negative-length, non-final continuation frame 556 | b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) 557 | 558 | // Mask key 559 | b1.Write([]byte("\x00\x00\x00\x00")) 560 | 561 | // Next, send a too long, final continuation frame 562 | b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) 563 | 564 | // Mask key 565 | b1.Write([]byte("\x00\x00\x00\x00")) 566 | 567 | // Too-long payload 568 | b1.Write([]byte("BCDEF")) 569 | 570 | op, r, err := rc.NextReader() 571 | if op != BinaryMessage || err != nil { 572 | t.Fatalf("1: NextReader() returned %d, %v", op, err) 573 | } 574 | 575 | var buf [10]byte 576 | var read int 577 | n, err := r.Read(buf[:]) 578 | if err != nil && err != ErrReadLimit { 579 | t.Fatalf("unexpected error testing read limit: %v", err) 580 | } 581 | read += n 582 | 583 | n, err = r.Read(buf[:]) 584 | if err != nil && err != ErrReadLimit { 585 | t.Fatalf("unexpected error testing read limit: %v", err) 586 | } 587 | read += n 588 | 589 | if err == nil && read > readLimit { 590 | t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) 591 | } 592 | }) 593 | } 594 | 595 | func TestAddrs(t *testing.T) { 596 | c := newTestConn(nil, nil, true) 597 | if c.LocalAddr() != localAddr { 598 | t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) 599 | } 600 | if c.RemoteAddr() != remoteAddr { 601 | t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) 602 | } 603 | } 604 | 605 | func TestDeprecatedUnderlyingConn(t *testing.T) { 606 | var b1, b2 bytes.Buffer 607 | fc := fakeNetConn{Reader: &b1, Writer: &b2} 608 | c := newConn(fc, true, 1024, 1024, nil, nil, nil) 609 | ul := c.UnderlyingConn() 610 | if ul != fc { 611 | t.Fatalf("Underlying conn is not what it should be.") 612 | } 613 | } 614 | 615 | func TestNetConn(t *testing.T) { 616 | var b1, b2 bytes.Buffer 617 | fc := fakeNetConn{Reader: &b1, Writer: &b2} 618 | c := newConn(fc, true, 1024, 1024, nil, nil, nil) 619 | ul := c.NetConn() 620 | if ul != fc { 621 | t.Fatalf("Underlying conn is not what it should be.") 622 | } 623 | } 624 | 625 | func TestBufioReadBytes(t *testing.T) { 626 | // Test calling bufio.ReadBytes for value longer than read buffer size. 627 | 628 | m := make([]byte, 512) 629 | m[len(m)-1] = '\n' 630 | 631 | var b1, b2 bytes.Buffer 632 | wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) 633 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) 634 | 635 | w, _ := wc.NextWriter(BinaryMessage) 636 | _, _ = w.Write(m) 637 | w.Close() 638 | 639 | op, r, err := rc.NextReader() 640 | if op != BinaryMessage || err != nil { 641 | t.Fatalf("NextReader() returned %d, %v", op, err) 642 | } 643 | 644 | br := bufio.NewReader(r) 645 | p, err := br.ReadBytes('\n') 646 | if err != nil { 647 | t.Fatalf("ReadBytes() returned %v", err) 648 | } 649 | if len(p) != len(m) { 650 | t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) 651 | } 652 | } 653 | 654 | var closeErrorTests = []struct { 655 | err error 656 | codes []int 657 | ok bool 658 | }{ 659 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, 660 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, 661 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, 662 | {errors.New("hello"), []int{CloseNormalClosure}, false}, 663 | } 664 | 665 | func TestCloseError(t *testing.T) { 666 | for _, tt := range closeErrorTests { 667 | ok := IsCloseError(tt.err, tt.codes...) 668 | if ok != tt.ok { 669 | t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 670 | } 671 | } 672 | } 673 | 674 | var unexpectedCloseErrorTests = []struct { 675 | err error 676 | codes []int 677 | ok bool 678 | }{ 679 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, 680 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, 681 | {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, 682 | {errors.New("hello"), []int{CloseNormalClosure}, false}, 683 | } 684 | 685 | func TestUnexpectedCloseErrors(t *testing.T) { 686 | for _, tt := range unexpectedCloseErrorTests { 687 | ok := IsUnexpectedCloseError(tt.err, tt.codes...) 688 | if ok != tt.ok { 689 | t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) 690 | } 691 | } 692 | } 693 | 694 | type blockingWriter struct { 695 | c1, c2 chan struct{} 696 | } 697 | 698 | func (w blockingWriter) Write(p []byte) (int, error) { 699 | // Allow main to continue 700 | close(w.c1) 701 | // Wait for panic in main 702 | <-w.c2 703 | return len(p), nil 704 | } 705 | 706 | func TestConcurrentWritePanic(t *testing.T) { 707 | w := blockingWriter{make(chan struct{}), make(chan struct{})} 708 | c := newTestConn(nil, w, false) 709 | go func() { 710 | _ = c.WriteMessage(TextMessage, []byte{}) 711 | }() 712 | 713 | // wait for goroutine to block in write. 714 | <-w.c1 715 | 716 | defer func() { 717 | close(w.c2) 718 | if v := recover(); v != nil { 719 | return 720 | } 721 | }() 722 | 723 | _ = c.WriteMessage(TextMessage, []byte{}) 724 | t.Fatal("should not get here") 725 | } 726 | 727 | type failingReader struct{} 728 | 729 | func (r failingReader) Read(p []byte) (int, error) { 730 | return 0, io.EOF 731 | } 732 | 733 | func TestFailedConnectionReadPanic(t *testing.T) { 734 | c := newTestConn(failingReader{}, nil, false) 735 | 736 | defer func() { 737 | if v := recover(); v != nil { 738 | return 739 | } 740 | }() 741 | 742 | for i := 0; i < 20000; i++ { 743 | _, _, _ = c.ReadMessage() 744 | } 745 | t.Fatal("should not get here") 746 | } 747 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package websocket implements the WebSocket protocol defined in RFC 6455. 6 | // 7 | // Overview 8 | // 9 | // The Conn type represents a WebSocket connection. A server application calls 10 | // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: 11 | // 12 | // var upgrader = websocket.Upgrader{ 13 | // ReadBufferSize: 1024, 14 | // WriteBufferSize: 1024, 15 | // } 16 | // 17 | // func handler(w http.ResponseWriter, r *http.Request) { 18 | // conn, err := upgrader.Upgrade(w, r, nil) 19 | // if err != nil { 20 | // log.Println(err) 21 | // return 22 | // } 23 | // ... Use conn to send and receive messages. 24 | // } 25 | // 26 | // Call the connection's WriteMessage and ReadMessage methods to send and 27 | // receive messages as a slice of bytes. This snippet of code shows how to echo 28 | // messages using these methods: 29 | // 30 | // for { 31 | // messageType, p, err := conn.ReadMessage() 32 | // if err != nil { 33 | // log.Println(err) 34 | // return 35 | // } 36 | // if err := conn.WriteMessage(messageType, p); err != nil { 37 | // log.Println(err) 38 | // return 39 | // } 40 | // } 41 | // 42 | // In above snippet of code, p is a []byte and messageType is an int with value 43 | // websocket.BinaryMessage or websocket.TextMessage. 44 | // 45 | // An application can also send and receive messages using the io.WriteCloser 46 | // and io.Reader interfaces. To send a message, call the connection NextWriter 47 | // method to get an io.WriteCloser, write the message to the writer and close 48 | // the writer when done. To receive a message, call the connection NextReader 49 | // method to get an io.Reader and read until io.EOF is returned. This snippet 50 | // shows how to echo messages using the NextWriter and NextReader methods: 51 | // 52 | // for { 53 | // messageType, r, err := conn.NextReader() 54 | // if err != nil { 55 | // return 56 | // } 57 | // w, err := conn.NextWriter(messageType) 58 | // if err != nil { 59 | // return err 60 | // } 61 | // if _, err := io.Copy(w, r); err != nil { 62 | // return err 63 | // } 64 | // if err := w.Close(); err != nil { 65 | // return err 66 | // } 67 | // } 68 | // 69 | // Data Messages 70 | // 71 | // The WebSocket protocol distinguishes between text and binary data messages. 72 | // Text messages are interpreted as UTF-8 encoded text. The interpretation of 73 | // binary messages is left to the application. 74 | // 75 | // This package uses the TextMessage and BinaryMessage integer constants to 76 | // identify the two data message types. The ReadMessage and NextReader methods 77 | // return the type of the received message. The messageType argument to the 78 | // WriteMessage and NextWriter methods specifies the type of a sent message. 79 | // 80 | // It is the application's responsibility to ensure that text messages are 81 | // valid UTF-8 encoded text. 82 | // 83 | // Control Messages 84 | // 85 | // The WebSocket protocol defines three types of control messages: close, ping 86 | // and pong. Call the connection WriteControl, WriteMessage or NextWriter 87 | // methods to send a control message to the peer. 88 | // 89 | // Connections handle received close messages by calling the handler function 90 | // set with the SetCloseHandler method and by returning a *CloseError from the 91 | // NextReader, ReadMessage or the message Read method. The default close 92 | // handler sends a close message to the peer. 93 | // 94 | // Connections handle received ping messages by calling the handler function 95 | // set with the SetPingHandler method. The default ping handler sends a pong 96 | // message to the peer. 97 | // 98 | // Connections handle received pong messages by calling the handler function 99 | // set with the SetPongHandler method. The default pong handler does nothing. 100 | // If an application sends ping messages, then the application should set a 101 | // pong handler to receive the corresponding pong. 102 | // 103 | // The control message handler functions are called from the NextReader, 104 | // ReadMessage and message reader Read methods. The default close and ping 105 | // handlers can block these methods for a short time when the handler writes to 106 | // the connection. 107 | // 108 | // The application must read the connection to process close, ping and pong 109 | // messages sent from the peer. If the application is not otherwise interested 110 | // in messages from the peer, then the application should start a goroutine to 111 | // read and discard messages from the peer. A simple example is: 112 | // 113 | // func readLoop(c *websocket.Conn) { 114 | // for { 115 | // if _, _, err := c.NextReader(); err != nil { 116 | // c.Close() 117 | // break 118 | // } 119 | // } 120 | // } 121 | // 122 | // Concurrency 123 | // 124 | // Connections support one concurrent reader and one concurrent writer. 125 | // 126 | // Applications are responsible for ensuring that no more than one goroutine 127 | // calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, 128 | // WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and 129 | // that no more than one goroutine calls the read methods (NextReader, 130 | // SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) 131 | // concurrently. 132 | // 133 | // The Close and WriteControl methods can be called concurrently with all other 134 | // methods. 135 | // 136 | // Origin Considerations 137 | // 138 | // Web browsers allow Javascript applications to open a WebSocket connection to 139 | // any host. It's up to the server to enforce an origin policy using the Origin 140 | // request header sent by the browser. 141 | // 142 | // The Upgrader calls the function specified in the CheckOrigin field to check 143 | // the origin. If the CheckOrigin function returns false, then the Upgrade 144 | // method fails the WebSocket handshake with HTTP status 403. 145 | // 146 | // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail 147 | // the handshake if the Origin request header is present and the Origin host is 148 | // not equal to the Host request header. 149 | // 150 | // The deprecated package-level Upgrade function does not perform origin 151 | // checking. The application is responsible for checking the Origin header 152 | // before calling the Upgrade function. 153 | // 154 | // Buffers 155 | // 156 | // Connections buffer network input and output to reduce the number 157 | // of system calls when reading or writing messages. 158 | // 159 | // Write buffers are also used for constructing WebSocket frames. See RFC 6455, 160 | // Section 5 for a discussion of message framing. A WebSocket frame header is 161 | // written to the network each time a write buffer is flushed to the network. 162 | // Decreasing the size of the write buffer can increase the amount of framing 163 | // overhead on the connection. 164 | // 165 | // The buffer sizes in bytes are specified by the ReadBufferSize and 166 | // WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default 167 | // size of 4096 when a buffer size field is set to zero. The Upgrader reuses 168 | // buffers created by the HTTP server when a buffer size field is set to zero. 169 | // The HTTP server buffers have a size of 4096 at the time of this writing. 170 | // 171 | // The buffer sizes do not limit the size of a message that can be read or 172 | // written by a connection. 173 | // 174 | // Buffers are held for the lifetime of the connection by default. If the 175 | // Dialer or Upgrader WriteBufferPool field is set, then a connection holds the 176 | // write buffer only when writing a message. 177 | // 178 | // Applications should tune the buffer sizes to balance memory use and 179 | // performance. Increasing the buffer size uses more memory, but can reduce the 180 | // number of system calls to read or write the network. In the case of writing, 181 | // increasing the buffer size can reduce the number of frame headers written to 182 | // the network. 183 | // 184 | // Some guidelines for setting buffer parameters are: 185 | // 186 | // Limit the buffer sizes to the maximum expected message size. Buffers larger 187 | // than the largest message do not provide any benefit. 188 | // 189 | // Depending on the distribution of message sizes, setting the buffer size to 190 | // a value less than the maximum expected message size can greatly reduce memory 191 | // use with a small impact on performance. Here's an example: If 99% of the 192 | // messages are smaller than 256 bytes and the maximum message size is 512 193 | // bytes, then a buffer size of 256 bytes will result in 1.01 more system calls 194 | // than a buffer size of 512 bytes. The memory savings is 50%. 195 | // 196 | // A write buffer pool is useful when the application has a modest number 197 | // writes over a large number of connections. when buffers are pooled, a larger 198 | // buffer size has a reduced impact on total memory use and has the benefit of 199 | // reducing system calls and frame overhead. 200 | // 201 | // Compression EXPERIMENTAL 202 | // 203 | // Per message compression extensions (RFC 7692) are experimentally supported 204 | // by this package in a limited capacity. Setting the EnableCompression option 205 | // to true in Dialer or Upgrader will attempt to negotiate per message deflate 206 | // support. 207 | // 208 | // var upgrader = websocket.Upgrader{ 209 | // EnableCompression: true, 210 | // } 211 | // 212 | // If compression was successfully negotiated with the connection's peer, any 213 | // message received in compressed form will be automatically decompressed. 214 | // All Read methods will return uncompressed bytes. 215 | // 216 | // Per message compression of messages written to a connection can be enabled 217 | // or disabled by calling the corresponding Conn method: 218 | // 219 | // conn.EnableWriteCompression(false) 220 | // 221 | // Currently this package does not support compression with "context takeover". 222 | // This means that messages must be compressed and decompressed in isolation, 223 | // without retaining sliding window or dictionary state across messages. For 224 | // more details refer to RFC 7692. 225 | // 226 | // Use of compression is experimental and may result in decreased performance. 227 | package websocket 228 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket_test 6 | 7 | import ( 8 | "log" 9 | "net/http" 10 | "testing" 11 | 12 | "github.com/gorilla/websocket" 13 | ) 14 | 15 | var ( 16 | c *websocket.Conn 17 | req *http.Request 18 | ) 19 | 20 | // The websocket.IsUnexpectedCloseError function is useful for identifying 21 | // application and protocol errors. 22 | // 23 | // This server application works with a client application running in the 24 | // browser. The client application does not explicitly close the websocket. The 25 | // only expected close message from the client has the code 26 | // websocket.CloseGoingAway. All other close messages are likely the 27 | // result of an application or protocol error and are logged to aid debugging. 28 | func ExampleIsUnexpectedCloseError() { 29 | for { 30 | messageType, p, err := c.ReadMessage() 31 | if err != nil { 32 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { 33 | log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent")) 34 | } 35 | return 36 | } 37 | processMessage(messageType, p) 38 | } 39 | } 40 | 41 | func processMessage(mt int, p []byte) {} 42 | 43 | // TestX prevents godoc from showing this entire file in the example. Remove 44 | // this function when a second example is added. 45 | func TestX(t *testing.T) {} 46 | -------------------------------------------------------------------------------- /examples/autobahn/README.md: -------------------------------------------------------------------------------- 1 | # Test Server 2 | 3 | This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite). 4 | 5 | To test the server, run 6 | 7 | go run server.go 8 | 9 | and start the client test driver 10 | 11 | mkdir -p reports 12 | docker run -it --rm \ 13 | -v ${PWD}/config:/config \ 14 | -v ${PWD}/reports:/reports \ 15 | crossbario/autobahn-testsuite \ 16 | wstest -m fuzzingclient -s /config/fuzzingclient.json 17 | 18 | When the client completes, it writes a report to reports/index.html. 19 | -------------------------------------------------------------------------------- /examples/autobahn/config/fuzzingclient.json: -------------------------------------------------------------------------------- 1 | { 2 | "cases": ["*"], 3 | "exclude-cases": [], 4 | "exclude-agent-cases": {}, 5 | "outdir": "/reports", 6 | "options": {"failByDrop": false}, 7 | "servers": [ 8 | { 9 | "agent": "ReadAllWriteMessage", 10 | "url": "ws://host.docker.internal:9000/m" 11 | }, 12 | { 13 | "agent": "ReadAllWritePreparedMessage", 14 | "url": "ws://host.docker.internal:9000/p" 15 | }, 16 | { 17 | "agent": "CopyFull", 18 | "url": "ws://host.docker.internal:9000/f" 19 | }, 20 | { 21 | "agent": "ReadAllWrite", 22 | "url": "ws://host.docker.internal:9000/r" 23 | }, 24 | { 25 | "agent": "CopyWriterOnly", 26 | "url": "ws://host.docker.internal:9000/c" 27 | } 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /examples/autobahn/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Command server is a test server for the Autobahn WebSockets Test Suite. 6 | package main 7 | 8 | import ( 9 | "errors" 10 | "flag" 11 | "io" 12 | "log" 13 | "net/http" 14 | "time" 15 | "unicode/utf8" 16 | 17 | "github.com/gorilla/websocket" 18 | ) 19 | 20 | var upgrader = websocket.Upgrader{ 21 | ReadBufferSize: 4096, 22 | WriteBufferSize: 4096, 23 | EnableCompression: true, 24 | CheckOrigin: func(r *http.Request) bool { 25 | return true 26 | }, 27 | } 28 | 29 | // echoCopy echoes messages from the client using io.Copy. 30 | func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) { 31 | conn, err := upgrader.Upgrade(w, r, nil) 32 | if err != nil { 33 | log.Println("Upgrade:", err) 34 | return 35 | } 36 | defer conn.Close() 37 | for { 38 | mt, r, err := conn.NextReader() 39 | if err != nil { 40 | if err != io.EOF { 41 | log.Println("NextReader:", err) 42 | } 43 | return 44 | } 45 | if mt == websocket.TextMessage { 46 | r = &validator{r: r} 47 | } 48 | w, err := conn.NextWriter(mt) 49 | if err != nil { 50 | log.Println("NextWriter:", err) 51 | return 52 | } 53 | if mt == websocket.TextMessage { 54 | r = &validator{r: r} 55 | } 56 | if writerOnly { 57 | _, err = io.Copy(struct{ io.Writer }{w}, r) 58 | } else { 59 | _, err = io.Copy(w, r) 60 | } 61 | if err != nil { 62 | if err == errInvalidUTF8 { 63 | conn.WriteControl(websocket.CloseMessage, 64 | websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), 65 | time.Time{}) 66 | } 67 | log.Println("Copy:", err) 68 | return 69 | } 70 | err = w.Close() 71 | if err != nil { 72 | log.Println("Close:", err) 73 | return 74 | } 75 | } 76 | } 77 | 78 | func echoCopyWriterOnly(w http.ResponseWriter, r *http.Request) { 79 | echoCopy(w, r, true) 80 | } 81 | 82 | func echoCopyFull(w http.ResponseWriter, r *http.Request) { 83 | echoCopy(w, r, false) 84 | } 85 | 86 | // echoReadAll echoes messages from the client by reading the entire message 87 | // with io.ReadAll. 88 | func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { 89 | conn, err := upgrader.Upgrade(w, r, nil) 90 | if err != nil { 91 | log.Println("Upgrade:", err) 92 | return 93 | } 94 | defer conn.Close() 95 | for { 96 | mt, b, err := conn.ReadMessage() 97 | if err != nil { 98 | if err != io.EOF { 99 | log.Println("NextReader:", err) 100 | } 101 | return 102 | } 103 | if mt == websocket.TextMessage { 104 | if !utf8.Valid(b) { 105 | conn.WriteControl(websocket.CloseMessage, 106 | websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), 107 | time.Time{}) 108 | log.Println("ReadAll: invalid utf8") 109 | } 110 | } 111 | if writeMessage { 112 | if !writePrepared { 113 | err = conn.WriteMessage(mt, b) 114 | if err != nil { 115 | log.Println("WriteMessage:", err) 116 | } 117 | } else { 118 | pm, err := websocket.NewPreparedMessage(mt, b) 119 | if err != nil { 120 | log.Println("NewPreparedMessage:", err) 121 | return 122 | } 123 | err = conn.WritePreparedMessage(pm) 124 | if err != nil { 125 | log.Println("WritePreparedMessage:", err) 126 | } 127 | } 128 | } else { 129 | w, err := conn.NextWriter(mt) 130 | if err != nil { 131 | log.Println("NextWriter:", err) 132 | return 133 | } 134 | if _, err := w.Write(b); err != nil { 135 | log.Println("Writer:", err) 136 | return 137 | } 138 | if err := w.Close(); err != nil { 139 | log.Println("Close:", err) 140 | return 141 | } 142 | } 143 | } 144 | } 145 | 146 | func echoReadAllWriter(w http.ResponseWriter, r *http.Request) { 147 | echoReadAll(w, r, false, false) 148 | } 149 | 150 | func echoReadAllWriteMessage(w http.ResponseWriter, r *http.Request) { 151 | echoReadAll(w, r, true, false) 152 | } 153 | 154 | func echoReadAllWritePreparedMessage(w http.ResponseWriter, r *http.Request) { 155 | echoReadAll(w, r, true, true) 156 | } 157 | 158 | func serveHome(w http.ResponseWriter, r *http.Request) { 159 | if r.URL.Path != "/" { 160 | http.Error(w, "Not found.", http.StatusNotFound) 161 | return 162 | } 163 | if r.Method != http.MethodGet { 164 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 165 | return 166 | } 167 | w.Header().Set("Content-Type", "text/html; charset=utf-8") 168 | io.WriteString(w, "Echo Server") 169 | } 170 | 171 | var addr = flag.String("addr", ":9000", "http service address") 172 | 173 | func main() { 174 | flag.Parse() 175 | http.HandleFunc("/", serveHome) 176 | http.HandleFunc("/c", echoCopyWriterOnly) 177 | http.HandleFunc("/f", echoCopyFull) 178 | http.HandleFunc("/r", echoReadAllWriter) 179 | http.HandleFunc("/m", echoReadAllWriteMessage) 180 | http.HandleFunc("/p", echoReadAllWritePreparedMessage) 181 | err := http.ListenAndServe(*addr, nil) 182 | if err != nil { 183 | log.Fatal("ListenAndServe: ", err) 184 | } 185 | } 186 | 187 | type validator struct { 188 | state int 189 | x rune 190 | r io.Reader 191 | } 192 | 193 | var errInvalidUTF8 = errors.New("invalid utf8") 194 | 195 | func (r *validator) Read(p []byte) (int, error) { 196 | n, err := r.r.Read(p) 197 | state := r.state 198 | x := r.x 199 | for _, b := range p[:n] { 200 | state, x = decode(state, x, b) 201 | if state == utf8Reject { 202 | break 203 | } 204 | } 205 | r.state = state 206 | r.x = x 207 | if state == utf8Reject || (err == io.EOF && state != utf8Accept) { 208 | return n, errInvalidUTF8 209 | } 210 | return n, err 211 | } 212 | 213 | // UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ 214 | // 215 | // Copyright (c) 2008-2009 Bjoern Hoehrmann 216 | // 217 | // Permission is hereby granted, free of charge, to any person obtaining a copy 218 | // of this software and associated documentation files (the "Software"), to 219 | // deal in the Software without restriction, including without limitation the 220 | // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 221 | // sell copies of the Software, and to permit persons to whom the Software is 222 | // furnished to do so, subject to the following conditions: 223 | // 224 | // The above copyright notice and this permission notice shall be included in 225 | // all copies or substantial portions of the Software. 226 | // 227 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 228 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 229 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 230 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 231 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 232 | // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 233 | // IN THE SOFTWARE. 234 | var utf8d = [...]byte{ 235 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1f 236 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3f 237 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5f 238 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7f 239 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9f 240 | 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // a0..bf 241 | 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // c0..df 242 | 0xa, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // e0..ef 243 | 0xb, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // f0..ff 244 | 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 245 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 246 | 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 247 | 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 248 | 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // s7..s8 249 | } 250 | 251 | const ( 252 | utf8Accept = 0 253 | utf8Reject = 1 254 | ) 255 | 256 | func decode(state int, x rune, b byte) (int, rune) { 257 | t := utf8d[b] 258 | if state != utf8Accept { 259 | x = rune(b&0x3f) | (x << 6) 260 | } else { 261 | x = rune((0xff >> t) & b) 262 | } 263 | state = int(utf8d[256+state*16+int(t)]) 264 | return state, x 265 | } 266 | -------------------------------------------------------------------------------- /examples/chat/README.md: -------------------------------------------------------------------------------- 1 | # Chat Example 2 | 3 | This application shows how to use the 4 | [websocket](https://github.com/gorilla/websocket) package to implement a simple 5 | web chat application. 6 | 7 | ## Running the example 8 | 9 | The example requires a working Go development environment. The [Getting 10 | Started](http://golang.org/doc/install) page describes how to install the 11 | development environment. 12 | 13 | Once you have Go up and running, you can download, build and run the example 14 | using the following commands. 15 | 16 | $ go get github.com/gorilla/websocket 17 | $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/chat` 18 | $ go run *.go 19 | 20 | To use the chat example, open http://localhost:8080/ in your browser. 21 | 22 | ## Server 23 | 24 | The server application defines two types, `Client` and `Hub`. The server 25 | creates an instance of the `Client` type for each websocket connection. A 26 | `Client` acts as an intermediary between the websocket connection and a single 27 | instance of the `Hub` type. The `Hub` maintains a set of registered clients and 28 | broadcasts messages to the clients. 29 | 30 | The application runs one goroutine for the `Hub` and two goroutines for each 31 | `Client`. The goroutines communicate with each other using channels. The `Hub` 32 | has channels for registering clients, unregistering clients and broadcasting 33 | messages. A `Client` has a buffered channel of outbound messages. One of the 34 | client's goroutines reads messages from this channel and writes the messages to 35 | the websocket. The other client goroutine reads messages from the websocket and 36 | sends them to the hub. 37 | 38 | ### Hub 39 | 40 | The code for the `Hub` type is in 41 | [hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go). 42 | The application's `main` function starts the hub's `run` method as a goroutine. 43 | Clients send requests to the hub using the `register`, `unregister` and 44 | `broadcast` channels. 45 | 46 | The hub registers clients by adding the client pointer as a key in the 47 | `clients` map. The map value is always true. 48 | 49 | The unregister code is a little more complicated. In addition to deleting the 50 | client pointer from the `clients` map, the hub closes the clients's `send` 51 | channel to signal the client that no more messages will be sent to the client. 52 | 53 | The hub handles messages by looping over the registered clients and sending the 54 | message to the client's `send` channel. If the client's `send` buffer is full, 55 | then the hub assumes that the client is dead or stuck. In this case, the hub 56 | unregisters the client and closes the websocket. 57 | 58 | ### Client 59 | 60 | The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go). 61 | 62 | The `serveWs` function is registered by the application's `main` function as 63 | an HTTP handler. The handler upgrades the HTTP connection to the WebSocket 64 | protocol, creates a client, registers the client with the hub and schedules the 65 | client to be unregistered using a defer statement. 66 | 67 | Next, the HTTP handler starts the client's `writePump` method as a goroutine. 68 | This method transfers messages from the client's send channel to the websocket 69 | connection. The writer method exits when the channel is closed by the hub or 70 | there's an error writing to the websocket connection. 71 | 72 | Finally, the HTTP handler calls the client's `readPump` method. This method 73 | transfers inbound messages from the websocket to the hub. 74 | 75 | WebSocket connections [support one concurrent reader and one concurrent 76 | writer](https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency). The 77 | application ensures that these concurrency requirements are met by executing 78 | all reads from the `readPump` goroutine and all writes from the `writePump` 79 | goroutine. 80 | 81 | To improve efficiency under high load, the `writePump` function coalesces 82 | pending chat messages in the `send` channel to a single WebSocket message. This 83 | reduces the number of system calls and the amount of data sent over the 84 | network. 85 | 86 | ## Frontend 87 | 88 | The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html). 89 | 90 | On document load, the script checks for websocket functionality in the browser. 91 | If websocket functionality is available, then the script opens a connection to 92 | the server and registers a callback to handle messages from the server. The 93 | callback appends the message to the chat log using the appendLog function. 94 | 95 | To allow the user to manually scroll through the chat log without interruption 96 | from new messages, the `appendLog` function checks the scroll position before 97 | adding new content. If the chat log is scrolled to the bottom, then the 98 | function scrolls new content into view after adding the content. Otherwise, the 99 | scroll position is not changed. 100 | 101 | The form handler writes the user input to the websocket and clears the input 102 | field. 103 | -------------------------------------------------------------------------------- /examples/chat/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "log" 10 | "net/http" 11 | "time" 12 | 13 | "github.com/gorilla/websocket" 14 | ) 15 | 16 | const ( 17 | // Time allowed to write a message to the peer. 18 | writeWait = 10 * time.Second 19 | 20 | // Time allowed to read the next pong message from the peer. 21 | pongWait = 60 * time.Second 22 | 23 | // Send pings to peer with this period. Must be less than pongWait. 24 | pingPeriod = (pongWait * 9) / 10 25 | 26 | // Maximum message size allowed from peer. 27 | maxMessageSize = 512 28 | ) 29 | 30 | var ( 31 | newline = []byte{'\n'} 32 | space = []byte{' '} 33 | ) 34 | 35 | var upgrader = websocket.Upgrader{ 36 | ReadBufferSize: 1024, 37 | WriteBufferSize: 1024, 38 | } 39 | 40 | // Client is a middleman between the websocket connection and the hub. 41 | type Client struct { 42 | hub *Hub 43 | 44 | // The websocket connection. 45 | conn *websocket.Conn 46 | 47 | // Buffered channel of outbound messages. 48 | send chan []byte 49 | } 50 | 51 | // readPump pumps messages from the websocket connection to the hub. 52 | // 53 | // The application runs readPump in a per-connection goroutine. The application 54 | // ensures that there is at most one reader on a connection by executing all 55 | // reads from this goroutine. 56 | func (c *Client) readPump() { 57 | defer func() { 58 | c.hub.unregister <- c 59 | c.conn.Close() 60 | }() 61 | c.conn.SetReadLimit(maxMessageSize) 62 | c.conn.SetReadDeadline(time.Now().Add(pongWait)) 63 | c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 64 | for { 65 | _, message, err := c.conn.ReadMessage() 66 | if err != nil { 67 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 68 | log.Printf("error: %v", err) 69 | } 70 | break 71 | } 72 | message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) 73 | c.hub.broadcast <- message 74 | } 75 | } 76 | 77 | // writePump pumps messages from the hub to the websocket connection. 78 | // 79 | // A goroutine running writePump is started for each connection. The 80 | // application ensures that there is at most one writer to a connection by 81 | // executing all writes from this goroutine. 82 | func (c *Client) writePump() { 83 | ticker := time.NewTicker(pingPeriod) 84 | defer func() { 85 | ticker.Stop() 86 | c.conn.Close() 87 | }() 88 | for { 89 | select { 90 | case message, ok := <-c.send: 91 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 92 | if !ok { 93 | // The hub closed the channel. 94 | c.conn.WriteMessage(websocket.CloseMessage, []byte{}) 95 | return 96 | } 97 | 98 | w, err := c.conn.NextWriter(websocket.TextMessage) 99 | if err != nil { 100 | return 101 | } 102 | w.Write(message) 103 | 104 | // Add queued chat messages to the current websocket message. 105 | n := len(c.send) 106 | for i := 0; i < n; i++ { 107 | w.Write(newline) 108 | w.Write(<-c.send) 109 | } 110 | 111 | if err := w.Close(); err != nil { 112 | return 113 | } 114 | case <-ticker.C: 115 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 116 | if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { 117 | return 118 | } 119 | } 120 | } 121 | } 122 | 123 | // serveWs handles websocket requests from the peer. 124 | func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) { 125 | conn, err := upgrader.Upgrade(w, r, nil) 126 | if err != nil { 127 | log.Println(err) 128 | return 129 | } 130 | client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} 131 | client.hub.register <- client 132 | 133 | // Allow collection of memory referenced by the caller by doing all work in 134 | // new goroutines. 135 | go client.writePump() 136 | go client.readPump() 137 | } 138 | -------------------------------------------------------------------------------- /examples/chat/home.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Chat Example 5 | 53 | 90 | 91 | 92 |
93 |
94 | 95 | 96 |
97 | 98 | 99 | -------------------------------------------------------------------------------- /examples/chat/hub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | // Hub maintains the set of active clients and broadcasts messages to the 8 | // clients. 9 | type Hub struct { 10 | // Registered clients. 11 | clients map[*Client]bool 12 | 13 | // Inbound messages from the clients. 14 | broadcast chan []byte 15 | 16 | // Register requests from the clients. 17 | register chan *Client 18 | 19 | // Unregister requests from clients. 20 | unregister chan *Client 21 | } 22 | 23 | func newHub() *Hub { 24 | return &Hub{ 25 | broadcast: make(chan []byte), 26 | register: make(chan *Client), 27 | unregister: make(chan *Client), 28 | clients: make(map[*Client]bool), 29 | } 30 | } 31 | 32 | func (h *Hub) run() { 33 | for { 34 | select { 35 | case client := <-h.register: 36 | h.clients[client] = true 37 | case client := <-h.unregister: 38 | if _, ok := h.clients[client]; ok { 39 | delete(h.clients, client) 40 | close(client.send) 41 | } 42 | case message := <-h.broadcast: 43 | for client := range h.clients { 44 | select { 45 | case client.send <- message: 46 | default: 47 | close(client.send) 48 | delete(h.clients, client) 49 | } 50 | } 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /examples/chat/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "log" 10 | "net/http" 11 | ) 12 | 13 | var addr = flag.String("addr", ":8080", "http service address") 14 | 15 | func serveHome(w http.ResponseWriter, r *http.Request) { 16 | log.Println(r.URL) 17 | if r.URL.Path != "/" { 18 | http.Error(w, "Not found", http.StatusNotFound) 19 | return 20 | } 21 | if r.Method != http.MethodGet { 22 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 23 | return 24 | } 25 | http.ServeFile(w, r, "home.html") 26 | } 27 | 28 | func main() { 29 | flag.Parse() 30 | hub := newHub() 31 | go hub.run() 32 | http.HandleFunc("/", serveHome) 33 | http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { 34 | serveWs(hub, w, r) 35 | }) 36 | err := http.ListenAndServe(*addr, nil) 37 | if err != nil { 38 | log.Fatal("ListenAndServe: ", err) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /examples/command/README.md: -------------------------------------------------------------------------------- 1 | # Command example 2 | 3 | This example connects a websocket connection to stdin and stdout of a command. 4 | Received messages are written to stdin followed by a `\n`. Each line read from 5 | standard out is sent as a message to the client. 6 | 7 | $ go get github.com/gorilla/websocket 8 | $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/command` 9 | $ go run main.go 10 | # Open http://localhost:8080/ . 11 | 12 | Try the following commands. 13 | 14 | # Echo sent messages to the output area. 15 | $ go run main.go cat 16 | 17 | # Run a shell.Try sending "ls" and "cat main.go". 18 | $ go run main.go sh 19 | 20 | -------------------------------------------------------------------------------- /examples/command/home.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Command Example 5 | 53 | 94 | 95 | 96 |
97 |
98 | 99 | 100 |
101 | 102 | 103 | -------------------------------------------------------------------------------- /examples/command/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "bufio" 9 | "flag" 10 | "io" 11 | "log" 12 | "net/http" 13 | "os" 14 | "os/exec" 15 | "time" 16 | 17 | "github.com/gorilla/websocket" 18 | ) 19 | 20 | var ( 21 | addr = flag.String("addr", "127.0.0.1:8080", "http service address") 22 | cmdPath string 23 | ) 24 | 25 | const ( 26 | // Time allowed to write a message to the peer. 27 | writeWait = 10 * time.Second 28 | 29 | // Maximum message size allowed from peer. 30 | maxMessageSize = 8192 31 | 32 | // Time allowed to read the next pong message from the peer. 33 | pongWait = 60 * time.Second 34 | 35 | // Send pings to peer with this period. Must be less than pongWait. 36 | pingPeriod = (pongWait * 9) / 10 37 | 38 | // Time to wait before force close on connection. 39 | closeGracePeriod = 10 * time.Second 40 | ) 41 | 42 | func pumpStdin(ws *websocket.Conn, w io.Writer) { 43 | defer ws.Close() 44 | ws.SetReadLimit(maxMessageSize) 45 | ws.SetReadDeadline(time.Now().Add(pongWait)) 46 | ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 47 | for { 48 | _, message, err := ws.ReadMessage() 49 | if err != nil { 50 | break 51 | } 52 | message = append(message, '\n') 53 | if _, err := w.Write(message); err != nil { 54 | break 55 | } 56 | } 57 | } 58 | 59 | func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { 60 | defer func() { 61 | }() 62 | s := bufio.NewScanner(r) 63 | for s.Scan() { 64 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 65 | if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil { 66 | ws.Close() 67 | break 68 | } 69 | } 70 | if s.Err() != nil { 71 | log.Println("scan:", s.Err()) 72 | } 73 | close(done) 74 | 75 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 76 | ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 77 | time.Sleep(closeGracePeriod) 78 | ws.Close() 79 | } 80 | 81 | func ping(ws *websocket.Conn, done chan struct{}) { 82 | ticker := time.NewTicker(pingPeriod) 83 | defer ticker.Stop() 84 | for { 85 | select { 86 | case <-ticker.C: 87 | if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { 88 | log.Println("ping:", err) 89 | } 90 | case <-done: 91 | return 92 | } 93 | } 94 | } 95 | 96 | func internalError(ws *websocket.Conn, msg string, err error) { 97 | log.Println(msg, err) 98 | ws.WriteMessage(websocket.TextMessage, []byte("Internal server error.")) 99 | } 100 | 101 | var upgrader = websocket.Upgrader{} 102 | 103 | func serveWs(w http.ResponseWriter, r *http.Request) { 104 | ws, err := upgrader.Upgrade(w, r, nil) 105 | if err != nil { 106 | log.Println("upgrade:", err) 107 | return 108 | } 109 | 110 | defer ws.Close() 111 | 112 | outr, outw, err := os.Pipe() 113 | if err != nil { 114 | internalError(ws, "stdout:", err) 115 | return 116 | } 117 | defer outr.Close() 118 | defer outw.Close() 119 | 120 | inr, inw, err := os.Pipe() 121 | if err != nil { 122 | internalError(ws, "stdin:", err) 123 | return 124 | } 125 | defer inr.Close() 126 | defer inw.Close() 127 | 128 | proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{ 129 | Files: []*os.File{inr, outw, outw}, 130 | }) 131 | if err != nil { 132 | internalError(ws, "start:", err) 133 | return 134 | } 135 | 136 | inr.Close() 137 | outw.Close() 138 | 139 | stdoutDone := make(chan struct{}) 140 | go pumpStdout(ws, outr, stdoutDone) 141 | go ping(ws, stdoutDone) 142 | 143 | pumpStdin(ws, inw) 144 | 145 | // Some commands will exit when stdin is closed. 146 | inw.Close() 147 | 148 | // Other commands need a bonk on the head. 149 | if err := proc.Signal(os.Interrupt); err != nil { 150 | log.Println("inter:", err) 151 | } 152 | 153 | select { 154 | case <-stdoutDone: 155 | case <-time.After(time.Second): 156 | // A bigger bonk on the head. 157 | if err := proc.Signal(os.Kill); err != nil { 158 | log.Println("term:", err) 159 | } 160 | <-stdoutDone 161 | } 162 | 163 | if _, err := proc.Wait(); err != nil { 164 | log.Println("wait:", err) 165 | } 166 | } 167 | 168 | func serveHome(w http.ResponseWriter, r *http.Request) { 169 | if r.URL.Path != "/" { 170 | http.Error(w, "Not found", http.StatusNotFound) 171 | return 172 | } 173 | if r.Method != http.MethodGet { 174 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 175 | return 176 | } 177 | http.ServeFile(w, r, "home.html") 178 | } 179 | 180 | func main() { 181 | flag.Parse() 182 | if len(flag.Args()) < 1 { 183 | log.Fatal("must specify at least one argument") 184 | } 185 | var err error 186 | cmdPath, err = exec.LookPath(flag.Args()[0]) 187 | if err != nil { 188 | log.Fatal(err) 189 | } 190 | http.HandleFunc("/", serveHome) 191 | http.HandleFunc("/ws", serveWs) 192 | log.Fatal(http.ListenAndServe(*addr, nil)) 193 | } 194 | -------------------------------------------------------------------------------- /examples/echo/README.md: -------------------------------------------------------------------------------- 1 | # Client and server example 2 | 3 | This example shows a simple client and server. 4 | 5 | The server echoes messages sent to it. The client sends a message every second 6 | and prints all messages received. 7 | 8 | To run the example, start the server: 9 | 10 | $ go run server.go 11 | 12 | Next, start the client: 13 | 14 | $ go run client.go 15 | 16 | The server includes a simple web client. To use the client, open 17 | http://127.0.0.1:8080 in the browser and follow the instructions on the page. 18 | -------------------------------------------------------------------------------- /examples/echo/client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | package main 9 | 10 | import ( 11 | "flag" 12 | "log" 13 | "net/url" 14 | "os" 15 | "os/signal" 16 | "time" 17 | 18 | "github.com/gorilla/websocket" 19 | ) 20 | 21 | var addr = flag.String("addr", "localhost:8080", "http service address") 22 | 23 | func main() { 24 | flag.Parse() 25 | log.SetFlags(0) 26 | 27 | interrupt := make(chan os.Signal, 1) 28 | signal.Notify(interrupt, os.Interrupt) 29 | 30 | u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"} 31 | log.Printf("connecting to %s", u.String()) 32 | 33 | c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) 34 | if err != nil { 35 | log.Fatal("dial:", err) 36 | } 37 | defer c.Close() 38 | 39 | done := make(chan struct{}) 40 | 41 | go func() { 42 | defer close(done) 43 | for { 44 | _, message, err := c.ReadMessage() 45 | if err != nil { 46 | log.Println("read:", err) 47 | return 48 | } 49 | log.Printf("recv: %s", message) 50 | } 51 | }() 52 | 53 | ticker := time.NewTicker(time.Second) 54 | defer ticker.Stop() 55 | 56 | for { 57 | select { 58 | case <-done: 59 | return 60 | case t := <-ticker.C: 61 | err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) 62 | if err != nil { 63 | log.Println("write:", err) 64 | return 65 | } 66 | case <-interrupt: 67 | log.Println("interrupt") 68 | 69 | // Cleanly close the connection by sending a close message and then 70 | // waiting (with timeout) for the server to close the connection. 71 | err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) 72 | if err != nil { 73 | log.Println("write close:", err) 74 | return 75 | } 76 | select { 77 | case <-done: 78 | case <-time.After(time.Second): 79 | } 80 | return 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /examples/echo/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build ignore 6 | // +build ignore 7 | 8 | package main 9 | 10 | import ( 11 | "flag" 12 | "html/template" 13 | "log" 14 | "net/http" 15 | 16 | "github.com/gorilla/websocket" 17 | ) 18 | 19 | var addr = flag.String("addr", "localhost:8080", "http service address") 20 | 21 | var upgrader = websocket.Upgrader{} // use default options 22 | 23 | func echo(w http.ResponseWriter, r *http.Request) { 24 | c, err := upgrader.Upgrade(w, r, nil) 25 | if err != nil { 26 | log.Print("upgrade:", err) 27 | return 28 | } 29 | defer c.Close() 30 | for { 31 | mt, message, err := c.ReadMessage() 32 | if err != nil { 33 | log.Println("read:", err) 34 | break 35 | } 36 | log.Printf("recv: %s", message) 37 | err = c.WriteMessage(mt, message) 38 | if err != nil { 39 | log.Println("write:", err) 40 | break 41 | } 42 | } 43 | } 44 | 45 | func home(w http.ResponseWriter, r *http.Request) { 46 | homeTemplate.Execute(w, "ws://"+r.Host+"/echo") 47 | } 48 | 49 | func main() { 50 | flag.Parse() 51 | log.SetFlags(0) 52 | http.HandleFunc("/echo", echo) 53 | http.HandleFunc("/", home) 54 | log.Fatal(http.ListenAndServe(*addr, nil)) 55 | } 56 | 57 | var homeTemplate = template.Must(template.New("").Parse(` 58 | 59 | 60 | 61 | 62 | 116 | 117 | 118 | 119 |
120 |

Click "Open" to create a connection to the server, 121 | "Send" to send a message to the server and "Close" to close the connection. 122 | You can change the message and send multiple times. 123 |

124 |

125 | 126 | 127 |

128 | 129 |

130 |
131 |
132 |
133 | 134 | 135 | `)) 136 | -------------------------------------------------------------------------------- /examples/filewatch/README.md: -------------------------------------------------------------------------------- 1 | # File Watch example. 2 | 3 | This example sends a file to the browser client for display whenever the file is modified. 4 | 5 | $ go get github.com/gorilla/websocket 6 | $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/filewatch` 7 | $ go run main.go 8 | # Open http://localhost:8080/ . 9 | # Modify the file to see it update in the browser. 10 | -------------------------------------------------------------------------------- /examples/filewatch/main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "html/template" 10 | "log" 11 | "net/http" 12 | "os" 13 | "strconv" 14 | "time" 15 | 16 | "github.com/gorilla/websocket" 17 | ) 18 | 19 | const ( 20 | // Time allowed to write the file to the client. 21 | writeWait = 10 * time.Second 22 | 23 | // Time allowed to read the next pong message from the client. 24 | pongWait = 60 * time.Second 25 | 26 | // Send pings to client with this period. Must be less than pongWait. 27 | pingPeriod = (pongWait * 9) / 10 28 | 29 | // Poll file for changes with this period. 30 | filePeriod = 10 * time.Second 31 | ) 32 | 33 | var ( 34 | addr = flag.String("addr", ":8080", "http service address") 35 | homeTempl = template.Must(template.New("").Parse(homeHTML)) 36 | filename string 37 | upgrader = websocket.Upgrader{ 38 | ReadBufferSize: 1024, 39 | WriteBufferSize: 1024, 40 | } 41 | ) 42 | 43 | func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) { 44 | fi, err := os.Stat(filename) 45 | if err != nil { 46 | return nil, lastMod, err 47 | } 48 | if !fi.ModTime().After(lastMod) { 49 | return nil, lastMod, nil 50 | } 51 | p, err := os.ReadFile(filename) 52 | if err != nil { 53 | return nil, fi.ModTime(), err 54 | } 55 | return p, fi.ModTime(), nil 56 | } 57 | 58 | func reader(ws *websocket.Conn) { 59 | defer ws.Close() 60 | ws.SetReadLimit(512) 61 | ws.SetReadDeadline(time.Now().Add(pongWait)) 62 | ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 63 | for { 64 | _, _, err := ws.ReadMessage() 65 | if err != nil { 66 | break 67 | } 68 | } 69 | } 70 | 71 | func writer(ws *websocket.Conn, lastMod time.Time) { 72 | lastError := "" 73 | pingTicker := time.NewTicker(pingPeriod) 74 | fileTicker := time.NewTicker(filePeriod) 75 | defer func() { 76 | pingTicker.Stop() 77 | fileTicker.Stop() 78 | ws.Close() 79 | }() 80 | for { 81 | select { 82 | case <-fileTicker.C: 83 | var p []byte 84 | var err error 85 | 86 | p, lastMod, err = readFileIfModified(lastMod) 87 | 88 | if err != nil { 89 | if s := err.Error(); s != lastError { 90 | lastError = s 91 | p = []byte(lastError) 92 | } 93 | } else { 94 | lastError = "" 95 | } 96 | 97 | if p != nil { 98 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 99 | if err := ws.WriteMessage(websocket.TextMessage, p); err != nil { 100 | return 101 | } 102 | } 103 | case <-pingTicker.C: 104 | ws.SetWriteDeadline(time.Now().Add(writeWait)) 105 | if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { 106 | return 107 | } 108 | } 109 | } 110 | } 111 | 112 | func serveWs(w http.ResponseWriter, r *http.Request) { 113 | ws, err := upgrader.Upgrade(w, r, nil) 114 | if err != nil { 115 | if _, ok := err.(websocket.HandshakeError); !ok { 116 | log.Println(err) 117 | } 118 | return 119 | } 120 | 121 | var lastMod time.Time 122 | if n, err := strconv.ParseInt(r.FormValue("lastMod"), 16, 64); err == nil { 123 | lastMod = time.Unix(0, n) 124 | } 125 | 126 | go writer(ws, lastMod) 127 | reader(ws) 128 | } 129 | 130 | func serveHome(w http.ResponseWriter, r *http.Request) { 131 | if r.URL.Path != "/" { 132 | http.Error(w, "Not found", http.StatusNotFound) 133 | return 134 | } 135 | if r.Method != http.MethodGet { 136 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 137 | return 138 | } 139 | w.Header().Set("Content-Type", "text/html; charset=utf-8") 140 | p, lastMod, err := readFileIfModified(time.Time{}) 141 | if err != nil { 142 | p = []byte(err.Error()) 143 | lastMod = time.Unix(0, 0) 144 | } 145 | var v = struct { 146 | Host string 147 | Data string 148 | LastMod string 149 | }{ 150 | r.Host, 151 | string(p), 152 | strconv.FormatInt(lastMod.UnixNano(), 16), 153 | } 154 | homeTempl.Execute(w, &v) 155 | } 156 | 157 | func main() { 158 | flag.Parse() 159 | if flag.NArg() != 1 { 160 | log.Fatal("filename not specified") 161 | } 162 | filename = flag.Args()[0] 163 | http.HandleFunc("/", serveHome) 164 | http.HandleFunc("/ws", serveWs) 165 | if err := http.ListenAndServe(*addr, nil); err != nil { 166 | log.Fatal(err) 167 | } 168 | } 169 | 170 | const homeHTML = ` 171 | 172 | 173 | WebSocket Example 174 | 175 | 176 |
{{.Data}}
177 | 190 | 191 | 192 | ` 193 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gorilla/websocket 2 | 3 | go 1.20 4 | 5 | retract ( 6 | v1.5.2 // tag accidentally overwritten 7 | ) 8 | 9 | require golang.org/x/net v0.26.0 10 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= 2 | golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= 3 | -------------------------------------------------------------------------------- /join.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "io" 9 | "strings" 10 | ) 11 | 12 | // JoinMessages concatenates received messages to create a single io.Reader. 13 | // The string term is appended to each message. The returned reader does not 14 | // support concurrent calls to the Read method. 15 | func JoinMessages(c *Conn, term string) io.Reader { 16 | return &joinReader{c: c, term: term} 17 | } 18 | 19 | type joinReader struct { 20 | c *Conn 21 | term string 22 | r io.Reader 23 | } 24 | 25 | func (r *joinReader) Read(p []byte) (int, error) { 26 | if r.r == nil { 27 | var err error 28 | _, r.r, err = r.c.NextReader() 29 | if err != nil { 30 | return 0, err 31 | } 32 | if r.term != "" { 33 | r.r = io.MultiReader(r.r, strings.NewReader(r.term)) 34 | } 35 | } 36 | n, err := r.r.Read(p) 37 | if err == io.EOF { 38 | err = nil 39 | r.r = nil 40 | } 41 | return n, err 42 | } 43 | -------------------------------------------------------------------------------- /join_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "io" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | func TestJoinMessages(t *testing.T) { 15 | messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} 16 | for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { 17 | for _, term := range []string{"", ","} { 18 | var connBuf bytes.Buffer 19 | wc := newTestConn(nil, &connBuf, true) 20 | rc := newTestConn(&connBuf, nil, false) 21 | for _, m := range messages { 22 | _ = wc.WriteMessage(BinaryMessage, []byte(m)) 23 | } 24 | 25 | var result bytes.Buffer 26 | _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk)) 27 | if IsUnexpectedCloseError(err, CloseAbnormalClosure) { 28 | t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err) 29 | } 30 | want := strings.Join(messages, term) + term 31 | if result.String() != want { 32 | t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want) 33 | } 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /json.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "encoding/json" 9 | "io" 10 | ) 11 | 12 | // WriteJSON writes the JSON encoding of v as a message. 13 | // 14 | // Deprecated: Use c.WriteJSON instead. 15 | func WriteJSON(c *Conn, v interface{}) error { 16 | return c.WriteJSON(v) 17 | } 18 | 19 | // WriteJSON writes the JSON encoding of v as a message. 20 | // 21 | // See the documentation for encoding/json Marshal for details about the 22 | // conversion of Go values to JSON. 23 | func (c *Conn) WriteJSON(v interface{}) error { 24 | w, err := c.NextWriter(TextMessage) 25 | if err != nil { 26 | return err 27 | } 28 | err1 := json.NewEncoder(w).Encode(v) 29 | err2 := w.Close() 30 | if err1 != nil { 31 | return err1 32 | } 33 | return err2 34 | } 35 | 36 | // ReadJSON reads the next JSON-encoded message from the connection and stores 37 | // it in the value pointed to by v. 38 | // 39 | // Deprecated: Use c.ReadJSON instead. 40 | func ReadJSON(c *Conn, v interface{}) error { 41 | return c.ReadJSON(v) 42 | } 43 | 44 | // ReadJSON reads the next JSON-encoded message from the connection and stores 45 | // it in the value pointed to by v. 46 | // 47 | // See the documentation for the encoding/json Unmarshal function for details 48 | // about the conversion of JSON to a Go value. 49 | func (c *Conn) ReadJSON(v interface{}) error { 50 | _, r, err := c.NextReader() 51 | if err != nil { 52 | return err 53 | } 54 | err = json.NewDecoder(r).Decode(v) 55 | if err == io.EOF { 56 | // One value is expected in the message. 57 | err = io.ErrUnexpectedEOF 58 | } 59 | return err 60 | } 61 | -------------------------------------------------------------------------------- /json_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "encoding/json" 10 | "io" 11 | "reflect" 12 | "testing" 13 | ) 14 | 15 | func TestJSON(t *testing.T) { 16 | var buf bytes.Buffer 17 | wc := newTestConn(nil, &buf, true) 18 | rc := newTestConn(&buf, nil, false) 19 | 20 | var actual, expect struct { 21 | A int 22 | B string 23 | } 24 | expect.A = 1 25 | expect.B = "hello" 26 | 27 | if err := wc.WriteJSON(&expect); err != nil { 28 | t.Fatal("write", err) 29 | } 30 | 31 | if err := rc.ReadJSON(&actual); err != nil { 32 | t.Fatal("read", err) 33 | } 34 | 35 | if !reflect.DeepEqual(&actual, &expect) { 36 | t.Fatal("equal", actual, expect) 37 | } 38 | } 39 | 40 | func TestPartialJSONRead(t *testing.T) { 41 | var buf0, buf1 bytes.Buffer 42 | wc := newTestConn(nil, &buf0, true) 43 | rc := newTestConn(&buf0, &buf1, false) 44 | 45 | var v struct { 46 | A int 47 | B string 48 | } 49 | v.A = 1 50 | v.B = "hello" 51 | 52 | messageCount := 0 53 | 54 | // Partial JSON values. 55 | 56 | data, err := json.Marshal(v) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | for i := len(data) - 1; i >= 0; i-- { 61 | if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { 62 | t.Fatal(err) 63 | } 64 | messageCount++ 65 | } 66 | 67 | // Whitespace. 68 | 69 | if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { 70 | t.Fatal(err) 71 | } 72 | messageCount++ 73 | 74 | // Close. 75 | 76 | if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | for i := 0; i < messageCount; i++ { 81 | err := rc.ReadJSON(&v) 82 | if err != io.ErrUnexpectedEOF { 83 | t.Error("read", i, err) 84 | } 85 | } 86 | 87 | err = rc.ReadJSON(&v) 88 | if _, ok := err.(*CloseError); !ok { 89 | t.Error("final", err) 90 | } 91 | } 92 | 93 | func TestDeprecatedJSON(t *testing.T) { 94 | var buf bytes.Buffer 95 | wc := newTestConn(nil, &buf, true) 96 | rc := newTestConn(&buf, nil, false) 97 | 98 | var actual, expect struct { 99 | A int 100 | B string 101 | } 102 | expect.A = 1 103 | expect.B = "hello" 104 | 105 | if err := WriteJSON(wc, &expect); err != nil { 106 | t.Fatal("write", err) 107 | } 108 | 109 | if err := ReadJSON(rc, &actual); err != nil { 110 | t.Fatal("read", err) 111 | } 112 | 113 | if !reflect.DeepEqual(&actual, &expect) { 114 | t.Fatal("equal", actual, expect) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /mask.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 | // this source code is governed by a BSD-style license that can be found in the 3 | // LICENSE file. 4 | 5 | //go:build !appengine 6 | // +build !appengine 7 | 8 | package websocket 9 | 10 | import "unsafe" 11 | 12 | const wordSize = int(unsafe.Sizeof(uintptr(0))) 13 | 14 | func maskBytes(key [4]byte, pos int, b []byte) int { 15 | // Mask one byte at a time for small buffers. 16 | if len(b) < 2*wordSize { 17 | for i := range b { 18 | b[i] ^= key[pos&3] 19 | pos++ 20 | } 21 | return pos & 3 22 | } 23 | 24 | // Mask one byte at a time to word boundary. 25 | if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { 26 | n = wordSize - n 27 | for i := range b[:n] { 28 | b[i] ^= key[pos&3] 29 | pos++ 30 | } 31 | b = b[n:] 32 | } 33 | 34 | // Create aligned word size key. 35 | var k [wordSize]byte 36 | for i := range k { 37 | k[i] = key[(pos+i)&3] 38 | } 39 | kw := *(*uintptr)(unsafe.Pointer(&k)) 40 | 41 | // Mask one word at a time. 42 | n := (len(b) / wordSize) * wordSize 43 | for i := 0; i < n; i += wordSize { 44 | *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw 45 | } 46 | 47 | // Mask one byte at a time for remaining bytes. 48 | b = b[n:] 49 | for i := range b { 50 | b[i] ^= key[pos&3] 51 | pos++ 52 | } 53 | 54 | return pos & 3 55 | } 56 | -------------------------------------------------------------------------------- /mask_safe.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 | // this source code is governed by a BSD-style license that can be found in the 3 | // LICENSE file. 4 | 5 | //go:build appengine 6 | // +build appengine 7 | 8 | package websocket 9 | 10 | func maskBytes(key [4]byte, pos int, b []byte) int { 11 | for i := range b { 12 | b[i] ^= key[pos&3] 13 | pos++ 14 | } 15 | return pos & 3 16 | } 17 | -------------------------------------------------------------------------------- /mask_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 | // this source code is governed by a BSD-style license that can be found in the 3 | // LICENSE file. 4 | 5 | // !appengine 6 | 7 | package websocket 8 | 9 | import ( 10 | "fmt" 11 | "testing" 12 | ) 13 | 14 | func maskBytesByByte(key [4]byte, pos int, b []byte) int { 15 | for i := range b { 16 | b[i] ^= key[pos&3] 17 | pos++ 18 | } 19 | return pos & 3 20 | } 21 | 22 | func notzero(b []byte) int { 23 | for i := range b { 24 | if b[i] != 0 { 25 | return i 26 | } 27 | } 28 | return -1 29 | } 30 | 31 | func TestMaskBytes(t *testing.T) { 32 | key := [4]byte{1, 2, 3, 4} 33 | for size := 1; size <= 1024; size++ { 34 | for align := 0; align < wordSize; align++ { 35 | for pos := 0; pos < 4; pos++ { 36 | b := make([]byte, size+align)[align:] 37 | maskBytes(key, pos, b) 38 | maskBytesByByte(key, pos, b) 39 | if i := notzero(b); i >= 0 { 40 | t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) 41 | } 42 | } 43 | } 44 | } 45 | } 46 | 47 | func BenchmarkMaskBytes(b *testing.B) { 48 | for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { 49 | b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { 50 | for _, align := range []int{wordSize / 2} { 51 | b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { 52 | for _, fn := range []struct { 53 | name string 54 | fn func(key [4]byte, pos int, b []byte) int 55 | }{ 56 | {"byte", maskBytesByByte}, 57 | {"word", maskBytes}, 58 | } { 59 | b.Run(fn.name, func(b *testing.B) { 60 | key := newMaskKey() 61 | data := make([]byte, size+align)[align:] 62 | for i := 0; i < b.N; i++ { 63 | fn.fn(key, 0, data) 64 | } 65 | b.SetBytes(int64(len(data))) 66 | }) 67 | } 68 | }) 69 | } 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /prepared.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "net" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | // PreparedMessage caches on the wire representations of a message payload. 15 | // Use PreparedMessage to efficiently send a message payload to multiple 16 | // connections. PreparedMessage is especially useful when compression is used 17 | // because the CPU and memory expensive compression operation can be executed 18 | // once for a given set of compression options. 19 | type PreparedMessage struct { 20 | messageType int 21 | data []byte 22 | mu sync.Mutex 23 | frames map[prepareKey]*preparedFrame 24 | } 25 | 26 | // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. 27 | type prepareKey struct { 28 | isServer bool 29 | compress bool 30 | compressionLevel int 31 | } 32 | 33 | // preparedFrame contains data in wire representation. 34 | type preparedFrame struct { 35 | once sync.Once 36 | data []byte 37 | } 38 | 39 | // NewPreparedMessage returns an initialized PreparedMessage. You can then send 40 | // it to connection using WritePreparedMessage method. Valid wire 41 | // representation will be calculated lazily only once for a set of current 42 | // connection options. 43 | func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { 44 | pm := &PreparedMessage{ 45 | messageType: messageType, 46 | frames: make(map[prepareKey]*preparedFrame), 47 | data: data, 48 | } 49 | 50 | // Prepare a plain server frame. 51 | _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | // To protect against caller modifying the data argument, remember the data 57 | // copied to the plain server frame. 58 | pm.data = frameData[len(frameData)-len(data):] 59 | return pm, nil 60 | } 61 | 62 | func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { 63 | pm.mu.Lock() 64 | frame, ok := pm.frames[key] 65 | if !ok { 66 | frame = &preparedFrame{} 67 | pm.frames[key] = frame 68 | } 69 | pm.mu.Unlock() 70 | 71 | var err error 72 | frame.once.Do(func() { 73 | // Prepare a frame using a 'fake' connection. 74 | // TODO: Refactor code in conn.go to allow more direct construction of 75 | // the frame. 76 | mu := make(chan struct{}, 1) 77 | mu <- struct{}{} 78 | var nc prepareConn 79 | c := &Conn{ 80 | conn: &nc, 81 | mu: mu, 82 | isServer: key.isServer, 83 | compressionLevel: key.compressionLevel, 84 | enableWriteCompression: true, 85 | writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), 86 | } 87 | if key.compress { 88 | c.newCompressionWriter = compressNoContextTakeover 89 | } 90 | err = c.WriteMessage(pm.messageType, pm.data) 91 | frame.data = nc.buf.Bytes() 92 | }) 93 | return pm.messageType, frame.data, err 94 | } 95 | 96 | type prepareConn struct { 97 | buf bytes.Buffer 98 | net.Conn 99 | } 100 | 101 | func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } 102 | func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } 103 | -------------------------------------------------------------------------------- /prepared_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bytes" 9 | "compress/flate" 10 | "math/rand" 11 | "testing" 12 | ) 13 | 14 | var preparedMessageTests = []struct { 15 | messageType int 16 | isServer bool 17 | enableWriteCompression bool 18 | compressionLevel int 19 | }{ 20 | // Server 21 | {TextMessage, true, false, flate.BestSpeed}, 22 | {TextMessage, true, true, flate.BestSpeed}, 23 | {TextMessage, true, true, flate.BestCompression}, 24 | {PingMessage, true, false, flate.BestSpeed}, 25 | {PingMessage, true, true, flate.BestSpeed}, 26 | 27 | // Client 28 | {TextMessage, false, false, flate.BestSpeed}, 29 | {TextMessage, false, true, flate.BestSpeed}, 30 | {TextMessage, false, true, flate.BestCompression}, 31 | {PingMessage, false, false, flate.BestSpeed}, 32 | {PingMessage, false, true, flate.BestSpeed}, 33 | } 34 | 35 | func TestPreparedMessage(t *testing.T) { 36 | testRand := rand.New(rand.NewSource(99)) 37 | prevMaskRand := maskRand 38 | maskRand = testRand 39 | defer func() { maskRand = prevMaskRand }() 40 | 41 | for _, tt := range preparedMessageTests { 42 | var data = []byte("this is a test") 43 | var buf bytes.Buffer 44 | c := newTestConn(nil, &buf, tt.isServer) 45 | if tt.enableWriteCompression { 46 | c.newCompressionWriter = compressNoContextTakeover 47 | } 48 | if err := c.SetCompressionLevel(tt.compressionLevel); err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | // Seed random number generator for consistent frame mask. 53 | testRand.Seed(1234) 54 | 55 | if err := c.WriteMessage(tt.messageType, data); err != nil { 56 | t.Fatal(err) 57 | } 58 | want := buf.String() 59 | 60 | pm, err := NewPreparedMessage(tt.messageType, data) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | // Scribble on data to ensure that NewPreparedMessage takes a snapshot. 66 | copy(data, "hello world") 67 | 68 | // Seed random number generator for consistent frame mask. 69 | testRand.Seed(1234) 70 | 71 | buf.Reset() 72 | if err := c.WritePreparedMessage(pm); err != nil { 73 | t.Fatal(err) 74 | } 75 | got := buf.String() 76 | 77 | if got != want { 78 | t.Errorf("write message != prepared message for %+v", tt) 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /proxy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "context" 11 | "encoding/base64" 12 | "errors" 13 | "net" 14 | "net/http" 15 | "net/url" 16 | "strings" 17 | 18 | "golang.org/x/net/proxy" 19 | ) 20 | 21 | type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) 22 | 23 | func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { 24 | return fn(context.Background(), network, addr) 25 | } 26 | 27 | func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 28 | return fn(ctx, network, addr) 29 | } 30 | 31 | func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { 32 | if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { 33 | return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil 34 | } 35 | dialer, err := proxy.FromURL(proxyURL, forwardDial) 36 | if err != nil { 37 | return nil, err 38 | } 39 | if d, ok := dialer.(proxy.ContextDialer); ok { 40 | return d.DialContext, nil 41 | } 42 | return func(ctx context.Context, net, addr string) (net.Conn, error) { 43 | return dialer.Dial(net, addr) 44 | }, nil 45 | } 46 | 47 | type httpProxyDialer struct { 48 | proxyURL *url.URL 49 | forwardDial netDialerFunc 50 | } 51 | 52 | func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { 53 | hostPort, _ := hostPortNoPort(hpd.proxyURL) 54 | conn, err := hpd.forwardDial(ctx, network, hostPort) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | connectHeader := make(http.Header) 60 | if user := hpd.proxyURL.User; user != nil { 61 | proxyUser := user.Username() 62 | if proxyPassword, passwordSet := user.Password(); passwordSet { 63 | credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) 64 | connectHeader.Set("Proxy-Authorization", "Basic "+credential) 65 | } 66 | } 67 | connectReq := &http.Request{ 68 | Method: http.MethodConnect, 69 | URL: &url.URL{Opaque: addr}, 70 | Host: addr, 71 | Header: connectHeader, 72 | } 73 | 74 | if err := connectReq.Write(conn); err != nil { 75 | conn.Close() 76 | return nil, err 77 | } 78 | 79 | // Read response. It's OK to use and discard buffered reader here because 80 | // the remote server does not speak until spoken to. 81 | br := bufio.NewReader(conn) 82 | resp, err := http.ReadResponse(br, connectReq) 83 | if err != nil { 84 | conn.Close() 85 | return nil, err 86 | } 87 | 88 | // Close the response body to silence false positives from linters. Reset 89 | // the buffered reader first to ensure that Close() does not read from 90 | // conn. 91 | // Note: Applications must call resp.Body.Close() on a response returned 92 | // http.ReadResponse to inspect trailers or read another response from the 93 | // buffered reader. The call to resp.Body.Close() does not release 94 | // resources. 95 | br.Reset(bytes.NewReader(nil)) 96 | _ = resp.Body.Close() 97 | 98 | if resp.StatusCode != http.StatusOK { 99 | _ = conn.Close() 100 | f := strings.SplitN(resp.Status, " ", 2) 101 | return nil, errors.New(f[1]) 102 | } 103 | return conn, nil 104 | } 105 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "net" 10 | "net/http" 11 | "net/url" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | // HandshakeError describes an error with the handshake from the peer. 17 | type HandshakeError struct { 18 | message string 19 | } 20 | 21 | func (e HandshakeError) Error() string { return e.message } 22 | 23 | // Upgrader specifies parameters for upgrading an HTTP connection to a 24 | // WebSocket connection. 25 | // 26 | // It is safe to call Upgrader's methods concurrently. 27 | type Upgrader struct { 28 | // HandshakeTimeout specifies the duration for the handshake to complete. 29 | HandshakeTimeout time.Duration 30 | 31 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 32 | // size is zero, then buffers allocated by the HTTP server are used. The 33 | // I/O buffer sizes do not limit the size of the messages that can be sent 34 | // or received. 35 | ReadBufferSize, WriteBufferSize int 36 | 37 | // WriteBufferPool is a pool of buffers for write operations. If the value 38 | // is not set, then write buffers are allocated to the connection for the 39 | // lifetime of the connection. 40 | // 41 | // A pool is most useful when the application has a modest volume of writes 42 | // across a large number of connections. 43 | // 44 | // Applications should use a single pool for each unique value of 45 | // WriteBufferSize. 46 | WriteBufferPool BufferPool 47 | 48 | // Subprotocols specifies the server's supported protocols in order of 49 | // preference. If this field is not nil, then the Upgrade method negotiates a 50 | // subprotocol by selecting the first match in this list with a protocol 51 | // requested by the client. If there's no match, then no protocol is 52 | // negotiated (the Sec-Websocket-Protocol header is not included in the 53 | // handshake response). 54 | Subprotocols []string 55 | 56 | // Error specifies the function for generating HTTP error responses. If Error 57 | // is nil, then http.Error is used to generate the HTTP response. 58 | Error func(w http.ResponseWriter, r *http.Request, status int, reason error) 59 | 60 | // CheckOrigin returns true if the request Origin header is acceptable. If 61 | // CheckOrigin is nil, then a safe default is used: return false if the 62 | // Origin request header is present and the origin host is not equal to 63 | // request Host header. 64 | // 65 | // A CheckOrigin function should carefully validate the request origin to 66 | // prevent cross-site request forgery. 67 | CheckOrigin func(r *http.Request) bool 68 | 69 | // EnableCompression specify if the server should attempt to negotiate per 70 | // message compression (RFC 7692). Setting this value to true does not 71 | // guarantee that compression will be supported. Currently only "no context 72 | // takeover" modes are supported. 73 | EnableCompression bool 74 | } 75 | 76 | func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { 77 | err := HandshakeError{reason} 78 | if u.Error != nil { 79 | u.Error(w, r, status, err) 80 | } else { 81 | w.Header().Set("Sec-Websocket-Version", "13") 82 | http.Error(w, http.StatusText(status), status) 83 | } 84 | return nil, err 85 | } 86 | 87 | // checkSameOrigin returns true if the origin is not set or is equal to the request host. 88 | func checkSameOrigin(r *http.Request) bool { 89 | origin := r.Header["Origin"] 90 | if len(origin) == 0 { 91 | return true 92 | } 93 | u, err := url.Parse(origin[0]) 94 | if err != nil { 95 | return false 96 | } 97 | return equalASCIIFold(u.Host, r.Host) 98 | } 99 | 100 | func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { 101 | if u.Subprotocols != nil { 102 | clientProtocols := Subprotocols(r) 103 | for _, clientProtocol := range clientProtocols { 104 | for _, serverProtocol := range u.Subprotocols { 105 | if clientProtocol == serverProtocol { 106 | return clientProtocol 107 | } 108 | } 109 | } 110 | } else if responseHeader != nil { 111 | return responseHeader.Get("Sec-Websocket-Protocol") 112 | } 113 | return "" 114 | } 115 | 116 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 117 | // 118 | // The responseHeader is included in the response to the client's upgrade 119 | // request. Use the responseHeader to specify cookies (Set-Cookie). To specify 120 | // subprotocols supported by the server, set Upgrader.Subprotocols directly. 121 | // 122 | // If the upgrade fails, then Upgrade replies to the client with an HTTP error 123 | // response. 124 | func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { 125 | const badHandshake = "websocket: the client is not using the websocket protocol: " 126 | 127 | if !tokenListContainsValue(r.Header, "Connection", "upgrade") { 128 | return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") 129 | } 130 | 131 | if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { 132 | w.Header().Set("Upgrade", "websocket") 133 | return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header") 134 | } 135 | 136 | if r.Method != http.MethodGet { 137 | return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") 138 | } 139 | 140 | if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { 141 | return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") 142 | } 143 | 144 | if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { 145 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") 146 | } 147 | 148 | checkOrigin := u.CheckOrigin 149 | if checkOrigin == nil { 150 | checkOrigin = checkSameOrigin 151 | } 152 | if !checkOrigin(r) { 153 | return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") 154 | } 155 | 156 | challengeKey := r.Header.Get("Sec-Websocket-Key") 157 | if !isValidChallengeKey(challengeKey) { 158 | return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") 159 | } 160 | 161 | subprotocol := u.selectSubprotocol(r, responseHeader) 162 | 163 | // Negotiate PMCE 164 | var compress bool 165 | if u.EnableCompression { 166 | for _, ext := range parseExtensions(r.Header) { 167 | if ext[""] != "permessage-deflate" { 168 | continue 169 | } 170 | compress = true 171 | break 172 | } 173 | } 174 | 175 | netConn, brw, err := http.NewResponseController(w).Hijack() 176 | if err != nil { 177 | return u.returnError(w, r, http.StatusInternalServerError, 178 | "websocket: hijack: "+err.Error()) 179 | } 180 | 181 | // Close the network connection when returning an error. The variable 182 | // netConn is set to nil before the success return at the end of the 183 | // function. 184 | defer func() { 185 | if netConn != nil { 186 | // It's safe to ignore the error from Close() because this code is 187 | // only executed when returning a more important error to the 188 | // application. 189 | _ = netConn.Close() 190 | } 191 | }() 192 | 193 | var br *bufio.Reader 194 | if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 { 195 | // Use hijacked buffered reader as the connection reader. 196 | br = brw.Reader 197 | } else if brw.Reader.Buffered() > 0 { 198 | // Wrap the network connection to read buffered data in brw.Reader 199 | // before reading from the network connection. This should be rare 200 | // because a client must not send message data before receiving the 201 | // handshake response. 202 | netConn = &brNetConn{br: brw.Reader, Conn: netConn} 203 | } 204 | 205 | buf := brw.Writer.AvailableBuffer() 206 | 207 | var writeBuf []byte 208 | if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { 209 | // Reuse hijacked write buffer as connection buffer. 210 | writeBuf = buf 211 | } 212 | 213 | c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) 214 | c.subprotocol = subprotocol 215 | 216 | if compress { 217 | c.newCompressionWriter = compressNoContextTakeover 218 | c.newDecompressionReader = decompressNoContextTakeover 219 | } 220 | 221 | // Use larger of hijacked buffer and connection write buffer for header. 222 | p := buf 223 | if len(c.writeBuf) > len(p) { 224 | p = c.writeBuf 225 | } 226 | p = p[:0] 227 | 228 | p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) 229 | p = append(p, computeAcceptKey(challengeKey)...) 230 | p = append(p, "\r\n"...) 231 | if c.subprotocol != "" { 232 | p = append(p, "Sec-WebSocket-Protocol: "...) 233 | p = append(p, c.subprotocol...) 234 | p = append(p, "\r\n"...) 235 | } 236 | if compress { 237 | p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) 238 | } 239 | for k, vs := range responseHeader { 240 | if k == "Sec-Websocket-Protocol" { 241 | continue 242 | } 243 | for _, v := range vs { 244 | p = append(p, k...) 245 | p = append(p, ": "...) 246 | for i := 0; i < len(v); i++ { 247 | b := v[i] 248 | if b <= 31 { 249 | // prevent response splitting. 250 | b = ' ' 251 | } 252 | p = append(p, b) 253 | } 254 | p = append(p, "\r\n"...) 255 | } 256 | } 257 | p = append(p, "\r\n"...) 258 | 259 | if u.HandshakeTimeout > 0 { 260 | if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { 261 | return nil, err 262 | } 263 | } else { 264 | // Clear deadlines set by HTTP server. 265 | if err := netConn.SetDeadline(time.Time{}); err != nil { 266 | return nil, err 267 | } 268 | } 269 | 270 | if _, err = netConn.Write(p); err != nil { 271 | return nil, err 272 | } 273 | if u.HandshakeTimeout > 0 { 274 | if err := netConn.SetWriteDeadline(time.Time{}); err != nil { 275 | return nil, err 276 | } 277 | } 278 | 279 | // Success! Set netConn to nil to stop the deferred function above from 280 | // closing the network connection. 281 | netConn = nil 282 | 283 | return c, nil 284 | } 285 | 286 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 287 | // 288 | // Deprecated: Use websocket.Upgrader instead. 289 | // 290 | // Upgrade does not perform origin checking. The application is responsible for 291 | // checking the Origin header before calling Upgrade. An example implementation 292 | // of the same origin policy check is: 293 | // 294 | // if req.Header.Get("Origin") != "http://"+req.Host { 295 | // http.Error(w, "Origin not allowed", http.StatusForbidden) 296 | // return 297 | // } 298 | // 299 | // If the endpoint supports subprotocols, then the application is responsible 300 | // for negotiating the protocol used on the connection. Use the Subprotocols() 301 | // function to get the subprotocols requested by the client. Use the 302 | // Sec-Websocket-Protocol response header to specify the subprotocol selected 303 | // by the application. 304 | // 305 | // The responseHeader is included in the response to the client's upgrade 306 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the 307 | // negotiated subprotocol (Sec-Websocket-Protocol). 308 | // 309 | // The connection buffers IO to the underlying network connection. The 310 | // readBufSize and writeBufSize parameters specify the size of the buffers to 311 | // use. Messages can be larger than the buffers. 312 | // 313 | // If the request is not a valid WebSocket handshake, then Upgrade returns an 314 | // error of type HandshakeError. Applications should handle this error by 315 | // replying to the client with an HTTP error response. 316 | func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { 317 | u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} 318 | u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { 319 | // don't return errors to maintain backwards compatibility 320 | } 321 | u.CheckOrigin = func(r *http.Request) bool { 322 | // allow all connections by default 323 | return true 324 | } 325 | return u.Upgrade(w, r, responseHeader) 326 | } 327 | 328 | // Subprotocols returns the subprotocols requested by the client in the 329 | // Sec-Websocket-Protocol header. 330 | func Subprotocols(r *http.Request) []string { 331 | h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) 332 | if h == "" { 333 | return nil 334 | } 335 | protocols := strings.Split(h, ",") 336 | for i := range protocols { 337 | protocols[i] = strings.TrimSpace(protocols[i]) 338 | } 339 | return protocols 340 | } 341 | 342 | // IsWebSocketUpgrade returns true if the client requested upgrade to the 343 | // WebSocket protocol. 344 | func IsWebSocketUpgrade(r *http.Request) bool { 345 | return tokenListContainsValue(r.Header, "Connection", "upgrade") && 346 | tokenListContainsValue(r.Header, "Upgrade", "websocket") 347 | } 348 | 349 | type brNetConn struct { 350 | br *bufio.Reader 351 | net.Conn 352 | } 353 | 354 | func (b *brNetConn) Read(p []byte) (n int, err error) { 355 | if b.br != nil { 356 | // Limit read to buferred data. 357 | if n := b.br.Buffered(); len(p) > n { 358 | p = p[:n] 359 | } 360 | n, err = b.br.Read(p) 361 | if b.br.Buffered() == 0 { 362 | b.br = nil 363 | } 364 | return n, err 365 | } 366 | return b.Conn.Read(p) 367 | } 368 | 369 | // NetConn returns the underlying connection that is wrapped by b. 370 | func (b *brNetConn) NetConn() net.Conn { 371 | return b.Conn 372 | } 373 | 374 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "errors" 11 | "net" 12 | "net/http" 13 | "net/http/httptest" 14 | "reflect" 15 | "strings" 16 | "testing" 17 | ) 18 | 19 | var subprotocolTests = []struct { 20 | h string 21 | protocols []string 22 | }{ 23 | {"", nil}, 24 | {"foo", []string{"foo"}}, 25 | {"foo,bar", []string{"foo", "bar"}}, 26 | {"foo, bar", []string{"foo", "bar"}}, 27 | {" foo, bar", []string{"foo", "bar"}}, 28 | {" foo, bar ", []string{"foo", "bar"}}, 29 | } 30 | 31 | func TestSubprotocols(t *testing.T) { 32 | for _, st := range subprotocolTests { 33 | r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} 34 | protocols := Subprotocols(&r) 35 | if !reflect.DeepEqual(st.protocols, protocols) { 36 | t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols) 37 | } 38 | } 39 | } 40 | 41 | var isWebSocketUpgradeTests = []struct { 42 | ok bool 43 | h http.Header 44 | }{ 45 | {false, http.Header{"Upgrade": {"websocket"}}}, 46 | {false, http.Header{"Connection": {"upgrade"}}}, 47 | {true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}}, 48 | } 49 | 50 | func TestIsWebSocketUpgrade(t *testing.T) { 51 | for _, tt := range isWebSocketUpgradeTests { 52 | ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) 53 | if tt.ok != ok { 54 | t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok) 55 | } 56 | } 57 | } 58 | 59 | func TestSubProtocolSelection(t *testing.T) { 60 | upgrader := Upgrader{ 61 | Subprotocols: []string{"foo", "bar", "baz"}, 62 | } 63 | 64 | r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}} 65 | s := upgrader.selectSubprotocol(&r, nil) 66 | if s != "foo" { 67 | t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo") 68 | } 69 | 70 | r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}} 71 | s = upgrader.selectSubprotocol(&r, nil) 72 | if s != "bar" { 73 | t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar") 74 | } 75 | 76 | r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}} 77 | s = upgrader.selectSubprotocol(&r, nil) 78 | if s != "baz" { 79 | t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz") 80 | } 81 | 82 | r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}} 83 | s = upgrader.selectSubprotocol(&r, nil) 84 | if s != "" { 85 | t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string") 86 | } 87 | } 88 | 89 | var checkSameOriginTests = []struct { 90 | ok bool 91 | r *http.Request 92 | }{ 93 | {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}}, 94 | {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, 95 | {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, 96 | } 97 | 98 | func TestCheckSameOrigin(t *testing.T) { 99 | for _, tt := range checkSameOriginTests { 100 | ok := checkSameOrigin(tt.r) 101 | if tt.ok != ok { 102 | t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok) 103 | } 104 | } 105 | } 106 | 107 | type reuseTestResponseWriter struct { 108 | brw *bufio.ReadWriter 109 | http.ResponseWriter 110 | } 111 | 112 | func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 113 | return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil 114 | } 115 | 116 | var bufioReuseTests = []struct { 117 | n int 118 | reuse bool 119 | }{ 120 | {4096, true}, 121 | {128, false}, 122 | } 123 | 124 | func xTestBufioReuse(t *testing.T) { 125 | for i, tt := range bufioReuseTests { 126 | br := bufio.NewReaderSize(strings.NewReader(""), tt.n) 127 | bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) 128 | resp := &reuseTestResponseWriter{ 129 | brw: bufio.NewReadWriter(br, bw), 130 | } 131 | upgrader := Upgrader{} 132 | c, err := upgrader.Upgrade(resp, &http.Request{ 133 | Method: http.MethodGet, 134 | Header: http.Header{ 135 | "Upgrade": []string{"websocket"}, 136 | "Connection": []string{"upgrade"}, 137 | "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="}, 138 | "Sec-Websocket-Version": []string{"13"}, 139 | }}, nil) 140 | if err != nil { 141 | t.Fatal(err) 142 | } 143 | if reuse := c.br == br; reuse != tt.reuse { 144 | t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) 145 | } 146 | writeBuf := bw.AvailableBuffer() 147 | if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { 148 | t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) 149 | } 150 | } 151 | } 152 | 153 | func TestHijack_NotSupported(t *testing.T) { 154 | t.Parallel() 155 | 156 | req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) 157 | req.Header.Set("Upgrade", "websocket") 158 | req.Header.Set("Connection", "upgrade") 159 | req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") 160 | req.Header.Set("Sec-Websocket-Version", "13") 161 | 162 | recorder := httptest.NewRecorder() 163 | 164 | upgrader := Upgrader{} 165 | _, err := upgrader.Upgrade(recorder, req, nil) 166 | 167 | if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError { 168 | t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError) 169 | t.Fatalf("got err=%T and status_code=%d", err, recorder.Code) 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "crypto/rand" 9 | "crypto/sha1" 10 | "encoding/base64" 11 | "io" 12 | "net/http" 13 | "strings" 14 | "unicode/utf8" 15 | ) 16 | 17 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 18 | 19 | func computeAcceptKey(challengeKey string) string { 20 | h := sha1.New() 21 | h.Write([]byte(challengeKey)) 22 | h.Write(keyGUID) 23 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) 24 | } 25 | 26 | func generateChallengeKey() (string, error) { 27 | p := make([]byte, 16) 28 | if _, err := io.ReadFull(rand.Reader, p); err != nil { 29 | return "", err 30 | } 31 | return base64.StdEncoding.EncodeToString(p), nil 32 | } 33 | 34 | // Token octets per RFC 2616. 35 | var isTokenOctet = [256]bool{ 36 | '!': true, 37 | '#': true, 38 | '$': true, 39 | '%': true, 40 | '&': true, 41 | '\'': true, 42 | '*': true, 43 | '+': true, 44 | '-': true, 45 | '.': true, 46 | '0': true, 47 | '1': true, 48 | '2': true, 49 | '3': true, 50 | '4': true, 51 | '5': true, 52 | '6': true, 53 | '7': true, 54 | '8': true, 55 | '9': true, 56 | 'A': true, 57 | 'B': true, 58 | 'C': true, 59 | 'D': true, 60 | 'E': true, 61 | 'F': true, 62 | 'G': true, 63 | 'H': true, 64 | 'I': true, 65 | 'J': true, 66 | 'K': true, 67 | 'L': true, 68 | 'M': true, 69 | 'N': true, 70 | 'O': true, 71 | 'P': true, 72 | 'Q': true, 73 | 'R': true, 74 | 'S': true, 75 | 'T': true, 76 | 'U': true, 77 | 'W': true, 78 | 'V': true, 79 | 'X': true, 80 | 'Y': true, 81 | 'Z': true, 82 | '^': true, 83 | '_': true, 84 | '`': true, 85 | 'a': true, 86 | 'b': true, 87 | 'c': true, 88 | 'd': true, 89 | 'e': true, 90 | 'f': true, 91 | 'g': true, 92 | 'h': true, 93 | 'i': true, 94 | 'j': true, 95 | 'k': true, 96 | 'l': true, 97 | 'm': true, 98 | 'n': true, 99 | 'o': true, 100 | 'p': true, 101 | 'q': true, 102 | 'r': true, 103 | 's': true, 104 | 't': true, 105 | 'u': true, 106 | 'v': true, 107 | 'w': true, 108 | 'x': true, 109 | 'y': true, 110 | 'z': true, 111 | '|': true, 112 | '~': true, 113 | } 114 | 115 | // skipSpace returns a slice of the string s with all leading RFC 2616 linear 116 | // whitespace removed. 117 | func skipSpace(s string) (rest string) { 118 | i := 0 119 | for ; i < len(s); i++ { 120 | if b := s[i]; b != ' ' && b != '\t' { 121 | break 122 | } 123 | } 124 | return s[i:] 125 | } 126 | 127 | // nextToken returns the leading RFC 2616 token of s and the string following 128 | // the token. 129 | func nextToken(s string) (token, rest string) { 130 | i := 0 131 | for ; i < len(s); i++ { 132 | if !isTokenOctet[s[i]] { 133 | break 134 | } 135 | } 136 | return s[:i], s[i:] 137 | } 138 | 139 | // nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 140 | // and the string following the token or quoted string. 141 | func nextTokenOrQuoted(s string) (value string, rest string) { 142 | if !strings.HasPrefix(s, "\"") { 143 | return nextToken(s) 144 | } 145 | s = s[1:] 146 | for i := 0; i < len(s); i++ { 147 | switch s[i] { 148 | case '"': 149 | return s[:i], s[i+1:] 150 | case '\\': 151 | p := make([]byte, len(s)-1) 152 | j := copy(p, s[:i]) 153 | escape := true 154 | for i = i + 1; i < len(s); i++ { 155 | b := s[i] 156 | switch { 157 | case escape: 158 | escape = false 159 | p[j] = b 160 | j++ 161 | case b == '\\': 162 | escape = true 163 | case b == '"': 164 | return string(p[:j]), s[i+1:] 165 | default: 166 | p[j] = b 167 | j++ 168 | } 169 | } 170 | return "", "" 171 | } 172 | } 173 | return "", "" 174 | } 175 | 176 | // equalASCIIFold returns true if s is equal to t with ASCII case folding as 177 | // defined in RFC 4790. 178 | func equalASCIIFold(s, t string) bool { 179 | for s != "" && t != "" { 180 | sr, size := utf8.DecodeRuneInString(s) 181 | s = s[size:] 182 | tr, size := utf8.DecodeRuneInString(t) 183 | t = t[size:] 184 | if sr == tr { 185 | continue 186 | } 187 | if 'A' <= sr && sr <= 'Z' { 188 | sr = sr + 'a' - 'A' 189 | } 190 | if 'A' <= tr && tr <= 'Z' { 191 | tr = tr + 'a' - 'A' 192 | } 193 | if sr != tr { 194 | return false 195 | } 196 | } 197 | return s == t 198 | } 199 | 200 | // tokenListContainsValue returns true if the 1#token header with the given 201 | // name contains a token equal to value with ASCII case folding. 202 | func tokenListContainsValue(header http.Header, name string, value string) bool { 203 | headers: 204 | for _, s := range header[name] { 205 | for { 206 | var t string 207 | t, s = nextToken(skipSpace(s)) 208 | if t == "" { 209 | continue headers 210 | } 211 | s = skipSpace(s) 212 | if s != "" && s[0] != ',' { 213 | continue headers 214 | } 215 | if equalASCIIFold(t, value) { 216 | return true 217 | } 218 | if s == "" { 219 | continue headers 220 | } 221 | s = s[1:] 222 | } 223 | } 224 | return false 225 | } 226 | 227 | // parseExtensions parses WebSocket extensions from a header. 228 | func parseExtensions(header http.Header) []map[string]string { 229 | // From RFC 6455: 230 | // 231 | // Sec-WebSocket-Extensions = extension-list 232 | // extension-list = 1#extension 233 | // extension = extension-token *( ";" extension-param ) 234 | // extension-token = registered-token 235 | // registered-token = token 236 | // extension-param = token [ "=" (token | quoted-string) ] 237 | // ;When using the quoted-string syntax variant, the value 238 | // ;after quoted-string unescaping MUST conform to the 239 | // ;'token' ABNF. 240 | 241 | var result []map[string]string 242 | headers: 243 | for _, s := range header["Sec-Websocket-Extensions"] { 244 | for { 245 | var t string 246 | t, s = nextToken(skipSpace(s)) 247 | if t == "" { 248 | continue headers 249 | } 250 | ext := map[string]string{"": t} 251 | for { 252 | s = skipSpace(s) 253 | if !strings.HasPrefix(s, ";") { 254 | break 255 | } 256 | var k string 257 | k, s = nextToken(skipSpace(s[1:])) 258 | if k == "" { 259 | continue headers 260 | } 261 | s = skipSpace(s) 262 | var v string 263 | if strings.HasPrefix(s, "=") { 264 | v, s = nextTokenOrQuoted(skipSpace(s[1:])) 265 | s = skipSpace(s) 266 | } 267 | if s != "" && s[0] != ',' && s[0] != ';' { 268 | continue headers 269 | } 270 | ext[k] = v 271 | } 272 | if s != "" && s[0] != ',' { 273 | continue headers 274 | } 275 | result = append(result, ext) 276 | if s == "" { 277 | continue headers 278 | } 279 | s = s[1:] 280 | } 281 | } 282 | return result 283 | } 284 | 285 | // isValidChallengeKey checks if the argument meets RFC6455 specification. 286 | func isValidChallengeKey(s string) bool { 287 | // From RFC6455: 288 | // 289 | // A |Sec-WebSocket-Key| header field with a base64-encoded (see 290 | // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in 291 | // length. 292 | 293 | if s == "" { 294 | return false 295 | } 296 | decoded, err := base64.StdEncoding.DecodeString(s) 297 | return err == nil && len(decoded) == 16 298 | } 299 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "net/http" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | var equalASCIIFoldTests = []struct { 14 | t, s string 15 | eq bool 16 | }{ 17 | {"WebSocket", "websocket", true}, 18 | {"websocket", "WebSocket", true}, 19 | {"Öyster", "öyster", false}, 20 | {"WebSocket", "WetSocket", false}, 21 | } 22 | 23 | func TestEqualASCIIFold(t *testing.T) { 24 | for _, tt := range equalASCIIFoldTests { 25 | eq := equalASCIIFold(tt.s, tt.t) 26 | if eq != tt.eq { 27 | t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq) 28 | } 29 | } 30 | } 31 | 32 | var tokenListContainsValueTests = []struct { 33 | value string 34 | ok bool 35 | }{ 36 | {"WebSocket", true}, 37 | {"WEBSOCKET", true}, 38 | {"websocket", true}, 39 | {"websockets", false}, 40 | {"x websocket", false}, 41 | {"websocket x", false}, 42 | {"other,websocket,more", true}, 43 | {"other, websocket, more", true}, 44 | } 45 | 46 | func TestTokenListContainsValue(t *testing.T) { 47 | for _, tt := range tokenListContainsValueTests { 48 | h := http.Header{"Upgrade": {tt.value}} 49 | ok := tokenListContainsValue(h, "Upgrade", "websocket") 50 | if ok != tt.ok { 51 | t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok) 52 | } 53 | } 54 | } 55 | 56 | var isValidChallengeKeyTests = []struct { 57 | key string 58 | ok bool 59 | }{ 60 | {"dGhlIHNhbXBsZSBub25jZQ==", true}, 61 | {"", false}, 62 | {"InvalidKey", false}, 63 | {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false}, 64 | } 65 | 66 | func TestIsValidChallengeKey(t *testing.T) { 67 | for _, tt := range isValidChallengeKeyTests { 68 | ok := isValidChallengeKey(tt.key) 69 | if ok != tt.ok { 70 | t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok) 71 | } 72 | } 73 | } 74 | 75 | var parseExtensionTests = []struct { 76 | value string 77 | extensions []map[string]string 78 | }{ 79 | {`foo`, []map[string]string{{"": "foo"}}}, 80 | {`foo, bar; baz=2`, []map[string]string{ 81 | {"": "foo"}, 82 | {"": "bar", "baz": "2"}}}, 83 | {`foo; bar="b,a;z"`, []map[string]string{ 84 | {"": "foo", "bar": "b,a;z"}}}, 85 | {`foo , bar; baz = 2`, []map[string]string{ 86 | {"": "foo"}, 87 | {"": "bar", "baz": "2"}}}, 88 | {`foo, bar; baz=2 junk`, []map[string]string{ 89 | {"": "foo"}}}, 90 | {`foo junk, bar; baz=2 junk`, nil}, 91 | {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ 92 | {"": "mux", "max-channels": "4", "flow-control": ""}, 93 | {"": "deflate-stream"}}}, 94 | {`permessage-foo; x="10"`, []map[string]string{ 95 | {"": "permessage-foo", "x": "10"}}}, 96 | {`permessage-foo; use_y, permessage-foo`, []map[string]string{ 97 | {"": "permessage-foo", "use_y": ""}, 98 | {"": "permessage-foo"}}}, 99 | {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ 100 | {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, 101 | {"": "permessage-deflate", "client_max_window_bits": ""}}}, 102 | {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{ 103 | {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"}, 104 | }}, 105 | } 106 | 107 | func TestParseExtensions(t *testing.T) { 108 | for _, tt := range parseExtensionTests { 109 | h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} 110 | extensions := parseExtensions(h) 111 | if !reflect.DeepEqual(extensions, tt.extensions) { 112 | t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions) 113 | } 114 | } 115 | } 116 | --------------------------------------------------------------------------------