├── .gitignore ├── LICENSE ├── README.md ├── dialer.go ├── dialer_test.go ├── dialer_utils.go ├── dialer_utils_test.go ├── doc.go ├── errors.go ├── errors_test.go ├── examples └── chat │ ├── main.go │ ├── manager.go │ └── public │ ├── css │ └── style.css │ ├── index.html │ └── js │ └── app.js ├── frame.go ├── frame_test.go ├── frame_utils.go ├── frame_utils_test.go ├── request.go ├── request_test.go ├── request_utils.go ├── request_utils_test.go ├── socket.go ├── socket_test.go ├── utils.go └── utils_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Luca Tabone 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Go Report Card](https://goreportcard.com/badge/github.com/tabone/websocket)](https://goreportcard.com/report/github.com/tabone/websocket) 2 | [![GoDoc](https://godoc.org/github.com/tabone/websocket?status.svg)](https://godoc.org/github.com/tabone/websocket) 3 | 4 | # websocket 5 | Package websocket implements the websocket protocol defined in rfc6455 6 | 7 | ## Installation 8 | go get github.com/tabone/websocket 9 | 10 | ## Documentation 11 | - [API Reference](https://godoc.org/github.com/tabone/websocket) 12 | - [Examples](https://github.com/tabone/websocket/tree/master/examples) 13 | -------------------------------------------------------------------------------- /dialer.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "crypto/tls" 6 | "net" 7 | "net/http" 8 | "net/url" 9 | "regexp" 10 | "strings" 11 | "sync" 12 | ) 13 | 14 | // Dialer is a websocket client. 15 | type Dialer struct { 16 | /* 17 | Header to be included in the opening handshake request. 18 | */ 19 | Header http.Header 20 | 21 | /* 22 | SubProtocols which the client supports. 23 | */ 24 | SubProtocols []string 25 | 26 | /* 27 | TLSConfig is used to configure the TLS client. 28 | */ 29 | TLSConfig *tls.Config 30 | } 31 | 32 | // Dial is the method used to start the websocket connection. 33 | func (d *Dialer) Dial(u string) (*Socket, *http.Response, error) { 34 | // Parse URL to return a valid URL instance. 35 | l, err := parseURL(u) 36 | if err != nil { 37 | return nil, nil, err 38 | } 39 | 40 | // Get a valid websocket opening handshake request instance. 41 | q := d.createRequest(l) 42 | 43 | // Connect with the websocket server. 44 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3 45 | conn, err := net.Dial("tcp", l.Host+"/"+l.Path+"?"+l.RawQuery) 46 | if err != nil { 47 | return nil, nil, err 48 | } 49 | 50 | // When the connection will be over TLS, we need to do the TLS handshake. 51 | if l.Scheme == "wss" { 52 | g := d.TLSConfig 53 | 54 | // Create tls config instance if user hasn't specified one since it is 55 | // required. 56 | if g == nil { 57 | g = &tls.Config{} 58 | } 59 | 60 | // If ServerName is empty, use the host provided by the user. 61 | if g.ServerName == "" { 62 | g.ServerName = strings.Split(l.Host, ":")[0] 63 | } 64 | 65 | // Change the current conenction to a secure one. 66 | c := tls.Client(conn, g) 67 | 68 | // Do the handshake. 69 | if err := c.Handshake(); err != nil { 70 | return nil, nil, err 71 | } 72 | 73 | conn = c 74 | } 75 | 76 | // Send request 77 | if err := q.Write(conn); err != nil { 78 | return nil, nil, err 79 | } 80 | 81 | // Buffer connection. 82 | b := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) 83 | 84 | // Read response 85 | r, err := http.ReadResponse(b.Reader, q) 86 | 87 | if err != nil { 88 | return nil, nil, err 89 | } 90 | 91 | // Validate response. 92 | if err := validateResponse(r); err != nil { 93 | return nil, nil, err 94 | } 95 | 96 | return &Socket{ 97 | conn: conn, 98 | buf: b, 99 | writeMutex: &sync.Mutex{}, 100 | }, r, nil 101 | } 102 | 103 | // createOpeningHandshakeRequest is used to return a valid websocket opening 104 | // handshake client request. 105 | // 106 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 107 | func (d *Dialer) createRequest(l *url.URL) *http.Request { 108 | // Initialize header if not already initialized. 109 | if d.Header == nil { 110 | d.Header = make(http.Header) 111 | } 112 | 113 | // When using the default port the Host header field should only consist of 114 | // the host (no port is shown). 115 | t := l.Host 116 | 117 | switch l.Scheme { 118 | case "ws": 119 | { 120 | re := regexp.MustCompile(":22$") 121 | t = re.ReplaceAllString(t, "") 122 | } 123 | case "wss": 124 | { 125 | re := regexp.MustCompile(":443$") 126 | t = re.ReplaceAllString(t, "") 127 | } 128 | } 129 | 130 | // Include headers 131 | d.Header.Set("Host", t) 132 | d.Header.Set("Upgrade", "websocket") 133 | d.Header.Set("Connection", "upgrade") 134 | d.Header.Set("Sec-WebSocket-Version", "13") 135 | d.Header.Set("Sec-WebSocket-Key", makeChallengeKey()) 136 | d.Header.Set("Sec-WebSocket-Protocol", strings.Join(d.SubProtocols, ", ")) 137 | 138 | // Create request instance 139 | q := &http.Request{ 140 | Method: "GET", 141 | URL: l, 142 | Proto: "HTTP/1.1", 143 | ProtoMajor: 1, 144 | ProtoMinor: 1, 145 | Header: d.Header, 146 | Host: l.Host, 147 | } 148 | 149 | return q 150 | } 151 | -------------------------------------------------------------------------------- /dialer_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "encoding/base64" 5 | "net/http" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestDialerCreateRequestNilHeader(t *testing.T) { 12 | d := &Dialer{Header: nil} 13 | 14 | q := d.createRequest(&url.URL{}) 15 | 16 | if q.Header == nil { 17 | t.Errorf("expected header to be initialized") 18 | } 19 | } 20 | 21 | func TestDialerCreateRequestNonNilHeader(t *testing.T) { 22 | h := make(http.Header) 23 | k := "testKey" 24 | v := "testValue" 25 | 26 | h.Add(k, v) 27 | 28 | d := &Dialer{Header: h} 29 | 30 | q := d.createRequest(&url.URL{}) 31 | 32 | if q.Header.Get(k) != v { 33 | t.Errorf("expected header to be the one provided in dialer instance") 34 | } 35 | } 36 | 37 | func TestDialerCreateRequestHostHeader(t *testing.T) { 38 | d := &Dialer{} 39 | 40 | type testCase struct { 41 | u *url.URL 42 | v string 43 | } 44 | 45 | testCases := []testCase{ 46 | {u: &url.URL{Scheme: "ws", Host: "localhost:22"}, v: "localhost"}, 47 | {u: &url.URL{Scheme: "wss", Host: "localhost:443"}, v: "localhost"}, 48 | {u: &url.URL{Scheme: "ws", Host: "localhost:80"}, v: "localhost:80"}, 49 | {u: &url.URL{Scheme: "wss", Host: "localhost:80"}, v: "localhost:80"}, 50 | } 51 | 52 | for i, c := range testCases { 53 | q := d.createRequest(c.u) 54 | v := q.Header.Get("Host") 55 | 56 | if v != c.v { 57 | t.Errorf(`test case %d: expected Host header field value to be "%s", but it is "%s"`, i, c.v, v) 58 | } 59 | } 60 | } 61 | 62 | func TestDialerCreateRequestHeaders(t *testing.T) { 63 | d := &Dialer{ 64 | SubProtocols: []string{"chat", "v1"}, 65 | } 66 | 67 | q := d.createRequest(&url.URL{Scheme: "ws", Host: "localhost"}) 68 | 69 | v := q.Header.Get("Upgrade") 70 | e := "websocket" 71 | if strings.ToLower(v) != e { 72 | t.Errorf(`expected Upgrade header field value to be "%s", but it is "%s"`, v, e) 73 | } 74 | 75 | v = q.Header.Get("Connection") 76 | e = "upgrade" 77 | if strings.ToLower(v) != e { 78 | t.Errorf(`expected Connection header field value to be "%s", but it is "%s"`, v, e) 79 | } 80 | 81 | v = q.Header.Get("Sec-WebSocket-Version") 82 | e = "13" 83 | if strings.ToLower(v) != e { 84 | t.Errorf(`expected Sec-WebSocket-Version header field value to be "%s", but it is "%s"`, v, e) 85 | } 86 | 87 | v = q.Header.Get("Sec-WebSocket-Protocol") 88 | e = strings.Join(d.SubProtocols, ", ") 89 | if strings.ToLower(v) != e { 90 | t.Errorf(`expected Sec-WebSocket-Protocol header field value to be "%s", but it is "%s"`, v, e) 91 | } 92 | 93 | l, err := base64.StdEncoding.DecodeString(q.Header.Get("Sec-WebSocket-Key")) 94 | 95 | if err != nil { 96 | t.Errorf(`unexpected error returned when decoding Sec-WebSocket-Key %s`, err) 97 | } 98 | 99 | if len(l) != 16 { 100 | t.Errorf(`expected Sec-WebSocket-Protocol header field value to be '%d' in length, but it is '%d'`, len(l), 16) 101 | } 102 | } 103 | 104 | func TestDialerCreateRequestRequest(t *testing.T) { 105 | d := &Dialer{} 106 | u := &url.URL{ 107 | Scheme: "ws", 108 | Host: "localhost:8080", 109 | } 110 | 111 | q := d.createRequest(u) 112 | 113 | if q.URL != u { 114 | t.Errorf("expected URL instance to be the one provided") 115 | } 116 | 117 | if q.Method != "GET" { 118 | t.Errorf(`expected method to be "GET", but it is "%s"`, q.Method) 119 | } 120 | 121 | if q.Proto != "HTTP/1.1" { 122 | t.Errorf(`expected http protocol to be "HTTP/1.1", but it is "%s"`, q.Proto) 123 | } 124 | 125 | if !q.ProtoAtLeast(1, 1) { 126 | t.Errorf("expected http protocol used to be at least version 1.1, but it is %d.%d", q.ProtoMajor, q.ProtoMinor) 127 | } 128 | 129 | if q.Host != u.Host { 130 | t.Errorf(`expected host to be "%s", but it is "%s"`, u.Host, q.Host) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /dialer_utils.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "net/http" 7 | "net/url" 8 | "regexp" 9 | "strings" 10 | ) 11 | 12 | // validateResponse is used to determine whether the servers handshake request 13 | // conforms with the WebSocket spec. When it doesn't the client fails the 14 | // websocket connection. 15 | // 16 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 17 | func validateResponse(r *http.Response) *OpenError { 18 | validations := []func(*http.Response) *OpenError{ 19 | validateResponseStatus, 20 | validateResponseUpgradeHeader, 21 | validateResponseConnectionHeader, 22 | validateResponseSecWebsocketAcceptHeader, 23 | } 24 | 25 | for _, v := range validations { 26 | if err := v(r); err != nil { 27 | return err 28 | } 29 | } 30 | 31 | return nil 32 | } 33 | 34 | // validateResponseStatus verifies that status code of the server's opening 35 | // handshake response is '101'. If it is not, it means that the handshake has 36 | // been rejected and thus the endpoints are still communicating using http. 37 | // 38 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 39 | func validateResponseStatus(r *http.Response) *OpenError { 40 | if r.StatusCode != 101 { 41 | return &OpenError{ 42 | Reason: "http status not 101", 43 | } 44 | } 45 | return nil 46 | } 47 | 48 | // validateResponseUpgradeHeader verifies that the Upgrade HTTP Header value 49 | // in the servers's opening handshake response is "websocket". 50 | // 51 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 52 | func validateResponseUpgradeHeader(r *http.Response) *OpenError { 53 | if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { 54 | return &OpenError{ 55 | Reason: `"Upgrade" Header should have the value of "websocket"`, 56 | } 57 | } 58 | return nil 59 | } 60 | 61 | // validateResponseConnectionHeader verifies that the Connection HTTP Header 62 | // value in the servers's opening handshake response is "upgrade". 63 | // 64 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 65 | func validateResponseConnectionHeader(r *http.Response) *OpenError { 66 | if strings.ToLower(r.Header.Get("Connection")) != "upgrade" { 67 | return &OpenError{ 68 | Reason: `"Connection" Header should have the value of "upgrade"`, 69 | } 70 | } 71 | return nil 72 | } 73 | 74 | // validateResponseSecWebsocketAcceptHeader verifies that the 75 | // Sec-WebSocket-Accept HTTP Header value in the server's opening handshake 76 | // response is the base64-encoded SHA-1 of the concatenation of the 77 | // Sec-WebSocket-Key value (sent with the opening handshake request) (as a 78 | // string, not base64-decoded) with the websocket accept key. 79 | // 80 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 81 | func validateResponseSecWebsocketAcceptHeader(r *http.Response) *OpenError { 82 | if r.Header.Get("Sec-WebSocket-Accept") != makeAcceptKey(r.Request.Header.Get("Sec-Websocket-Key")) { 83 | return &OpenError{ 84 | Reason: `challenge key failure`, 85 | } 86 | } 87 | return nil 88 | } 89 | 90 | // validateResponseSecWebsocketProtocol verifies that the sub protocol the 91 | // server has agreed to use (Sec-WebSocket-Protocol Header) was in the list the 92 | // client has sent in the opening handshake request. 93 | // 94 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 95 | func validateResponseSecWebsocketProtocol(r *http.Response) *OpenError { 96 | // Sub protocols sent by the client. 97 | c := headerToSlice(r.Request.Header.Get("Sec-WebSocket-Protocol")) 98 | // Sub protocol the server has agreed to use. 99 | s := r.Header.Get("Sec-WebSocket-Protocol") 100 | 101 | // If the server hasn't agreed to use anything, stop process. 102 | if len(s) == 0 { 103 | return nil 104 | } 105 | 106 | // Loop through the lists of sub protocols the client has sent in its 107 | // opening handshake request and if the sub protocol the server argeed to 108 | // use is found stop the process. 109 | for _, cv := range c { 110 | if cv == s { 111 | return nil 112 | } 113 | } 114 | 115 | // At this point the server has agreed to use a sub protocol which the 116 | // client doesn't support and thus return an error. 117 | return &OpenError{ 118 | Reason: `server choose a sub protocol which was not in the list sent by the client`, 119 | } 120 | } 121 | 122 | // makeChallengeKey is used to generate the key to be sent with the client's 123 | // opening handshake using the Sec-Websocket-Key header field. 124 | // 125 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1 126 | func makeChallengeKey() string { 127 | // return Base64 encode version of the byte generated. 128 | return base64.StdEncoding.EncodeToString(randomByteSlice(4)) 129 | } 130 | 131 | // parseURL is used to parse the URL string provided and verifies that it 132 | // conforms with the websocket spec. If it does it will create and return a URL 133 | // instance representing the URL string provided. 134 | // 135 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3 136 | func parseURL(u string) (*url.URL, error) { 137 | // Parse scheme. 138 | if err := parseURLScheme(&u); err != nil { 139 | return nil, err 140 | } 141 | 142 | // Create URL Instance. 143 | l, err := url.Parse(u) 144 | if err != nil { 145 | return nil, err 146 | } 147 | 148 | // Parse Host. 149 | if err := parseURLHost(l); err != nil { 150 | return nil, err 151 | } 152 | 153 | return l, nil 154 | } 155 | 156 | // parseURLScheme is used to parse the Scheme portion of a URL string. If the 157 | // scheme provided is not a valid websocket scheme an error is returned. If no 158 | // scheme is given it will be defaulted to "ws". 159 | // 160 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3 161 | func parseURLScheme(u *string) error { 162 | // Regex to retrieve Scheme portion of a URL string. 163 | re := regexp.MustCompile("^([a-zA-Z]+)://") 164 | m := re.FindStringSubmatch(*u) 165 | 166 | // If m is smaller than 2 it means that the user hasn't provided one and 167 | // thus the default sheme (ws) is used. 168 | if len(m) < 2 { 169 | *u = "ws://" + *u 170 | return nil 171 | } 172 | 173 | // If a sheme was captured, make sure it is valid. 174 | if !schemeValid(m[1]) { 175 | return errors.New("invalid scheme: " + m[1]) 176 | } 177 | 178 | return nil 179 | } 180 | 181 | // parseURLHost is used to parse the Host portion of a URL instance to 182 | // determine whether it has a port or not. When no port is found this method 183 | // will assign a port based on the URL instance scheme (ws = 22, wss = 443). If 184 | // the scheme is not a valid scheme for websocket an error is returned. 185 | // 186 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3 187 | func parseURLHost(u *url.URL) error { 188 | // If scheme is invalid throw an error 189 | if !schemeValid(u.Scheme) { 190 | return errors.New("invalid scheme: " + u.Scheme) 191 | } 192 | 193 | // Regex to retrieve the Port portion of the URL. 194 | re := regexp.MustCompile(":(\\d+)") 195 | m := re.FindStringSubmatch(u.Host) 196 | 197 | // If the length of m is greater than or equals to 2 it means that there is 198 | // a submatch, meaning that the user has provided a port and thus there is 199 | // no need to include the default ports. 200 | if len(m) >= 2 { 201 | return nil 202 | } 203 | 204 | // Based on the scheme, set the port. 205 | switch u.Scheme { 206 | case "ws": 207 | { 208 | u.Host += ":22" 209 | } 210 | case "wss": 211 | { 212 | u.Host += ":443" 213 | } 214 | } 215 | 216 | return nil 217 | } 218 | 219 | // schemeValid is used to determine whether the scheme provided is a valid 220 | // scheme for the websocket protocol. 221 | func schemeValid(s string) bool { 222 | return s == "ws" || s == "wss" 223 | } 224 | -------------------------------------------------------------------------------- /dialer_utils_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "encoding/base64" 5 | "net/http" 6 | "net/url" 7 | "testing" 8 | ) 9 | 10 | func TestParseURLScheme(t *testing.T) { 11 | type testCase struct { 12 | u string 13 | f string 14 | } 15 | 16 | testCases := []testCase{ 17 | {u: "ws://localhost:8080", f: "ws://localhost:8080"}, 18 | {u: "wss://localhost:8080", f: "wss://localhost:8080"}, 19 | {u: "localhost:8080", f: "ws://localhost:8080"}, 20 | } 21 | 22 | for i, c := range testCases { 23 | n := c.u 24 | 25 | if err := parseURLScheme(&n); err != nil { 26 | t.Errorf("test case %d: unexpected error was returned %s", i, err) 27 | } 28 | 29 | if n != c.f { 30 | t.Errorf(`test case %d: expected url to be "%s", but it is "%s"`, i, c.f, n) 31 | } 32 | } 33 | } 34 | 35 | func TestParseURLSchemeError(t *testing.T) { 36 | u := "http://localhost:8080" 37 | 38 | if err := parseURLScheme(&u); err == nil { 39 | t.Error("expected an error for", u) 40 | } 41 | } 42 | 43 | func TestParseURLHost(t *testing.T) { 44 | type testCase struct { 45 | u *url.URL 46 | h string 47 | } 48 | 49 | testCases := []testCase{ 50 | {u: &url.URL{Scheme: "ws", Host: "localhost:80"}, h: "localhost:80"}, 51 | {u: &url.URL{Scheme: "wss", Host: "localhost:80"}, h: "localhost:80"}, 52 | {u: &url.URL{Scheme: "ws", Host: "localhost"}, h: "localhost:22"}, 53 | {u: &url.URL{Scheme: "wss", Host: "localhost"}, h: "localhost:443"}, 54 | } 55 | 56 | for i, c := range testCases { 57 | if err := parseURLHost(c.u); err != nil { 58 | t.Errorf("test case %d: unexpected error was returned %s", i, err) 59 | } 60 | 61 | if c.u.Host != c.h { 62 | t.Errorf(`test case %d: expected host to be "%s", but it is "%s"`, i, c.h, c.u.Host) 63 | } 64 | } 65 | } 66 | 67 | func TestParseURLHostError(t *testing.T) { 68 | u := &url.URL{ 69 | Scheme: "http", 70 | Host: "localhost", 71 | } 72 | 73 | if err := parseURLHost(u); err == nil { 74 | t.Errorf("expected an error to be returned") 75 | } 76 | } 77 | 78 | func TestMakeChallengeKey(t *testing.T) { 79 | k := makeChallengeKey() 80 | b, err := base64.StdEncoding.DecodeString(k) 81 | 82 | if err != nil { 83 | t.Errorf("unexpected error was returned while decoding value: %s", err) 84 | } 85 | 86 | if len(b) != 16 { 87 | t.Errorf("expected length of decoded challenge key to be 16, but it is %d", len(b)) 88 | } 89 | } 90 | 91 | func TestValidateResponseStatus(t *testing.T) { 92 | type testCase struct { 93 | s int 94 | e bool 95 | } 96 | 97 | testCases := []testCase{ 98 | {s: 101, e: false}, 99 | {s: 200, e: true}, 100 | } 101 | 102 | for i, c := range testCases { 103 | r := &http.Response{ 104 | StatusCode: c.s, 105 | } 106 | 107 | err := validateResponseStatus(r) 108 | 109 | if c.e && err == nil { 110 | t.Errorf(`test case %d: expected an error for '%d'`, i, c.s) 111 | } 112 | 113 | if !c.e && err != nil { 114 | t.Errorf(`test case %d: unexpected error returned for '%d'`, i, c.s) 115 | } 116 | } 117 | } 118 | 119 | func TestValidateResponseUpgradeHeader(t *testing.T) { 120 | type testCase struct { 121 | v string 122 | e bool 123 | } 124 | 125 | testCases := []testCase{ 126 | {v: "websocket", e: false}, 127 | {v: "WebSocket", e: false}, 128 | {v: "wrong", e: true}, 129 | } 130 | 131 | for i, c := range testCases { 132 | r := &http.Response{ 133 | Header: make(http.Header), 134 | } 135 | 136 | r.Header.Add("Upgrade", c.v) 137 | err := validateResponseUpgradeHeader(r) 138 | 139 | if c.e && err == nil { 140 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v) 141 | } 142 | 143 | if !c.e && err != nil { 144 | t.Errorf(`test case %d: unexpected error returned for "%s"`, i, c.v) 145 | } 146 | } 147 | } 148 | 149 | func TestValidateResponseConnectionHeader(t *testing.T) { 150 | type testCase struct { 151 | v string 152 | e bool 153 | } 154 | 155 | testCases := []testCase{ 156 | {v: "upgrade", e: false}, 157 | {v: "UpgrADE", e: false}, 158 | {v: "wrong", e: true}, 159 | } 160 | 161 | for i, c := range testCases { 162 | r := &http.Response{ 163 | Header: make(http.Header), 164 | } 165 | 166 | r.Header.Add("Connection", c.v) 167 | err := validateResponseConnectionHeader(r) 168 | 169 | if c.e && err == nil { 170 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v) 171 | } 172 | 173 | if !c.e && err != nil { 174 | t.Errorf(`test case %d: unexpected error returned for "%s"`, i, c.v) 175 | } 176 | } 177 | } 178 | 179 | func TestValidateResponseSecWebsocketProtocol(t *testing.T) { 180 | type testCase struct { 181 | c string 182 | s string 183 | e bool 184 | } 185 | 186 | testCases := []testCase{ 187 | {c: "client, v1", s: "", e: false}, 188 | {c: "client, v1", s: "v1", e: false}, 189 | {c: "client, v1", s: "v2", e: true}, 190 | } 191 | 192 | for i, c := range testCases { 193 | // Headers sent by client 194 | hq := make(http.Header) 195 | hq.Set("Sec-WebSocket-Protocol", c.c) 196 | 197 | // Headers sent by server 198 | hr := make(http.Header) 199 | hr.Set("Sec-WebSocket-Protocol", c.s) 200 | 201 | q := &http.Request{ 202 | Header: hq, 203 | } 204 | 205 | r := &http.Response{ 206 | Header: hr, 207 | Request: q, 208 | } 209 | 210 | err := validateResponseSecWebsocketProtocol(r) 211 | 212 | if c.e && err == nil { 213 | t.Errorf(`test case %d: expected an error when the client sent "%s" as supported protocols and the server agreed to use "%s"`, i, c.c, c.s) 214 | } 215 | 216 | if !c.e && err != nil { 217 | t.Errorf(`test case %d: unexpected error was returned when the client sent "%s" as supported protocols and the server agreed to use "%s"`, i, c.c, c.s) 218 | } 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package websocket implements the websocket protocol defined in rfc6455. 2 | package websocket 3 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | // CloseError represents errors related to the websocket closing handshake. 10 | type CloseError struct { 11 | Code int 12 | Reason string 13 | } 14 | 15 | // Error implements the built in error interface. 16 | func (c *CloseError) Error() string { 17 | return fmt.Sprintf("Close Error: %d %s", c.Code, c.Reason) 18 | } 19 | 20 | // ToBytes returns the representation of a CloseError instance in a []bytes 21 | // that conforms with the way the websocket rfc expects the payload data of 22 | // CLOSE FRAMES to be. 23 | // 24 | // While generating the []bytes, if the CloseError instance has an invalid 25 | // error code, it will instead create the representation of a 'No Status 26 | // Received Error' (i.e. 1005). 27 | // 28 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1 29 | func (c *CloseError) ToBytes() ([]byte, error) { 30 | // Validate Error Code 31 | if !closeErrorExist(c.Code) { 32 | // If it is not valid, return bytes for No Status Received error. 33 | n := &CloseError{ 34 | Code: CloseNoStatusReceived, 35 | Reason: "no status recieved", 36 | } 37 | b, _ := n.ToBytes() 38 | return b, errors.New("invalid error code") 39 | } 40 | 41 | return append(c.toBytesCode(), []byte(c.Reason)...), nil 42 | } 43 | 44 | // toBytesCode is used to get a representation of the CloseError instance 45 | // status code in []bytes. 46 | func (c *CloseError) toBytesCode() []byte { 47 | b := make([]byte, 2) 48 | binary.BigEndian.PutUint16(b, uint16(c.Code)) 49 | return b 50 | } 51 | 52 | // NewCloseError is used to create a new CloseError instance by parsing 'b'. In 53 | // order for this to happen the []bytes needs to conform with the way the 54 | // websocket rfc expects the payload data of CLOSE FRAMES to be. 55 | // 56 | // While parsing if the error code (i.e. first two bytes) is invalid, it will 57 | // default the CloseError instance returned to represent a 'No Status Received 58 | // Error' (i.e. 1005). 59 | // 60 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1 61 | func NewCloseError(b []byte) (*CloseError, error) { 62 | var c int 63 | 64 | if len(b) >= 2 { 65 | cb := b[:2] 66 | c = int(binary.BigEndian.Uint16(cb)) 67 | } 68 | 69 | if !closeErrorExist(c) { 70 | return &CloseError{ 71 | Code: CloseNoStatusReceived, 72 | Reason: "no status recieved", 73 | }, errors.New("invalid error code") 74 | } 75 | 76 | return &CloseError{ 77 | Code: c, 78 | Reason: string(b[2:]), 79 | }, nil 80 | } 81 | 82 | // OpenError represents errors related to the websocket opening handshake. 83 | type OpenError struct { 84 | Reason string 85 | } 86 | 87 | // Error implements the built in error interface. 88 | func (h *OpenError) Error() string { 89 | return "Handshake Error: " + h.Reason 90 | } 91 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestCloseErrorToBytes(t *testing.T) { 8 | type testCase struct { 9 | c int 10 | r string 11 | b []byte 12 | } 13 | 14 | testCases := []testCase{ 15 | {c: 1001, r: "normal closure", b: []byte{3, 233, 110, 111, 114, 109, 97, 108, 32, 99, 108, 111, 115, 117, 114, 101}}, 16 | {c: 1001, r: "", b: []byte{3, 233}}, 17 | } 18 | 19 | for i, c := range testCases { 20 | e := &CloseError{Code: c.c, Reason: c.r} 21 | 22 | b, err := e.ToBytes() 23 | 24 | if err != nil { 25 | t.Errorf(`test case %d: unexpected error`, i) 26 | } 27 | 28 | if len(b) != len(c.b) { 29 | t.Errorf(`test case %d: unexpected slice of bytes`, i) 30 | } 31 | 32 | same := true 33 | for bi, bv := range b { 34 | if bv != c.b[bi] { 35 | same = false 36 | break 37 | } 38 | } 39 | 40 | if !same { 41 | t.Errorf(`test case %d: unexpected slice of bytes`, i) 42 | } 43 | } 44 | } 45 | 46 | func TestCloseErrorToBytesError(t *testing.T) { 47 | b := []byte{3, 237, 110, 111, 32, 115, 116, 97, 116, 117, 115, 32, 114, 101, 99, 105, 101, 118, 101, 100} 48 | 49 | c := &CloseError{Code: 0, Reason: "woops"} 50 | e, err := c.ToBytes() 51 | 52 | if err == nil { 53 | t.Error("expected an error") 54 | } 55 | 56 | same := true 57 | for bi, bv := range b { 58 | if bv != e[bi] { 59 | same = false 60 | break 61 | } 62 | } 63 | 64 | if !same { 65 | t.Errorf(`unexpected slice of bytes`) 66 | } 67 | } 68 | 69 | func TestNewCloseError(t *testing.T) { 70 | type testCase struct { 71 | c int 72 | r string 73 | b []byte 74 | } 75 | 76 | testCases := []testCase{ 77 | {c: 1001, r: "normal closure", b: []byte{3, 233, 110, 111, 114, 109, 97, 108, 32, 99, 108, 111, 115, 117, 114, 101}}, 78 | {c: 1001, r: "", b: []byte{3, 233}}, 79 | } 80 | 81 | for i, c := range testCases { 82 | e, err := NewCloseError(c.b) 83 | 84 | if err != nil { 85 | t.Errorf(`test case %d: unexpected error`, i) 86 | } 87 | 88 | if e.Code != c.c { 89 | t.Errorf("test case %d: expected Code to be '%d', but it is '%d'", i, c.c, e.Code) 90 | } 91 | 92 | if e.Reason != c.r { 93 | t.Errorf(`test case %d: expected Reason to be "%s", but it is "%s"`, i, c.r, e.Reason) 94 | } 95 | } 96 | } 97 | 98 | func TestNewCloseErrorError(t *testing.T) { 99 | type testCase struct { 100 | p []byte 101 | } 102 | 103 | testCases := []testCase{ 104 | {p: make([]byte, 0)}, 105 | {p: []byte{3, 133}}, 106 | } 107 | 108 | for i, c := range testCases { 109 | c, err := NewCloseError(c.p) 110 | r := "no status recieved" 111 | 112 | if err == nil { 113 | t.Errorf("test case %d: expected an error", i) 114 | } 115 | 116 | if c.Code != CloseNoStatusReceived { 117 | t.Errorf("test case %d, expected Code to be '%d', but it is '%d'", i, CloseNoStatusReceived, c.Code) 118 | } 119 | 120 | if c.Reason != r { 121 | t.Errorf(`test case %d, expected Reason to be "%s", but it is "%s"`, i, r, c.Reason) 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /examples/chat/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/tabone/websocket" 5 | "log" 6 | "net/http" 7 | ) 8 | 9 | var m *manager 10 | 11 | func main() { 12 | m = &manager{ 13 | users: make(map[int]*websocket.Socket), 14 | } 15 | http.HandleFunc("/ws", wsHandler) 16 | http.Handle("/", http.FileServer(http.Dir("public/"))) 17 | 18 | log.Println("listening on localhost:8080.") 19 | http.ListenAndServe("localhost:8080", nil) 20 | } 21 | 22 | func wsHandler(w http.ResponseWriter, r *http.Request) { 23 | log.Println("new connection") 24 | 25 | // Create a new websocket request 26 | q := &websocket.Request{ 27 | CheckOrigin: func(r *http.Request) bool { 28 | // Accept all requests. 29 | return true 30 | }, 31 | } 32 | 33 | // Try to upgrade the http request. 34 | s, err := q.Upgrade(w, r) 35 | 36 | if err != nil { 37 | log.Println("upgrade failed:", err) 38 | } 39 | 40 | // If upgrade has been successfull, include the socket with the other online 41 | // sockets 42 | m.addSocket(s) 43 | } 44 | -------------------------------------------------------------------------------- /examples/chat/manager.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/tabone/websocket" 6 | "log" 7 | "time" 8 | ) 9 | 10 | type manager struct { 11 | /* 12 | seq is a sequence which will be used to assign a unique id to each 13 | socket added to the list of online users. 14 | */ 15 | seq int 16 | 17 | /* 18 | users will contain a reference to all the online sockets. 19 | */ 20 | users map[int]*websocket.Socket 21 | } 22 | 23 | /* 24 | addSocket is used to add a socket to the online list of users. 25 | */ 26 | func (m *manager) addSocket(s *websocket.Socket) { 27 | m.seq++ 28 | log.Println("user", m.seq, "has logged in") 29 | m.users[m.seq] = s 30 | m.config(m.seq) 31 | 32 | j := fmt.Sprintf(`{"type":"login","data":{"user": %d, "count":%d}}`, m.seq, len(m.users)) 33 | m.broadcast([]byte(j)) 34 | 35 | go m.ping(s) 36 | 37 | // Start listening for new data. 38 | s.Listen() 39 | } 40 | 41 | func (m *manager) ping(s *websocket.Socket) { 42 | t := time.NewTicker(time.Second * 5) 43 | 44 | for { 45 | <-t.C 46 | if err := s.WriteMessage(websocket.OpcodePing, nil); err != nil { 47 | log.Println(err) 48 | break 49 | } 50 | } 51 | t.Stop() 52 | } 53 | 54 | /* 55 | removeSocket is used to remove a socket from the online list of users using 56 | its id. 57 | */ 58 | func (m *manager) removeSocket(i int) { 59 | log.Println("user", i, "has logged out") 60 | delete(m.users, i) 61 | } 62 | 63 | /* 64 | config is used to configure the socket instance. 65 | */ 66 | func (m *manager) config(i int) { 67 | s := m.users[i] 68 | 69 | s.ReadHandler = func(o int, p []byte) { 70 | log.Println("user", i, "sent a message:", string(p)) 71 | j := fmt.Sprintf(`{"type":"message","data":"%s"}`, p) 72 | m.broadcast([]byte(j)) 73 | } 74 | 75 | s.CloseHandler = func(err error) { 76 | log.Println("user", i, "disconnected:", err) 77 | m.removeSocket(i) 78 | j := fmt.Sprintf(`{"type":"logout","data":{"user": %d, "count":%d}}`, i, len(m.users)) 79 | m.broadcast([]byte(j)) 80 | } 81 | 82 | s.PongHandler = func(p []byte) { 83 | log.Println("user", i, "pong recieved") 84 | } 85 | } 86 | 87 | /* 88 | broadcast is used to send a message to all the connected users. 89 | */ 90 | func (m *manager) broadcast(p []byte) { 91 | for _, s := range m.users { 92 | s.WriteMessage(websocket.OpcodeText, p) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /examples/chat/public/css/style.css: -------------------------------------------------------------------------------- 1 | * { 2 | margin:0px; 3 | padding:0px; 4 | box-sizing:border-box; 5 | -moz-box-sizing:border-box; 6 | -webkit-box-sizing:border-box; 7 | } 8 | 9 | body { 10 | font-size:10px; 11 | font-family:sans-serif; 12 | background-color: #F7F7F7; 13 | color:#555; 14 | } 15 | 16 | header h1 { 17 | text-align:center; 18 | padding:30px; 19 | } 20 | 21 | header h1 a { 22 | color:inherit; 23 | text-decoration:none; 24 | font-size:1.5em; 25 | } 26 | 27 | header h1 a:hover { 28 | color:#5E97D6; 29 | } 30 | 31 | .cntr { 32 | padding: 7px 14px; 33 | border-radius:3px; 34 | -moz-border-radius:3px; 35 | -webkit-border-radius:3px; 36 | box-shadow: 0px 1px 2px 0px #AFAFAF; 37 | -moz-box-shadow: 0px 1px 2px 0px #AFAFAF; 38 | -webkit-box-shadow: 0px 1px 2px 0px #AFAFAF; 39 | } 40 | 41 | main #conversation { 42 | overflow: scroll; 43 | margin-bottom:10px; 44 | background-color:#fff; 45 | height:400px; 46 | } 47 | 48 | main { 49 | margin:0px auto; 50 | width:400px; 51 | } 52 | 53 | main #conversation .message { 54 | background-color: #f1f1f1; 55 | padding: 7px 10px; 56 | margin-bottom: 8px; 57 | font-size: 1.3em; 58 | border-radius: 3px; 59 | -moz-border-radius: 3px; 60 | -webkit-border-radius: 3px; 61 | } 62 | 63 | main #conversation .message.logout { 64 | background-color:#E8BABA; 65 | } 66 | 67 | main #conversation .message.login { 68 | background-color:#C0DAB7; 69 | } 70 | 71 | main input[type="text"] { 72 | width: 100%; 73 | border:none; 74 | font-size: 1.3em; 75 | outline:none; 76 | } 77 | 78 | footer { 79 | text-align:center; 80 | margin-top:10px; 81 | } 82 | 83 | footer #count { 84 | font-size: 2.3em; 85 | font-weight: bold; 86 | } -------------------------------------------------------------------------------- /examples/chat/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Chat Example 4 | 5 | 6 | 7 | 8 |
9 |

10 | 11 | github.com/tabone/websocket 12 | 13 |

14 |
15 | 16 |
17 |
18 |
19 | Hello, this is a test 20 |
21 |
22 |
23 | 24 |
25 |
26 | 27 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /examples/chat/public/js/app.js: -------------------------------------------------------------------------------- 1 | 'use strict' 2 | 3 | function Application() { 4 | this._socket = new WebSocket("ws://localhost:8080/ws") 5 | 6 | /** 7 | * Object containing references to important dom elements. 8 | * @type {Object} 9 | */ 10 | this._dom = { 11 | /** 12 | * Element which will contain all messages. 13 | * @type {HTML Element} 14 | */ 15 | conversation: document.getElementById("conversation"), 16 | 17 | /** 18 | * Input element to be used to send new messages. 19 | * @type {HTML Element} 20 | */ 21 | textbox: document.getElementById("textbox"), 22 | 23 | /** 24 | * Element to display information about the number of online users. 25 | * @type {HTML Element} 26 | */ 27 | count: document.getElementById("count") 28 | } 29 | 30 | this._init() 31 | } 32 | 33 | /** 34 | * Initializer. 35 | */ 36 | Application.prototype._init = function () { 37 | this._setupSocket() 38 | ._setupTextbox() 39 | } 40 | 41 | /** 42 | * Connect with the websocket server and setup listeners. 43 | * @return {Application} The instance. 44 | */ 45 | Application.prototype._setupSocket = function () { 46 | var self = this 47 | this._socket.onmessage = function (resp) { 48 | var msg = JSON.parse(resp.data) 49 | 50 | switch (msg.type) { 51 | case "message": { 52 | self._onMessage(msg.data) 53 | break 54 | } 55 | case "login": { 56 | self._onLogin(msg.data) 57 | break 58 | } 59 | case "logout": { 60 | self._onLogout(msg.data) 61 | break 62 | } 63 | } 64 | } 65 | return this 66 | } 67 | 68 | /** 69 | * Add a listener on the Textbox element which when the user clicks on the Enter 70 | * key the text within the input field is sent to the websocket server. 71 | * @return {Application} The instance. 72 | */ 73 | Application.prototype._setupTextbox = function () { 74 | var self = this 75 | this._dom.textbox.onkeydown = function (ev) { 76 | if (ev.keyCode == 13 && this.value !== "") { 77 | self._socket.send(this.value) 78 | this.value = "" 79 | } 80 | } 81 | return this 82 | } 83 | 84 | /** 85 | * Method used to create a comment box. 86 | * @return {HTML Element} The comment box element. 87 | */ 88 | Application.prototype._createCommentBox = function () { 89 | var elem = document.createElement("div") 90 | elem.className = "message" 91 | return elem 92 | } 93 | 94 | /** 95 | * Method triggered when a new message is recieved from the server. 96 | * @param {String} msg The message to be displayed. 97 | */ 98 | Application.prototype._onMessage = function (msg) { 99 | var elem = this._createCommentBox() 100 | elem.innerHTML = msg 101 | this._dom.conversation.appendChild(elem) 102 | } 103 | 104 | /** 105 | * Method triggered when the message recieved is of type 'login' which means a 106 | * new user has joined to conversation. 107 | * @param {Number} Object.user The id of the user. 108 | * @param {Number} Object.count The number of online users. 109 | */ 110 | Application.prototype._onLogin = function (msg) { 111 | var elem = this._createCommentBox() 112 | elem.className = "message login" 113 | elem.innerHTML = "User " + msg.user + " Logged in" 114 | this._dom.conversation.appendChild(elem) 115 | this._dom.count.innerHTML = msg.count 116 | } 117 | 118 | /** 119 | * Method triggered when the message recieved is of type 'logout' which means a 120 | * new user has exited to conversation. 121 | * @param {Number} Object.user The id of the user. 122 | * @param {Number} Object.count The number of online users. 123 | */ 124 | Application.prototype._onLogout = function (msg) { 125 | var elem = this._createCommentBox() 126 | elem.className = "message logout" 127 | elem.innerHTML = "User " + msg.user + " Logged out" 128 | this._dom.conversation.appendChild(elem) 129 | this._dom.count.innerHTML = msg.count 130 | } 131 | 132 | ;(function () { 133 | var app = new Application() 134 | }()) -------------------------------------------------------------------------------- /frame.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "encoding/binary" 6 | "fmt" 7 | ) 8 | 9 | // WebSocket Opcodes. 10 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2 11 | const ( 12 | OpcodeContinuation int = 0 13 | OpcodeText int = 1 14 | OpcodeBinary int = 2 15 | OpcodeClose int = 8 16 | OpcodePing int = 9 17 | OpcodePong int = 10 18 | ) 19 | 20 | // frame represents a Websocket Data Frame. 21 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2 22 | type frame struct { 23 | /* 24 | fin indicates that the frame is the final fragment. 25 | */ 26 | fin bool 27 | 28 | /* 29 | opcode defines the interpretation of the payload data. 30 | */ 31 | opcode int 32 | 33 | /* 34 | masked defines whether the payload data is masked. 35 | */ 36 | masked bool 37 | 38 | /* 39 | length specifies the length of the payload data in bytes. 40 | */ 41 | length uint64 42 | 43 | /* 44 | key contains the masking key to be used to decode the payload data (if 45 | data is masked). It is 32 bits in length. 46 | */ 47 | key []byte 48 | 49 | /* 50 | payload contains the data received from the client. 51 | */ 52 | payload []byte 53 | } 54 | 55 | // newFrame is a constructor function to create a new instance of frame by 56 | // reading from a buffer. The construction of the websocket frame is divided 57 | // into four sections: 58 | // 1. Parsing of first 2 bytes. 59 | // 2. Parsing of 'payload length' if 'payload length' parsed in first 60 | // section is greater 125. 61 | // 3. Parsing of 'masking key' if 'masked' value parsed in first section is 62 | // set to true. 63 | // 4. Parsing of payload data. 64 | func newFrame(b *bufio.Reader) (*frame, error) { 65 | // Create frame instance. 66 | f := &frame{} 67 | 68 | reads := []func(*bufio.Reader) error{ 69 | f.readInitial, 70 | f.readLength, 71 | f.readMaskKey, 72 | f.readPayload, 73 | } 74 | 75 | for _, read := range reads { 76 | if err := read(b); err != nil { 77 | return nil, err 78 | } 79 | } 80 | 81 | return f, nil 82 | } 83 | 84 | // readInitial is the first method that should be invoked to create the frame 85 | // instance based on the contents from a buffer. This method reads first 2 86 | // bytes of a websocket frame which includes: fin (1 bit), rsv 1-3 (3 bits), 87 | // opcode (4 bits), mask (1 bit) and payload length (7 bits). It accepts a 88 | // buffer as an argument which will be used to read the frame from. 89 | func (f *frame) readInitial(b *bufio.Reader) error { 90 | // Read first 2 bytes. 91 | p, err := readFromBuffer(b, 2) 92 | 93 | if err != nil { 94 | return err 95 | } 96 | 97 | // Reading 'fin' 98 | if p[0]>>7 == 1 { 99 | f.fin = true 100 | } 101 | 102 | // Since library doesn't support extensions if RSV1-3 are non zeros, fail 103 | // connection 104 | if p[0]&112 /* 01110000 */ != 0 { 105 | return &CloseError{ 106 | Code: CloseProtocolError, 107 | Reason: "no support for extensions", 108 | } 109 | } 110 | 111 | // Reading 'opcode' 112 | f.opcode = int(p[0]) & 15 /* 00001111 */ 113 | 114 | // if opcode doesn't exists, must stop connection 115 | if !opcodeExist(f.opcode) { 116 | return &CloseError{ 117 | Code: CloseProtocolError, 118 | Reason: fmt.Sprintf("unsupported opcode: %d", f.opcode), 119 | } 120 | } 121 | 122 | // Reading 'mask' 123 | if p[1]>>7 == 1 { 124 | f.masked = true 125 | } 126 | 127 | // Reading 'payload len' 128 | f.length = uint64(p[1]) & 127 /* 01111111 */ 129 | 130 | return nil 131 | } 132 | 133 | // readLength should be invoked after readInitial method and is used to read 134 | // the next 2 (if f.length == 126) or 8 (if f.length == 127) bytes. If f.length 135 | // is <= 125, no read operations are done to the buffer provided as an 136 | // argument. 137 | func (f *frame) readLength(b *bufio.Reader) error { 138 | // If f.length is <= 125 it means that we already have the payload length, 139 | // thus stop read operation. 140 | if f.length <= 125 { 141 | return nil 142 | } 143 | 144 | // For when f.length == 126, read next 2 bytes. 145 | var l uint64 = 2 146 | 147 | // If f.length == 127, read next 8 bytes. 148 | if f.length == 127 { 149 | l = 8 150 | } 151 | 152 | // Read number of bytes based on f.length. 153 | u, err := readFromBuffer(b, l) 154 | 155 | if err != nil { 156 | return err 157 | } 158 | 159 | // Reset length 160 | f.length = 0 161 | 162 | // At this point the bytes that represent the real payload length has been 163 | // retrieved from the buffer. So the next thing to do is to convert the byte 164 | // slice (representing the length) to an integer by combining the bytes 165 | // together. 166 | // 167 | // Example: Let say the slice of bytes repesenting the payload length is 168 | // [134, 129] (or [10000110, 10000001] in binary). 169 | // 170 | // loop 1: f.length == 0 171 | // line 1: Bitwise left shift of 8 172 | // length = 0 173 | // line 2: Add the byte being traversed to f.length. 174 | // length = 1310000110 175 | // 176 | // loop 2: f.length == 134 (or 10000110) 177 | // line 1: Bitwise left shift of 8 178 | // length = 10000110 00000000 (i.e. 34304) 179 | // line 2: Add the byte being traversed to f.length. 180 | // length = 10000110 10000001 (i.e. 34433) 181 | for _, v := range u { 182 | f.length = f.length << 8 183 | f.length += uint64(v) 184 | } 185 | 186 | // Most Significant Bit must be 0. 187 | f.length = f.length & 9223372036854775807 188 | 189 | return nil 190 | } 191 | 192 | // readMaskKey should be invoked after readLength method and is used to read 193 | // the next 4 bytes from the buffer to retrieve the masking key. Note that if 194 | // the payload data is not masked (f.masked == false - info retrieved from 195 | // readInitial) no read operations are done to the buffer provided as an 196 | // argument. 197 | func (f *frame) readMaskKey(b *bufio.Reader) error { 198 | // If payload is not masked, stop process 199 | if !f.masked { 200 | return nil 201 | } 202 | 203 | // Read 4 bytes for masking key 204 | p, err := readFromBuffer(b, 4) 205 | 206 | if err != nil { 207 | return err 208 | } 209 | 210 | // Store key in frame instance 211 | f.key = p 212 | 213 | return nil 214 | } 215 | 216 | // readPayload should be invoked after readMaskKey method and is used to read 217 | // the payload data from the buffer. The number of bytes to read are known from 218 | // f.length (info retrieved from either readInitial or readLength). In addition 219 | // to this if the payload data is masked (f.masked == true - info retrieved 220 | // from readInitial) the payload data will also be decoded using the masking 221 | // key provided with the frame (f.key - info retrieved from readMaskKey). 222 | func (f *frame) readPayload(b *bufio.Reader) error { 223 | // Read f.length bytes 224 | p, err := readFromBuffer(b, f.length) 225 | 226 | if err != nil { 227 | return err 228 | } 229 | 230 | if f.masked { 231 | // Unmask (decode) payload data 232 | mask(p, f.key) 233 | } 234 | 235 | // Store payload in frame instance. 236 | f.payload = p 237 | 238 | return nil 239 | } 240 | 241 | // toBytes returns a representation of the frame instance as a slice of bytes. 242 | // This method does not consider the values assigned to f.length and f.masked 243 | // since these are calculated using the length of f.payload and value of f.key 244 | // respectively. 245 | func (f *frame) toBytes() ([]byte, error) { 246 | if err := f.validate(); err != nil { 247 | return nil, err 248 | } 249 | 250 | // Slice of bytes used to contain the payload data. 251 | p := make([]byte, 2) 252 | 253 | // Include info for FIN bit. 254 | f.toBytesFin(p) 255 | 256 | // Include info for OPCODE bits. 257 | f.toBytesOpcode(p) 258 | 259 | // Include info for MASK bit. 260 | f.toBytesMasked(p) 261 | 262 | // Include info for PAYLOAD LEN bits. 263 | f.toBytesPayloadLength(p) 264 | 265 | // Append (if any) info for PAYLOAD LENGTH EXTENDED bits. 266 | p = append(p, f.toBytesPayloadLengthExt()...) 267 | 268 | // Append (if any) MASK KEY bits. 269 | p = append(p, f.key...) 270 | 271 | // Append (Masked) Payload data. bits 272 | p = append(p, f.toBytesPayloadData()...) 273 | 274 | // Append and PAYLOAD DATA bits and return whole payload 275 | return p, nil 276 | } 277 | 278 | // validate verifies that the data of the frame instance will result in a valid 279 | // websocket data frame. 280 | func (f *frame) validate() *CloseError { 281 | switch { 282 | // Opcode must exists. 283 | case !opcodeExist(f.opcode): 284 | { 285 | return &CloseError{ 286 | Code: CloseProtocolError, 287 | Reason: fmt.Sprintf("unsupported opcode: %d", f.opcode), 288 | } 289 | } 290 | // Masking key must have a valid length. 291 | case !validateKey(f.key): 292 | { 293 | return &CloseError{ 294 | Code: CloseProtocolError, 295 | Reason: "masking key must either be 0 or 4 bytes long", 296 | } 297 | } 298 | // Payload data must have a valid length. 299 | case !validatePayload(f.payload): 300 | { 301 | return &CloseError{ 302 | Code: CloseMessageTooBig, 303 | Reason: "maximum payload data exceeded", 304 | } 305 | } 306 | } 307 | return nil 308 | } 309 | 310 | // toBytesFin is used by toBytes to include info in 'p' about the FIN bit of 311 | // the frame instance. Note that this method should be invoked before 312 | // toBytesOpcode method. 313 | func (f *frame) toBytesFin(p []byte) { 314 | if f.fin { 315 | p[0] = 128 316 | } 317 | } 318 | 319 | // toBytesOpcode is used by toBytes to include info in 'p' about the OPCODE 320 | // bits of the frame instance. Note that this method should be invoked after 321 | // toBytesFin. 322 | func (f *frame) toBytesOpcode(p []byte) { 323 | p[0] += byte(f.opcode) 324 | } 325 | 326 | // toBytesMasked is used by toBytes to include info in 'p' about the MASK bit 327 | // of the frame instance. This method does not consider f.masked but instead it 328 | // calculates the MASK bit value based on f.key. Note that this method should 329 | // be invoked before toBytesPayloadLength. 330 | func (f *frame) toBytesMasked(p []byte) { 331 | if len(f.key) != 0 { 332 | p[1] = 128 333 | } 334 | } 335 | 336 | // toBytesPayloadLength is used by toBytes to include info in 'p' about the 337 | // PAYLOAD LENGTH bits of the frame instance. This method does not consider 338 | // f.length but instead it calculates the PAYLOAD LENGTH value based on the 339 | // payload that will be sent (f.payload). Note that this method should 340 | // be invoked after toBytesMasked. 341 | func (f *frame) toBytesPayloadLength(p []byte) { 342 | l := len(f.payload) 343 | 344 | switch { 345 | case l <= 125: 346 | { 347 | p[1] += byte(l) 348 | return 349 | } 350 | case l <= 65535: 351 | { 352 | p[1] += 126 353 | } 354 | case l <= 9223372036854775807: 355 | { 356 | p[1] += 127 357 | } 358 | } 359 | } 360 | 361 | // toBytesPayloadLengthExt is used by toBytes to include info about the PAYLOAD 362 | // LENGTH EXTENDED bits. Just like toBytesPayloadLength, this method does not 363 | // consider f.length but instead it calculates the PAYLOAD LENGTH EXTENDED bits 364 | // using the payload that will be sent (f.payload). 365 | func (f *frame) toBytesPayloadLengthExt() []byte { 366 | l := len(f.payload) 367 | 368 | // If <= 125, stop process since the true length is already known. 369 | if l <= 125 { 370 | return nil 371 | } 372 | 373 | var p []byte 374 | 375 | switch { 376 | case l <= 65535: 377 | { 378 | // Convert to binary. 379 | p = make([]byte, 2) 380 | binary.BigEndian.PutUint16(p, uint16(l)) 381 | } 382 | case l <= 9223372036854775807: 383 | { 384 | // Convert to binary. 385 | p = make([]byte, 8) 386 | binary.BigEndian.PutUint64(p, uint64(l)) 387 | } 388 | } 389 | 390 | return p 391 | } 392 | 393 | // toBytesPayloadData is used by toBytes to include info about the PAYLOAD 394 | // DATA. This method also handles the masking of the payload data (f.payload). 395 | // Note that just like toBytesMasked, this method does not consider f.masked 396 | // but instead it directly checks for the masking key (f.key). 397 | func (f *frame) toBytesPayloadData() []byte { 398 | // Put payload into another slice of bytes - so that the payload in the 399 | // frame instance is left untouched. 400 | p := append([]byte{}, f.payload...) 401 | 402 | // If masking key is present, use it to mask the payload data. 403 | if len(f.key) == 4 { 404 | mask(p, f.key) 405 | } 406 | 407 | return p 408 | } 409 | -------------------------------------------------------------------------------- /frame_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "testing" 6 | ) 7 | 8 | func TestReadInitialForFin(t *testing.T) { 9 | type testCase struct { 10 | b *bufio.Reader 11 | v bool 12 | } 13 | 14 | testCases := []testCase{ 15 | // When fin bit is '0' should set fin to false. 16 | {b: newBuffer([]byte{1 /* 00000001 */, 0}), v: false}, 17 | // When fin bit is '1' should set fin to true. 18 | {b: newBuffer([]byte{129 /* 10000001 */, 0}), v: true}, 19 | } 20 | 21 | for i, c := range testCases { 22 | f := &frame{} 23 | 24 | if err := f.readInitial(c.b); err != nil { 25 | t.Errorf("test case %d: unexpected error returned: %v", i, err) 26 | } 27 | 28 | if f.fin != c.v { 29 | t.Errorf("test case %d: expected 'fin' to be '%t'", i, c.v) 30 | } 31 | } 32 | } 33 | 34 | func TestReadInitialForOpcode(t *testing.T) { 35 | type testCase struct { 36 | b *bufio.Reader 37 | v int 38 | } 39 | 40 | // When opcode is valid, should not return an error. 41 | testCases := []testCase{ 42 | // Without mask bit. 43 | {b: newBuffer([]byte{0, 0}), v: OpcodeContinuation}, 44 | {b: newBuffer([]byte{1, 0}), v: OpcodeText}, 45 | {b: newBuffer([]byte{2, 0}), v: OpcodeBinary}, 46 | {b: newBuffer([]byte{8, 0}), v: OpcodeClose}, 47 | {b: newBuffer([]byte{9, 0}), v: OpcodePing}, 48 | {b: newBuffer([]byte{10, 0}), v: OpcodePong}, 49 | 50 | // With mask bit. 51 | {b: newBuffer([]byte{128, 0}), v: OpcodeContinuation}, 52 | {b: newBuffer([]byte{129, 0}), v: OpcodeText}, 53 | {b: newBuffer([]byte{130, 0}), v: OpcodeBinary}, 54 | {b: newBuffer([]byte{136, 0}), v: OpcodeClose}, 55 | {b: newBuffer([]byte{137, 0}), v: OpcodePing}, 56 | {b: newBuffer([]byte{138, 0}), v: OpcodePong}, 57 | } 58 | 59 | for i, c := range testCases { 60 | f := &frame{} 61 | 62 | if err := f.readInitial(c.b); err != nil { 63 | t.Errorf("test case %d: unexpected error returned: %v", i, err) 64 | } 65 | 66 | if f.opcode != c.v { 67 | t.Errorf("test case %d: expected 'opcode' to be '%d'", i, c.v) 68 | } 69 | } 70 | } 71 | 72 | func TestReadInitialForRSVError(t *testing.T) { 73 | type testCase struct { 74 | b *bufio.Reader 75 | } 76 | 77 | // Library doesn't support extensions thus when extension bits are used, 78 | // lib should return an error. 79 | testCases := []testCase{ 80 | {b: newBuffer([]byte{17 /* 00010001 */, 0})}, 81 | {b: newBuffer([]byte{33 /* 00100001 */, 0})}, 82 | {b: newBuffer([]byte{49 /* 00110001 */, 0})}, 83 | {b: newBuffer([]byte{65 /* 01000001 */, 0})}, 84 | {b: newBuffer([]byte{81 /* 01010001 */, 0})}, 85 | {b: newBuffer([]byte{97 /* 01100001 */, 0})}, 86 | {b: newBuffer([]byte{113 /* 01110001 */, 0})}, 87 | } 88 | 89 | for i, c := range testCases { 90 | f := &frame{} 91 | 92 | err := f.readInitial(c.b) 93 | 94 | if err == nil { 95 | t.Errorf("test case %d: an error was expected.", i) 96 | } 97 | 98 | e, k := err.(*CloseError) 99 | 100 | if !k { 101 | t.Errorf("test case %d: expected error to be of type '*CloseError' but it is '%T'.", i, e) 102 | } 103 | 104 | if e.Reason != "no support for extensions" { 105 | t.Errorf(`test case %d: expected error to have reason "no support for extensions", instead it got "%s".`, i, e.Reason) 106 | } 107 | } 108 | } 109 | 110 | // Should return an error if opcode is invalid. 111 | func TestReadInitialForOpcodeError(t *testing.T) { 112 | f := &frame{} 113 | b := newBuffer([]byte{15, 0}) 114 | 115 | err := f.readInitial(b) 116 | 117 | if err == nil { 118 | t.Error("unexpected error returned") 119 | } 120 | 121 | e, k := err.(*CloseError) 122 | 123 | if !k { 124 | t.Fatalf("expected error to be of type '*websocket.CloseError', but it is '%T'.", e) 125 | } 126 | 127 | if e.Reason != "unsupported opcode: 15" { 128 | t.Errorf(`expected error to have reason "unsupported opcode: 15", but it got "%s".`, e.Reason) 129 | } 130 | } 131 | 132 | func TestReadInitialForMasked(t *testing.T) { 133 | type testCase struct { 134 | b *bufio.Reader 135 | v bool 136 | } 137 | 138 | testCases := []testCase{ 139 | // When masked bit is '0' should set masked to false. 140 | {b: newBuffer([]byte{1, 0}), v: false}, 141 | // When masked bit is '1' should set masked to true. 142 | {b: newBuffer([]byte{1, 128}), v: true}, 143 | } 144 | 145 | for i, c := range testCases { 146 | f := &frame{} 147 | 148 | if err := f.readInitial(c.b); err != nil { 149 | t.Errorf("test case %d: unexpected error returned: %v", i, err) 150 | } 151 | 152 | if f.masked != c.v { 153 | t.Errorf("test case %d: expected 'masked' to be '%t'", i, c.v) 154 | } 155 | } 156 | } 157 | 158 | func TestReadInitialForLength(t *testing.T) { 159 | type testCase struct { 160 | b *bufio.Reader 161 | v uint64 162 | } 163 | 164 | testCases := []testCase{ 165 | // Should set length to 124 when payload len is 124. 166 | {b: newBuffer([]byte{1, 124}), v: 124}, 167 | {b: newBuffer([]byte{1, 252}), v: 124}, 168 | // Should set length to 125 when payload len is 125. 169 | {b: newBuffer([]byte{1, 125}), v: 125}, 170 | {b: newBuffer([]byte{1, 253}), v: 125}, 171 | // Should set length to 126 when payload len is 126. 172 | {b: newBuffer([]byte{1, 126}), v: 126}, 173 | {b: newBuffer([]byte{1, 254}), v: 126}, 174 | // Should set length to 127 when payload len is 127. 175 | {b: newBuffer([]byte{1, 127}), v: 127}, 176 | {b: newBuffer([]byte{1, 255}), v: 127}, 177 | } 178 | 179 | for i, c := range testCases { 180 | f := &frame{} 181 | 182 | if err := f.readInitial(c.b); err != nil { 183 | t.Errorf("test case %d: unexpected error returned: %v", i, err) 184 | } 185 | 186 | if f.length != c.v { 187 | t.Errorf("test case %d: expected 'length' to be '%d'", i, c.v) 188 | } 189 | } 190 | } 191 | 192 | func TestReadMaskKey(t *testing.T) { 193 | f := &frame{} 194 | p := []byte{102, 100, 1, 54} 195 | b := newBuffer(p) 196 | 197 | // When f.masked is false, it means that the payload is not masked and 198 | // therefore no key has been sent. For this reason f.key should be left 199 | // untouched. 200 | f.masked = false 201 | if err := f.readMaskKey(b); err != nil { 202 | t.Error("unexpected error returned:", err) 203 | } 204 | 205 | if len(f.key) != 0 { 206 | t.Error("expected f.key to be empty but it is:", len(f.key)) 207 | } 208 | 209 | // When f.masked is true, it means that the payload is masked and therefore 210 | // the key must be read and stored in f.key. 211 | f.masked = true 212 | f.key = nil 213 | if err := f.readMaskKey(b); err != nil { 214 | t.Error("unexpected error returned:", err) 215 | } 216 | 217 | if len(f.key) != 4 { 218 | t.Errorf("expected f.key to be '4 bytes' long but it is '%d bytes'", len(f.key)) 219 | } 220 | 221 | for i, v := range p { 222 | if v != f.key[i] { 223 | t.Fatalf("expected mask key to be '%v' but it is '%v'", p, f.key) 224 | } 225 | } 226 | } 227 | 228 | func TestReadLength(t *testing.T) { 229 | f := &frame{} 230 | 231 | type testCase struct { 232 | // initial length 233 | i uint64 234 | // final length 235 | l uint64 236 | } 237 | 238 | testCases := []testCase{ 239 | {i: 124, l: 124}, 240 | {i: 125, l: 125}, 241 | {i: 126, l: 65535}, 242 | {i: 127, l: 9223372036854775807}, 243 | } 244 | 245 | for i, c := range testCases { 246 | f.length = c.i 247 | 248 | b := newBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255}) 249 | if err := f.readLength(b); err != nil { 250 | t.Errorf("test case %d: unexpected error returned: %v", i, err) 251 | } 252 | 253 | if f.length != c.l { 254 | t.Errorf("test case %d: expected f.length to be '%d', but it is '%d'", i, c.l, f.length) 255 | } 256 | } 257 | } 258 | 259 | func TestReadPayload(t *testing.T) { 260 | type testCase struct { 261 | // Masked or not 262 | m bool 263 | } 264 | 265 | testCases := []testCase{ 266 | {m: false}, 267 | {m: true}, 268 | } 269 | 270 | for i, c := range testCases { 271 | // Data Frame Received 272 | p := []byte{120, 15, 17} 273 | b := newBuffer(p) 274 | 275 | // Creation and config of frame instance. 276 | f := &frame{} 277 | f.key = []byte{10, 15, 1, 120} 278 | f.length = 2 279 | 280 | // Setting Masked. 281 | f.masked = c.m 282 | 283 | if err := f.readPayload(b); err != nil { 284 | t.Fatalf("test case %d: unexpected error was returned: %v", i, err) 285 | } 286 | 287 | // If masked unmask it. 288 | if f.masked { 289 | mask(f.payload, f.key) 290 | } 291 | 292 | if uint64(len(f.payload)) != f.length { 293 | t.Errorf("test case %d: expected length of f.payload to be '%d', but it is '%d'", i, f.length, len(f.payload)) 294 | } 295 | 296 | for i, v := range f.payload { 297 | if v != p[i] { 298 | t.Fatalf("test case %d: expected slice of bytes to be '%v', but it is '%v'.", i, p[:f.length], f.payload) 299 | } 300 | } 301 | } 302 | } 303 | 304 | func TestToBytesFin(t *testing.T) { 305 | type testCase struct { 306 | v bool 307 | r byte 308 | } 309 | 310 | testCases := []testCase{ 311 | // When f.fin is false first byte should not be affected 312 | {v: false, r: 0}, 313 | // When f.fin is true first byte should has its MSB == 1. 314 | {v: true, r: 128}, 315 | } 316 | 317 | for i, c := range testCases { 318 | f := &frame{fin: c.v} 319 | p := make([]byte, 1) 320 | 321 | f.toBytesFin(p) 322 | 323 | if p[0] != c.r { 324 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[0]) 325 | } 326 | } 327 | } 328 | 329 | func TestToBytesOpcode(t *testing.T) { 330 | type testCase struct { 331 | // Fin Value 332 | v bool 333 | // Opcode Value 334 | o int 335 | // Resultant byte 336 | r byte 337 | } 338 | 339 | testCases := []testCase{ 340 | // With Fin == false 341 | {v: false, o: OpcodeText, r: byte(OpcodeText)}, 342 | // With Fin == true 343 | {v: true, o: OpcodeText, r: 128 + byte(OpcodeText)}, 344 | } 345 | 346 | for i, c := range testCases { 347 | f := &frame{fin: c.v, opcode: c.o} 348 | p := make([]byte, 1) 349 | 350 | f.toBytesFin(p) 351 | f.toBytesOpcode(p) 352 | 353 | if p[0] != c.r { 354 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[0]) 355 | } 356 | } 357 | } 358 | 359 | func TestToBytesMasked(t *testing.T) { 360 | type testCase struct { 361 | // Value of f.key. 362 | v []byte 363 | // Resultant byte. 364 | r byte 365 | } 366 | 367 | testCases := []testCase{ 368 | {v: nil, r: 0}, 369 | {v: []byte{1, 2, 3, 4}, r: 128}, 370 | } 371 | 372 | for i, c := range testCases { 373 | f := frame{key: c.v} 374 | p := make([]byte, 2) 375 | 376 | f.toBytesMasked(p) 377 | 378 | if p[1] != c.r { 379 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[1]) 380 | } 381 | } 382 | } 383 | 384 | func TestToBytesPayloadLength(t *testing.T) { 385 | type testCase struct { 386 | m bool 387 | r byte 388 | l int 389 | } 390 | 391 | testCases := []testCase{ 392 | // With Mask Bit (f.masked) set to false 393 | {m: false, r: 124, l: 124}, 394 | {m: false, r: 125, l: 125}, 395 | {m: false, r: 126, l: 30000}, 396 | {m: false, r: 126, l: 65535}, 397 | {m: false, r: 127, l: 700000}, 398 | // With Mask Bit (f.masked) set to true 399 | {m: true, r: 128 + 124, l: 124}, 400 | {m: true, r: 128 + 125, l: 125}, 401 | {m: true, r: 128 + 126, l: 30000}, 402 | {m: true, r: 128 + 126, l: 65535}, 403 | {m: true, r: 128 + 127, l: 700000}, 404 | // testCase{m: false, r: 127, l: 9223372036854775807}, 405 | } 406 | 407 | for i, c := range testCases { 408 | p := make([]byte, 2) 409 | f := frame{payload: make([]byte, c.l)} 410 | 411 | if c.m { 412 | f.key = []byte{1, 2, 3, 4} 413 | } 414 | 415 | f.toBytesMasked(p) 416 | f.toBytesPayloadLength(p) 417 | 418 | if p[1] != c.r { 419 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[1]) 420 | } 421 | } 422 | } 423 | 424 | func TestToBytesPayloadLengthExt(t *testing.T) { 425 | type testCase struct { 426 | l int 427 | r []byte 428 | } 429 | 430 | testCases := []testCase{ 431 | // Length Known. 432 | {l: 124, r: nil}, 433 | // Length Known. 434 | {l: 125, r: nil}, 435 | // Read next 2 bytes. 436 | {l: 30000, r: []byte{117, 48}}, 437 | // Read next 2 bytes. 438 | {l: 65535, r: []byte{255, 255}}, 439 | // Read next 8 bytes. 440 | {l: 700000, r: []byte{0, 0, 0, 0, 0, 10, 174, 96}}, 441 | } 442 | 443 | for i, c := range testCases { 444 | f := frame{payload: make([]byte, c.l)} 445 | p := f.toBytesPayloadLengthExt() 446 | 447 | if len(p) != len(c.r) { 448 | t.Errorf("test case %d: expected length to be '%d' but it is '%d'", i, len(c.r), len(p)) 449 | } 450 | 451 | for ci, cv := range c.r { 452 | if cv != p[ci] { 453 | t.Errorf("test case %d: Expected slice of bytes to be %v but it is %v", i, c.r, p) 454 | break 455 | } 456 | } 457 | } 458 | } 459 | 460 | func TestToBytesPayloadData(t *testing.T) { 461 | type testCase struct { 462 | m []byte 463 | p []byte 464 | } 465 | 466 | testCases := []testCase{ 467 | // When masking key is present and valid, payload must be masked. 468 | {p: []byte{3, 4, 5, 6}, m: nil}, 469 | // When masking key is not present, payload must not be masked. 470 | {p: []byte{3, 4, 5, 6}, m: []byte{1, 2, 3, 4}}, 471 | } 472 | 473 | for i, c := range testCases { 474 | f := &frame{key: c.m, payload: c.p} 475 | 476 | p := f.toBytesPayloadData() 477 | 478 | if c.m != nil { 479 | mask(p, c.m) 480 | } 481 | 482 | for ci, cv := range c.p { 483 | if cv != p[ci] { 484 | t.Errorf("test case %d: Expected slice of bytes to be %v but it is %v", i, c.p, p) 485 | } 486 | } 487 | } 488 | } 489 | -------------------------------------------------------------------------------- /frame_utils.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | // mask is used to mask or unmask an array of bytes. It accepts two arguments, 4 | // p the data that will be masked (usually the application data), k the masking 5 | // key. 6 | // 7 | // From spec: https://tools.ietf.org/html/rfc6455#section-5.3 8 | func mask(p, k []byte) { 9 | for i := range p { 10 | p[i] ^= k[i%4] 11 | } 12 | } 13 | 14 | // opcodeExist returns whether the opcode number provided as an argument is a 15 | // valid opcode or not. 16 | func opcodeExist(i int) bool { 17 | switch i { 18 | case OpcodeContinuation, OpcodeText, OpcodeBinary, OpcodeClose, OpcodePing, OpcodePong: 19 | { 20 | return true 21 | } 22 | } 23 | return false 24 | } 25 | 26 | // validateKey returns whether the masking key is a valid key or not. Note that 27 | // a masking key can either be of length 0 or 4. 28 | // 29 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2 30 | func validateKey(k []byte) bool { 31 | return len(k) == 0 || len(k) == 4 32 | } 33 | 34 | // validatePayload returns whether the payload data is valid or not. Note that 35 | // the maximum size of payload data can be 9223372036854775807 bits. 36 | // 37 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2 38 | func validatePayload(p []byte) bool { 39 | return len(p) <= 9223372036854775807 40 | } 41 | -------------------------------------------------------------------------------- /frame_utils_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import "testing" 4 | 5 | func TestOpcodeExist(t *testing.T) { 6 | type testCase struct { 7 | o int 8 | v bool 9 | } 10 | 11 | testCases := []testCase{ 12 | // Should return false when opcode is invalid 13 | {o: 15, v: false}, 14 | // Should return true when opcode is valid. 15 | {o: OpcodeText, v: true}, 16 | } 17 | 18 | for i, c := range testCases { 19 | if v := opcodeExist(c.o); v != c.v { 20 | t.Errorf("test case %d: expected '%t' for '%d'", i, c.v, c.o) 21 | } 22 | } 23 | } 24 | 25 | func TestValidateKey(t *testing.T) { 26 | type testCase struct { 27 | k []byte 28 | r bool 29 | } 30 | 31 | testCases := []testCase{ 32 | {k: []byte{1, 2, 3, 4}, r: true}, 33 | {k: []byte{}, r: true}, 34 | {k: []byte{1, 2, 3, 4, 5}, r: false}, 35 | {k: []byte{1, 2, 3}, r: false}, 36 | } 37 | 38 | for i, c := range testCases { 39 | if validateKey(c.k) != c.r { 40 | t.Errorf("test case %d: expected '%t' for %v", i, c.r, c.k) 41 | } 42 | } 43 | } 44 | 45 | func TestValidatePayload(t *testing.T) { 46 | type testCase struct { 47 | l uint64 48 | r bool 49 | } 50 | 51 | testCases := []testCase{ 52 | {l: 125, r: true}, 53 | // testCase{l: 9223372036854775807, r: true}, 54 | // testCase{l: 9223372036854775808, r: false}, 55 | } 56 | 57 | for i, c := range testCases { 58 | b := make([]byte, c.l) 59 | if validatePayload(b) != c.r { 60 | t.Errorf("test case %d: expected '%t' for payload of size '%d'", i, c.r, c.l) 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /request.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | ) 7 | 8 | // wsVersion is the websocket version this library supports. 9 | const wsVersion = "13" 10 | 11 | // Request represents the HTTP Request that will be upgraded to the WebSocket 12 | // protocol once it is validated. 13 | type Request struct { 14 | /* 15 | request is the http request to be upgraded. 16 | */ 17 | request *http.Request 18 | 19 | /* 20 | CheckOrigin is the function which will be used to validate the ORIGIN 21 | HTTP Header of the request. By default this method will fail the opening 22 | handshake when the origin is not the same. This method can be overridden 23 | during the initiation of the Request struct. 24 | */ 25 | CheckOrigin func(r *http.Request) bool 26 | 27 | /* 28 | SubProtocol name which the server has agreed to use from the list 29 | provided by the client (through the Sec-WebSocket-Protocol HTTP Header 30 | Field). Before sending the servers opening handshake response, checks 31 | are made to verify that the chosen protocol was indeed been provided as 32 | an option from the client. If this is not the case, the HTTP 33 | Sec-WebSocket-Protocol HTTP Response Header Field is not sent 34 | */ 35 | SubProtocol string 36 | } 37 | 38 | // Upgrade is used to upgrade the HTTP connection to use the WS protocol once 39 | // the client request is validated. 40 | func (q *Request) Upgrade(w http.ResponseWriter, r *http.Request) (*Socket, error) { 41 | // Store a reference to the HTTP Request. 42 | q.request = r 43 | 44 | // Check origin. 45 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.2 46 | if err := q.handleOrigin(); err != nil { 47 | http.Error(w, "Forbidden", http.StatusForbidden) 48 | return nil, err 49 | } 50 | 51 | // Check websocket version. 52 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.2 53 | if err := validateWSVersionHeader(r); err != nil { 54 | w.Header().Set("Sec-WebSocket-Version", wsVersion) 55 | http.Error(w, "Upgrade Required", 426) 56 | return nil, err 57 | } 58 | 59 | // Check handshake request. 60 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.2 61 | if err := validateRequest(r); err != nil { 62 | http.Error(w, "Bad Request", http.StatusBadRequest) 63 | return nil, err 64 | } 65 | 66 | // At this point, the clients handshake request is valid and therefore the 67 | // connection can be upgraded to use the ws protocol. 68 | s, err := q.upgrade(w) 69 | 70 | if err != nil { 71 | http.Error(w, "Internal Server Error", http.StatusInternalServerError) 72 | return nil, err 73 | } 74 | 75 | return s, nil 76 | } 77 | 78 | func (q *Request) upgrade(w http.ResponseWriter) (*Socket, error) { 79 | // Take control of the net.Conn instance. 80 | h, k := w.(http.Hijacker) 81 | 82 | if !k { 83 | return nil, &OpenError{Reason: "assertion failed with current http.ResponseWriter instance"} 84 | } 85 | 86 | conn, buf, err := h.Hijack() 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | // Build the HTTP Header response code required for the ws opening 92 | // handshake. 93 | // From RFC2616: https://www.w3.org/Protocols/rfc2616/rfc2616-sec6.html 94 | resp := "HTTP/1.1 101 Switching Protocols\n" 95 | resp += "Upgrade: websocket\n" 96 | resp += "Connection: upgrade\n" 97 | resp += "Sec-WebSocket-Version: " + wsVersion + "\n" 98 | 99 | // If server has agreed to use a sub-protocol, the chosen sub-protocol needs 100 | // to be an option provided by the clients endpoint. If not, the 101 | // Sec-WebSocket-Protocol HTTP Header field is not sent. 102 | if q.SubProtocol != "" && stringExists(q.ClientSubProtocols(), q.SubProtocol) != -1 { 103 | resp += "Sec-WebSocket-Protocol: " + q.SubProtocol + "\n" 104 | } 105 | 106 | // Generate the accept key based on the challenge key provided by the 107 | // client and include it inside 'Sec-WebSocket-Accept' response header 108 | // field. 109 | acceptKey := makeAcceptKey(q.request.Header.Get("Sec-WebSocket-Key")) 110 | resp += "Sec-WebSocket-Accept: " + acceptKey + "\n\n" 111 | 112 | // Send response 113 | buf.WriteString(resp) 114 | buf.Flush() 115 | 116 | // Create and return socket. 117 | return &Socket{ 118 | conn: conn, 119 | buf: buf, 120 | server: true, 121 | writeMutex: &sync.Mutex{}, 122 | }, nil 123 | } 124 | 125 | // handleOrigin is used to invoke either the CheckOrigin method provided by the 126 | // user or the default method (if the user doesn't provide one). 127 | func (q *Request) handleOrigin() *OpenError { 128 | fn := q.CheckOrigin 129 | 130 | if fn == nil { 131 | fn = checkOrigin 132 | } 133 | 134 | if !fn(q.request) { 135 | return &OpenError{Reason: `failure due to origin.`} 136 | } 137 | 138 | return nil 139 | } 140 | 141 | // ClientSubProtocols returns the list of Sub Protocols the client can interact 142 | // with. 143 | // 144 | // From spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 145 | func (q *Request) ClientSubProtocols() []string { 146 | return headerToSlice(q.request.Header.Get("Sec-WebSocket-Protocol")) 147 | } 148 | 149 | // ClientExtensions returns the list of Extensions the client can interact 150 | // with. 151 | // 152 | // From spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 153 | func (q *Request) ClientExtensions() []string { 154 | return headerToSlice(q.request.Header.Get("Sec-WebSocket-Extensions")) 155 | } 156 | -------------------------------------------------------------------------------- /request_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func makeRequestValid(r *http.Request) { 11 | r.Header.Set("Sec-WebSocket-Version", wsVersion) 12 | r.Header.Set("Upgrade", "websocket") 13 | r.Header.Set("Connection", "upgrade") 14 | r.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") 15 | } 16 | 17 | func TestUpgradeResponseWhenInvalidOrigin(t *testing.T) { 18 | r, err := http.NewRequest("GET", "example.com", nil) 19 | 20 | if err != nil { 21 | t.Fatal("error occured while creating request:", err) 22 | } 23 | 24 | w := httptest.NewRecorder() 25 | 26 | h := func(w http.ResponseWriter, r *http.Request) { 27 | wsr := &Request{ 28 | CheckOrigin: func(r *http.Request) bool { 29 | return false 30 | }, 31 | } 32 | 33 | makeRequestValid(r) 34 | 35 | s, err := wsr.Upgrade(w, r) 36 | 37 | if err == nil { 38 | t.Error("expected Upgrade() to return a OpenError") 39 | } 40 | 41 | if s != nil { 42 | t.Error("expected Upgrade() to return a nil Socket instance") 43 | } 44 | } 45 | 46 | h(w, r) 47 | 48 | if w.Code != 403 { 49 | t.Errorf(`expected HTTP Status '403'. '%d' was returned.`, w.Code) 50 | } 51 | } 52 | 53 | func TestUpgradeResponseWhenInvalidWSVersion(t *testing.T) { 54 | r, err := http.NewRequest("GET", "example.com", nil) 55 | 56 | if err != nil { 57 | t.Fatal("error occured while creating request:", err) 58 | } 59 | 60 | w := httptest.NewRecorder() 61 | 62 | h := func(w http.ResponseWriter, r *http.Request) { 63 | wsr := &Request{} 64 | 65 | makeRequestValid(r) 66 | r.Header.Set("Sec-WebSocket-Version", "14") 67 | 68 | s, err := wsr.Upgrade(w, r) 69 | 70 | if err == nil { 71 | t.Error("expected Upgrade() to return a OpenError") 72 | } 73 | 74 | if s != nil { 75 | t.Error("expected Upgrade() to return a nil Socket instance") 76 | } 77 | } 78 | 79 | h(w, r) 80 | 81 | if w.Code != 426 { 82 | t.Errorf(`expected HTTP Status '426'. '%d' was returned.`, w.Code) 83 | } 84 | 85 | if w.Header().Get("Sec-WebSocket-Version") != wsVersion { 86 | t.Errorf(`expected "Sec-WebSocket-Version" HTTP Header field value to be %s`, wsVersion) 87 | } 88 | } 89 | 90 | func TestUpgradeResponseWhenNotValid(t *testing.T) { 91 | r, err := http.NewRequest("POST", "example.com", nil) 92 | 93 | if err != nil { 94 | t.Fatal("error occured while creating request:", err) 95 | } 96 | 97 | w := httptest.NewRecorder() 98 | 99 | h := func(w http.ResponseWriter, r *http.Request) { 100 | wsr := &Request{ 101 | CheckOrigin: func(r *http.Request) bool { 102 | return true 103 | }, 104 | } 105 | 106 | makeRequestValid(r) 107 | 108 | s, err := wsr.Upgrade(w, r) 109 | 110 | if err == nil { 111 | t.Error("expected Upgrade() to return a OpenError.") 112 | } 113 | 114 | if s != nil { 115 | t.Error("expected Upgrade() to return a nil Socket instance.") 116 | } 117 | } 118 | 119 | h(w, r) 120 | 121 | if w.Code != 400 { 122 | t.Errorf(`expected HTTP Status '400'. '%d' was returned.`, w.Code) 123 | } 124 | } 125 | 126 | func TestUpgradeGoodRequest(t *testing.T) { 127 | h := func(w http.ResponseWriter, r *http.Request) { 128 | wsr := &Request{ 129 | CheckOrigin: func(r *http.Request) bool { 130 | return true 131 | }, 132 | } 133 | 134 | makeRequestValid(r) 135 | 136 | s, err := wsr.Upgrade(w, r) 137 | 138 | if err != nil { 139 | t.Error("unexpected error from Upgrade():", err) 140 | } 141 | 142 | if s == nil { 143 | t.Error("expected Upgrade() to return a non-nil Socket instance") 144 | } 145 | 146 | if !s.server { 147 | t.Error("expected socket to have 'server' property set to 'true'") 148 | } 149 | } 150 | 151 | s := httptest.NewServer(http.HandlerFunc(h)) 152 | defer s.Close() 153 | 154 | w, err := http.Get(s.URL) 155 | 156 | if err != nil { 157 | t.Error("unexpected error when requesting the test server:", err) 158 | } 159 | 160 | if w.StatusCode != 101 { 161 | t.Errorf("expected HTTP Status to be '101' but it is '%d'", w.StatusCode) 162 | } 163 | 164 | if w.Header.Get("Upgrade") != "websocket" { 165 | t.Errorf(`expected "Upgrade" HTTP Header value to be "websocket" but it is "%s"`, w.Header.Get("Upgrade")) 166 | } 167 | 168 | if w.Header.Get("Connection") != "upgrade" { 169 | t.Errorf(`expected "Connection" HTTP Header value to be "upgrade" but it is "%s"`, w.Header.Get("Connection")) 170 | } 171 | 172 | if w.Header.Get("Sec-WebSocket-Accept") != "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" { 173 | t.Errorf(`expected "Sec-WebSocket-Accept" HTTP Header value to be "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" but it is "%s"`, w.Header.Get("Sec-WebSocket-Accept")) 174 | } 175 | } 176 | 177 | func TestUpgradeWithSubProtocols(t *testing.T) { 178 | h := func(w http.ResponseWriter, r *http.Request) { 179 | wsr := &Request{ 180 | CheckOrigin: func(r *http.Request) bool { 181 | return true 182 | }, 183 | } 184 | 185 | makeRequestValid(r) 186 | wsr.SubProtocol = "one" 187 | 188 | s, err := wsr.Upgrade(w, r) 189 | 190 | if err != nil { 191 | t.Error("unexpected error from Upgrade():", err) 192 | } 193 | 194 | if s == nil { 195 | t.Error("expected Upgrade() to return a non-nil Socket instance") 196 | } 197 | } 198 | 199 | s := httptest.NewServer(http.HandlerFunc(h)) 200 | defer s.Close() 201 | 202 | type testCase struct { 203 | p string 204 | v bool 205 | } 206 | 207 | testCases := []testCase{ 208 | {p: "one, two, three", v: true}, 209 | {p: "two, three", v: false}, 210 | {p: "", v: false}, 211 | } 212 | 213 | for i, c := range testCases { 214 | _ = i 215 | r, err := http.NewRequest("GET", s.URL, nil) 216 | 217 | if err != nil { 218 | t.Error("unexpected error returned while trying to create a request instance:", err) 219 | } 220 | 221 | if c.p != "" { 222 | r.Header.Set("Sec-WebSocket-Protocol", c.p) 223 | } 224 | 225 | l := &http.Client{} 226 | w, err := l.Do(r) 227 | 228 | if err != nil { 229 | t.Error("unexpected error returned while trying to create a client instance:", err) 230 | } 231 | 232 | if c.v { 233 | v := w.Header.Get("Sec-WebSocket-Protocol") 234 | if w.Header.Get("Sec-WebSocket-Protocol") != "one" { 235 | t.Errorf(`expected 'Sec-WebSocket-Protocol' Response Header to be "one", but it is "%v".`, v) 236 | } 237 | } else { 238 | v := w.Header.Get("Sec-WebSocket-Protocol") 239 | if w.Header.Get("Sec-WebSocket-Protocol") != "" { 240 | t.Errorf(`expected 'Sec-WebSocket-Protocol' Response Header to be "", but it is "%v".`, v) 241 | } 242 | } 243 | } 244 | } 245 | 246 | func TestClientSubProtocols(t *testing.T) { 247 | r := &http.Request{} 248 | 249 | l := []string{"one", "two", "three"} 250 | 251 | r.Header = make(http.Header) 252 | r.Header.Set("Sec-WebSocket-Protocol", strings.Join(l, ", ")) 253 | 254 | q := &Request{ 255 | request: r, 256 | } 257 | 258 | p := q.ClientSubProtocols() 259 | 260 | if len(l) != len(p) { 261 | t.Errorf("The length of the list of header value assigned to Sec-WebSocket-Protocol HTTP Header are not the same. %d != %d", len(l), len(p)) 262 | } 263 | 264 | for _, v := range p { 265 | k := false 266 | for _, h := range l { 267 | if v == h { 268 | k = true 269 | break 270 | } 271 | } 272 | if !k { 273 | t.Errorf(`"%s" was not returned in the slice of Sub Protocols.`, v) 274 | } 275 | } 276 | } 277 | 278 | func TestClientExtensions(t *testing.T) { 279 | r := &http.Request{} 280 | 281 | l := []string{"one", "two", "three"} 282 | 283 | r.Header = make(http.Header) 284 | r.Header.Set("Sec-WebSocket-Extensions", strings.Join(l, ", ")) 285 | 286 | q := &Request{ 287 | request: r, 288 | } 289 | 290 | p := q.ClientExtensions() 291 | 292 | if len(l) != len(p) { 293 | t.Errorf("The length of the list of header value assigned to Sec-WebSocket-Extensions HTTP Header are not the same. '%d' != '%d'", len(l), len(p)) 294 | } 295 | 296 | for _, v := range p { 297 | k := false 298 | for _, h := range l { 299 | if v == h { 300 | k = true 301 | break 302 | } 303 | } 304 | if !k { 305 | t.Errorf(`"%s" was not returned in the slice of Extensions.`, v) 306 | } 307 | } 308 | } 309 | -------------------------------------------------------------------------------- /request_utils.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "encoding/base64" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | // validateRequest is used to determine whether the client handshake request 10 | // conforms with the WebSocket spec. When it doesn't the server should respond 11 | // with an HTTP Status 400 Bad Request. 12 | // 13 | // Note that this method doesn't validate the websocket version 14 | // ("Sec-WebSocket-Version" HTTP Header Field) and origin ("Origin" HTTP 15 | // Header Field) since these require specific HTTP Status Code (427 and 403 16 | // respectively). 17 | // 18 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 19 | // https://tools.ietf.org/html/rfc6455#section-4.2.2 20 | 21 | func validateRequest(r *http.Request) *OpenError { 22 | validations := []func(*http.Request) *OpenError{ 23 | // Check HTTP version to be at least v1.1. 24 | validateRequestVersion, 25 | // Check HTTP method to be 'GET'. 26 | validateRequestMethod, 27 | // Validate 'Upgrade' header field. 28 | validateRequestUpgradeHeader, 29 | // Validate 'Connection' header field. 30 | validateRequestConnectionHeader, 31 | // Validate 'Sec-WebSocket-Key' header field. 32 | validateRequestSecWebsocketKeyHeader, 33 | } 34 | 35 | for _, v := range validations { 36 | if err := v(r); err != nil { 37 | return err 38 | } 39 | } 40 | 41 | return nil 42 | } 43 | 44 | // validateRequestVersion verifies that the HTTP Version used in the client's 45 | // opening handshake request is at least v1.1. 46 | // 47 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 48 | func validateRequestVersion(r *http.Request) *OpenError { 49 | if !r.ProtoAtLeast(1, 1) { 50 | return &OpenError{Reason: `HTTP must be v1.1 or higher`} 51 | } 52 | return nil 53 | } 54 | 55 | // validateRequestMethod verifies that the HTTP Method used in the client's 56 | // opening handshake request is 'GET'. 57 | // 58 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 59 | func validateRequestMethod(r *http.Request) *OpenError { 60 | if r.Method != "GET" { 61 | return &OpenError{Reason: `HTTP method must be "GET"`} 62 | } 63 | return nil 64 | } 65 | 66 | // validateRequestUpgradeHeader verifies that the Upgrade HTTP Header value in the 67 | // client's opening handshake request is "websocket". 68 | // 69 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 70 | func validateRequestUpgradeHeader(r *http.Request) *OpenError { 71 | h := r.Header.Get("Upgrade") 72 | 73 | if strings.ToLower(h) != "websocket" { 74 | return &OpenError{Reason: `"Upgrade" Header should have the value of "websocket"`} 75 | } 76 | 77 | return nil 78 | } 79 | 80 | // validateRequestConnectionHeader verfies that the Connection HTTP Header value in 81 | // the client's opening handshake request is "upgrade". 82 | // 83 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 84 | func validateRequestConnectionHeader(r *http.Request) *OpenError { 85 | h := r.Header.Get("Connection") 86 | 87 | if strings.ToLower(h) != "upgrade" { 88 | return &OpenError{Reason: `"Connection" Header should have the value of "upgrade"`} 89 | } 90 | 91 | return nil 92 | } 93 | 94 | // validateRequestSecWebsocketKeyHeader verifies that the Sec-WebSocket-Key HTTP Header value in 95 | // the client's opening handshake request is of length 16 when base64 decoded. 96 | // 97 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 98 | func validateRequestSecWebsocketKeyHeader(r *http.Request) *OpenError { 99 | h := r.Header.Get("Sec-WebSocket-Key") 100 | d, err := base64.StdEncoding.DecodeString(h) 101 | 102 | // Check for decoding errors. 103 | if err != nil { 104 | return &OpenError{Reason: `an error had occured while validating "Sec-WebSocket-Key" header`} 105 | } 106 | 107 | // Check that the length of the decoded Sec-WebSocket-Key value is 16 108 | // (bytes). 109 | if len(d) != 16 { 110 | return &OpenError{Reason: `"Sec-WebSocket-Key" must be 16 bytes in length when decoded`} 111 | } 112 | 113 | return nil 114 | } 115 | 116 | // validateWSVersionHeader verifies that the Sec-WebSocket-Verion HTTP Header 117 | // value in the client's opening handshake request is "13". 118 | // 119 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 120 | func validateWSVersionHeader(r *http.Request) *OpenError { 121 | if r.Header.Get("Sec-WebSocket-Version") != wsVersion { 122 | return &OpenError{Reason: "upgrade required"} 123 | } 124 | 125 | return nil 126 | } 127 | 128 | // checkOrigin is the default CheckOrigin handler used by the Request struct. 129 | // This method will allow requests that are either coming from a non-browser 130 | // client (Origin HTTP Header field omitted) or are not cross origin requests. 131 | // 132 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.1 133 | func checkOrigin(r *http.Request) bool { 134 | h := r.Header.Get("Origin") 135 | 136 | if strings.HasPrefix(h, "http://") { 137 | h = strings.Replace(h, "http://", "", 1) 138 | } else if strings.HasPrefix(h, "https://") { 139 | h = strings.Replace(h, "https://", "", 1) 140 | } 141 | 142 | return h == "" || h == r.Host 143 | } 144 | -------------------------------------------------------------------------------- /request_utils_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | func TestValidateRequestVersion(t *testing.T) { 9 | r := &http.Request{} 10 | r.Header = make(http.Header) 11 | 12 | type testCase struct { 13 | a int 14 | i int 15 | r bool 16 | } 17 | 18 | testCases := []testCase{ 19 | // HTTP v1.1 should be valid. 20 | {a: 1, i: 1, r: true}, 21 | // HTTP v2.1 should be valid. 22 | {a: 2, i: 1, r: true}, 23 | // HTTP v1.0 should be not valid. 24 | {a: 1, i: 0, r: false}, 25 | // HTTP v0.1 should be not valid. 26 | {a: 0, i: 1, r: false}, 27 | } 28 | 29 | for i, c := range testCases { 30 | r.ProtoMajor = c.a 31 | r.ProtoMinor = c.i 32 | 33 | err := validateRequestVersion(r) 34 | 35 | if c.r && err != nil { 36 | t.Errorf(`test case %d: unexpected error retured for "v%d.%d"`, i, c.a, c.i) 37 | } 38 | 39 | if !c.r && err == nil { 40 | t.Errorf(`test case %d: expected an error for "v%d.%d"`, i, c.a, c.i) 41 | } 42 | } 43 | } 44 | 45 | func TestValidateRequestMethod(t *testing.T) { 46 | r := &http.Request{} 47 | r.Header = make(http.Header) 48 | 49 | type testCase struct { 50 | m string 51 | r bool 52 | } 53 | 54 | testCases := []testCase{ 55 | // HTTP GET should be valid. 56 | {m: "GET", r: true}, 57 | // HTTP POST should be not valid. 58 | {m: "POST", r: false}, 59 | } 60 | 61 | for i, c := range testCases { 62 | r.Method = c.m 63 | 64 | err := validateRequestMethod(r) 65 | 66 | if c.r && err != nil { 67 | t.Errorf(`test case %d: unexpected error retured for "%s" request`, i, c.m) 68 | } 69 | 70 | if !c.r && err == nil { 71 | t.Errorf(`test case %d: expected an error for "%s" request`, i, c.m) 72 | } 73 | } 74 | } 75 | 76 | func TestValidateRequestUpgradeHeader(t *testing.T) { 77 | r := &http.Request{} 78 | r.Header = make(http.Header) 79 | 80 | type testCase struct { 81 | v string 82 | r bool 83 | } 84 | 85 | testCases := []testCase{ 86 | // When value is "websocket" should be valid. 87 | {v: "websocket", r: true}, 88 | // When value is "webSocket" should be valid. 89 | {v: "webSocket", r: true}, 90 | // When value is not "websocket" should not be valid. 91 | {v: "ValueOtherThanWebsocket", r: false}, 92 | } 93 | 94 | for i, c := range testCases { 95 | r.Header.Set("Upgrade", c.v) 96 | 97 | err := validateRequestUpgradeHeader(r) 98 | 99 | if c.r && err != nil { 100 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v) 101 | } 102 | 103 | if !c.r && err == nil { 104 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v) 105 | } 106 | } 107 | } 108 | 109 | func TestValidateRequestConnectionHeader(t *testing.T) { 110 | r := &http.Request{} 111 | r.Header = make(http.Header) 112 | 113 | type testCase struct { 114 | v string 115 | r bool 116 | } 117 | 118 | testCases := []testCase{ 119 | // When value is "upgrade" should be valid. 120 | {v: "upgrade", r: true}, 121 | // When value is "Upgrade" should be valid. 122 | {v: "Upgrade", r: true}, 123 | // When value is not "upgrade" should not be valid. 124 | {v: "ValueOtherThanUpgrade", r: false}, 125 | } 126 | 127 | for i, c := range testCases { 128 | r.Header.Set("Connection", c.v) 129 | 130 | err := validateRequestConnectionHeader(r) 131 | 132 | if c.r && err != nil { 133 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v) 134 | } 135 | 136 | if !c.r && err == nil { 137 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v) 138 | } 139 | } 140 | } 141 | 142 | func TestValidateRequestSecWebsocketKeyHeader(t *testing.T) { 143 | r := &http.Request{} 144 | r.Header = make(http.Header) 145 | 146 | type testCase struct { 147 | v string 148 | r bool 149 | } 150 | 151 | testCases := []testCase{ 152 | // Valid key. 153 | {v: "FlBPpXKmN36AUZxV0tYHYw==", r: true}, 154 | // Invalid decoded length. 155 | {v: "InvalidKey==", r: false}, 156 | // Invalid encoded data. 157 | {v: "InvalidKeyError", r: false}, 158 | } 159 | 160 | for i, c := range testCases { 161 | r.Header.Set("Sec-WebSocket-Key", c.v) 162 | 163 | err := validateRequestSecWebsocketKeyHeader(r) 164 | 165 | if c.r && err != nil { 166 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v) 167 | } 168 | 169 | if !c.r && err == nil { 170 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v) 171 | } 172 | } 173 | } 174 | 175 | func TestValidateWSVersionHeader(t *testing.T) { 176 | r := &http.Request{} 177 | r.Header = make(http.Header) 178 | 179 | type testCase struct { 180 | v string 181 | r bool 182 | } 183 | 184 | testCases := []testCase{ 185 | // Valid when value is the same as the version of the ws supported. 186 | {v: wsVersion, r: true}, 187 | // Not valid when value is not the same as the version of the ws 188 | // supported. 189 | {v: "14", r: false}, 190 | } 191 | 192 | for i, c := range testCases { 193 | r.Header.Set("Sec-WebSocket-Version", c.v) 194 | 195 | err := validateWSVersionHeader(r) 196 | 197 | if c.r && err != nil { 198 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v) 199 | } 200 | 201 | if !c.r && err == nil { 202 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v) 203 | } 204 | } 205 | } 206 | 207 | func TestCheckOrigin(t *testing.T) { 208 | r := &http.Request{} 209 | r.Header = make(http.Header) 210 | r.Host = "example.com:8080" 211 | 212 | type testCase struct { 213 | v string 214 | r bool 215 | } 216 | 217 | testCases := []testCase{ 218 | // Valid when origin is omitted (non-browser client). 219 | {v: "", r: true}, 220 | // Valid when same origin. 221 | {v: r.Host, r: true}, 222 | {v: "example.com:8080", r: true}, 223 | {v: "http://example.com:8080", r: true}, 224 | {v: "https://example.com:8080", r: true}, 225 | } 226 | 227 | for i, c := range testCases { 228 | r.Header.Set("Origin", c.v) 229 | 230 | if checkOrigin(r) != c.r { 231 | t.Errorf(`Test Case %d: Expected checkOrigin() to return '%t' when 'Origin' header == "%s" and Host is at "%s".`, i, c.r, c.v, r.Host) 232 | } 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /socket.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "io" 7 | "net" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | // ErrSocketClosed is the error returned when a user tries to send a frame with 13 | // a closed socket. 14 | var ErrSocketClosed = errors.New("socket has been closed") 15 | 16 | // WebSocket Error codes. 17 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-7.4.1 18 | const ( 19 | CloseNormalClosure int = 1000 20 | CloseGoingAway int = 1001 21 | CloseProtocolError int = 1002 22 | CloseUnsupportedData int = 1003 23 | CloseNoStatusReceived int = 1005 24 | CloseAbnormalClosure int = 1006 25 | CloseInvalidFramePayloadData int = 1007 26 | ClosePolicyViolation int = 1008 27 | CloseMessageTooBig int = 1009 28 | CloseMandatoryExtension int = 1010 29 | CloseInternalServerErr int = 1011 30 | CloseTLSHandshake int = 1015 31 | ) 32 | 33 | // Represents the state of the Socket instance 34 | const ( 35 | /* 36 | stateOpened will be the state when the socket instance is open. 37 | */ 38 | stateOpened int = 0 39 | 40 | /* 41 | stateClosing will be the state when the socket instance is in the middle 42 | of the closing handshake. 43 | */ 44 | stateClosing int = 1 45 | 46 | /* 47 | stateClosed will be the state when the socket instance is closed. 48 | */ 49 | stateClosed int = 2 50 | ) 51 | 52 | // Socket represents a socket endpoint. 53 | type Socket struct { 54 | /* 55 | conn is the underlying tcp connection. 56 | */ 57 | conn net.Conn 58 | 59 | /* 60 | buf is a buffered version of the underlying tcp connection. 61 | */ 62 | buf *bufio.ReadWriter 63 | 64 | /* 65 | server indicates whether the socket instance represents a server or a 66 | client endpoint. 67 | */ 68 | server bool 69 | 70 | /* 71 | state is the current state of the socket instance. 72 | */ 73 | state int 74 | 75 | /* 76 | closeDelay is the duration the socket instance will wait until it closes 77 | the underlying tcp connection once the closing handshake has been 78 | completed. 79 | 80 | The websocket rfc suggests that when the closing handshake is completed 81 | the underlying tcp connection should first be terminated by the server 82 | endpoint. Having said this it doesn't restrict the client endpoint to do 83 | so itself. CloseDelay is the maximum time the socket instance will wait 84 | before it closes the tcp connection. 85 | 86 | Note: Server endpoints should always have this property set to 0. 87 | 88 | Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1 89 | */ 90 | CloseDelay time.Duration 91 | 92 | /* 93 | readHandler is invoked whenever a text or binary frame is received. The 94 | opcode and payload data are provided as args respectively. 95 | */ 96 | ReadHandler func(int, []byte) 97 | 98 | /* 99 | pingHandler is invoked whenever a ping frame is received. The payload 100 | data is provided as arg. 101 | */ 102 | PingHandler func([]byte) 103 | 104 | /* 105 | pongHandler is invoked whenever a pong frame is received. The payload 106 | data is provided as arg. 107 | */ 108 | PongHandler func([]byte) 109 | 110 | /* 111 | closeHandler is invoked whenever the websocket connection is closed. The 112 | reason for the closure is provided as an arg. 113 | */ 114 | CloseHandler func(error) 115 | 116 | /* 117 | closeError contains the error which caused the websocket connection to 118 | terminate. This is then provided as an arg when invoking the close 119 | handler once the underlying tcp connection is terminated. 120 | */ 121 | closeError error 122 | 123 | /* 124 | writeMutex is used to queue the write functionality of a socket 125 | instance. 126 | */ 127 | writeMutex *sync.Mutex 128 | } 129 | 130 | // Listen is used to start listening for new frames sent by the connected 131 | // endpoint. 132 | func (s *Socket) Listen() { 133 | s.read() 134 | } 135 | 136 | func (s *Socket) read() { 137 | Read: 138 | for { 139 | // Read frame 140 | f, err := newFrame(s.buf.Reader) 141 | 142 | if s.state == stateClosed { 143 | break Read 144 | } 145 | 146 | if err != nil { 147 | // If an error occurred due to something which doesn't conform with 148 | // the websocket rfc, use the error itself as a reason. 149 | if c, k := err.(*CloseError); k { 150 | s.CloseWithError(c) 151 | return 152 | } 153 | 154 | // When EOF returns it means that the other endpoint isn't reachable 155 | // and thus there won't be the need to initate the closing 156 | // handshake. 157 | if err == io.EOF { 158 | s.closeError = &CloseError{ 159 | Code: CloseAbnormalClosure, 160 | Reason: "abnormal closure", 161 | } 162 | s.TCPClose() 163 | break Read 164 | } 165 | 166 | // When Read times out or connection is closed the other endpoing 167 | // won't be reachable and thus there won't be the need to initiate 168 | // the closing handshake. 169 | if _, k := err.(*net.OpError); k { 170 | s.closeError = &CloseError{ 171 | Code: CloseAbnormalClosure, 172 | Reason: "abnormal closure", 173 | } 174 | s.TCPClose() 175 | break Read 176 | } 177 | 178 | // Else use a generic error. 179 | s.CloseWithError(&CloseError{ 180 | Code: CloseProtocolError, 181 | Reason: "protocol error", 182 | }) 183 | 184 | return 185 | } 186 | 187 | // If Socket instance represents a server endpoint, payload data must be 188 | // masked. 189 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.1 190 | if s.server && !f.masked { 191 | s.CloseWithError(&CloseError{ 192 | Code: CloseProtocolError, 193 | Reason: "expected payload to be masked", 194 | }) 195 | return 196 | } 197 | 198 | // If Socket instance represents a client endpoint, payload data must 199 | // not be masked. 200 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.1 201 | if !s.server && f.masked { 202 | s.CloseWithError(&CloseError{ 203 | Code: CloseProtocolError, 204 | Reason: "expected payload to not be masked", 205 | }) 206 | return 207 | } 208 | 209 | switch f.opcode { 210 | case OpcodeText, OpcodeBinary: 211 | { 212 | s.callReadHandler(f.opcode, f.payload) 213 | } 214 | case OpcodePing: 215 | { 216 | s.callPingHandler(f.payload) 217 | } 218 | case OpcodePong: 219 | { 220 | s.callPongHandler(f.payload) 221 | } 222 | case OpcodeClose: 223 | { 224 | // Create a new CloseError using the payload data 225 | c, cerr := NewCloseError(f.payload) 226 | 227 | // Store close error for close handler. 228 | s.closeError = c 229 | 230 | // If the state of the socket instance is CLOSING, it means that 231 | // the closing handshake has been initiated from this socket 232 | // instance and the retrieved frame was the acknowledge close 233 | // frame. At this point the closing handshake has been completed 234 | // and therefore the underlying tcp connection can be closed, 235 | // since the connected endpoint won't be waiting for furthur 236 | // frames. 237 | if s.state == stateClosing { 238 | // closing handshake has been finalized therefore close tcp 239 | // connection. 240 | s.tcpClose() 241 | // Stop reading from connection. 242 | break Read 243 | } 244 | 245 | // If the state of the socket instance is not CLOSING, it means 246 | // that the closing handshake has been initiated by the 247 | // connected endpoint and therefore it is still waiting for the 248 | // acknowledgement close frame. 249 | s.state = stateClosing 250 | 251 | // The acknowledgment close frame to be sent will echo the 252 | // status code of the close frame just received. 253 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1 254 | var b []byte 255 | 256 | // If the status code of the close frame received is valid, echo 257 | // it. Else leave the payload data of the acknowledgement close 258 | // frame empty. 259 | if cerr == nil { 260 | b = c.toBytesCode() 261 | } 262 | 263 | // Send acknowledgement close frame. 264 | s.WriteMessage(OpcodeClose, b) 265 | 266 | // At this point the closing handshake would have been finalized 267 | // therefore the tcp connection can be closed. 268 | s.tcpClose() 269 | 270 | // Stop reading from connection. 271 | break Read 272 | } 273 | } 274 | } 275 | } 276 | 277 | // WriteMessage is used to send frames to the connected endpoint. It accepts 278 | // two arguments 'o' opcode, 'p' payload data. 279 | func (s *Socket) WriteMessage(o int, p []byte) error { 280 | s.writeMutex.Lock() 281 | defer s.writeMutex.Unlock() 282 | 283 | // Before writing make sure that the socket instance is still in an open 284 | // state. 285 | if s.state == stateClosed { 286 | return ErrSocketClosed 287 | } 288 | 289 | // Create a frame instance which will represent the frame to be sent. 290 | f := &frame{ 291 | fin: true, 292 | opcode: o, 293 | payload: p, 294 | } 295 | 296 | // If the socket instance represents a client endpoint, the payload data 297 | // must be masked. 298 | if !s.server { 299 | // Generate random mask key 300 | f.key = randomByteSlice(1) 301 | } 302 | 303 | // Get a []byte representation of the frame instance. 304 | b, err := f.toBytes() 305 | 306 | // If an error is not nil, since the error doesn't relate with the socket 307 | // connection itself, the error is returned. 308 | if err != nil { 309 | return err 310 | } 311 | 312 | // Send frame 313 | s.buf.Write(b) 314 | if err := s.buf.Flush(); err != nil { 315 | // Store error. 316 | s.closeError = err 317 | 318 | // Close TCP Connection. 319 | s.TCPClose() 320 | 321 | // Since the error is related with the socket connection the error is 322 | // not returned but passed to the close handler. 323 | return nil 324 | } 325 | 326 | // If frame sent is a close frame, change state to closing. 327 | if f.opcode == OpcodeClose { 328 | s.state = stateClosing 329 | } 330 | 331 | return nil 332 | } 333 | 334 | // SetReadDeadline sets the deadline for future Read calls. A zero value for t 335 | // means Read will not time out. 336 | func (s *Socket) SetReadDeadline(t time.Time) { 337 | s.conn.SetReadDeadline(t) 338 | } 339 | 340 | // SetWriteDeadline sets the deadline for future Write calls. Even if write 341 | // times out, it may return n > 0, indicating that some of the data was 342 | // successfully written. A zero value for t means Write will not time out. 343 | func (s *Socket) SetWriteDeadline(t time.Time) { 344 | s.conn.SetWriteDeadline(t) 345 | } 346 | 347 | // callReadHandler invokes the read handler provided by the user (if any). 348 | func (s *Socket) callReadHandler(o int, p []byte) { 349 | if s.ReadHandler != nil { 350 | s.ReadHandler(o, p) 351 | } 352 | } 353 | 354 | // callPingHandler first tries to invoke the ping handler provided by the 355 | // user. If the user hasn't provided one it invokes the default functionality. 356 | func (s *Socket) callPingHandler(p []byte) { 357 | if s.PingHandler != nil { 358 | s.PingHandler(p) 359 | return 360 | } 361 | s.defaultPingHandler(p) 362 | } 363 | 364 | // defaultPingHandler sends a pong frame with the same payload data of the ping 365 | // frame just received. 366 | // 367 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.3 368 | func (s *Socket) defaultPingHandler(p []byte) { 369 | s.WriteMessage(OpcodePong, p) 370 | } 371 | 372 | // callPongHandler invokes the pong handler provided by the user (if any). 373 | func (s *Socket) callPongHandler(p []byte) { 374 | if s.PongHandler != nil { 375 | s.PongHandler(p) 376 | return 377 | } 378 | } 379 | 380 | // callCloseHandler first tries to invoke the close handler provided by the 381 | // user. 382 | func (s *Socket) callCloseHandler(e error) { 383 | if s.CloseHandler != nil { 384 | s.CloseHandler(e) 385 | } 386 | } 387 | 388 | // TCPClose closes the underlying tcp connection if it hasn't already been 389 | // closed. 390 | func (s *Socket) TCPClose() { 391 | // If socket has already been closed, don't reclose the tcp connection 392 | if s.state == stateClosed { 393 | return 394 | } 395 | 396 | // Change state of socket instance to closed. 397 | s.state = stateClosed 398 | 399 | // Close tcp connection 400 | s.conn.Close() 401 | 402 | // Invoke close handler. 403 | s.callCloseHandler(s.closeError) 404 | } 405 | 406 | // tcpClose closes the underlying tcp connection after s.CloseDelay seconds if 407 | // it hasn't already been closed . More info on why this is needed documented 408 | // in s.CloseDelay. 409 | func (s *Socket) tcpClose() { 410 | // If socket has already been closed, don't reclose the tcp connection 411 | if s.state == stateClosed { 412 | return 413 | } 414 | 415 | if s.CloseDelay > 0 { 416 | t := time.NewTicker(time.Second * s.CloseDelay) 417 | <-t.C 418 | } 419 | 420 | // Close tcp connection 421 | s.TCPClose() 422 | } 423 | 424 | // Close initiates the normal closures (1000) closing handshake. 425 | func (s *Socket) Close() { 426 | s.CloseWithError(&CloseError{ 427 | Code: CloseNormalClosure, 428 | Reason: "normal closure", 429 | }) 430 | } 431 | 432 | // CloseWithError initiates the closing handshake. 433 | func (s *Socket) CloseWithError(e *CloseError) { 434 | // Store error. 435 | s.closeError = e 436 | 437 | // Start the closing handshake 438 | b, _ := e.ToBytes() 439 | s.WriteMessage(OpcodeClose, b) 440 | } 441 | -------------------------------------------------------------------------------- /socket_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "sync" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestSocketReadTextFrame(t *testing.T) { 13 | payload := "expected payload" 14 | 15 | done := make(chan bool) 16 | timeout := time.NewTicker(time.Second * 2) 17 | 18 | h := func(w http.ResponseWriter, r *http.Request) { 19 | q := Request{} 20 | s, err := q.Upgrade(w, r) 21 | 22 | if err != nil { 23 | t.Fatal("unexpected error was returned", err) 24 | } 25 | 26 | s.ReadHandler = func(o int, p []byte) { 27 | if o != OpcodeText { 28 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeText, o) 29 | } 30 | 31 | if string(p) != payload { 32 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload) 33 | } 34 | 35 | done <- true 36 | } 37 | 38 | s.Listen() 39 | } 40 | 41 | s := httptest.NewServer(http.HandlerFunc(h)) 42 | defer s.Close() 43 | 44 | d := &Dialer{} 45 | c, _, err := d.Dial(adaptURL(s.URL)) 46 | 47 | if err != nil { 48 | t.Fatal("unexpected error returned", err) 49 | } 50 | 51 | defer c.TCPClose() 52 | 53 | f := &frame{ 54 | fin: true, 55 | opcode: OpcodeText, 56 | key: []byte{1, 1, 1, 1}, 57 | payload: []byte(payload), 58 | } 59 | 60 | b, err := f.toBytes() 61 | 62 | if err != nil { 63 | t.Fatal("unexpected error returned", err) 64 | } 65 | 66 | c.buf.Write(b) 67 | if err := c.buf.Flush(); err != nil { 68 | t.Fatal("unexpected error returned", err) 69 | } 70 | 71 | select { 72 | case <-done: 73 | { 74 | 75 | } 76 | case <-timeout.C: 77 | { 78 | t.Error("test case timed out") 79 | } 80 | } 81 | } 82 | 83 | func TestSocketReadBinaryFrame(t *testing.T) { 84 | payload := "expected payload" 85 | 86 | done := make(chan bool) 87 | timeout := time.NewTicker(time.Second * 2) 88 | 89 | h := func(w http.ResponseWriter, r *http.Request) { 90 | q := Request{} 91 | s, err := q.Upgrade(w, r) 92 | 93 | if err != nil { 94 | t.Fatal("unexpected error was returned", err) 95 | } 96 | 97 | s.ReadHandler = func(o int, p []byte) { 98 | if o != OpcodeBinary { 99 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeBinary, o) 100 | } 101 | 102 | if string(p) != payload { 103 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload) 104 | } 105 | 106 | done <- true 107 | } 108 | 109 | s.Listen() 110 | } 111 | 112 | s := httptest.NewServer(http.HandlerFunc(h)) 113 | defer s.Close() 114 | 115 | d := &Dialer{} 116 | c, _, err := d.Dial(adaptURL(s.URL)) 117 | 118 | if err != nil { 119 | t.Fatal("unexpected error returned", err) 120 | } 121 | 122 | defer c.TCPClose() 123 | 124 | f := &frame{ 125 | fin: true, 126 | opcode: OpcodeBinary, 127 | key: []byte{1, 1, 1, 1}, 128 | payload: []byte(payload), 129 | } 130 | 131 | b, err := f.toBytes() 132 | 133 | if err != nil { 134 | t.Fatal("unexpected error returned", err) 135 | } 136 | 137 | c.buf.Write(b) 138 | if err := c.buf.Flush(); err != nil { 139 | t.Fatal("unexpected error returned", err) 140 | } 141 | 142 | select { 143 | case <-done: 144 | { 145 | 146 | } 147 | case <-timeout.C: 148 | { 149 | t.Error("test case timed out") 150 | } 151 | } 152 | } 153 | 154 | func TestSocketReadPingFrame(t *testing.T) { 155 | payload := "expected payload" 156 | 157 | done := make(chan bool) 158 | timeout := time.NewTicker(time.Second * 2) 159 | 160 | h := func(w http.ResponseWriter, r *http.Request) { 161 | q := Request{} 162 | s, err := q.Upgrade(w, r) 163 | 164 | if err != nil { 165 | t.Fatal("unexpected error was returned", err) 166 | } 167 | 168 | s.PingHandler = func(p []byte) { 169 | if string(p) != payload { 170 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload) 171 | } 172 | 173 | done <- true 174 | } 175 | 176 | s.Listen() 177 | } 178 | 179 | s := httptest.NewServer(http.HandlerFunc(h)) 180 | defer s.Close() 181 | 182 | d := &Dialer{} 183 | c, _, err := d.Dial(adaptURL(s.URL)) 184 | 185 | if err != nil { 186 | t.Fatal("unexpected error returned", err) 187 | } 188 | 189 | defer c.TCPClose() 190 | 191 | f := &frame{ 192 | fin: true, 193 | opcode: OpcodePing, 194 | key: []byte{1, 1, 1, 1}, 195 | payload: []byte(payload), 196 | } 197 | 198 | b, err := f.toBytes() 199 | 200 | if err != nil { 201 | t.Fatal("unexpected error returned", err) 202 | } 203 | 204 | c.buf.Write(b) 205 | if err := c.buf.Flush(); err != nil { 206 | t.Fatal("unexpected error returned", err) 207 | } 208 | 209 | select { 210 | case <-done: 211 | { 212 | 213 | } 214 | case <-timeout.C: 215 | { 216 | t.Error("test case timed out") 217 | } 218 | } 219 | } 220 | 221 | func TestSocketReadPongFrame(t *testing.T) { 222 | payload := "expected payload" 223 | 224 | done := make(chan bool) 225 | timeout := time.NewTicker(time.Second * 2) 226 | 227 | h := func(w http.ResponseWriter, r *http.Request) { 228 | q := Request{} 229 | s, err := q.Upgrade(w, r) 230 | 231 | if err != nil { 232 | t.Fatal("unexpected error was returned", err) 233 | } 234 | 235 | s.PongHandler = func(p []byte) { 236 | if string(p) != payload { 237 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload) 238 | } 239 | 240 | done <- true 241 | } 242 | 243 | s.Listen() 244 | } 245 | 246 | s := httptest.NewServer(http.HandlerFunc(h)) 247 | defer s.Close() 248 | 249 | d := &Dialer{} 250 | c, _, err := d.Dial(adaptURL(s.URL)) 251 | 252 | if err != nil { 253 | t.Fatal("unexpected error returned", err) 254 | } 255 | 256 | defer c.TCPClose() 257 | 258 | f := &frame{ 259 | fin: true, 260 | opcode: OpcodePong, 261 | key: []byte{1, 1, 1, 1}, 262 | payload: []byte(payload), 263 | } 264 | 265 | b, err := f.toBytes() 266 | 267 | if err != nil { 268 | t.Fatal("unexpected error returned", err) 269 | } 270 | 271 | c.buf.Write(b) 272 | if err := c.buf.Flush(); err != nil { 273 | t.Fatal("unexpected error returned", err) 274 | } 275 | 276 | select { 277 | case <-done: 278 | { 279 | 280 | } 281 | case <-timeout.C: 282 | { 283 | t.Error("test case timed out") 284 | } 285 | } 286 | } 287 | 288 | func TestSocketdefaultPingHandler(t *testing.T) { 289 | payload := "expected payload" 290 | 291 | done := make(chan bool) 292 | timeout := time.NewTicker(time.Second * 2) 293 | 294 | h := func(w http.ResponseWriter, r *http.Request) { 295 | q := Request{} 296 | s, err := q.Upgrade(w, r) 297 | 298 | if err != nil { 299 | t.Fatal("unexpected error was returned", err) 300 | } 301 | 302 | s.Listen() 303 | } 304 | 305 | s := httptest.NewServer(http.HandlerFunc(h)) 306 | defer s.Close() 307 | 308 | d := &Dialer{} 309 | c, _, err := d.Dial(adaptURL(s.URL)) 310 | 311 | if err != nil { 312 | t.Fatal("unexpected error returned", err) 313 | } 314 | 315 | defer c.TCPClose() 316 | 317 | f := &frame{ 318 | fin: true, 319 | opcode: OpcodePing, 320 | key: []byte{1, 1, 1, 1}, 321 | payload: []byte(payload), 322 | } 323 | 324 | b, err := f.toBytes() 325 | 326 | if err != nil { 327 | t.Fatal("unexpected error returned", err) 328 | } 329 | 330 | c.PongHandler = func(p []byte) { 331 | if string(p) != payload { 332 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload) 333 | } 334 | done <- true 335 | } 336 | 337 | go c.Listen() 338 | 339 | c.buf.Write(b) 340 | if err := c.buf.Flush(); err != nil { 341 | t.Fatal("unexpected error returned", err) 342 | } 343 | 344 | select { 345 | case <-done: 346 | { 347 | 348 | } 349 | case <-timeout.C: 350 | { 351 | t.Error("test case timed out") 352 | } 353 | } 354 | } 355 | 356 | func TestSocketReadInvalidFrame(t *testing.T) { 357 | done := make(chan bool) 358 | timeout := time.NewTicker(time.Second * 2) 359 | 360 | h := func(w http.ResponseWriter, r *http.Request) { 361 | q := Request{} 362 | s, err := q.Upgrade(w, r) 363 | 364 | if err != nil { 365 | t.Fatal("unexpected error was returned", err) 366 | } 367 | 368 | s.ReadHandler = func(o int, p []byte) { 369 | t.Error("unexpected invocation of Read Handler") 370 | } 371 | 372 | s.Listen() 373 | } 374 | 375 | s := httptest.NewServer(http.HandlerFunc(h)) 376 | defer s.Close() 377 | 378 | d := &Dialer{} 379 | c, _, err := d.Dial(adaptURL(s.URL)) 380 | 381 | if err != nil { 382 | t.Fatal("unexpected error returned", err) 383 | } 384 | 385 | defer c.TCPClose() 386 | 387 | c.CloseHandler = func(err error) { 388 | if e, k := err.(*CloseError); k { 389 | if e.Code != CloseProtocolError { 390 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseProtocolError, e.Code) 391 | } 392 | } else { 393 | t.Errorf("expected error instance to be of type *CloseError") 394 | } 395 | done <- true 396 | } 397 | 398 | go c.Listen() 399 | 400 | c.buf.Write([]byte("bad frame")) 401 | if err := c.buf.Flush(); err != nil { 402 | t.Fatal("unexpected error returned", err) 403 | } 404 | 405 | select { 406 | case <-done: 407 | { 408 | 409 | } 410 | case <-timeout.C: 411 | { 412 | t.Error("test case timed out") 413 | } 414 | } 415 | } 416 | 417 | func TestSocketReadClientUnMaskedFrame(t *testing.T) { 418 | done := make(chan bool) 419 | timeout := time.NewTicker(time.Second * 2) 420 | 421 | h := func(w http.ResponseWriter, r *http.Request) { 422 | q := Request{} 423 | s, err := q.Upgrade(w, r) 424 | 425 | if err != nil { 426 | t.Fatal("unexpected error was returned", err) 427 | } 428 | 429 | s.ReadHandler = func(o int, p []byte) { 430 | t.Errorf("unexpected invocation of Read Handler") 431 | } 432 | 433 | s.Listen() 434 | } 435 | 436 | s := httptest.NewServer(http.HandlerFunc(h)) 437 | defer s.Close() 438 | 439 | d := &Dialer{} 440 | c, _, err := d.Dial(adaptURL(s.URL)) 441 | 442 | if err != nil { 443 | t.Fatal("unexpected error returned", err) 444 | } 445 | 446 | defer c.TCPClose() 447 | 448 | f := &frame{ 449 | fin: true, 450 | opcode: OpcodeText, 451 | payload: []byte("something"), 452 | } 453 | 454 | b, err := f.toBytes() 455 | 456 | if err != nil { 457 | t.Fatal("unexpected error returned", err) 458 | } 459 | 460 | c.CloseHandler = func(err error) { 461 | if e, k := err.(*CloseError); k { 462 | r := "expected payload to be masked" 463 | 464 | if e.Code != CloseProtocolError { 465 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseProtocolError, e.Code) 466 | } 467 | 468 | if e.Reason != r { 469 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason) 470 | } 471 | } else { 472 | t.Errorf("expected error instance to be of type *CloseError") 473 | } 474 | done <- true 475 | } 476 | 477 | go c.Listen() 478 | 479 | c.buf.Write(b) 480 | if err := c.buf.Flush(); err != nil { 481 | t.Fatal("unexpected error returned", err) 482 | } 483 | 484 | select { 485 | case <-done: 486 | { 487 | 488 | } 489 | case <-timeout.C: 490 | { 491 | t.Error("test case timed out") 492 | } 493 | } 494 | } 495 | 496 | func TestSocketReadServerMaskedFrame(t *testing.T) { 497 | done := make(chan bool) 498 | timeout := time.NewTicker(time.Second * 2) 499 | 500 | h := func(w http.ResponseWriter, r *http.Request) { 501 | q := Request{} 502 | s, err := q.Upgrade(w, r) 503 | 504 | if err != nil { 505 | t.Fatal("unexpected error was returned", err) 506 | } 507 | 508 | f := &frame{ 509 | fin: true, 510 | opcode: OpcodeText, 511 | key: []byte{1, 1, 1, 1}, 512 | payload: []byte("something"), 513 | } 514 | 515 | b, err := f.toBytes() 516 | 517 | if err != nil { 518 | t.Fatal("unexpected error returned", err) 519 | } 520 | 521 | s.CloseHandler = func(err error) { 522 | if e, k := err.(*CloseError); k { 523 | r := "expected payload to not be masked" 524 | 525 | if e.Code != CloseProtocolError { 526 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseProtocolError, e.Code) 527 | } 528 | 529 | if e.Reason != r { 530 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason) 531 | } 532 | } else { 533 | t.Errorf("expected error instance to be of type *CloseError") 534 | } 535 | done <- true 536 | } 537 | 538 | s.buf.Write(b) 539 | if err := s.buf.Flush(); err != nil { 540 | t.Error("unexpected error returned", err) 541 | } 542 | 543 | s.Listen() 544 | } 545 | 546 | s := httptest.NewServer(http.HandlerFunc(h)) 547 | defer s.Close() 548 | 549 | d := &Dialer{} 550 | c, _, err := d.Dial(adaptURL(s.URL)) 551 | 552 | if err != nil { 553 | t.Fatal("unexpected error returned", err) 554 | } 555 | 556 | defer c.TCPClose() 557 | 558 | c.ReadHandler = func(o int, p []byte) { 559 | t.Errorf("unexpected invocation of Read Handler") 560 | } 561 | 562 | go c.Listen() 563 | 564 | select { 565 | case <-done: 566 | { 567 | 568 | } 569 | case <-timeout.C: 570 | { 571 | t.Error("test case timed out") 572 | } 573 | } 574 | } 575 | 576 | func TestSocketClose(t *testing.T) { 577 | done := make(chan bool) 578 | timeout := time.NewTicker(time.Second * 2) 579 | 580 | h := func(w http.ResponseWriter, r *http.Request) { 581 | q := Request{} 582 | s, err := q.Upgrade(w, r) 583 | 584 | if err != nil { 585 | t.Fatal("unexpected error was returned", err) 586 | } 587 | 588 | s.Listen() 589 | } 590 | 591 | s := httptest.NewServer(http.HandlerFunc(h)) 592 | defer s.Close() 593 | 594 | d := &Dialer{} 595 | c, _, err := d.Dial(adaptURL(s.URL)) 596 | 597 | if err != nil { 598 | t.Fatal("unexpected error returned", err) 599 | } 600 | 601 | c.CloseHandler = func(err error) { 602 | if e, k := err.(*CloseError); k { 603 | if e.Code != CloseNormalClosure { 604 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseNormalClosure, e.Code) 605 | } 606 | 607 | if e.Reason != "" { 608 | t.Errorf(`expected Close Error Reason to be empty, but it is "%s"`, e.Reason) 609 | } 610 | } else { 611 | t.Errorf("expected error instance to be of type *CloseError") 612 | } 613 | done <- true 614 | } 615 | 616 | go c.Listen() 617 | 618 | c.Close() 619 | 620 | select { 621 | case <-done: 622 | { 623 | } 624 | case <-timeout.C: 625 | { 626 | t.Error("test case timed out") 627 | } 628 | } 629 | } 630 | 631 | func TestSocketReadEOFError(t *testing.T) { 632 | done := make(chan bool) 633 | timeout := time.NewTicker(time.Second * 2) 634 | 635 | h := func(w http.ResponseWriter, r *http.Request) { 636 | q := Request{} 637 | s, err := q.Upgrade(w, r) 638 | 639 | if err != nil { 640 | t.Fatal("unexpected error was returned", err) 641 | } 642 | 643 | s.TCPClose() 644 | } 645 | 646 | s := httptest.NewServer(http.HandlerFunc(h)) 647 | defer s.Close() 648 | 649 | d := &Dialer{} 650 | c, _, err := d.Dial(adaptURL(s.URL)) 651 | 652 | if err != nil { 653 | t.Fatal("unexpected error returned", err) 654 | } 655 | 656 | defer c.TCPClose() 657 | 658 | c.CloseHandler = func(err error) { 659 | if e, k := err.(*CloseError); k { 660 | r := "abnormal closure" 661 | if e.Code != CloseAbnormalClosure { 662 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseAbnormalClosure, e.Code) 663 | } 664 | 665 | if e.Reason != r { 666 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason) 667 | } 668 | } else { 669 | t.Errorf("expected error instance to be of type *CloseError") 670 | } 671 | done <- true 672 | } 673 | 674 | go c.Listen() 675 | 676 | select { 677 | case <-done: 678 | { 679 | 680 | } 681 | case <-timeout.C: 682 | { 683 | t.Error("test case timed out") 684 | } 685 | } 686 | } 687 | 688 | func TestSocketReadTimeoutError(t *testing.T) { 689 | done := make(chan bool) 690 | timeout := time.NewTicker(time.Second * 4) 691 | 692 | h := func(w http.ResponseWriter, r *http.Request) { 693 | q := Request{} 694 | s, err := q.Upgrade(w, r) 695 | 696 | if err != nil { 697 | t.Fatal("unexpected error was returned", err) 698 | } 699 | 700 | s.Listen() 701 | } 702 | 703 | s := httptest.NewServer(http.HandlerFunc(h)) 704 | defer s.Close() 705 | 706 | d := &Dialer{} 707 | c, _, err := d.Dial(adaptURL(s.URL)) 708 | 709 | if err != nil { 710 | t.Fatal("unexpected error returned", err) 711 | } 712 | 713 | defer c.TCPClose() 714 | 715 | c.CloseHandler = func(err error) { 716 | if e, k := err.(*CloseError); k { 717 | r := "abnormal closure" 718 | if e.Code != CloseAbnormalClosure { 719 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseAbnormalClosure, e.Code) 720 | } 721 | 722 | if e.Reason != r { 723 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason) 724 | } 725 | } else { 726 | t.Errorf("expected error instance to be of type *CloseError") 727 | } 728 | done <- true 729 | } 730 | 731 | go c.Listen() 732 | 733 | c.SetReadDeadline(time.Now().Add(time.Second * 1)) 734 | 735 | select { 736 | case <-done: 737 | { 738 | 739 | } 740 | case <-timeout.C: 741 | { 742 | t.Error("test case timed out") 743 | } 744 | } 745 | } 746 | 747 | func TestSocketWriteTimeoutErorr(t *testing.T) { 748 | done := make(chan bool) 749 | timeout := time.NewTicker(time.Second * 4) 750 | 751 | h := func(w http.ResponseWriter, r *http.Request) { 752 | q := Request{} 753 | s, err := q.Upgrade(w, r) 754 | 755 | if err != nil { 756 | t.Fatal("unexpected error was returned", err) 757 | } 758 | 759 | s.CloseHandler = func(err error) { 760 | done <- true 761 | } 762 | 763 | s.SetWriteDeadline(time.Now().Add(time.Second * 1)) 764 | 765 | go s.Listen() 766 | 767 | time.Sleep(time.Second * 2) 768 | 769 | s.WriteMessage(OpcodeText, []byte("something")) 770 | } 771 | 772 | s := httptest.NewServer(http.HandlerFunc(h)) 773 | defer s.Close() 774 | 775 | d := &Dialer{} 776 | c, _, err := d.Dial(adaptURL(s.URL)) 777 | 778 | if err != nil { 779 | t.Fatal("unexpected error returned", err) 780 | } 781 | 782 | defer c.TCPClose() 783 | 784 | select { 785 | case <-done: 786 | { 787 | 788 | } 789 | case <-timeout.C: 790 | { 791 | t.Error("test case timed out") 792 | } 793 | } 794 | } 795 | 796 | func TestSocketWriteFromClient(t *testing.T) { 797 | payload := "expected payload" 798 | 799 | done := make(chan bool) 800 | timeout := time.NewTicker(time.Second * 2) 801 | 802 | h := func(w http.ResponseWriter, r *http.Request) { 803 | q := Request{} 804 | s, err := q.Upgrade(w, r) 805 | 806 | if err != nil { 807 | t.Fatal("unexpected error was returned", err) 808 | } 809 | 810 | s.ReadHandler = func(o int, p []byte) { 811 | if o != OpcodeText { 812 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeText, o) 813 | } 814 | 815 | if string(p) != payload { 816 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload) 817 | } 818 | 819 | done <- true 820 | } 821 | 822 | s.Listen() 823 | } 824 | 825 | s := httptest.NewServer(http.HandlerFunc(h)) 826 | defer s.Close() 827 | 828 | d := &Dialer{} 829 | c, _, err := d.Dial(adaptURL(s.URL)) 830 | 831 | if err != nil { 832 | t.Fatal("unexpected error was returned", err) 833 | } 834 | 835 | defer c.TCPClose() 836 | 837 | if err := c.WriteMessage(OpcodeText, []byte(payload)); err != nil { 838 | t.Fatal("unexpected error returned", err) 839 | } 840 | 841 | select { 842 | case <-done: 843 | { 844 | 845 | } 846 | case <-timeout.C: 847 | { 848 | t.Error("test case timed out") 849 | } 850 | } 851 | } 852 | 853 | func TestSocketWriteFromServer(t *testing.T) { 854 | payload := "expected payload" 855 | 856 | done := make(chan bool) 857 | timeout := time.NewTicker(time.Second * 2) 858 | 859 | h := func(w http.ResponseWriter, r *http.Request) { 860 | q := Request{} 861 | s, err := q.Upgrade(w, r) 862 | 863 | if err != nil { 864 | t.Fatal("unexpected error was returned", err) 865 | } 866 | 867 | if err := s.WriteMessage(OpcodeText, []byte(payload)); err != nil { 868 | t.Fatal("unexpected error was returned", err) 869 | } 870 | } 871 | 872 | s := httptest.NewServer(http.HandlerFunc(h)) 873 | defer s.Close() 874 | 875 | d := &Dialer{} 876 | c, _, err := d.Dial(adaptURL(s.URL)) 877 | 878 | if err != nil { 879 | t.Fatal("unexpected error returned", err) 880 | } 881 | 882 | defer c.TCPClose() 883 | 884 | c.ReadHandler = func(o int, p []byte) { 885 | if o != OpcodeText { 886 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeText, o) 887 | } 888 | 889 | if string(p) != payload { 890 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload) 891 | } 892 | 893 | done <- true 894 | } 895 | 896 | go c.Listen() 897 | 898 | select { 899 | case <-done: 900 | { 901 | 902 | } 903 | case <-timeout.C: 904 | { 905 | t.Error("test case timed out") 906 | } 907 | } 908 | } 909 | 910 | func TestSocketWriteWhenClosed(t *testing.T) { 911 | s := &Socket{ 912 | writeMutex: &sync.Mutex{}, 913 | } 914 | s.state = stateClosed 915 | 916 | if err := s.WriteMessage(1, []byte("test")); err != ErrSocketClosed { 917 | t.Errorf(`expected error "%s", but got "%v"`, ErrSocketClosed, err) 918 | } 919 | } 920 | 921 | func adaptURL(u string) string { 922 | return strings.Replace(u, "http://", "ws://", 1) 923 | } 924 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "crypto/sha1" 6 | "encoding/base64" 7 | "encoding/binary" 8 | "io" 9 | "math/rand" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | // wsAcceptSalt is the GUID used by the WebSocket protocol to generate the 15 | // value for the "Sec-Websocket-Accept" response HTTP Header field. 16 | const wsAcceptSalt string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 17 | 18 | // makeAcceptKey is used to generate the Accept Key which is then sent to the 19 | // client using the 'Sec-Websocket-Accept' Response Header Field. This is used 20 | // to prevent an attacker from ticking the server. 21 | // 22 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-1.3 23 | func makeAcceptKey(k string) string { 24 | h := sha1.New() 25 | io.WriteString(h, k+wsAcceptSalt) 26 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) 27 | } 28 | 29 | // readFromBuffer reads from the buffer (b) provided the number of specified 30 | // bytes (l). 31 | func readFromBuffer(b *bufio.Reader, l uint64) ([]byte, error) { 32 | p := make([]byte, l) 33 | 34 | // If the number of buffered bytes will accommodate the number of bytes 35 | // requested, read once and return the read bytes. 36 | if uint64(b.Buffered()) >= l { 37 | _, err := b.Read(p) 38 | return p, err 39 | } 40 | 41 | // If the user requires more bytes than there is buffered, the buffer will 42 | // be read from multiple times. 43 | 44 | // Total number of bytes read from buffer. 45 | n := 0 46 | 47 | for { 48 | // Temporary slice of bytes. 49 | t := make([]byte, l) 50 | 51 | // Read from buffer and put read bytes in temporary slice of bytes. 52 | i, err := b.Read(t) 53 | 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | // Append bytes to the slice of bytes to be returned. 59 | p = append(p[:n], t[:i]...) 60 | 61 | // Increment the total number of bytes with the bytes read. 62 | n += i 63 | 64 | // If the total number of bytes is the same as the number of bytes 65 | // requested, stop read operation and read bytes. 66 | if uint64(n) == l { 67 | break 68 | } 69 | } 70 | 71 | return p, nil 72 | } 73 | 74 | // stringExists is a utility function used to check whether a slice of string 75 | // ('l') contains a particular value ('k'). If it does, its position will be 76 | // returned otherwise '-1' is returned. 77 | func stringExists(l []string, k string) int { 78 | for i, v := range l { 79 | if k == v { 80 | return i 81 | } 82 | } 83 | 84 | return -1 85 | } 86 | 87 | // headerToSlice is used to turn the values of a multi value HTTP Header field 88 | // to a slice of string. 89 | // 90 | // From RFC2616: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 91 | func headerToSlice(v string) []string { 92 | l := strings.Split(v, ",") 93 | 94 | for i, v := range l { 95 | l[i] = strings.Trim(v, " ") 96 | } 97 | 98 | return l 99 | } 100 | 101 | // randomByteSlice is used to generate a byte slice of random 32 bit integers. 102 | func randomByteSlice(i int) []byte { 103 | // Slice of bytes which will grow to be 16 bytes in length once the 104 | // operation is ready. This slice will then be used to generate the key to 105 | // be sent with the clients opening handshake using the Sec-Websocket-Key 106 | // Header. 107 | var b []byte 108 | 109 | // Set seed. 110 | rand.Seed(time.Now().UnixNano()) 111 | 112 | // The challenge key must be 16 bytes in length. 113 | for l := 0; l < i; l++ { 114 | // Temp slice 115 | t := make([]byte, 4) 116 | 117 | // Generate a random 32bit number and store its binary value in 't'. 118 | binary.BigEndian.PutUint32(t, rand.Uint32()) 119 | 120 | // Finally append the random generated number to 'b'. 121 | b = append(b, t...) 122 | } 123 | 124 | return b 125 | } 126 | 127 | // closeErrorExist returns whether the error number provided as an argument is 128 | // a valid error number or not. 129 | func closeErrorExist(i int) bool { 130 | switch i { 131 | case CloseNormalClosure, CloseGoingAway, CloseProtocolError, CloseUnsupportedData, CloseNoStatusReceived, CloseAbnormalClosure, CloseInvalidFramePayloadData, ClosePolicyViolation, CloseMessageTooBig, CloseMandatoryExtension, CloseInternalServerErr, CloseTLSHandshake: 132 | { 133 | return true 134 | } 135 | } 136 | return false 137 | } 138 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package websocket 2 | 3 | import ( 4 | "bufio" 5 | "math/rand" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestMakeAcceptKey(t *testing.T) { 11 | e := "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" 12 | k := makeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==") 13 | if k != e { 14 | t.Errorf(`expected "%s" instead "%s" was returned.`, e, k) 15 | } 16 | } 17 | 18 | type payloadMock struct { 19 | p []byte 20 | } 21 | 22 | func (m *payloadMock) Read(p []byte) (int, error) { 23 | n := 0 24 | 25 | for i := range m.p { 26 | if i == len(p) { 27 | break 28 | } 29 | 30 | p[i] = m.p[i] 31 | n++ 32 | } 33 | 34 | m.p = append(make([]byte, 0), m.p[n:]...) 35 | 36 | return n, nil 37 | } 38 | 39 | func newBuffer(d []byte) *bufio.Reader { 40 | p := &payloadMock{p: d} 41 | r := bufio.NewReader(p) 42 | return r 43 | } 44 | 45 | func TestReadFromBufferSingleRead(t *testing.T) { 46 | var c uint64 = 3 47 | p := []byte{120, 123, 54, 32, 102} 48 | b := newBuffer(p) 49 | 50 | n, err := readFromBuffer(b, c) 51 | 52 | if err != nil { 53 | t.Fatal("An unexpected error was returned while invoking readFromBuffer():", err) 54 | } 55 | 56 | if uint64(len(n)) != c { 57 | t.Errorf("Expected slice of bytes returned from readFromBuffer to be of the length '%d'. Instead it is '%d'.", c, len(n)) 58 | } 59 | 60 | for i, v := range n { 61 | if v != p[i] { 62 | t.Fatalf("Expected slice of bytes to be '%v'. Instead it is '%v'.", p[:c], n) 63 | } 64 | } 65 | } 66 | 67 | func TestReadFromBufferMultiRead(t *testing.T) { 68 | // The slice to be read from the buffer must be greater than 4096. Since 69 | // this is the default size of a bufio buffer. 70 | // GO Ref: https://golang.org/src/bufio/bufio.go#L18 71 | p := make([]byte, 4100) 72 | 73 | for i := range p { 74 | rand.Seed(int64(i)) 75 | p[i] = byte(rand.Intn(255)) 76 | } 77 | 78 | b := newBuffer(p) 79 | 80 | readFromBuffer(b, 4090) 81 | n, err := readFromBuffer(b, 10) 82 | 83 | if err != nil { 84 | t.Error("Unexpected error was returned while invoking readFromBuffer:", err) 85 | } 86 | 87 | for i, v := range n { 88 | if v != p[i+4090] { 89 | t.Errorf("%v != %v", p[i+4090], v) 90 | } 91 | } 92 | } 93 | 94 | func TestStringExists(t *testing.T) { 95 | l := []string{"one", "two", "three"} 96 | 97 | type testCase struct { 98 | k string 99 | v int 100 | } 101 | 102 | testCases := []testCase{ 103 | {k: "one", v: 0}, 104 | {k: "four", v: -1}, 105 | } 106 | 107 | for i, c := range testCases { 108 | r := stringExists(l, c.k) 109 | 110 | if r != c.v { 111 | t.Errorf(`Test Case %d: Expected stringExists("%s") to return '%d' instead returned '%d'`, i, c.k, c.v, r) 112 | } 113 | } 114 | } 115 | 116 | func TestHeaderToSlice(t *testing.T) { 117 | l := []string{" both ", " left", "right ", "none"} 118 | 119 | r := headerToSlice(strings.Join(l, ",")) 120 | 121 | if len(l) != len(r) { 122 | t.Errorf("The length of the list of header value are not the same. '%d' != '%d'.", len(l), len(r)) 123 | } 124 | 125 | if r[0] != "both" { 126 | t.Errorf(`Expected "both" instead got "%s".`, r[0]) 127 | } 128 | 129 | if r[1] != "left" { 130 | t.Errorf(`Expected "left" instead got "%s".`, r[1]) 131 | } 132 | 133 | if r[2] != "right" { 134 | t.Errorf(`Expected "right" instead got "%s".`, r[2]) 135 | } 136 | 137 | if r[3] != "none" { 138 | t.Errorf(`Expected "none" instead got "%s".`, r[3]) 139 | } 140 | } 141 | 142 | func TestRandomByteSlice(t *testing.T) { 143 | type testCase struct { 144 | l int 145 | } 146 | 147 | testCases := []testCase{ 148 | {l: 2}, 149 | {l: 6}, 150 | } 151 | 152 | for i, c := range testCases { 153 | if b := randomByteSlice(c.l); len(b) != c.l*4 { 154 | t.Errorf("test case %d: expected slice of bytes to be '%d' in length, but it is '%d'", i, c.l*4, len(b)) 155 | } 156 | } 157 | } 158 | 159 | func TestCloseErrorExist(t *testing.T) { 160 | type testCase struct { 161 | e int 162 | v bool 163 | } 164 | 165 | testCases := []testCase{ 166 | // Should return false when opcode is invalid 167 | {e: 15, v: false}, 168 | // Should return true when opcode is valid. 169 | {e: CloseNormalClosure, v: true}, 170 | } 171 | 172 | for i, c := range testCases { 173 | if v := closeErrorExist(c.e); v != c.v { 174 | t.Errorf("test case %d: expected '%t' for '%d'", i, c.v, c.e) 175 | } 176 | } 177 | } 178 | --------------------------------------------------------------------------------