├── .gitignore
├── LICENSE
├── README.md
├── dialer.go
├── dialer_test.go
├── dialer_utils.go
├── dialer_utils_test.go
├── doc.go
├── errors.go
├── errors_test.go
├── examples
└── chat
│ ├── main.go
│ ├── manager.go
│ └── public
│ ├── css
│ └── style.css
│ ├── index.html
│ └── js
│ └── app.js
├── frame.go
├── frame_test.go
├── frame_utils.go
├── frame_utils_test.go
├── request.go
├── request_test.go
├── request_utils.go
├── request_utils_test.go
├── socket.go
├── socket_test.go
├── utils.go
└── utils_test.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled Object files, Static and Dynamic libs (Shared Objects)
2 | *.o
3 | *.a
4 | *.so
5 |
6 | # Folders
7 | _obj
8 | _test
9 |
10 | # Architecture specific extensions/prefixes
11 | *.[568vq]
12 | [568vq].out
13 |
14 | *.cgo1.go
15 | *.cgo2.c
16 | _cgo_defun.c
17 | _cgo_gotypes.go
18 | _cgo_export.*
19 |
20 | _testmain.go
21 |
22 | *.exe
23 | *.test
24 | *.prof
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Luca Tabone
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://goreportcard.com/report/github.com/tabone/websocket)
2 | [](https://godoc.org/github.com/tabone/websocket)
3 |
4 | # websocket
5 | Package websocket implements the websocket protocol defined in rfc6455
6 |
7 | ## Installation
8 | go get github.com/tabone/websocket
9 |
10 | ## Documentation
11 | - [API Reference](https://godoc.org/github.com/tabone/websocket)
12 | - [Examples](https://github.com/tabone/websocket/tree/master/examples)
13 |
--------------------------------------------------------------------------------
/dialer.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "bufio"
5 | "crypto/tls"
6 | "net"
7 | "net/http"
8 | "net/url"
9 | "regexp"
10 | "strings"
11 | "sync"
12 | )
13 |
14 | // Dialer is a websocket client.
15 | type Dialer struct {
16 | /*
17 | Header to be included in the opening handshake request.
18 | */
19 | Header http.Header
20 |
21 | /*
22 | SubProtocols which the client supports.
23 | */
24 | SubProtocols []string
25 |
26 | /*
27 | TLSConfig is used to configure the TLS client.
28 | */
29 | TLSConfig *tls.Config
30 | }
31 |
32 | // Dial is the method used to start the websocket connection.
33 | func (d *Dialer) Dial(u string) (*Socket, *http.Response, error) {
34 | // Parse URL to return a valid URL instance.
35 | l, err := parseURL(u)
36 | if err != nil {
37 | return nil, nil, err
38 | }
39 |
40 | // Get a valid websocket opening handshake request instance.
41 | q := d.createRequest(l)
42 |
43 | // Connect with the websocket server.
44 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3
45 | conn, err := net.Dial("tcp", l.Host+"/"+l.Path+"?"+l.RawQuery)
46 | if err != nil {
47 | return nil, nil, err
48 | }
49 |
50 | // When the connection will be over TLS, we need to do the TLS handshake.
51 | if l.Scheme == "wss" {
52 | g := d.TLSConfig
53 |
54 | // Create tls config instance if user hasn't specified one since it is
55 | // required.
56 | if g == nil {
57 | g = &tls.Config{}
58 | }
59 |
60 | // If ServerName is empty, use the host provided by the user.
61 | if g.ServerName == "" {
62 | g.ServerName = strings.Split(l.Host, ":")[0]
63 | }
64 |
65 | // Change the current conenction to a secure one.
66 | c := tls.Client(conn, g)
67 |
68 | // Do the handshake.
69 | if err := c.Handshake(); err != nil {
70 | return nil, nil, err
71 | }
72 |
73 | conn = c
74 | }
75 |
76 | // Send request
77 | if err := q.Write(conn); err != nil {
78 | return nil, nil, err
79 | }
80 |
81 | // Buffer connection.
82 | b := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
83 |
84 | // Read response
85 | r, err := http.ReadResponse(b.Reader, q)
86 |
87 | if err != nil {
88 | return nil, nil, err
89 | }
90 |
91 | // Validate response.
92 | if err := validateResponse(r); err != nil {
93 | return nil, nil, err
94 | }
95 |
96 | return &Socket{
97 | conn: conn,
98 | buf: b,
99 | writeMutex: &sync.Mutex{},
100 | }, r, nil
101 | }
102 |
103 | // createOpeningHandshakeRequest is used to return a valid websocket opening
104 | // handshake client request.
105 | //
106 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
107 | func (d *Dialer) createRequest(l *url.URL) *http.Request {
108 | // Initialize header if not already initialized.
109 | if d.Header == nil {
110 | d.Header = make(http.Header)
111 | }
112 |
113 | // When using the default port the Host header field should only consist of
114 | // the host (no port is shown).
115 | t := l.Host
116 |
117 | switch l.Scheme {
118 | case "ws":
119 | {
120 | re := regexp.MustCompile(":22$")
121 | t = re.ReplaceAllString(t, "")
122 | }
123 | case "wss":
124 | {
125 | re := regexp.MustCompile(":443$")
126 | t = re.ReplaceAllString(t, "")
127 | }
128 | }
129 |
130 | // Include headers
131 | d.Header.Set("Host", t)
132 | d.Header.Set("Upgrade", "websocket")
133 | d.Header.Set("Connection", "upgrade")
134 | d.Header.Set("Sec-WebSocket-Version", "13")
135 | d.Header.Set("Sec-WebSocket-Key", makeChallengeKey())
136 | d.Header.Set("Sec-WebSocket-Protocol", strings.Join(d.SubProtocols, ", "))
137 |
138 | // Create request instance
139 | q := &http.Request{
140 | Method: "GET",
141 | URL: l,
142 | Proto: "HTTP/1.1",
143 | ProtoMajor: 1,
144 | ProtoMinor: 1,
145 | Header: d.Header,
146 | Host: l.Host,
147 | }
148 |
149 | return q
150 | }
151 |
--------------------------------------------------------------------------------
/dialer_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "encoding/base64"
5 | "net/http"
6 | "net/url"
7 | "strings"
8 | "testing"
9 | )
10 |
11 | func TestDialerCreateRequestNilHeader(t *testing.T) {
12 | d := &Dialer{Header: nil}
13 |
14 | q := d.createRequest(&url.URL{})
15 |
16 | if q.Header == nil {
17 | t.Errorf("expected header to be initialized")
18 | }
19 | }
20 |
21 | func TestDialerCreateRequestNonNilHeader(t *testing.T) {
22 | h := make(http.Header)
23 | k := "testKey"
24 | v := "testValue"
25 |
26 | h.Add(k, v)
27 |
28 | d := &Dialer{Header: h}
29 |
30 | q := d.createRequest(&url.URL{})
31 |
32 | if q.Header.Get(k) != v {
33 | t.Errorf("expected header to be the one provided in dialer instance")
34 | }
35 | }
36 |
37 | func TestDialerCreateRequestHostHeader(t *testing.T) {
38 | d := &Dialer{}
39 |
40 | type testCase struct {
41 | u *url.URL
42 | v string
43 | }
44 |
45 | testCases := []testCase{
46 | {u: &url.URL{Scheme: "ws", Host: "localhost:22"}, v: "localhost"},
47 | {u: &url.URL{Scheme: "wss", Host: "localhost:443"}, v: "localhost"},
48 | {u: &url.URL{Scheme: "ws", Host: "localhost:80"}, v: "localhost:80"},
49 | {u: &url.URL{Scheme: "wss", Host: "localhost:80"}, v: "localhost:80"},
50 | }
51 |
52 | for i, c := range testCases {
53 | q := d.createRequest(c.u)
54 | v := q.Header.Get("Host")
55 |
56 | if v != c.v {
57 | t.Errorf(`test case %d: expected Host header field value to be "%s", but it is "%s"`, i, c.v, v)
58 | }
59 | }
60 | }
61 |
62 | func TestDialerCreateRequestHeaders(t *testing.T) {
63 | d := &Dialer{
64 | SubProtocols: []string{"chat", "v1"},
65 | }
66 |
67 | q := d.createRequest(&url.URL{Scheme: "ws", Host: "localhost"})
68 |
69 | v := q.Header.Get("Upgrade")
70 | e := "websocket"
71 | if strings.ToLower(v) != e {
72 | t.Errorf(`expected Upgrade header field value to be "%s", but it is "%s"`, v, e)
73 | }
74 |
75 | v = q.Header.Get("Connection")
76 | e = "upgrade"
77 | if strings.ToLower(v) != e {
78 | t.Errorf(`expected Connection header field value to be "%s", but it is "%s"`, v, e)
79 | }
80 |
81 | v = q.Header.Get("Sec-WebSocket-Version")
82 | e = "13"
83 | if strings.ToLower(v) != e {
84 | t.Errorf(`expected Sec-WebSocket-Version header field value to be "%s", but it is "%s"`, v, e)
85 | }
86 |
87 | v = q.Header.Get("Sec-WebSocket-Protocol")
88 | e = strings.Join(d.SubProtocols, ", ")
89 | if strings.ToLower(v) != e {
90 | t.Errorf(`expected Sec-WebSocket-Protocol header field value to be "%s", but it is "%s"`, v, e)
91 | }
92 |
93 | l, err := base64.StdEncoding.DecodeString(q.Header.Get("Sec-WebSocket-Key"))
94 |
95 | if err != nil {
96 | t.Errorf(`unexpected error returned when decoding Sec-WebSocket-Key %s`, err)
97 | }
98 |
99 | if len(l) != 16 {
100 | t.Errorf(`expected Sec-WebSocket-Protocol header field value to be '%d' in length, but it is '%d'`, len(l), 16)
101 | }
102 | }
103 |
104 | func TestDialerCreateRequestRequest(t *testing.T) {
105 | d := &Dialer{}
106 | u := &url.URL{
107 | Scheme: "ws",
108 | Host: "localhost:8080",
109 | }
110 |
111 | q := d.createRequest(u)
112 |
113 | if q.URL != u {
114 | t.Errorf("expected URL instance to be the one provided")
115 | }
116 |
117 | if q.Method != "GET" {
118 | t.Errorf(`expected method to be "GET", but it is "%s"`, q.Method)
119 | }
120 |
121 | if q.Proto != "HTTP/1.1" {
122 | t.Errorf(`expected http protocol to be "HTTP/1.1", but it is "%s"`, q.Proto)
123 | }
124 |
125 | if !q.ProtoAtLeast(1, 1) {
126 | t.Errorf("expected http protocol used to be at least version 1.1, but it is %d.%d", q.ProtoMajor, q.ProtoMinor)
127 | }
128 |
129 | if q.Host != u.Host {
130 | t.Errorf(`expected host to be "%s", but it is "%s"`, u.Host, q.Host)
131 | }
132 | }
133 |
--------------------------------------------------------------------------------
/dialer_utils.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "encoding/base64"
5 | "errors"
6 | "net/http"
7 | "net/url"
8 | "regexp"
9 | "strings"
10 | )
11 |
12 | // validateResponse is used to determine whether the servers handshake request
13 | // conforms with the WebSocket spec. When it doesn't the client fails the
14 | // websocket connection.
15 | //
16 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
17 | func validateResponse(r *http.Response) *OpenError {
18 | validations := []func(*http.Response) *OpenError{
19 | validateResponseStatus,
20 | validateResponseUpgradeHeader,
21 | validateResponseConnectionHeader,
22 | validateResponseSecWebsocketAcceptHeader,
23 | }
24 |
25 | for _, v := range validations {
26 | if err := v(r); err != nil {
27 | return err
28 | }
29 | }
30 |
31 | return nil
32 | }
33 |
34 | // validateResponseStatus verifies that status code of the server's opening
35 | // handshake response is '101'. If it is not, it means that the handshake has
36 | // been rejected and thus the endpoints are still communicating using http.
37 | //
38 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
39 | func validateResponseStatus(r *http.Response) *OpenError {
40 | if r.StatusCode != 101 {
41 | return &OpenError{
42 | Reason: "http status not 101",
43 | }
44 | }
45 | return nil
46 | }
47 |
48 | // validateResponseUpgradeHeader verifies that the Upgrade HTTP Header value
49 | // in the servers's opening handshake response is "websocket".
50 | //
51 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
52 | func validateResponseUpgradeHeader(r *http.Response) *OpenError {
53 | if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
54 | return &OpenError{
55 | Reason: `"Upgrade" Header should have the value of "websocket"`,
56 | }
57 | }
58 | return nil
59 | }
60 |
61 | // validateResponseConnectionHeader verifies that the Connection HTTP Header
62 | // value in the servers's opening handshake response is "upgrade".
63 | //
64 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
65 | func validateResponseConnectionHeader(r *http.Response) *OpenError {
66 | if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
67 | return &OpenError{
68 | Reason: `"Connection" Header should have the value of "upgrade"`,
69 | }
70 | }
71 | return nil
72 | }
73 |
74 | // validateResponseSecWebsocketAcceptHeader verifies that the
75 | // Sec-WebSocket-Accept HTTP Header value in the server's opening handshake
76 | // response is the base64-encoded SHA-1 of the concatenation of the
77 | // Sec-WebSocket-Key value (sent with the opening handshake request) (as a
78 | // string, not base64-decoded) with the websocket accept key.
79 | //
80 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
81 | func validateResponseSecWebsocketAcceptHeader(r *http.Response) *OpenError {
82 | if r.Header.Get("Sec-WebSocket-Accept") != makeAcceptKey(r.Request.Header.Get("Sec-Websocket-Key")) {
83 | return &OpenError{
84 | Reason: `challenge key failure`,
85 | }
86 | }
87 | return nil
88 | }
89 |
90 | // validateResponseSecWebsocketProtocol verifies that the sub protocol the
91 | // server has agreed to use (Sec-WebSocket-Protocol Header) was in the list the
92 | // client has sent in the opening handshake request.
93 | //
94 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
95 | func validateResponseSecWebsocketProtocol(r *http.Response) *OpenError {
96 | // Sub protocols sent by the client.
97 | c := headerToSlice(r.Request.Header.Get("Sec-WebSocket-Protocol"))
98 | // Sub protocol the server has agreed to use.
99 | s := r.Header.Get("Sec-WebSocket-Protocol")
100 |
101 | // If the server hasn't agreed to use anything, stop process.
102 | if len(s) == 0 {
103 | return nil
104 | }
105 |
106 | // Loop through the lists of sub protocols the client has sent in its
107 | // opening handshake request and if the sub protocol the server argeed to
108 | // use is found stop the process.
109 | for _, cv := range c {
110 | if cv == s {
111 | return nil
112 | }
113 | }
114 |
115 | // At this point the server has agreed to use a sub protocol which the
116 | // client doesn't support and thus return an error.
117 | return &OpenError{
118 | Reason: `server choose a sub protocol which was not in the list sent by the client`,
119 | }
120 | }
121 |
122 | // makeChallengeKey is used to generate the key to be sent with the client's
123 | // opening handshake using the Sec-Websocket-Key header field.
124 | //
125 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.1
126 | func makeChallengeKey() string {
127 | // return Base64 encode version of the byte generated.
128 | return base64.StdEncoding.EncodeToString(randomByteSlice(4))
129 | }
130 |
131 | // parseURL is used to parse the URL string provided and verifies that it
132 | // conforms with the websocket spec. If it does it will create and return a URL
133 | // instance representing the URL string provided.
134 | //
135 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3
136 | func parseURL(u string) (*url.URL, error) {
137 | // Parse scheme.
138 | if err := parseURLScheme(&u); err != nil {
139 | return nil, err
140 | }
141 |
142 | // Create URL Instance.
143 | l, err := url.Parse(u)
144 | if err != nil {
145 | return nil, err
146 | }
147 |
148 | // Parse Host.
149 | if err := parseURLHost(l); err != nil {
150 | return nil, err
151 | }
152 |
153 | return l, nil
154 | }
155 |
156 | // parseURLScheme is used to parse the Scheme portion of a URL string. If the
157 | // scheme provided is not a valid websocket scheme an error is returned. If no
158 | // scheme is given it will be defaulted to "ws".
159 | //
160 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3
161 | func parseURLScheme(u *string) error {
162 | // Regex to retrieve Scheme portion of a URL string.
163 | re := regexp.MustCompile("^([a-zA-Z]+)://")
164 | m := re.FindStringSubmatch(*u)
165 |
166 | // If m is smaller than 2 it means that the user hasn't provided one and
167 | // thus the default sheme (ws) is used.
168 | if len(m) < 2 {
169 | *u = "ws://" + *u
170 | return nil
171 | }
172 |
173 | // If a sheme was captured, make sure it is valid.
174 | if !schemeValid(m[1]) {
175 | return errors.New("invalid scheme: " + m[1])
176 | }
177 |
178 | return nil
179 | }
180 |
181 | // parseURLHost is used to parse the Host portion of a URL instance to
182 | // determine whether it has a port or not. When no port is found this method
183 | // will assign a port based on the URL instance scheme (ws = 22, wss = 443). If
184 | // the scheme is not a valid scheme for websocket an error is returned.
185 | //
186 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-3
187 | func parseURLHost(u *url.URL) error {
188 | // If scheme is invalid throw an error
189 | if !schemeValid(u.Scheme) {
190 | return errors.New("invalid scheme: " + u.Scheme)
191 | }
192 |
193 | // Regex to retrieve the Port portion of the URL.
194 | re := regexp.MustCompile(":(\\d+)")
195 | m := re.FindStringSubmatch(u.Host)
196 |
197 | // If the length of m is greater than or equals to 2 it means that there is
198 | // a submatch, meaning that the user has provided a port and thus there is
199 | // no need to include the default ports.
200 | if len(m) >= 2 {
201 | return nil
202 | }
203 |
204 | // Based on the scheme, set the port.
205 | switch u.Scheme {
206 | case "ws":
207 | {
208 | u.Host += ":22"
209 | }
210 | case "wss":
211 | {
212 | u.Host += ":443"
213 | }
214 | }
215 |
216 | return nil
217 | }
218 |
219 | // schemeValid is used to determine whether the scheme provided is a valid
220 | // scheme for the websocket protocol.
221 | func schemeValid(s string) bool {
222 | return s == "ws" || s == "wss"
223 | }
224 |
--------------------------------------------------------------------------------
/dialer_utils_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "encoding/base64"
5 | "net/http"
6 | "net/url"
7 | "testing"
8 | )
9 |
10 | func TestParseURLScheme(t *testing.T) {
11 | type testCase struct {
12 | u string
13 | f string
14 | }
15 |
16 | testCases := []testCase{
17 | {u: "ws://localhost:8080", f: "ws://localhost:8080"},
18 | {u: "wss://localhost:8080", f: "wss://localhost:8080"},
19 | {u: "localhost:8080", f: "ws://localhost:8080"},
20 | }
21 |
22 | for i, c := range testCases {
23 | n := c.u
24 |
25 | if err := parseURLScheme(&n); err != nil {
26 | t.Errorf("test case %d: unexpected error was returned %s", i, err)
27 | }
28 |
29 | if n != c.f {
30 | t.Errorf(`test case %d: expected url to be "%s", but it is "%s"`, i, c.f, n)
31 | }
32 | }
33 | }
34 |
35 | func TestParseURLSchemeError(t *testing.T) {
36 | u := "http://localhost:8080"
37 |
38 | if err := parseURLScheme(&u); err == nil {
39 | t.Error("expected an error for", u)
40 | }
41 | }
42 |
43 | func TestParseURLHost(t *testing.T) {
44 | type testCase struct {
45 | u *url.URL
46 | h string
47 | }
48 |
49 | testCases := []testCase{
50 | {u: &url.URL{Scheme: "ws", Host: "localhost:80"}, h: "localhost:80"},
51 | {u: &url.URL{Scheme: "wss", Host: "localhost:80"}, h: "localhost:80"},
52 | {u: &url.URL{Scheme: "ws", Host: "localhost"}, h: "localhost:22"},
53 | {u: &url.URL{Scheme: "wss", Host: "localhost"}, h: "localhost:443"},
54 | }
55 |
56 | for i, c := range testCases {
57 | if err := parseURLHost(c.u); err != nil {
58 | t.Errorf("test case %d: unexpected error was returned %s", i, err)
59 | }
60 |
61 | if c.u.Host != c.h {
62 | t.Errorf(`test case %d: expected host to be "%s", but it is "%s"`, i, c.h, c.u.Host)
63 | }
64 | }
65 | }
66 |
67 | func TestParseURLHostError(t *testing.T) {
68 | u := &url.URL{
69 | Scheme: "http",
70 | Host: "localhost",
71 | }
72 |
73 | if err := parseURLHost(u); err == nil {
74 | t.Errorf("expected an error to be returned")
75 | }
76 | }
77 |
78 | func TestMakeChallengeKey(t *testing.T) {
79 | k := makeChallengeKey()
80 | b, err := base64.StdEncoding.DecodeString(k)
81 |
82 | if err != nil {
83 | t.Errorf("unexpected error was returned while decoding value: %s", err)
84 | }
85 |
86 | if len(b) != 16 {
87 | t.Errorf("expected length of decoded challenge key to be 16, but it is %d", len(b))
88 | }
89 | }
90 |
91 | func TestValidateResponseStatus(t *testing.T) {
92 | type testCase struct {
93 | s int
94 | e bool
95 | }
96 |
97 | testCases := []testCase{
98 | {s: 101, e: false},
99 | {s: 200, e: true},
100 | }
101 |
102 | for i, c := range testCases {
103 | r := &http.Response{
104 | StatusCode: c.s,
105 | }
106 |
107 | err := validateResponseStatus(r)
108 |
109 | if c.e && err == nil {
110 | t.Errorf(`test case %d: expected an error for '%d'`, i, c.s)
111 | }
112 |
113 | if !c.e && err != nil {
114 | t.Errorf(`test case %d: unexpected error returned for '%d'`, i, c.s)
115 | }
116 | }
117 | }
118 |
119 | func TestValidateResponseUpgradeHeader(t *testing.T) {
120 | type testCase struct {
121 | v string
122 | e bool
123 | }
124 |
125 | testCases := []testCase{
126 | {v: "websocket", e: false},
127 | {v: "WebSocket", e: false},
128 | {v: "wrong", e: true},
129 | }
130 |
131 | for i, c := range testCases {
132 | r := &http.Response{
133 | Header: make(http.Header),
134 | }
135 |
136 | r.Header.Add("Upgrade", c.v)
137 | err := validateResponseUpgradeHeader(r)
138 |
139 | if c.e && err == nil {
140 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v)
141 | }
142 |
143 | if !c.e && err != nil {
144 | t.Errorf(`test case %d: unexpected error returned for "%s"`, i, c.v)
145 | }
146 | }
147 | }
148 |
149 | func TestValidateResponseConnectionHeader(t *testing.T) {
150 | type testCase struct {
151 | v string
152 | e bool
153 | }
154 |
155 | testCases := []testCase{
156 | {v: "upgrade", e: false},
157 | {v: "UpgrADE", e: false},
158 | {v: "wrong", e: true},
159 | }
160 |
161 | for i, c := range testCases {
162 | r := &http.Response{
163 | Header: make(http.Header),
164 | }
165 |
166 | r.Header.Add("Connection", c.v)
167 | err := validateResponseConnectionHeader(r)
168 |
169 | if c.e && err == nil {
170 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v)
171 | }
172 |
173 | if !c.e && err != nil {
174 | t.Errorf(`test case %d: unexpected error returned for "%s"`, i, c.v)
175 | }
176 | }
177 | }
178 |
179 | func TestValidateResponseSecWebsocketProtocol(t *testing.T) {
180 | type testCase struct {
181 | c string
182 | s string
183 | e bool
184 | }
185 |
186 | testCases := []testCase{
187 | {c: "client, v1", s: "", e: false},
188 | {c: "client, v1", s: "v1", e: false},
189 | {c: "client, v1", s: "v2", e: true},
190 | }
191 |
192 | for i, c := range testCases {
193 | // Headers sent by client
194 | hq := make(http.Header)
195 | hq.Set("Sec-WebSocket-Protocol", c.c)
196 |
197 | // Headers sent by server
198 | hr := make(http.Header)
199 | hr.Set("Sec-WebSocket-Protocol", c.s)
200 |
201 | q := &http.Request{
202 | Header: hq,
203 | }
204 |
205 | r := &http.Response{
206 | Header: hr,
207 | Request: q,
208 | }
209 |
210 | err := validateResponseSecWebsocketProtocol(r)
211 |
212 | if c.e && err == nil {
213 | t.Errorf(`test case %d: expected an error when the client sent "%s" as supported protocols and the server agreed to use "%s"`, i, c.c, c.s)
214 | }
215 |
216 | if !c.e && err != nil {
217 | t.Errorf(`test case %d: unexpected error was returned when the client sent "%s" as supported protocols and the server agreed to use "%s"`, i, c.c, c.s)
218 | }
219 | }
220 | }
221 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | // Package websocket implements the websocket protocol defined in rfc6455.
2 | package websocket
3 |
--------------------------------------------------------------------------------
/errors.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "encoding/binary"
5 | "errors"
6 | "fmt"
7 | )
8 |
9 | // CloseError represents errors related to the websocket closing handshake.
10 | type CloseError struct {
11 | Code int
12 | Reason string
13 | }
14 |
15 | // Error implements the built in error interface.
16 | func (c *CloseError) Error() string {
17 | return fmt.Sprintf("Close Error: %d %s", c.Code, c.Reason)
18 | }
19 |
20 | // ToBytes returns the representation of a CloseError instance in a []bytes
21 | // that conforms with the way the websocket rfc expects the payload data of
22 | // CLOSE FRAMES to be.
23 | //
24 | // While generating the []bytes, if the CloseError instance has an invalid
25 | // error code, it will instead create the representation of a 'No Status
26 | // Received Error' (i.e. 1005).
27 | //
28 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1
29 | func (c *CloseError) ToBytes() ([]byte, error) {
30 | // Validate Error Code
31 | if !closeErrorExist(c.Code) {
32 | // If it is not valid, return bytes for No Status Received error.
33 | n := &CloseError{
34 | Code: CloseNoStatusReceived,
35 | Reason: "no status recieved",
36 | }
37 | b, _ := n.ToBytes()
38 | return b, errors.New("invalid error code")
39 | }
40 |
41 | return append(c.toBytesCode(), []byte(c.Reason)...), nil
42 | }
43 |
44 | // toBytesCode is used to get a representation of the CloseError instance
45 | // status code in []bytes.
46 | func (c *CloseError) toBytesCode() []byte {
47 | b := make([]byte, 2)
48 | binary.BigEndian.PutUint16(b, uint16(c.Code))
49 | return b
50 | }
51 |
52 | // NewCloseError is used to create a new CloseError instance by parsing 'b'. In
53 | // order for this to happen the []bytes needs to conform with the way the
54 | // websocket rfc expects the payload data of CLOSE FRAMES to be.
55 | //
56 | // While parsing if the error code (i.e. first two bytes) is invalid, it will
57 | // default the CloseError instance returned to represent a 'No Status Received
58 | // Error' (i.e. 1005).
59 | //
60 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1
61 | func NewCloseError(b []byte) (*CloseError, error) {
62 | var c int
63 |
64 | if len(b) >= 2 {
65 | cb := b[:2]
66 | c = int(binary.BigEndian.Uint16(cb))
67 | }
68 |
69 | if !closeErrorExist(c) {
70 | return &CloseError{
71 | Code: CloseNoStatusReceived,
72 | Reason: "no status recieved",
73 | }, errors.New("invalid error code")
74 | }
75 |
76 | return &CloseError{
77 | Code: c,
78 | Reason: string(b[2:]),
79 | }, nil
80 | }
81 |
82 | // OpenError represents errors related to the websocket opening handshake.
83 | type OpenError struct {
84 | Reason string
85 | }
86 |
87 | // Error implements the built in error interface.
88 | func (h *OpenError) Error() string {
89 | return "Handshake Error: " + h.Reason
90 | }
91 |
--------------------------------------------------------------------------------
/errors_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestCloseErrorToBytes(t *testing.T) {
8 | type testCase struct {
9 | c int
10 | r string
11 | b []byte
12 | }
13 |
14 | testCases := []testCase{
15 | {c: 1001, r: "normal closure", b: []byte{3, 233, 110, 111, 114, 109, 97, 108, 32, 99, 108, 111, 115, 117, 114, 101}},
16 | {c: 1001, r: "", b: []byte{3, 233}},
17 | }
18 |
19 | for i, c := range testCases {
20 | e := &CloseError{Code: c.c, Reason: c.r}
21 |
22 | b, err := e.ToBytes()
23 |
24 | if err != nil {
25 | t.Errorf(`test case %d: unexpected error`, i)
26 | }
27 |
28 | if len(b) != len(c.b) {
29 | t.Errorf(`test case %d: unexpected slice of bytes`, i)
30 | }
31 |
32 | same := true
33 | for bi, bv := range b {
34 | if bv != c.b[bi] {
35 | same = false
36 | break
37 | }
38 | }
39 |
40 | if !same {
41 | t.Errorf(`test case %d: unexpected slice of bytes`, i)
42 | }
43 | }
44 | }
45 |
46 | func TestCloseErrorToBytesError(t *testing.T) {
47 | b := []byte{3, 237, 110, 111, 32, 115, 116, 97, 116, 117, 115, 32, 114, 101, 99, 105, 101, 118, 101, 100}
48 |
49 | c := &CloseError{Code: 0, Reason: "woops"}
50 | e, err := c.ToBytes()
51 |
52 | if err == nil {
53 | t.Error("expected an error")
54 | }
55 |
56 | same := true
57 | for bi, bv := range b {
58 | if bv != e[bi] {
59 | same = false
60 | break
61 | }
62 | }
63 |
64 | if !same {
65 | t.Errorf(`unexpected slice of bytes`)
66 | }
67 | }
68 |
69 | func TestNewCloseError(t *testing.T) {
70 | type testCase struct {
71 | c int
72 | r string
73 | b []byte
74 | }
75 |
76 | testCases := []testCase{
77 | {c: 1001, r: "normal closure", b: []byte{3, 233, 110, 111, 114, 109, 97, 108, 32, 99, 108, 111, 115, 117, 114, 101}},
78 | {c: 1001, r: "", b: []byte{3, 233}},
79 | }
80 |
81 | for i, c := range testCases {
82 | e, err := NewCloseError(c.b)
83 |
84 | if err != nil {
85 | t.Errorf(`test case %d: unexpected error`, i)
86 | }
87 |
88 | if e.Code != c.c {
89 | t.Errorf("test case %d: expected Code to be '%d', but it is '%d'", i, c.c, e.Code)
90 | }
91 |
92 | if e.Reason != c.r {
93 | t.Errorf(`test case %d: expected Reason to be "%s", but it is "%s"`, i, c.r, e.Reason)
94 | }
95 | }
96 | }
97 |
98 | func TestNewCloseErrorError(t *testing.T) {
99 | type testCase struct {
100 | p []byte
101 | }
102 |
103 | testCases := []testCase{
104 | {p: make([]byte, 0)},
105 | {p: []byte{3, 133}},
106 | }
107 |
108 | for i, c := range testCases {
109 | c, err := NewCloseError(c.p)
110 | r := "no status recieved"
111 |
112 | if err == nil {
113 | t.Errorf("test case %d: expected an error", i)
114 | }
115 |
116 | if c.Code != CloseNoStatusReceived {
117 | t.Errorf("test case %d, expected Code to be '%d', but it is '%d'", i, CloseNoStatusReceived, c.Code)
118 | }
119 |
120 | if c.Reason != r {
121 | t.Errorf(`test case %d, expected Reason to be "%s", but it is "%s"`, i, r, c.Reason)
122 | }
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/examples/chat/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/tabone/websocket"
5 | "log"
6 | "net/http"
7 | )
8 |
9 | var m *manager
10 |
11 | func main() {
12 | m = &manager{
13 | users: make(map[int]*websocket.Socket),
14 | }
15 | http.HandleFunc("/ws", wsHandler)
16 | http.Handle("/", http.FileServer(http.Dir("public/")))
17 |
18 | log.Println("listening on localhost:8080.")
19 | http.ListenAndServe("localhost:8080", nil)
20 | }
21 |
22 | func wsHandler(w http.ResponseWriter, r *http.Request) {
23 | log.Println("new connection")
24 |
25 | // Create a new websocket request
26 | q := &websocket.Request{
27 | CheckOrigin: func(r *http.Request) bool {
28 | // Accept all requests.
29 | return true
30 | },
31 | }
32 |
33 | // Try to upgrade the http request.
34 | s, err := q.Upgrade(w, r)
35 |
36 | if err != nil {
37 | log.Println("upgrade failed:", err)
38 | }
39 |
40 | // If upgrade has been successfull, include the socket with the other online
41 | // sockets
42 | m.addSocket(s)
43 | }
44 |
--------------------------------------------------------------------------------
/examples/chat/manager.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "github.com/tabone/websocket"
6 | "log"
7 | "time"
8 | )
9 |
10 | type manager struct {
11 | /*
12 | seq is a sequence which will be used to assign a unique id to each
13 | socket added to the list of online users.
14 | */
15 | seq int
16 |
17 | /*
18 | users will contain a reference to all the online sockets.
19 | */
20 | users map[int]*websocket.Socket
21 | }
22 |
23 | /*
24 | addSocket is used to add a socket to the online list of users.
25 | */
26 | func (m *manager) addSocket(s *websocket.Socket) {
27 | m.seq++
28 | log.Println("user", m.seq, "has logged in")
29 | m.users[m.seq] = s
30 | m.config(m.seq)
31 |
32 | j := fmt.Sprintf(`{"type":"login","data":{"user": %d, "count":%d}}`, m.seq, len(m.users))
33 | m.broadcast([]byte(j))
34 |
35 | go m.ping(s)
36 |
37 | // Start listening for new data.
38 | s.Listen()
39 | }
40 |
41 | func (m *manager) ping(s *websocket.Socket) {
42 | t := time.NewTicker(time.Second * 5)
43 |
44 | for {
45 | <-t.C
46 | if err := s.WriteMessage(websocket.OpcodePing, nil); err != nil {
47 | log.Println(err)
48 | break
49 | }
50 | }
51 | t.Stop()
52 | }
53 |
54 | /*
55 | removeSocket is used to remove a socket from the online list of users using
56 | its id.
57 | */
58 | func (m *manager) removeSocket(i int) {
59 | log.Println("user", i, "has logged out")
60 | delete(m.users, i)
61 | }
62 |
63 | /*
64 | config is used to configure the socket instance.
65 | */
66 | func (m *manager) config(i int) {
67 | s := m.users[i]
68 |
69 | s.ReadHandler = func(o int, p []byte) {
70 | log.Println("user", i, "sent a message:", string(p))
71 | j := fmt.Sprintf(`{"type":"message","data":"%s"}`, p)
72 | m.broadcast([]byte(j))
73 | }
74 |
75 | s.CloseHandler = func(err error) {
76 | log.Println("user", i, "disconnected:", err)
77 | m.removeSocket(i)
78 | j := fmt.Sprintf(`{"type":"logout","data":{"user": %d, "count":%d}}`, i, len(m.users))
79 | m.broadcast([]byte(j))
80 | }
81 |
82 | s.PongHandler = func(p []byte) {
83 | log.Println("user", i, "pong recieved")
84 | }
85 | }
86 |
87 | /*
88 | broadcast is used to send a message to all the connected users.
89 | */
90 | func (m *manager) broadcast(p []byte) {
91 | for _, s := range m.users {
92 | s.WriteMessage(websocket.OpcodeText, p)
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/examples/chat/public/css/style.css:
--------------------------------------------------------------------------------
1 | * {
2 | margin:0px;
3 | padding:0px;
4 | box-sizing:border-box;
5 | -moz-box-sizing:border-box;
6 | -webkit-box-sizing:border-box;
7 | }
8 |
9 | body {
10 | font-size:10px;
11 | font-family:sans-serif;
12 | background-color: #F7F7F7;
13 | color:#555;
14 | }
15 |
16 | header h1 {
17 | text-align:center;
18 | padding:30px;
19 | }
20 |
21 | header h1 a {
22 | color:inherit;
23 | text-decoration:none;
24 | font-size:1.5em;
25 | }
26 |
27 | header h1 a:hover {
28 | color:#5E97D6;
29 | }
30 |
31 | .cntr {
32 | padding: 7px 14px;
33 | border-radius:3px;
34 | -moz-border-radius:3px;
35 | -webkit-border-radius:3px;
36 | box-shadow: 0px 1px 2px 0px #AFAFAF;
37 | -moz-box-shadow: 0px 1px 2px 0px #AFAFAF;
38 | -webkit-box-shadow: 0px 1px 2px 0px #AFAFAF;
39 | }
40 |
41 | main #conversation {
42 | overflow: scroll;
43 | margin-bottom:10px;
44 | background-color:#fff;
45 | height:400px;
46 | }
47 |
48 | main {
49 | margin:0px auto;
50 | width:400px;
51 | }
52 |
53 | main #conversation .message {
54 | background-color: #f1f1f1;
55 | padding: 7px 10px;
56 | margin-bottom: 8px;
57 | font-size: 1.3em;
58 | border-radius: 3px;
59 | -moz-border-radius: 3px;
60 | -webkit-border-radius: 3px;
61 | }
62 |
63 | main #conversation .message.logout {
64 | background-color:#E8BABA;
65 | }
66 |
67 | main #conversation .message.login {
68 | background-color:#C0DAB7;
69 | }
70 |
71 | main input[type="text"] {
72 | width: 100%;
73 | border:none;
74 | font-size: 1.3em;
75 | outline:none;
76 | }
77 |
78 | footer {
79 | text-align:center;
80 | margin-top:10px;
81 | }
82 |
83 | footer #count {
84 | font-size: 2.3em;
85 | font-weight: bold;
86 | }
--------------------------------------------------------------------------------
/examples/chat/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | Chat Example
4 |
5 |
6 |
7 |
8 |
15 |
16 |
17 |
18 |
19 | Hello, this is a test
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/examples/chat/public/js/app.js:
--------------------------------------------------------------------------------
1 | 'use strict'
2 |
3 | function Application() {
4 | this._socket = new WebSocket("ws://localhost:8080/ws")
5 |
6 | /**
7 | * Object containing references to important dom elements.
8 | * @type {Object}
9 | */
10 | this._dom = {
11 | /**
12 | * Element which will contain all messages.
13 | * @type {HTML Element}
14 | */
15 | conversation: document.getElementById("conversation"),
16 |
17 | /**
18 | * Input element to be used to send new messages.
19 | * @type {HTML Element}
20 | */
21 | textbox: document.getElementById("textbox"),
22 |
23 | /**
24 | * Element to display information about the number of online users.
25 | * @type {HTML Element}
26 | */
27 | count: document.getElementById("count")
28 | }
29 |
30 | this._init()
31 | }
32 |
33 | /**
34 | * Initializer.
35 | */
36 | Application.prototype._init = function () {
37 | this._setupSocket()
38 | ._setupTextbox()
39 | }
40 |
41 | /**
42 | * Connect with the websocket server and setup listeners.
43 | * @return {Application} The instance.
44 | */
45 | Application.prototype._setupSocket = function () {
46 | var self = this
47 | this._socket.onmessage = function (resp) {
48 | var msg = JSON.parse(resp.data)
49 |
50 | switch (msg.type) {
51 | case "message": {
52 | self._onMessage(msg.data)
53 | break
54 | }
55 | case "login": {
56 | self._onLogin(msg.data)
57 | break
58 | }
59 | case "logout": {
60 | self._onLogout(msg.data)
61 | break
62 | }
63 | }
64 | }
65 | return this
66 | }
67 |
68 | /**
69 | * Add a listener on the Textbox element which when the user clicks on the Enter
70 | * key the text within the input field is sent to the websocket server.
71 | * @return {Application} The instance.
72 | */
73 | Application.prototype._setupTextbox = function () {
74 | var self = this
75 | this._dom.textbox.onkeydown = function (ev) {
76 | if (ev.keyCode == 13 && this.value !== "") {
77 | self._socket.send(this.value)
78 | this.value = ""
79 | }
80 | }
81 | return this
82 | }
83 |
84 | /**
85 | * Method used to create a comment box.
86 | * @return {HTML Element} The comment box element.
87 | */
88 | Application.prototype._createCommentBox = function () {
89 | var elem = document.createElement("div")
90 | elem.className = "message"
91 | return elem
92 | }
93 |
94 | /**
95 | * Method triggered when a new message is recieved from the server.
96 | * @param {String} msg The message to be displayed.
97 | */
98 | Application.prototype._onMessage = function (msg) {
99 | var elem = this._createCommentBox()
100 | elem.innerHTML = msg
101 | this._dom.conversation.appendChild(elem)
102 | }
103 |
104 | /**
105 | * Method triggered when the message recieved is of type 'login' which means a
106 | * new user has joined to conversation.
107 | * @param {Number} Object.user The id of the user.
108 | * @param {Number} Object.count The number of online users.
109 | */
110 | Application.prototype._onLogin = function (msg) {
111 | var elem = this._createCommentBox()
112 | elem.className = "message login"
113 | elem.innerHTML = "User " + msg.user + " Logged in"
114 | this._dom.conversation.appendChild(elem)
115 | this._dom.count.innerHTML = msg.count
116 | }
117 |
118 | /**
119 | * Method triggered when the message recieved is of type 'logout' which means a
120 | * new user has exited to conversation.
121 | * @param {Number} Object.user The id of the user.
122 | * @param {Number} Object.count The number of online users.
123 | */
124 | Application.prototype._onLogout = function (msg) {
125 | var elem = this._createCommentBox()
126 | elem.className = "message logout"
127 | elem.innerHTML = "User " + msg.user + " Logged out"
128 | this._dom.conversation.appendChild(elem)
129 | this._dom.count.innerHTML = msg.count
130 | }
131 |
132 | ;(function () {
133 | var app = new Application()
134 | }())
--------------------------------------------------------------------------------
/frame.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "bufio"
5 | "encoding/binary"
6 | "fmt"
7 | )
8 |
9 | // WebSocket Opcodes.
10 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2
11 | const (
12 | OpcodeContinuation int = 0
13 | OpcodeText int = 1
14 | OpcodeBinary int = 2
15 | OpcodeClose int = 8
16 | OpcodePing int = 9
17 | OpcodePong int = 10
18 | )
19 |
20 | // frame represents a Websocket Data Frame.
21 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2
22 | type frame struct {
23 | /*
24 | fin indicates that the frame is the final fragment.
25 | */
26 | fin bool
27 |
28 | /*
29 | opcode defines the interpretation of the payload data.
30 | */
31 | opcode int
32 |
33 | /*
34 | masked defines whether the payload data is masked.
35 | */
36 | masked bool
37 |
38 | /*
39 | length specifies the length of the payload data in bytes.
40 | */
41 | length uint64
42 |
43 | /*
44 | key contains the masking key to be used to decode the payload data (if
45 | data is masked). It is 32 bits in length.
46 | */
47 | key []byte
48 |
49 | /*
50 | payload contains the data received from the client.
51 | */
52 | payload []byte
53 | }
54 |
55 | // newFrame is a constructor function to create a new instance of frame by
56 | // reading from a buffer. The construction of the websocket frame is divided
57 | // into four sections:
58 | // 1. Parsing of first 2 bytes.
59 | // 2. Parsing of 'payload length' if 'payload length' parsed in first
60 | // section is greater 125.
61 | // 3. Parsing of 'masking key' if 'masked' value parsed in first section is
62 | // set to true.
63 | // 4. Parsing of payload data.
64 | func newFrame(b *bufio.Reader) (*frame, error) {
65 | // Create frame instance.
66 | f := &frame{}
67 |
68 | reads := []func(*bufio.Reader) error{
69 | f.readInitial,
70 | f.readLength,
71 | f.readMaskKey,
72 | f.readPayload,
73 | }
74 |
75 | for _, read := range reads {
76 | if err := read(b); err != nil {
77 | return nil, err
78 | }
79 | }
80 |
81 | return f, nil
82 | }
83 |
84 | // readInitial is the first method that should be invoked to create the frame
85 | // instance based on the contents from a buffer. This method reads first 2
86 | // bytes of a websocket frame which includes: fin (1 bit), rsv 1-3 (3 bits),
87 | // opcode (4 bits), mask (1 bit) and payload length (7 bits). It accepts a
88 | // buffer as an argument which will be used to read the frame from.
89 | func (f *frame) readInitial(b *bufio.Reader) error {
90 | // Read first 2 bytes.
91 | p, err := readFromBuffer(b, 2)
92 |
93 | if err != nil {
94 | return err
95 | }
96 |
97 | // Reading 'fin'
98 | if p[0]>>7 == 1 {
99 | f.fin = true
100 | }
101 |
102 | // Since library doesn't support extensions if RSV1-3 are non zeros, fail
103 | // connection
104 | if p[0]&112 /* 01110000 */ != 0 {
105 | return &CloseError{
106 | Code: CloseProtocolError,
107 | Reason: "no support for extensions",
108 | }
109 | }
110 |
111 | // Reading 'opcode'
112 | f.opcode = int(p[0]) & 15 /* 00001111 */
113 |
114 | // if opcode doesn't exists, must stop connection
115 | if !opcodeExist(f.opcode) {
116 | return &CloseError{
117 | Code: CloseProtocolError,
118 | Reason: fmt.Sprintf("unsupported opcode: %d", f.opcode),
119 | }
120 | }
121 |
122 | // Reading 'mask'
123 | if p[1]>>7 == 1 {
124 | f.masked = true
125 | }
126 |
127 | // Reading 'payload len'
128 | f.length = uint64(p[1]) & 127 /* 01111111 */
129 |
130 | return nil
131 | }
132 |
133 | // readLength should be invoked after readInitial method and is used to read
134 | // the next 2 (if f.length == 126) or 8 (if f.length == 127) bytes. If f.length
135 | // is <= 125, no read operations are done to the buffer provided as an
136 | // argument.
137 | func (f *frame) readLength(b *bufio.Reader) error {
138 | // If f.length is <= 125 it means that we already have the payload length,
139 | // thus stop read operation.
140 | if f.length <= 125 {
141 | return nil
142 | }
143 |
144 | // For when f.length == 126, read next 2 bytes.
145 | var l uint64 = 2
146 |
147 | // If f.length == 127, read next 8 bytes.
148 | if f.length == 127 {
149 | l = 8
150 | }
151 |
152 | // Read number of bytes based on f.length.
153 | u, err := readFromBuffer(b, l)
154 |
155 | if err != nil {
156 | return err
157 | }
158 |
159 | // Reset length
160 | f.length = 0
161 |
162 | // At this point the bytes that represent the real payload length has been
163 | // retrieved from the buffer. So the next thing to do is to convert the byte
164 | // slice (representing the length) to an integer by combining the bytes
165 | // together.
166 | //
167 | // Example: Let say the slice of bytes repesenting the payload length is
168 | // [134, 129] (or [10000110, 10000001] in binary).
169 | //
170 | // loop 1: f.length == 0
171 | // line 1: Bitwise left shift of 8
172 | // length = 0
173 | // line 2: Add the byte being traversed to f.length.
174 | // length = 1310000110
175 | //
176 | // loop 2: f.length == 134 (or 10000110)
177 | // line 1: Bitwise left shift of 8
178 | // length = 10000110 00000000 (i.e. 34304)
179 | // line 2: Add the byte being traversed to f.length.
180 | // length = 10000110 10000001 (i.e. 34433)
181 | for _, v := range u {
182 | f.length = f.length << 8
183 | f.length += uint64(v)
184 | }
185 |
186 | // Most Significant Bit must be 0.
187 | f.length = f.length & 9223372036854775807
188 |
189 | return nil
190 | }
191 |
192 | // readMaskKey should be invoked after readLength method and is used to read
193 | // the next 4 bytes from the buffer to retrieve the masking key. Note that if
194 | // the payload data is not masked (f.masked == false - info retrieved from
195 | // readInitial) no read operations are done to the buffer provided as an
196 | // argument.
197 | func (f *frame) readMaskKey(b *bufio.Reader) error {
198 | // If payload is not masked, stop process
199 | if !f.masked {
200 | return nil
201 | }
202 |
203 | // Read 4 bytes for masking key
204 | p, err := readFromBuffer(b, 4)
205 |
206 | if err != nil {
207 | return err
208 | }
209 |
210 | // Store key in frame instance
211 | f.key = p
212 |
213 | return nil
214 | }
215 |
216 | // readPayload should be invoked after readMaskKey method and is used to read
217 | // the payload data from the buffer. The number of bytes to read are known from
218 | // f.length (info retrieved from either readInitial or readLength). In addition
219 | // to this if the payload data is masked (f.masked == true - info retrieved
220 | // from readInitial) the payload data will also be decoded using the masking
221 | // key provided with the frame (f.key - info retrieved from readMaskKey).
222 | func (f *frame) readPayload(b *bufio.Reader) error {
223 | // Read f.length bytes
224 | p, err := readFromBuffer(b, f.length)
225 |
226 | if err != nil {
227 | return err
228 | }
229 |
230 | if f.masked {
231 | // Unmask (decode) payload data
232 | mask(p, f.key)
233 | }
234 |
235 | // Store payload in frame instance.
236 | f.payload = p
237 |
238 | return nil
239 | }
240 |
241 | // toBytes returns a representation of the frame instance as a slice of bytes.
242 | // This method does not consider the values assigned to f.length and f.masked
243 | // since these are calculated using the length of f.payload and value of f.key
244 | // respectively.
245 | func (f *frame) toBytes() ([]byte, error) {
246 | if err := f.validate(); err != nil {
247 | return nil, err
248 | }
249 |
250 | // Slice of bytes used to contain the payload data.
251 | p := make([]byte, 2)
252 |
253 | // Include info for FIN bit.
254 | f.toBytesFin(p)
255 |
256 | // Include info for OPCODE bits.
257 | f.toBytesOpcode(p)
258 |
259 | // Include info for MASK bit.
260 | f.toBytesMasked(p)
261 |
262 | // Include info for PAYLOAD LEN bits.
263 | f.toBytesPayloadLength(p)
264 |
265 | // Append (if any) info for PAYLOAD LENGTH EXTENDED bits.
266 | p = append(p, f.toBytesPayloadLengthExt()...)
267 |
268 | // Append (if any) MASK KEY bits.
269 | p = append(p, f.key...)
270 |
271 | // Append (Masked) Payload data. bits
272 | p = append(p, f.toBytesPayloadData()...)
273 |
274 | // Append and PAYLOAD DATA bits and return whole payload
275 | return p, nil
276 | }
277 |
278 | // validate verifies that the data of the frame instance will result in a valid
279 | // websocket data frame.
280 | func (f *frame) validate() *CloseError {
281 | switch {
282 | // Opcode must exists.
283 | case !opcodeExist(f.opcode):
284 | {
285 | return &CloseError{
286 | Code: CloseProtocolError,
287 | Reason: fmt.Sprintf("unsupported opcode: %d", f.opcode),
288 | }
289 | }
290 | // Masking key must have a valid length.
291 | case !validateKey(f.key):
292 | {
293 | return &CloseError{
294 | Code: CloseProtocolError,
295 | Reason: "masking key must either be 0 or 4 bytes long",
296 | }
297 | }
298 | // Payload data must have a valid length.
299 | case !validatePayload(f.payload):
300 | {
301 | return &CloseError{
302 | Code: CloseMessageTooBig,
303 | Reason: "maximum payload data exceeded",
304 | }
305 | }
306 | }
307 | return nil
308 | }
309 |
310 | // toBytesFin is used by toBytes to include info in 'p' about the FIN bit of
311 | // the frame instance. Note that this method should be invoked before
312 | // toBytesOpcode method.
313 | func (f *frame) toBytesFin(p []byte) {
314 | if f.fin {
315 | p[0] = 128
316 | }
317 | }
318 |
319 | // toBytesOpcode is used by toBytes to include info in 'p' about the OPCODE
320 | // bits of the frame instance. Note that this method should be invoked after
321 | // toBytesFin.
322 | func (f *frame) toBytesOpcode(p []byte) {
323 | p[0] += byte(f.opcode)
324 | }
325 |
326 | // toBytesMasked is used by toBytes to include info in 'p' about the MASK bit
327 | // of the frame instance. This method does not consider f.masked but instead it
328 | // calculates the MASK bit value based on f.key. Note that this method should
329 | // be invoked before toBytesPayloadLength.
330 | func (f *frame) toBytesMasked(p []byte) {
331 | if len(f.key) != 0 {
332 | p[1] = 128
333 | }
334 | }
335 |
336 | // toBytesPayloadLength is used by toBytes to include info in 'p' about the
337 | // PAYLOAD LENGTH bits of the frame instance. This method does not consider
338 | // f.length but instead it calculates the PAYLOAD LENGTH value based on the
339 | // payload that will be sent (f.payload). Note that this method should
340 | // be invoked after toBytesMasked.
341 | func (f *frame) toBytesPayloadLength(p []byte) {
342 | l := len(f.payload)
343 |
344 | switch {
345 | case l <= 125:
346 | {
347 | p[1] += byte(l)
348 | return
349 | }
350 | case l <= 65535:
351 | {
352 | p[1] += 126
353 | }
354 | case l <= 9223372036854775807:
355 | {
356 | p[1] += 127
357 | }
358 | }
359 | }
360 |
361 | // toBytesPayloadLengthExt is used by toBytes to include info about the PAYLOAD
362 | // LENGTH EXTENDED bits. Just like toBytesPayloadLength, this method does not
363 | // consider f.length but instead it calculates the PAYLOAD LENGTH EXTENDED bits
364 | // using the payload that will be sent (f.payload).
365 | func (f *frame) toBytesPayloadLengthExt() []byte {
366 | l := len(f.payload)
367 |
368 | // If <= 125, stop process since the true length is already known.
369 | if l <= 125 {
370 | return nil
371 | }
372 |
373 | var p []byte
374 |
375 | switch {
376 | case l <= 65535:
377 | {
378 | // Convert to binary.
379 | p = make([]byte, 2)
380 | binary.BigEndian.PutUint16(p, uint16(l))
381 | }
382 | case l <= 9223372036854775807:
383 | {
384 | // Convert to binary.
385 | p = make([]byte, 8)
386 | binary.BigEndian.PutUint64(p, uint64(l))
387 | }
388 | }
389 |
390 | return p
391 | }
392 |
393 | // toBytesPayloadData is used by toBytes to include info about the PAYLOAD
394 | // DATA. This method also handles the masking of the payload data (f.payload).
395 | // Note that just like toBytesMasked, this method does not consider f.masked
396 | // but instead it directly checks for the masking key (f.key).
397 | func (f *frame) toBytesPayloadData() []byte {
398 | // Put payload into another slice of bytes - so that the payload in the
399 | // frame instance is left untouched.
400 | p := append([]byte{}, f.payload...)
401 |
402 | // If masking key is present, use it to mask the payload data.
403 | if len(f.key) == 4 {
404 | mask(p, f.key)
405 | }
406 |
407 | return p
408 | }
409 |
--------------------------------------------------------------------------------
/frame_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "bufio"
5 | "testing"
6 | )
7 |
8 | func TestReadInitialForFin(t *testing.T) {
9 | type testCase struct {
10 | b *bufio.Reader
11 | v bool
12 | }
13 |
14 | testCases := []testCase{
15 | // When fin bit is '0' should set fin to false.
16 | {b: newBuffer([]byte{1 /* 00000001 */, 0}), v: false},
17 | // When fin bit is '1' should set fin to true.
18 | {b: newBuffer([]byte{129 /* 10000001 */, 0}), v: true},
19 | }
20 |
21 | for i, c := range testCases {
22 | f := &frame{}
23 |
24 | if err := f.readInitial(c.b); err != nil {
25 | t.Errorf("test case %d: unexpected error returned: %v", i, err)
26 | }
27 |
28 | if f.fin != c.v {
29 | t.Errorf("test case %d: expected 'fin' to be '%t'", i, c.v)
30 | }
31 | }
32 | }
33 |
34 | func TestReadInitialForOpcode(t *testing.T) {
35 | type testCase struct {
36 | b *bufio.Reader
37 | v int
38 | }
39 |
40 | // When opcode is valid, should not return an error.
41 | testCases := []testCase{
42 | // Without mask bit.
43 | {b: newBuffer([]byte{0, 0}), v: OpcodeContinuation},
44 | {b: newBuffer([]byte{1, 0}), v: OpcodeText},
45 | {b: newBuffer([]byte{2, 0}), v: OpcodeBinary},
46 | {b: newBuffer([]byte{8, 0}), v: OpcodeClose},
47 | {b: newBuffer([]byte{9, 0}), v: OpcodePing},
48 | {b: newBuffer([]byte{10, 0}), v: OpcodePong},
49 |
50 | // With mask bit.
51 | {b: newBuffer([]byte{128, 0}), v: OpcodeContinuation},
52 | {b: newBuffer([]byte{129, 0}), v: OpcodeText},
53 | {b: newBuffer([]byte{130, 0}), v: OpcodeBinary},
54 | {b: newBuffer([]byte{136, 0}), v: OpcodeClose},
55 | {b: newBuffer([]byte{137, 0}), v: OpcodePing},
56 | {b: newBuffer([]byte{138, 0}), v: OpcodePong},
57 | }
58 |
59 | for i, c := range testCases {
60 | f := &frame{}
61 |
62 | if err := f.readInitial(c.b); err != nil {
63 | t.Errorf("test case %d: unexpected error returned: %v", i, err)
64 | }
65 |
66 | if f.opcode != c.v {
67 | t.Errorf("test case %d: expected 'opcode' to be '%d'", i, c.v)
68 | }
69 | }
70 | }
71 |
72 | func TestReadInitialForRSVError(t *testing.T) {
73 | type testCase struct {
74 | b *bufio.Reader
75 | }
76 |
77 | // Library doesn't support extensions thus when extension bits are used,
78 | // lib should return an error.
79 | testCases := []testCase{
80 | {b: newBuffer([]byte{17 /* 00010001 */, 0})},
81 | {b: newBuffer([]byte{33 /* 00100001 */, 0})},
82 | {b: newBuffer([]byte{49 /* 00110001 */, 0})},
83 | {b: newBuffer([]byte{65 /* 01000001 */, 0})},
84 | {b: newBuffer([]byte{81 /* 01010001 */, 0})},
85 | {b: newBuffer([]byte{97 /* 01100001 */, 0})},
86 | {b: newBuffer([]byte{113 /* 01110001 */, 0})},
87 | }
88 |
89 | for i, c := range testCases {
90 | f := &frame{}
91 |
92 | err := f.readInitial(c.b)
93 |
94 | if err == nil {
95 | t.Errorf("test case %d: an error was expected.", i)
96 | }
97 |
98 | e, k := err.(*CloseError)
99 |
100 | if !k {
101 | t.Errorf("test case %d: expected error to be of type '*CloseError' but it is '%T'.", i, e)
102 | }
103 |
104 | if e.Reason != "no support for extensions" {
105 | t.Errorf(`test case %d: expected error to have reason "no support for extensions", instead it got "%s".`, i, e.Reason)
106 | }
107 | }
108 | }
109 |
110 | // Should return an error if opcode is invalid.
111 | func TestReadInitialForOpcodeError(t *testing.T) {
112 | f := &frame{}
113 | b := newBuffer([]byte{15, 0})
114 |
115 | err := f.readInitial(b)
116 |
117 | if err == nil {
118 | t.Error("unexpected error returned")
119 | }
120 |
121 | e, k := err.(*CloseError)
122 |
123 | if !k {
124 | t.Fatalf("expected error to be of type '*websocket.CloseError', but it is '%T'.", e)
125 | }
126 |
127 | if e.Reason != "unsupported opcode: 15" {
128 | t.Errorf(`expected error to have reason "unsupported opcode: 15", but it got "%s".`, e.Reason)
129 | }
130 | }
131 |
132 | func TestReadInitialForMasked(t *testing.T) {
133 | type testCase struct {
134 | b *bufio.Reader
135 | v bool
136 | }
137 |
138 | testCases := []testCase{
139 | // When masked bit is '0' should set masked to false.
140 | {b: newBuffer([]byte{1, 0}), v: false},
141 | // When masked bit is '1' should set masked to true.
142 | {b: newBuffer([]byte{1, 128}), v: true},
143 | }
144 |
145 | for i, c := range testCases {
146 | f := &frame{}
147 |
148 | if err := f.readInitial(c.b); err != nil {
149 | t.Errorf("test case %d: unexpected error returned: %v", i, err)
150 | }
151 |
152 | if f.masked != c.v {
153 | t.Errorf("test case %d: expected 'masked' to be '%t'", i, c.v)
154 | }
155 | }
156 | }
157 |
158 | func TestReadInitialForLength(t *testing.T) {
159 | type testCase struct {
160 | b *bufio.Reader
161 | v uint64
162 | }
163 |
164 | testCases := []testCase{
165 | // Should set length to 124 when payload len is 124.
166 | {b: newBuffer([]byte{1, 124}), v: 124},
167 | {b: newBuffer([]byte{1, 252}), v: 124},
168 | // Should set length to 125 when payload len is 125.
169 | {b: newBuffer([]byte{1, 125}), v: 125},
170 | {b: newBuffer([]byte{1, 253}), v: 125},
171 | // Should set length to 126 when payload len is 126.
172 | {b: newBuffer([]byte{1, 126}), v: 126},
173 | {b: newBuffer([]byte{1, 254}), v: 126},
174 | // Should set length to 127 when payload len is 127.
175 | {b: newBuffer([]byte{1, 127}), v: 127},
176 | {b: newBuffer([]byte{1, 255}), v: 127},
177 | }
178 |
179 | for i, c := range testCases {
180 | f := &frame{}
181 |
182 | if err := f.readInitial(c.b); err != nil {
183 | t.Errorf("test case %d: unexpected error returned: %v", i, err)
184 | }
185 |
186 | if f.length != c.v {
187 | t.Errorf("test case %d: expected 'length' to be '%d'", i, c.v)
188 | }
189 | }
190 | }
191 |
192 | func TestReadMaskKey(t *testing.T) {
193 | f := &frame{}
194 | p := []byte{102, 100, 1, 54}
195 | b := newBuffer(p)
196 |
197 | // When f.masked is false, it means that the payload is not masked and
198 | // therefore no key has been sent. For this reason f.key should be left
199 | // untouched.
200 | f.masked = false
201 | if err := f.readMaskKey(b); err != nil {
202 | t.Error("unexpected error returned:", err)
203 | }
204 |
205 | if len(f.key) != 0 {
206 | t.Error("expected f.key to be empty but it is:", len(f.key))
207 | }
208 |
209 | // When f.masked is true, it means that the payload is masked and therefore
210 | // the key must be read and stored in f.key.
211 | f.masked = true
212 | f.key = nil
213 | if err := f.readMaskKey(b); err != nil {
214 | t.Error("unexpected error returned:", err)
215 | }
216 |
217 | if len(f.key) != 4 {
218 | t.Errorf("expected f.key to be '4 bytes' long but it is '%d bytes'", len(f.key))
219 | }
220 |
221 | for i, v := range p {
222 | if v != f.key[i] {
223 | t.Fatalf("expected mask key to be '%v' but it is '%v'", p, f.key)
224 | }
225 | }
226 | }
227 |
228 | func TestReadLength(t *testing.T) {
229 | f := &frame{}
230 |
231 | type testCase struct {
232 | // initial length
233 | i uint64
234 | // final length
235 | l uint64
236 | }
237 |
238 | testCases := []testCase{
239 | {i: 124, l: 124},
240 | {i: 125, l: 125},
241 | {i: 126, l: 65535},
242 | {i: 127, l: 9223372036854775807},
243 | }
244 |
245 | for i, c := range testCases {
246 | f.length = c.i
247 |
248 | b := newBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255})
249 | if err := f.readLength(b); err != nil {
250 | t.Errorf("test case %d: unexpected error returned: %v", i, err)
251 | }
252 |
253 | if f.length != c.l {
254 | t.Errorf("test case %d: expected f.length to be '%d', but it is '%d'", i, c.l, f.length)
255 | }
256 | }
257 | }
258 |
259 | func TestReadPayload(t *testing.T) {
260 | type testCase struct {
261 | // Masked or not
262 | m bool
263 | }
264 |
265 | testCases := []testCase{
266 | {m: false},
267 | {m: true},
268 | }
269 |
270 | for i, c := range testCases {
271 | // Data Frame Received
272 | p := []byte{120, 15, 17}
273 | b := newBuffer(p)
274 |
275 | // Creation and config of frame instance.
276 | f := &frame{}
277 | f.key = []byte{10, 15, 1, 120}
278 | f.length = 2
279 |
280 | // Setting Masked.
281 | f.masked = c.m
282 |
283 | if err := f.readPayload(b); err != nil {
284 | t.Fatalf("test case %d: unexpected error was returned: %v", i, err)
285 | }
286 |
287 | // If masked unmask it.
288 | if f.masked {
289 | mask(f.payload, f.key)
290 | }
291 |
292 | if uint64(len(f.payload)) != f.length {
293 | t.Errorf("test case %d: expected length of f.payload to be '%d', but it is '%d'", i, f.length, len(f.payload))
294 | }
295 |
296 | for i, v := range f.payload {
297 | if v != p[i] {
298 | t.Fatalf("test case %d: expected slice of bytes to be '%v', but it is '%v'.", i, p[:f.length], f.payload)
299 | }
300 | }
301 | }
302 | }
303 |
304 | func TestToBytesFin(t *testing.T) {
305 | type testCase struct {
306 | v bool
307 | r byte
308 | }
309 |
310 | testCases := []testCase{
311 | // When f.fin is false first byte should not be affected
312 | {v: false, r: 0},
313 | // When f.fin is true first byte should has its MSB == 1.
314 | {v: true, r: 128},
315 | }
316 |
317 | for i, c := range testCases {
318 | f := &frame{fin: c.v}
319 | p := make([]byte, 1)
320 |
321 | f.toBytesFin(p)
322 |
323 | if p[0] != c.r {
324 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[0])
325 | }
326 | }
327 | }
328 |
329 | func TestToBytesOpcode(t *testing.T) {
330 | type testCase struct {
331 | // Fin Value
332 | v bool
333 | // Opcode Value
334 | o int
335 | // Resultant byte
336 | r byte
337 | }
338 |
339 | testCases := []testCase{
340 | // With Fin == false
341 | {v: false, o: OpcodeText, r: byte(OpcodeText)},
342 | // With Fin == true
343 | {v: true, o: OpcodeText, r: 128 + byte(OpcodeText)},
344 | }
345 |
346 | for i, c := range testCases {
347 | f := &frame{fin: c.v, opcode: c.o}
348 | p := make([]byte, 1)
349 |
350 | f.toBytesFin(p)
351 | f.toBytesOpcode(p)
352 |
353 | if p[0] != c.r {
354 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[0])
355 | }
356 | }
357 | }
358 |
359 | func TestToBytesMasked(t *testing.T) {
360 | type testCase struct {
361 | // Value of f.key.
362 | v []byte
363 | // Resultant byte.
364 | r byte
365 | }
366 |
367 | testCases := []testCase{
368 | {v: nil, r: 0},
369 | {v: []byte{1, 2, 3, 4}, r: 128},
370 | }
371 |
372 | for i, c := range testCases {
373 | f := frame{key: c.v}
374 | p := make([]byte, 2)
375 |
376 | f.toBytesMasked(p)
377 |
378 | if p[1] != c.r {
379 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[1])
380 | }
381 | }
382 | }
383 |
384 | func TestToBytesPayloadLength(t *testing.T) {
385 | type testCase struct {
386 | m bool
387 | r byte
388 | l int
389 | }
390 |
391 | testCases := []testCase{
392 | // With Mask Bit (f.masked) set to false
393 | {m: false, r: 124, l: 124},
394 | {m: false, r: 125, l: 125},
395 | {m: false, r: 126, l: 30000},
396 | {m: false, r: 126, l: 65535},
397 | {m: false, r: 127, l: 700000},
398 | // With Mask Bit (f.masked) set to true
399 | {m: true, r: 128 + 124, l: 124},
400 | {m: true, r: 128 + 125, l: 125},
401 | {m: true, r: 128 + 126, l: 30000},
402 | {m: true, r: 128 + 126, l: 65535},
403 | {m: true, r: 128 + 127, l: 700000},
404 | // testCase{m: false, r: 127, l: 9223372036854775807},
405 | }
406 |
407 | for i, c := range testCases {
408 | p := make([]byte, 2)
409 | f := frame{payload: make([]byte, c.l)}
410 |
411 | if c.m {
412 | f.key = []byte{1, 2, 3, 4}
413 | }
414 |
415 | f.toBytesMasked(p)
416 | f.toBytesPayloadLength(p)
417 |
418 | if p[1] != c.r {
419 | t.Errorf("test case %d: expected slice of bytes to be [%d] but it is [%d]", i, c.r, p[1])
420 | }
421 | }
422 | }
423 |
424 | func TestToBytesPayloadLengthExt(t *testing.T) {
425 | type testCase struct {
426 | l int
427 | r []byte
428 | }
429 |
430 | testCases := []testCase{
431 | // Length Known.
432 | {l: 124, r: nil},
433 | // Length Known.
434 | {l: 125, r: nil},
435 | // Read next 2 bytes.
436 | {l: 30000, r: []byte{117, 48}},
437 | // Read next 2 bytes.
438 | {l: 65535, r: []byte{255, 255}},
439 | // Read next 8 bytes.
440 | {l: 700000, r: []byte{0, 0, 0, 0, 0, 10, 174, 96}},
441 | }
442 |
443 | for i, c := range testCases {
444 | f := frame{payload: make([]byte, c.l)}
445 | p := f.toBytesPayloadLengthExt()
446 |
447 | if len(p) != len(c.r) {
448 | t.Errorf("test case %d: expected length to be '%d' but it is '%d'", i, len(c.r), len(p))
449 | }
450 |
451 | for ci, cv := range c.r {
452 | if cv != p[ci] {
453 | t.Errorf("test case %d: Expected slice of bytes to be %v but it is %v", i, c.r, p)
454 | break
455 | }
456 | }
457 | }
458 | }
459 |
460 | func TestToBytesPayloadData(t *testing.T) {
461 | type testCase struct {
462 | m []byte
463 | p []byte
464 | }
465 |
466 | testCases := []testCase{
467 | // When masking key is present and valid, payload must be masked.
468 | {p: []byte{3, 4, 5, 6}, m: nil},
469 | // When masking key is not present, payload must not be masked.
470 | {p: []byte{3, 4, 5, 6}, m: []byte{1, 2, 3, 4}},
471 | }
472 |
473 | for i, c := range testCases {
474 | f := &frame{key: c.m, payload: c.p}
475 |
476 | p := f.toBytesPayloadData()
477 |
478 | if c.m != nil {
479 | mask(p, c.m)
480 | }
481 |
482 | for ci, cv := range c.p {
483 | if cv != p[ci] {
484 | t.Errorf("test case %d: Expected slice of bytes to be %v but it is %v", i, c.p, p)
485 | }
486 | }
487 | }
488 | }
489 |
--------------------------------------------------------------------------------
/frame_utils.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | // mask is used to mask or unmask an array of bytes. It accepts two arguments,
4 | // p the data that will be masked (usually the application data), k the masking
5 | // key.
6 | //
7 | // From spec: https://tools.ietf.org/html/rfc6455#section-5.3
8 | func mask(p, k []byte) {
9 | for i := range p {
10 | p[i] ^= k[i%4]
11 | }
12 | }
13 |
14 | // opcodeExist returns whether the opcode number provided as an argument is a
15 | // valid opcode or not.
16 | func opcodeExist(i int) bool {
17 | switch i {
18 | case OpcodeContinuation, OpcodeText, OpcodeBinary, OpcodeClose, OpcodePing, OpcodePong:
19 | {
20 | return true
21 | }
22 | }
23 | return false
24 | }
25 |
26 | // validateKey returns whether the masking key is a valid key or not. Note that
27 | // a masking key can either be of length 0 or 4.
28 | //
29 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2
30 | func validateKey(k []byte) bool {
31 | return len(k) == 0 || len(k) == 4
32 | }
33 |
34 | // validatePayload returns whether the payload data is valid or not. Note that
35 | // the maximum size of payload data can be 9223372036854775807 bits.
36 | //
37 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.2
38 | func validatePayload(p []byte) bool {
39 | return len(p) <= 9223372036854775807
40 | }
41 |
--------------------------------------------------------------------------------
/frame_utils_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import "testing"
4 |
5 | func TestOpcodeExist(t *testing.T) {
6 | type testCase struct {
7 | o int
8 | v bool
9 | }
10 |
11 | testCases := []testCase{
12 | // Should return false when opcode is invalid
13 | {o: 15, v: false},
14 | // Should return true when opcode is valid.
15 | {o: OpcodeText, v: true},
16 | }
17 |
18 | for i, c := range testCases {
19 | if v := opcodeExist(c.o); v != c.v {
20 | t.Errorf("test case %d: expected '%t' for '%d'", i, c.v, c.o)
21 | }
22 | }
23 | }
24 |
25 | func TestValidateKey(t *testing.T) {
26 | type testCase struct {
27 | k []byte
28 | r bool
29 | }
30 |
31 | testCases := []testCase{
32 | {k: []byte{1, 2, 3, 4}, r: true},
33 | {k: []byte{}, r: true},
34 | {k: []byte{1, 2, 3, 4, 5}, r: false},
35 | {k: []byte{1, 2, 3}, r: false},
36 | }
37 |
38 | for i, c := range testCases {
39 | if validateKey(c.k) != c.r {
40 | t.Errorf("test case %d: expected '%t' for %v", i, c.r, c.k)
41 | }
42 | }
43 | }
44 |
45 | func TestValidatePayload(t *testing.T) {
46 | type testCase struct {
47 | l uint64
48 | r bool
49 | }
50 |
51 | testCases := []testCase{
52 | {l: 125, r: true},
53 | // testCase{l: 9223372036854775807, r: true},
54 | // testCase{l: 9223372036854775808, r: false},
55 | }
56 |
57 | for i, c := range testCases {
58 | b := make([]byte, c.l)
59 | if validatePayload(b) != c.r {
60 | t.Errorf("test case %d: expected '%t' for payload of size '%d'", i, c.r, c.l)
61 | }
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/request.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "net/http"
5 | "sync"
6 | )
7 |
8 | // wsVersion is the websocket version this library supports.
9 | const wsVersion = "13"
10 |
11 | // Request represents the HTTP Request that will be upgraded to the WebSocket
12 | // protocol once it is validated.
13 | type Request struct {
14 | /*
15 | request is the http request to be upgraded.
16 | */
17 | request *http.Request
18 |
19 | /*
20 | CheckOrigin is the function which will be used to validate the ORIGIN
21 | HTTP Header of the request. By default this method will fail the opening
22 | handshake when the origin is not the same. This method can be overridden
23 | during the initiation of the Request struct.
24 | */
25 | CheckOrigin func(r *http.Request) bool
26 |
27 | /*
28 | SubProtocol name which the server has agreed to use from the list
29 | provided by the client (through the Sec-WebSocket-Protocol HTTP Header
30 | Field). Before sending the servers opening handshake response, checks
31 | are made to verify that the chosen protocol was indeed been provided as
32 | an option from the client. If this is not the case, the HTTP
33 | Sec-WebSocket-Protocol HTTP Response Header Field is not sent
34 | */
35 | SubProtocol string
36 | }
37 |
38 | // Upgrade is used to upgrade the HTTP connection to use the WS protocol once
39 | // the client request is validated.
40 | func (q *Request) Upgrade(w http.ResponseWriter, r *http.Request) (*Socket, error) {
41 | // Store a reference to the HTTP Request.
42 | q.request = r
43 |
44 | // Check origin.
45 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.2
46 | if err := q.handleOrigin(); err != nil {
47 | http.Error(w, "Forbidden", http.StatusForbidden)
48 | return nil, err
49 | }
50 |
51 | // Check websocket version.
52 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.2
53 | if err := validateWSVersionHeader(r); err != nil {
54 | w.Header().Set("Sec-WebSocket-Version", wsVersion)
55 | http.Error(w, "Upgrade Required", 426)
56 | return nil, err
57 | }
58 |
59 | // Check handshake request.
60 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.2
61 | if err := validateRequest(r); err != nil {
62 | http.Error(w, "Bad Request", http.StatusBadRequest)
63 | return nil, err
64 | }
65 |
66 | // At this point, the clients handshake request is valid and therefore the
67 | // connection can be upgraded to use the ws protocol.
68 | s, err := q.upgrade(w)
69 |
70 | if err != nil {
71 | http.Error(w, "Internal Server Error", http.StatusInternalServerError)
72 | return nil, err
73 | }
74 |
75 | return s, nil
76 | }
77 |
78 | func (q *Request) upgrade(w http.ResponseWriter) (*Socket, error) {
79 | // Take control of the net.Conn instance.
80 | h, k := w.(http.Hijacker)
81 |
82 | if !k {
83 | return nil, &OpenError{Reason: "assertion failed with current http.ResponseWriter instance"}
84 | }
85 |
86 | conn, buf, err := h.Hijack()
87 | if err != nil {
88 | return nil, err
89 | }
90 |
91 | // Build the HTTP Header response code required for the ws opening
92 | // handshake.
93 | // From RFC2616: https://www.w3.org/Protocols/rfc2616/rfc2616-sec6.html
94 | resp := "HTTP/1.1 101 Switching Protocols\n"
95 | resp += "Upgrade: websocket\n"
96 | resp += "Connection: upgrade\n"
97 | resp += "Sec-WebSocket-Version: " + wsVersion + "\n"
98 |
99 | // If server has agreed to use a sub-protocol, the chosen sub-protocol needs
100 | // to be an option provided by the clients endpoint. If not, the
101 | // Sec-WebSocket-Protocol HTTP Header field is not sent.
102 | if q.SubProtocol != "" && stringExists(q.ClientSubProtocols(), q.SubProtocol) != -1 {
103 | resp += "Sec-WebSocket-Protocol: " + q.SubProtocol + "\n"
104 | }
105 |
106 | // Generate the accept key based on the challenge key provided by the
107 | // client and include it inside 'Sec-WebSocket-Accept' response header
108 | // field.
109 | acceptKey := makeAcceptKey(q.request.Header.Get("Sec-WebSocket-Key"))
110 | resp += "Sec-WebSocket-Accept: " + acceptKey + "\n\n"
111 |
112 | // Send response
113 | buf.WriteString(resp)
114 | buf.Flush()
115 |
116 | // Create and return socket.
117 | return &Socket{
118 | conn: conn,
119 | buf: buf,
120 | server: true,
121 | writeMutex: &sync.Mutex{},
122 | }, nil
123 | }
124 |
125 | // handleOrigin is used to invoke either the CheckOrigin method provided by the
126 | // user or the default method (if the user doesn't provide one).
127 | func (q *Request) handleOrigin() *OpenError {
128 | fn := q.CheckOrigin
129 |
130 | if fn == nil {
131 | fn = checkOrigin
132 | }
133 |
134 | if !fn(q.request) {
135 | return &OpenError{Reason: `failure due to origin.`}
136 | }
137 |
138 | return nil
139 | }
140 |
141 | // ClientSubProtocols returns the list of Sub Protocols the client can interact
142 | // with.
143 | //
144 | // From spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
145 | func (q *Request) ClientSubProtocols() []string {
146 | return headerToSlice(q.request.Header.Get("Sec-WebSocket-Protocol"))
147 | }
148 |
149 | // ClientExtensions returns the list of Extensions the client can interact
150 | // with.
151 | //
152 | // From spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
153 | func (q *Request) ClientExtensions() []string {
154 | return headerToSlice(q.request.Header.Get("Sec-WebSocket-Extensions"))
155 | }
156 |
--------------------------------------------------------------------------------
/request_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "strings"
7 | "testing"
8 | )
9 |
10 | func makeRequestValid(r *http.Request) {
11 | r.Header.Set("Sec-WebSocket-Version", wsVersion)
12 | r.Header.Set("Upgrade", "websocket")
13 | r.Header.Set("Connection", "upgrade")
14 | r.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
15 | }
16 |
17 | func TestUpgradeResponseWhenInvalidOrigin(t *testing.T) {
18 | r, err := http.NewRequest("GET", "example.com", nil)
19 |
20 | if err != nil {
21 | t.Fatal("error occured while creating request:", err)
22 | }
23 |
24 | w := httptest.NewRecorder()
25 |
26 | h := func(w http.ResponseWriter, r *http.Request) {
27 | wsr := &Request{
28 | CheckOrigin: func(r *http.Request) bool {
29 | return false
30 | },
31 | }
32 |
33 | makeRequestValid(r)
34 |
35 | s, err := wsr.Upgrade(w, r)
36 |
37 | if err == nil {
38 | t.Error("expected Upgrade() to return a OpenError")
39 | }
40 |
41 | if s != nil {
42 | t.Error("expected Upgrade() to return a nil Socket instance")
43 | }
44 | }
45 |
46 | h(w, r)
47 |
48 | if w.Code != 403 {
49 | t.Errorf(`expected HTTP Status '403'. '%d' was returned.`, w.Code)
50 | }
51 | }
52 |
53 | func TestUpgradeResponseWhenInvalidWSVersion(t *testing.T) {
54 | r, err := http.NewRequest("GET", "example.com", nil)
55 |
56 | if err != nil {
57 | t.Fatal("error occured while creating request:", err)
58 | }
59 |
60 | w := httptest.NewRecorder()
61 |
62 | h := func(w http.ResponseWriter, r *http.Request) {
63 | wsr := &Request{}
64 |
65 | makeRequestValid(r)
66 | r.Header.Set("Sec-WebSocket-Version", "14")
67 |
68 | s, err := wsr.Upgrade(w, r)
69 |
70 | if err == nil {
71 | t.Error("expected Upgrade() to return a OpenError")
72 | }
73 |
74 | if s != nil {
75 | t.Error("expected Upgrade() to return a nil Socket instance")
76 | }
77 | }
78 |
79 | h(w, r)
80 |
81 | if w.Code != 426 {
82 | t.Errorf(`expected HTTP Status '426'. '%d' was returned.`, w.Code)
83 | }
84 |
85 | if w.Header().Get("Sec-WebSocket-Version") != wsVersion {
86 | t.Errorf(`expected "Sec-WebSocket-Version" HTTP Header field value to be %s`, wsVersion)
87 | }
88 | }
89 |
90 | func TestUpgradeResponseWhenNotValid(t *testing.T) {
91 | r, err := http.NewRequest("POST", "example.com", nil)
92 |
93 | if err != nil {
94 | t.Fatal("error occured while creating request:", err)
95 | }
96 |
97 | w := httptest.NewRecorder()
98 |
99 | h := func(w http.ResponseWriter, r *http.Request) {
100 | wsr := &Request{
101 | CheckOrigin: func(r *http.Request) bool {
102 | return true
103 | },
104 | }
105 |
106 | makeRequestValid(r)
107 |
108 | s, err := wsr.Upgrade(w, r)
109 |
110 | if err == nil {
111 | t.Error("expected Upgrade() to return a OpenError.")
112 | }
113 |
114 | if s != nil {
115 | t.Error("expected Upgrade() to return a nil Socket instance.")
116 | }
117 | }
118 |
119 | h(w, r)
120 |
121 | if w.Code != 400 {
122 | t.Errorf(`expected HTTP Status '400'. '%d' was returned.`, w.Code)
123 | }
124 | }
125 |
126 | func TestUpgradeGoodRequest(t *testing.T) {
127 | h := func(w http.ResponseWriter, r *http.Request) {
128 | wsr := &Request{
129 | CheckOrigin: func(r *http.Request) bool {
130 | return true
131 | },
132 | }
133 |
134 | makeRequestValid(r)
135 |
136 | s, err := wsr.Upgrade(w, r)
137 |
138 | if err != nil {
139 | t.Error("unexpected error from Upgrade():", err)
140 | }
141 |
142 | if s == nil {
143 | t.Error("expected Upgrade() to return a non-nil Socket instance")
144 | }
145 |
146 | if !s.server {
147 | t.Error("expected socket to have 'server' property set to 'true'")
148 | }
149 | }
150 |
151 | s := httptest.NewServer(http.HandlerFunc(h))
152 | defer s.Close()
153 |
154 | w, err := http.Get(s.URL)
155 |
156 | if err != nil {
157 | t.Error("unexpected error when requesting the test server:", err)
158 | }
159 |
160 | if w.StatusCode != 101 {
161 | t.Errorf("expected HTTP Status to be '101' but it is '%d'", w.StatusCode)
162 | }
163 |
164 | if w.Header.Get("Upgrade") != "websocket" {
165 | t.Errorf(`expected "Upgrade" HTTP Header value to be "websocket" but it is "%s"`, w.Header.Get("Upgrade"))
166 | }
167 |
168 | if w.Header.Get("Connection") != "upgrade" {
169 | t.Errorf(`expected "Connection" HTTP Header value to be "upgrade" but it is "%s"`, w.Header.Get("Connection"))
170 | }
171 |
172 | if w.Header.Get("Sec-WebSocket-Accept") != "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" {
173 | t.Errorf(`expected "Sec-WebSocket-Accept" HTTP Header value to be "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" but it is "%s"`, w.Header.Get("Sec-WebSocket-Accept"))
174 | }
175 | }
176 |
177 | func TestUpgradeWithSubProtocols(t *testing.T) {
178 | h := func(w http.ResponseWriter, r *http.Request) {
179 | wsr := &Request{
180 | CheckOrigin: func(r *http.Request) bool {
181 | return true
182 | },
183 | }
184 |
185 | makeRequestValid(r)
186 | wsr.SubProtocol = "one"
187 |
188 | s, err := wsr.Upgrade(w, r)
189 |
190 | if err != nil {
191 | t.Error("unexpected error from Upgrade():", err)
192 | }
193 |
194 | if s == nil {
195 | t.Error("expected Upgrade() to return a non-nil Socket instance")
196 | }
197 | }
198 |
199 | s := httptest.NewServer(http.HandlerFunc(h))
200 | defer s.Close()
201 |
202 | type testCase struct {
203 | p string
204 | v bool
205 | }
206 |
207 | testCases := []testCase{
208 | {p: "one, two, three", v: true},
209 | {p: "two, three", v: false},
210 | {p: "", v: false},
211 | }
212 |
213 | for i, c := range testCases {
214 | _ = i
215 | r, err := http.NewRequest("GET", s.URL, nil)
216 |
217 | if err != nil {
218 | t.Error("unexpected error returned while trying to create a request instance:", err)
219 | }
220 |
221 | if c.p != "" {
222 | r.Header.Set("Sec-WebSocket-Protocol", c.p)
223 | }
224 |
225 | l := &http.Client{}
226 | w, err := l.Do(r)
227 |
228 | if err != nil {
229 | t.Error("unexpected error returned while trying to create a client instance:", err)
230 | }
231 |
232 | if c.v {
233 | v := w.Header.Get("Sec-WebSocket-Protocol")
234 | if w.Header.Get("Sec-WebSocket-Protocol") != "one" {
235 | t.Errorf(`expected 'Sec-WebSocket-Protocol' Response Header to be "one", but it is "%v".`, v)
236 | }
237 | } else {
238 | v := w.Header.Get("Sec-WebSocket-Protocol")
239 | if w.Header.Get("Sec-WebSocket-Protocol") != "" {
240 | t.Errorf(`expected 'Sec-WebSocket-Protocol' Response Header to be "", but it is "%v".`, v)
241 | }
242 | }
243 | }
244 | }
245 |
246 | func TestClientSubProtocols(t *testing.T) {
247 | r := &http.Request{}
248 |
249 | l := []string{"one", "two", "three"}
250 |
251 | r.Header = make(http.Header)
252 | r.Header.Set("Sec-WebSocket-Protocol", strings.Join(l, ", "))
253 |
254 | q := &Request{
255 | request: r,
256 | }
257 |
258 | p := q.ClientSubProtocols()
259 |
260 | if len(l) != len(p) {
261 | t.Errorf("The length of the list of header value assigned to Sec-WebSocket-Protocol HTTP Header are not the same. %d != %d", len(l), len(p))
262 | }
263 |
264 | for _, v := range p {
265 | k := false
266 | for _, h := range l {
267 | if v == h {
268 | k = true
269 | break
270 | }
271 | }
272 | if !k {
273 | t.Errorf(`"%s" was not returned in the slice of Sub Protocols.`, v)
274 | }
275 | }
276 | }
277 |
278 | func TestClientExtensions(t *testing.T) {
279 | r := &http.Request{}
280 |
281 | l := []string{"one", "two", "three"}
282 |
283 | r.Header = make(http.Header)
284 | r.Header.Set("Sec-WebSocket-Extensions", strings.Join(l, ", "))
285 |
286 | q := &Request{
287 | request: r,
288 | }
289 |
290 | p := q.ClientExtensions()
291 |
292 | if len(l) != len(p) {
293 | t.Errorf("The length of the list of header value assigned to Sec-WebSocket-Extensions HTTP Header are not the same. '%d' != '%d'", len(l), len(p))
294 | }
295 |
296 | for _, v := range p {
297 | k := false
298 | for _, h := range l {
299 | if v == h {
300 | k = true
301 | break
302 | }
303 | }
304 | if !k {
305 | t.Errorf(`"%s" was not returned in the slice of Extensions.`, v)
306 | }
307 | }
308 | }
309 |
--------------------------------------------------------------------------------
/request_utils.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "encoding/base64"
5 | "net/http"
6 | "strings"
7 | )
8 |
9 | // validateRequest is used to determine whether the client handshake request
10 | // conforms with the WebSocket spec. When it doesn't the server should respond
11 | // with an HTTP Status 400 Bad Request.
12 | //
13 | // Note that this method doesn't validate the websocket version
14 | // ("Sec-WebSocket-Version" HTTP Header Field) and origin ("Origin" HTTP
15 | // Header Field) since these require specific HTTP Status Code (427 and 403
16 | // respectively).
17 | //
18 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
19 | // https://tools.ietf.org/html/rfc6455#section-4.2.2
20 |
21 | func validateRequest(r *http.Request) *OpenError {
22 | validations := []func(*http.Request) *OpenError{
23 | // Check HTTP version to be at least v1.1.
24 | validateRequestVersion,
25 | // Check HTTP method to be 'GET'.
26 | validateRequestMethod,
27 | // Validate 'Upgrade' header field.
28 | validateRequestUpgradeHeader,
29 | // Validate 'Connection' header field.
30 | validateRequestConnectionHeader,
31 | // Validate 'Sec-WebSocket-Key' header field.
32 | validateRequestSecWebsocketKeyHeader,
33 | }
34 |
35 | for _, v := range validations {
36 | if err := v(r); err != nil {
37 | return err
38 | }
39 | }
40 |
41 | return nil
42 | }
43 |
44 | // validateRequestVersion verifies that the HTTP Version used in the client's
45 | // opening handshake request is at least v1.1.
46 | //
47 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
48 | func validateRequestVersion(r *http.Request) *OpenError {
49 | if !r.ProtoAtLeast(1, 1) {
50 | return &OpenError{Reason: `HTTP must be v1.1 or higher`}
51 | }
52 | return nil
53 | }
54 |
55 | // validateRequestMethod verifies that the HTTP Method used in the client's
56 | // opening handshake request is 'GET'.
57 | //
58 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
59 | func validateRequestMethod(r *http.Request) *OpenError {
60 | if r.Method != "GET" {
61 | return &OpenError{Reason: `HTTP method must be "GET"`}
62 | }
63 | return nil
64 | }
65 |
66 | // validateRequestUpgradeHeader verifies that the Upgrade HTTP Header value in the
67 | // client's opening handshake request is "websocket".
68 | //
69 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
70 | func validateRequestUpgradeHeader(r *http.Request) *OpenError {
71 | h := r.Header.Get("Upgrade")
72 |
73 | if strings.ToLower(h) != "websocket" {
74 | return &OpenError{Reason: `"Upgrade" Header should have the value of "websocket"`}
75 | }
76 |
77 | return nil
78 | }
79 |
80 | // validateRequestConnectionHeader verfies that the Connection HTTP Header value in
81 | // the client's opening handshake request is "upgrade".
82 | //
83 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
84 | func validateRequestConnectionHeader(r *http.Request) *OpenError {
85 | h := r.Header.Get("Connection")
86 |
87 | if strings.ToLower(h) != "upgrade" {
88 | return &OpenError{Reason: `"Connection" Header should have the value of "upgrade"`}
89 | }
90 |
91 | return nil
92 | }
93 |
94 | // validateRequestSecWebsocketKeyHeader verifies that the Sec-WebSocket-Key HTTP Header value in
95 | // the client's opening handshake request is of length 16 when base64 decoded.
96 | //
97 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
98 | func validateRequestSecWebsocketKeyHeader(r *http.Request) *OpenError {
99 | h := r.Header.Get("Sec-WebSocket-Key")
100 | d, err := base64.StdEncoding.DecodeString(h)
101 |
102 | // Check for decoding errors.
103 | if err != nil {
104 | return &OpenError{Reason: `an error had occured while validating "Sec-WebSocket-Key" header`}
105 | }
106 |
107 | // Check that the length of the decoded Sec-WebSocket-Key value is 16
108 | // (bytes).
109 | if len(d) != 16 {
110 | return &OpenError{Reason: `"Sec-WebSocket-Key" must be 16 bytes in length when decoded`}
111 | }
112 |
113 | return nil
114 | }
115 |
116 | // validateWSVersionHeader verifies that the Sec-WebSocket-Verion HTTP Header
117 | // value in the client's opening handshake request is "13".
118 | //
119 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
120 | func validateWSVersionHeader(r *http.Request) *OpenError {
121 | if r.Header.Get("Sec-WebSocket-Version") != wsVersion {
122 | return &OpenError{Reason: "upgrade required"}
123 | }
124 |
125 | return nil
126 | }
127 |
128 | // checkOrigin is the default CheckOrigin handler used by the Request struct.
129 | // This method will allow requests that are either coming from a non-browser
130 | // client (Origin HTTP Header field omitted) or are not cross origin requests.
131 | //
132 | // Ref spec: https://tools.ietf.org/html/rfc6455#section-4.2.1
133 | func checkOrigin(r *http.Request) bool {
134 | h := r.Header.Get("Origin")
135 |
136 | if strings.HasPrefix(h, "http://") {
137 | h = strings.Replace(h, "http://", "", 1)
138 | } else if strings.HasPrefix(h, "https://") {
139 | h = strings.Replace(h, "https://", "", 1)
140 | }
141 |
142 | return h == "" || h == r.Host
143 | }
144 |
--------------------------------------------------------------------------------
/request_utils_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 | )
7 |
8 | func TestValidateRequestVersion(t *testing.T) {
9 | r := &http.Request{}
10 | r.Header = make(http.Header)
11 |
12 | type testCase struct {
13 | a int
14 | i int
15 | r bool
16 | }
17 |
18 | testCases := []testCase{
19 | // HTTP v1.1 should be valid.
20 | {a: 1, i: 1, r: true},
21 | // HTTP v2.1 should be valid.
22 | {a: 2, i: 1, r: true},
23 | // HTTP v1.0 should be not valid.
24 | {a: 1, i: 0, r: false},
25 | // HTTP v0.1 should be not valid.
26 | {a: 0, i: 1, r: false},
27 | }
28 |
29 | for i, c := range testCases {
30 | r.ProtoMajor = c.a
31 | r.ProtoMinor = c.i
32 |
33 | err := validateRequestVersion(r)
34 |
35 | if c.r && err != nil {
36 | t.Errorf(`test case %d: unexpected error retured for "v%d.%d"`, i, c.a, c.i)
37 | }
38 |
39 | if !c.r && err == nil {
40 | t.Errorf(`test case %d: expected an error for "v%d.%d"`, i, c.a, c.i)
41 | }
42 | }
43 | }
44 |
45 | func TestValidateRequestMethod(t *testing.T) {
46 | r := &http.Request{}
47 | r.Header = make(http.Header)
48 |
49 | type testCase struct {
50 | m string
51 | r bool
52 | }
53 |
54 | testCases := []testCase{
55 | // HTTP GET should be valid.
56 | {m: "GET", r: true},
57 | // HTTP POST should be not valid.
58 | {m: "POST", r: false},
59 | }
60 |
61 | for i, c := range testCases {
62 | r.Method = c.m
63 |
64 | err := validateRequestMethod(r)
65 |
66 | if c.r && err != nil {
67 | t.Errorf(`test case %d: unexpected error retured for "%s" request`, i, c.m)
68 | }
69 |
70 | if !c.r && err == nil {
71 | t.Errorf(`test case %d: expected an error for "%s" request`, i, c.m)
72 | }
73 | }
74 | }
75 |
76 | func TestValidateRequestUpgradeHeader(t *testing.T) {
77 | r := &http.Request{}
78 | r.Header = make(http.Header)
79 |
80 | type testCase struct {
81 | v string
82 | r bool
83 | }
84 |
85 | testCases := []testCase{
86 | // When value is "websocket" should be valid.
87 | {v: "websocket", r: true},
88 | // When value is "webSocket" should be valid.
89 | {v: "webSocket", r: true},
90 | // When value is not "websocket" should not be valid.
91 | {v: "ValueOtherThanWebsocket", r: false},
92 | }
93 |
94 | for i, c := range testCases {
95 | r.Header.Set("Upgrade", c.v)
96 |
97 | err := validateRequestUpgradeHeader(r)
98 |
99 | if c.r && err != nil {
100 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v)
101 | }
102 |
103 | if !c.r && err == nil {
104 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v)
105 | }
106 | }
107 | }
108 |
109 | func TestValidateRequestConnectionHeader(t *testing.T) {
110 | r := &http.Request{}
111 | r.Header = make(http.Header)
112 |
113 | type testCase struct {
114 | v string
115 | r bool
116 | }
117 |
118 | testCases := []testCase{
119 | // When value is "upgrade" should be valid.
120 | {v: "upgrade", r: true},
121 | // When value is "Upgrade" should be valid.
122 | {v: "Upgrade", r: true},
123 | // When value is not "upgrade" should not be valid.
124 | {v: "ValueOtherThanUpgrade", r: false},
125 | }
126 |
127 | for i, c := range testCases {
128 | r.Header.Set("Connection", c.v)
129 |
130 | err := validateRequestConnectionHeader(r)
131 |
132 | if c.r && err != nil {
133 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v)
134 | }
135 |
136 | if !c.r && err == nil {
137 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v)
138 | }
139 | }
140 | }
141 |
142 | func TestValidateRequestSecWebsocketKeyHeader(t *testing.T) {
143 | r := &http.Request{}
144 | r.Header = make(http.Header)
145 |
146 | type testCase struct {
147 | v string
148 | r bool
149 | }
150 |
151 | testCases := []testCase{
152 | // Valid key.
153 | {v: "FlBPpXKmN36AUZxV0tYHYw==", r: true},
154 | // Invalid decoded length.
155 | {v: "InvalidKey==", r: false},
156 | // Invalid encoded data.
157 | {v: "InvalidKeyError", r: false},
158 | }
159 |
160 | for i, c := range testCases {
161 | r.Header.Set("Sec-WebSocket-Key", c.v)
162 |
163 | err := validateRequestSecWebsocketKeyHeader(r)
164 |
165 | if c.r && err != nil {
166 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v)
167 | }
168 |
169 | if !c.r && err == nil {
170 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v)
171 | }
172 | }
173 | }
174 |
175 | func TestValidateWSVersionHeader(t *testing.T) {
176 | r := &http.Request{}
177 | r.Header = make(http.Header)
178 |
179 | type testCase struct {
180 | v string
181 | r bool
182 | }
183 |
184 | testCases := []testCase{
185 | // Valid when value is the same as the version of the ws supported.
186 | {v: wsVersion, r: true},
187 | // Not valid when value is not the same as the version of the ws
188 | // supported.
189 | {v: "14", r: false},
190 | }
191 |
192 | for i, c := range testCases {
193 | r.Header.Set("Sec-WebSocket-Version", c.v)
194 |
195 | err := validateWSVersionHeader(r)
196 |
197 | if c.r && err != nil {
198 | t.Errorf(`test case %d: unexpected error retured for "%s"`, i, c.v)
199 | }
200 |
201 | if !c.r && err == nil {
202 | t.Errorf(`test case %d: expected an error for "%s"`, i, c.v)
203 | }
204 | }
205 | }
206 |
207 | func TestCheckOrigin(t *testing.T) {
208 | r := &http.Request{}
209 | r.Header = make(http.Header)
210 | r.Host = "example.com:8080"
211 |
212 | type testCase struct {
213 | v string
214 | r bool
215 | }
216 |
217 | testCases := []testCase{
218 | // Valid when origin is omitted (non-browser client).
219 | {v: "", r: true},
220 | // Valid when same origin.
221 | {v: r.Host, r: true},
222 | {v: "example.com:8080", r: true},
223 | {v: "http://example.com:8080", r: true},
224 | {v: "https://example.com:8080", r: true},
225 | }
226 |
227 | for i, c := range testCases {
228 | r.Header.Set("Origin", c.v)
229 |
230 | if checkOrigin(r) != c.r {
231 | t.Errorf(`Test Case %d: Expected checkOrigin() to return '%t' when 'Origin' header == "%s" and Host is at "%s".`, i, c.r, c.v, r.Host)
232 | }
233 | }
234 | }
235 |
--------------------------------------------------------------------------------
/socket.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "bufio"
5 | "errors"
6 | "io"
7 | "net"
8 | "sync"
9 | "time"
10 | )
11 |
12 | // ErrSocketClosed is the error returned when a user tries to send a frame with
13 | // a closed socket.
14 | var ErrSocketClosed = errors.New("socket has been closed")
15 |
16 | // WebSocket Error codes.
17 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-7.4.1
18 | const (
19 | CloseNormalClosure int = 1000
20 | CloseGoingAway int = 1001
21 | CloseProtocolError int = 1002
22 | CloseUnsupportedData int = 1003
23 | CloseNoStatusReceived int = 1005
24 | CloseAbnormalClosure int = 1006
25 | CloseInvalidFramePayloadData int = 1007
26 | ClosePolicyViolation int = 1008
27 | CloseMessageTooBig int = 1009
28 | CloseMandatoryExtension int = 1010
29 | CloseInternalServerErr int = 1011
30 | CloseTLSHandshake int = 1015
31 | )
32 |
33 | // Represents the state of the Socket instance
34 | const (
35 | /*
36 | stateOpened will be the state when the socket instance is open.
37 | */
38 | stateOpened int = 0
39 |
40 | /*
41 | stateClosing will be the state when the socket instance is in the middle
42 | of the closing handshake.
43 | */
44 | stateClosing int = 1
45 |
46 | /*
47 | stateClosed will be the state when the socket instance is closed.
48 | */
49 | stateClosed int = 2
50 | )
51 |
52 | // Socket represents a socket endpoint.
53 | type Socket struct {
54 | /*
55 | conn is the underlying tcp connection.
56 | */
57 | conn net.Conn
58 |
59 | /*
60 | buf is a buffered version of the underlying tcp connection.
61 | */
62 | buf *bufio.ReadWriter
63 |
64 | /*
65 | server indicates whether the socket instance represents a server or a
66 | client endpoint.
67 | */
68 | server bool
69 |
70 | /*
71 | state is the current state of the socket instance.
72 | */
73 | state int
74 |
75 | /*
76 | closeDelay is the duration the socket instance will wait until it closes
77 | the underlying tcp connection once the closing handshake has been
78 | completed.
79 |
80 | The websocket rfc suggests that when the closing handshake is completed
81 | the underlying tcp connection should first be terminated by the server
82 | endpoint. Having said this it doesn't restrict the client endpoint to do
83 | so itself. CloseDelay is the maximum time the socket instance will wait
84 | before it closes the tcp connection.
85 |
86 | Note: Server endpoints should always have this property set to 0.
87 |
88 | Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1
89 | */
90 | CloseDelay time.Duration
91 |
92 | /*
93 | readHandler is invoked whenever a text or binary frame is received. The
94 | opcode and payload data are provided as args respectively.
95 | */
96 | ReadHandler func(int, []byte)
97 |
98 | /*
99 | pingHandler is invoked whenever a ping frame is received. The payload
100 | data is provided as arg.
101 | */
102 | PingHandler func([]byte)
103 |
104 | /*
105 | pongHandler is invoked whenever a pong frame is received. The payload
106 | data is provided as arg.
107 | */
108 | PongHandler func([]byte)
109 |
110 | /*
111 | closeHandler is invoked whenever the websocket connection is closed. The
112 | reason for the closure is provided as an arg.
113 | */
114 | CloseHandler func(error)
115 |
116 | /*
117 | closeError contains the error which caused the websocket connection to
118 | terminate. This is then provided as an arg when invoking the close
119 | handler once the underlying tcp connection is terminated.
120 | */
121 | closeError error
122 |
123 | /*
124 | writeMutex is used to queue the write functionality of a socket
125 | instance.
126 | */
127 | writeMutex *sync.Mutex
128 | }
129 |
130 | // Listen is used to start listening for new frames sent by the connected
131 | // endpoint.
132 | func (s *Socket) Listen() {
133 | s.read()
134 | }
135 |
136 | func (s *Socket) read() {
137 | Read:
138 | for {
139 | // Read frame
140 | f, err := newFrame(s.buf.Reader)
141 |
142 | if s.state == stateClosed {
143 | break Read
144 | }
145 |
146 | if err != nil {
147 | // If an error occurred due to something which doesn't conform with
148 | // the websocket rfc, use the error itself as a reason.
149 | if c, k := err.(*CloseError); k {
150 | s.CloseWithError(c)
151 | return
152 | }
153 |
154 | // When EOF returns it means that the other endpoint isn't reachable
155 | // and thus there won't be the need to initate the closing
156 | // handshake.
157 | if err == io.EOF {
158 | s.closeError = &CloseError{
159 | Code: CloseAbnormalClosure,
160 | Reason: "abnormal closure",
161 | }
162 | s.TCPClose()
163 | break Read
164 | }
165 |
166 | // When Read times out or connection is closed the other endpoing
167 | // won't be reachable and thus there won't be the need to initiate
168 | // the closing handshake.
169 | if _, k := err.(*net.OpError); k {
170 | s.closeError = &CloseError{
171 | Code: CloseAbnormalClosure,
172 | Reason: "abnormal closure",
173 | }
174 | s.TCPClose()
175 | break Read
176 | }
177 |
178 | // Else use a generic error.
179 | s.CloseWithError(&CloseError{
180 | Code: CloseProtocolError,
181 | Reason: "protocol error",
182 | })
183 |
184 | return
185 | }
186 |
187 | // If Socket instance represents a server endpoint, payload data must be
188 | // masked.
189 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.1
190 | if s.server && !f.masked {
191 | s.CloseWithError(&CloseError{
192 | Code: CloseProtocolError,
193 | Reason: "expected payload to be masked",
194 | })
195 | return
196 | }
197 |
198 | // If Socket instance represents a client endpoint, payload data must
199 | // not be masked.
200 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.1
201 | if !s.server && f.masked {
202 | s.CloseWithError(&CloseError{
203 | Code: CloseProtocolError,
204 | Reason: "expected payload to not be masked",
205 | })
206 | return
207 | }
208 |
209 | switch f.opcode {
210 | case OpcodeText, OpcodeBinary:
211 | {
212 | s.callReadHandler(f.opcode, f.payload)
213 | }
214 | case OpcodePing:
215 | {
216 | s.callPingHandler(f.payload)
217 | }
218 | case OpcodePong:
219 | {
220 | s.callPongHandler(f.payload)
221 | }
222 | case OpcodeClose:
223 | {
224 | // Create a new CloseError using the payload data
225 | c, cerr := NewCloseError(f.payload)
226 |
227 | // Store close error for close handler.
228 | s.closeError = c
229 |
230 | // If the state of the socket instance is CLOSING, it means that
231 | // the closing handshake has been initiated from this socket
232 | // instance and the retrieved frame was the acknowledge close
233 | // frame. At this point the closing handshake has been completed
234 | // and therefore the underlying tcp connection can be closed,
235 | // since the connected endpoint won't be waiting for furthur
236 | // frames.
237 | if s.state == stateClosing {
238 | // closing handshake has been finalized therefore close tcp
239 | // connection.
240 | s.tcpClose()
241 | // Stop reading from connection.
242 | break Read
243 | }
244 |
245 | // If the state of the socket instance is not CLOSING, it means
246 | // that the closing handshake has been initiated by the
247 | // connected endpoint and therefore it is still waiting for the
248 | // acknowledgement close frame.
249 | s.state = stateClosing
250 |
251 | // The acknowledgment close frame to be sent will echo the
252 | // status code of the close frame just received.
253 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.1
254 | var b []byte
255 |
256 | // If the status code of the close frame received is valid, echo
257 | // it. Else leave the payload data of the acknowledgement close
258 | // frame empty.
259 | if cerr == nil {
260 | b = c.toBytesCode()
261 | }
262 |
263 | // Send acknowledgement close frame.
264 | s.WriteMessage(OpcodeClose, b)
265 |
266 | // At this point the closing handshake would have been finalized
267 | // therefore the tcp connection can be closed.
268 | s.tcpClose()
269 |
270 | // Stop reading from connection.
271 | break Read
272 | }
273 | }
274 | }
275 | }
276 |
277 | // WriteMessage is used to send frames to the connected endpoint. It accepts
278 | // two arguments 'o' opcode, 'p' payload data.
279 | func (s *Socket) WriteMessage(o int, p []byte) error {
280 | s.writeMutex.Lock()
281 | defer s.writeMutex.Unlock()
282 |
283 | // Before writing make sure that the socket instance is still in an open
284 | // state.
285 | if s.state == stateClosed {
286 | return ErrSocketClosed
287 | }
288 |
289 | // Create a frame instance which will represent the frame to be sent.
290 | f := &frame{
291 | fin: true,
292 | opcode: o,
293 | payload: p,
294 | }
295 |
296 | // If the socket instance represents a client endpoint, the payload data
297 | // must be masked.
298 | if !s.server {
299 | // Generate random mask key
300 | f.key = randomByteSlice(1)
301 | }
302 |
303 | // Get a []byte representation of the frame instance.
304 | b, err := f.toBytes()
305 |
306 | // If an error is not nil, since the error doesn't relate with the socket
307 | // connection itself, the error is returned.
308 | if err != nil {
309 | return err
310 | }
311 |
312 | // Send frame
313 | s.buf.Write(b)
314 | if err := s.buf.Flush(); err != nil {
315 | // Store error.
316 | s.closeError = err
317 |
318 | // Close TCP Connection.
319 | s.TCPClose()
320 |
321 | // Since the error is related with the socket connection the error is
322 | // not returned but passed to the close handler.
323 | return nil
324 | }
325 |
326 | // If frame sent is a close frame, change state to closing.
327 | if f.opcode == OpcodeClose {
328 | s.state = stateClosing
329 | }
330 |
331 | return nil
332 | }
333 |
334 | // SetReadDeadline sets the deadline for future Read calls. A zero value for t
335 | // means Read will not time out.
336 | func (s *Socket) SetReadDeadline(t time.Time) {
337 | s.conn.SetReadDeadline(t)
338 | }
339 |
340 | // SetWriteDeadline sets the deadline for future Write calls. Even if write
341 | // times out, it may return n > 0, indicating that some of the data was
342 | // successfully written. A zero value for t means Write will not time out.
343 | func (s *Socket) SetWriteDeadline(t time.Time) {
344 | s.conn.SetWriteDeadline(t)
345 | }
346 |
347 | // callReadHandler invokes the read handler provided by the user (if any).
348 | func (s *Socket) callReadHandler(o int, p []byte) {
349 | if s.ReadHandler != nil {
350 | s.ReadHandler(o, p)
351 | }
352 | }
353 |
354 | // callPingHandler first tries to invoke the ping handler provided by the
355 | // user. If the user hasn't provided one it invokes the default functionality.
356 | func (s *Socket) callPingHandler(p []byte) {
357 | if s.PingHandler != nil {
358 | s.PingHandler(p)
359 | return
360 | }
361 | s.defaultPingHandler(p)
362 | }
363 |
364 | // defaultPingHandler sends a pong frame with the same payload data of the ping
365 | // frame just received.
366 | //
367 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-5.5.3
368 | func (s *Socket) defaultPingHandler(p []byte) {
369 | s.WriteMessage(OpcodePong, p)
370 | }
371 |
372 | // callPongHandler invokes the pong handler provided by the user (if any).
373 | func (s *Socket) callPongHandler(p []byte) {
374 | if s.PongHandler != nil {
375 | s.PongHandler(p)
376 | return
377 | }
378 | }
379 |
380 | // callCloseHandler first tries to invoke the close handler provided by the
381 | // user.
382 | func (s *Socket) callCloseHandler(e error) {
383 | if s.CloseHandler != nil {
384 | s.CloseHandler(e)
385 | }
386 | }
387 |
388 | // TCPClose closes the underlying tcp connection if it hasn't already been
389 | // closed.
390 | func (s *Socket) TCPClose() {
391 | // If socket has already been closed, don't reclose the tcp connection
392 | if s.state == stateClosed {
393 | return
394 | }
395 |
396 | // Change state of socket instance to closed.
397 | s.state = stateClosed
398 |
399 | // Close tcp connection
400 | s.conn.Close()
401 |
402 | // Invoke close handler.
403 | s.callCloseHandler(s.closeError)
404 | }
405 |
406 | // tcpClose closes the underlying tcp connection after s.CloseDelay seconds if
407 | // it hasn't already been closed . More info on why this is needed documented
408 | // in s.CloseDelay.
409 | func (s *Socket) tcpClose() {
410 | // If socket has already been closed, don't reclose the tcp connection
411 | if s.state == stateClosed {
412 | return
413 | }
414 |
415 | if s.CloseDelay > 0 {
416 | t := time.NewTicker(time.Second * s.CloseDelay)
417 | <-t.C
418 | }
419 |
420 | // Close tcp connection
421 | s.TCPClose()
422 | }
423 |
424 | // Close initiates the normal closures (1000) closing handshake.
425 | func (s *Socket) Close() {
426 | s.CloseWithError(&CloseError{
427 | Code: CloseNormalClosure,
428 | Reason: "normal closure",
429 | })
430 | }
431 |
432 | // CloseWithError initiates the closing handshake.
433 | func (s *Socket) CloseWithError(e *CloseError) {
434 | // Store error.
435 | s.closeError = e
436 |
437 | // Start the closing handshake
438 | b, _ := e.ToBytes()
439 | s.WriteMessage(OpcodeClose, b)
440 | }
441 |
--------------------------------------------------------------------------------
/socket_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "strings"
7 | "sync"
8 | "testing"
9 | "time"
10 | )
11 |
12 | func TestSocketReadTextFrame(t *testing.T) {
13 | payload := "expected payload"
14 |
15 | done := make(chan bool)
16 | timeout := time.NewTicker(time.Second * 2)
17 |
18 | h := func(w http.ResponseWriter, r *http.Request) {
19 | q := Request{}
20 | s, err := q.Upgrade(w, r)
21 |
22 | if err != nil {
23 | t.Fatal("unexpected error was returned", err)
24 | }
25 |
26 | s.ReadHandler = func(o int, p []byte) {
27 | if o != OpcodeText {
28 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeText, o)
29 | }
30 |
31 | if string(p) != payload {
32 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload)
33 | }
34 |
35 | done <- true
36 | }
37 |
38 | s.Listen()
39 | }
40 |
41 | s := httptest.NewServer(http.HandlerFunc(h))
42 | defer s.Close()
43 |
44 | d := &Dialer{}
45 | c, _, err := d.Dial(adaptURL(s.URL))
46 |
47 | if err != nil {
48 | t.Fatal("unexpected error returned", err)
49 | }
50 |
51 | defer c.TCPClose()
52 |
53 | f := &frame{
54 | fin: true,
55 | opcode: OpcodeText,
56 | key: []byte{1, 1, 1, 1},
57 | payload: []byte(payload),
58 | }
59 |
60 | b, err := f.toBytes()
61 |
62 | if err != nil {
63 | t.Fatal("unexpected error returned", err)
64 | }
65 |
66 | c.buf.Write(b)
67 | if err := c.buf.Flush(); err != nil {
68 | t.Fatal("unexpected error returned", err)
69 | }
70 |
71 | select {
72 | case <-done:
73 | {
74 |
75 | }
76 | case <-timeout.C:
77 | {
78 | t.Error("test case timed out")
79 | }
80 | }
81 | }
82 |
83 | func TestSocketReadBinaryFrame(t *testing.T) {
84 | payload := "expected payload"
85 |
86 | done := make(chan bool)
87 | timeout := time.NewTicker(time.Second * 2)
88 |
89 | h := func(w http.ResponseWriter, r *http.Request) {
90 | q := Request{}
91 | s, err := q.Upgrade(w, r)
92 |
93 | if err != nil {
94 | t.Fatal("unexpected error was returned", err)
95 | }
96 |
97 | s.ReadHandler = func(o int, p []byte) {
98 | if o != OpcodeBinary {
99 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeBinary, o)
100 | }
101 |
102 | if string(p) != payload {
103 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload)
104 | }
105 |
106 | done <- true
107 | }
108 |
109 | s.Listen()
110 | }
111 |
112 | s := httptest.NewServer(http.HandlerFunc(h))
113 | defer s.Close()
114 |
115 | d := &Dialer{}
116 | c, _, err := d.Dial(adaptURL(s.URL))
117 |
118 | if err != nil {
119 | t.Fatal("unexpected error returned", err)
120 | }
121 |
122 | defer c.TCPClose()
123 |
124 | f := &frame{
125 | fin: true,
126 | opcode: OpcodeBinary,
127 | key: []byte{1, 1, 1, 1},
128 | payload: []byte(payload),
129 | }
130 |
131 | b, err := f.toBytes()
132 |
133 | if err != nil {
134 | t.Fatal("unexpected error returned", err)
135 | }
136 |
137 | c.buf.Write(b)
138 | if err := c.buf.Flush(); err != nil {
139 | t.Fatal("unexpected error returned", err)
140 | }
141 |
142 | select {
143 | case <-done:
144 | {
145 |
146 | }
147 | case <-timeout.C:
148 | {
149 | t.Error("test case timed out")
150 | }
151 | }
152 | }
153 |
154 | func TestSocketReadPingFrame(t *testing.T) {
155 | payload := "expected payload"
156 |
157 | done := make(chan bool)
158 | timeout := time.NewTicker(time.Second * 2)
159 |
160 | h := func(w http.ResponseWriter, r *http.Request) {
161 | q := Request{}
162 | s, err := q.Upgrade(w, r)
163 |
164 | if err != nil {
165 | t.Fatal("unexpected error was returned", err)
166 | }
167 |
168 | s.PingHandler = func(p []byte) {
169 | if string(p) != payload {
170 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload)
171 | }
172 |
173 | done <- true
174 | }
175 |
176 | s.Listen()
177 | }
178 |
179 | s := httptest.NewServer(http.HandlerFunc(h))
180 | defer s.Close()
181 |
182 | d := &Dialer{}
183 | c, _, err := d.Dial(adaptURL(s.URL))
184 |
185 | if err != nil {
186 | t.Fatal("unexpected error returned", err)
187 | }
188 |
189 | defer c.TCPClose()
190 |
191 | f := &frame{
192 | fin: true,
193 | opcode: OpcodePing,
194 | key: []byte{1, 1, 1, 1},
195 | payload: []byte(payload),
196 | }
197 |
198 | b, err := f.toBytes()
199 |
200 | if err != nil {
201 | t.Fatal("unexpected error returned", err)
202 | }
203 |
204 | c.buf.Write(b)
205 | if err := c.buf.Flush(); err != nil {
206 | t.Fatal("unexpected error returned", err)
207 | }
208 |
209 | select {
210 | case <-done:
211 | {
212 |
213 | }
214 | case <-timeout.C:
215 | {
216 | t.Error("test case timed out")
217 | }
218 | }
219 | }
220 |
221 | func TestSocketReadPongFrame(t *testing.T) {
222 | payload := "expected payload"
223 |
224 | done := make(chan bool)
225 | timeout := time.NewTicker(time.Second * 2)
226 |
227 | h := func(w http.ResponseWriter, r *http.Request) {
228 | q := Request{}
229 | s, err := q.Upgrade(w, r)
230 |
231 | if err != nil {
232 | t.Fatal("unexpected error was returned", err)
233 | }
234 |
235 | s.PongHandler = func(p []byte) {
236 | if string(p) != payload {
237 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload)
238 | }
239 |
240 | done <- true
241 | }
242 |
243 | s.Listen()
244 | }
245 |
246 | s := httptest.NewServer(http.HandlerFunc(h))
247 | defer s.Close()
248 |
249 | d := &Dialer{}
250 | c, _, err := d.Dial(adaptURL(s.URL))
251 |
252 | if err != nil {
253 | t.Fatal("unexpected error returned", err)
254 | }
255 |
256 | defer c.TCPClose()
257 |
258 | f := &frame{
259 | fin: true,
260 | opcode: OpcodePong,
261 | key: []byte{1, 1, 1, 1},
262 | payload: []byte(payload),
263 | }
264 |
265 | b, err := f.toBytes()
266 |
267 | if err != nil {
268 | t.Fatal("unexpected error returned", err)
269 | }
270 |
271 | c.buf.Write(b)
272 | if err := c.buf.Flush(); err != nil {
273 | t.Fatal("unexpected error returned", err)
274 | }
275 |
276 | select {
277 | case <-done:
278 | {
279 |
280 | }
281 | case <-timeout.C:
282 | {
283 | t.Error("test case timed out")
284 | }
285 | }
286 | }
287 |
288 | func TestSocketdefaultPingHandler(t *testing.T) {
289 | payload := "expected payload"
290 |
291 | done := make(chan bool)
292 | timeout := time.NewTicker(time.Second * 2)
293 |
294 | h := func(w http.ResponseWriter, r *http.Request) {
295 | q := Request{}
296 | s, err := q.Upgrade(w, r)
297 |
298 | if err != nil {
299 | t.Fatal("unexpected error was returned", err)
300 | }
301 |
302 | s.Listen()
303 | }
304 |
305 | s := httptest.NewServer(http.HandlerFunc(h))
306 | defer s.Close()
307 |
308 | d := &Dialer{}
309 | c, _, err := d.Dial(adaptURL(s.URL))
310 |
311 | if err != nil {
312 | t.Fatal("unexpected error returned", err)
313 | }
314 |
315 | defer c.TCPClose()
316 |
317 | f := &frame{
318 | fin: true,
319 | opcode: OpcodePing,
320 | key: []byte{1, 1, 1, 1},
321 | payload: []byte(payload),
322 | }
323 |
324 | b, err := f.toBytes()
325 |
326 | if err != nil {
327 | t.Fatal("unexpected error returned", err)
328 | }
329 |
330 | c.PongHandler = func(p []byte) {
331 | if string(p) != payload {
332 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload)
333 | }
334 | done <- true
335 | }
336 |
337 | go c.Listen()
338 |
339 | c.buf.Write(b)
340 | if err := c.buf.Flush(); err != nil {
341 | t.Fatal("unexpected error returned", err)
342 | }
343 |
344 | select {
345 | case <-done:
346 | {
347 |
348 | }
349 | case <-timeout.C:
350 | {
351 | t.Error("test case timed out")
352 | }
353 | }
354 | }
355 |
356 | func TestSocketReadInvalidFrame(t *testing.T) {
357 | done := make(chan bool)
358 | timeout := time.NewTicker(time.Second * 2)
359 |
360 | h := func(w http.ResponseWriter, r *http.Request) {
361 | q := Request{}
362 | s, err := q.Upgrade(w, r)
363 |
364 | if err != nil {
365 | t.Fatal("unexpected error was returned", err)
366 | }
367 |
368 | s.ReadHandler = func(o int, p []byte) {
369 | t.Error("unexpected invocation of Read Handler")
370 | }
371 |
372 | s.Listen()
373 | }
374 |
375 | s := httptest.NewServer(http.HandlerFunc(h))
376 | defer s.Close()
377 |
378 | d := &Dialer{}
379 | c, _, err := d.Dial(adaptURL(s.URL))
380 |
381 | if err != nil {
382 | t.Fatal("unexpected error returned", err)
383 | }
384 |
385 | defer c.TCPClose()
386 |
387 | c.CloseHandler = func(err error) {
388 | if e, k := err.(*CloseError); k {
389 | if e.Code != CloseProtocolError {
390 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseProtocolError, e.Code)
391 | }
392 | } else {
393 | t.Errorf("expected error instance to be of type *CloseError")
394 | }
395 | done <- true
396 | }
397 |
398 | go c.Listen()
399 |
400 | c.buf.Write([]byte("bad frame"))
401 | if err := c.buf.Flush(); err != nil {
402 | t.Fatal("unexpected error returned", err)
403 | }
404 |
405 | select {
406 | case <-done:
407 | {
408 |
409 | }
410 | case <-timeout.C:
411 | {
412 | t.Error("test case timed out")
413 | }
414 | }
415 | }
416 |
417 | func TestSocketReadClientUnMaskedFrame(t *testing.T) {
418 | done := make(chan bool)
419 | timeout := time.NewTicker(time.Second * 2)
420 |
421 | h := func(w http.ResponseWriter, r *http.Request) {
422 | q := Request{}
423 | s, err := q.Upgrade(w, r)
424 |
425 | if err != nil {
426 | t.Fatal("unexpected error was returned", err)
427 | }
428 |
429 | s.ReadHandler = func(o int, p []byte) {
430 | t.Errorf("unexpected invocation of Read Handler")
431 | }
432 |
433 | s.Listen()
434 | }
435 |
436 | s := httptest.NewServer(http.HandlerFunc(h))
437 | defer s.Close()
438 |
439 | d := &Dialer{}
440 | c, _, err := d.Dial(adaptURL(s.URL))
441 |
442 | if err != nil {
443 | t.Fatal("unexpected error returned", err)
444 | }
445 |
446 | defer c.TCPClose()
447 |
448 | f := &frame{
449 | fin: true,
450 | opcode: OpcodeText,
451 | payload: []byte("something"),
452 | }
453 |
454 | b, err := f.toBytes()
455 |
456 | if err != nil {
457 | t.Fatal("unexpected error returned", err)
458 | }
459 |
460 | c.CloseHandler = func(err error) {
461 | if e, k := err.(*CloseError); k {
462 | r := "expected payload to be masked"
463 |
464 | if e.Code != CloseProtocolError {
465 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseProtocolError, e.Code)
466 | }
467 |
468 | if e.Reason != r {
469 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason)
470 | }
471 | } else {
472 | t.Errorf("expected error instance to be of type *CloseError")
473 | }
474 | done <- true
475 | }
476 |
477 | go c.Listen()
478 |
479 | c.buf.Write(b)
480 | if err := c.buf.Flush(); err != nil {
481 | t.Fatal("unexpected error returned", err)
482 | }
483 |
484 | select {
485 | case <-done:
486 | {
487 |
488 | }
489 | case <-timeout.C:
490 | {
491 | t.Error("test case timed out")
492 | }
493 | }
494 | }
495 |
496 | func TestSocketReadServerMaskedFrame(t *testing.T) {
497 | done := make(chan bool)
498 | timeout := time.NewTicker(time.Second * 2)
499 |
500 | h := func(w http.ResponseWriter, r *http.Request) {
501 | q := Request{}
502 | s, err := q.Upgrade(w, r)
503 |
504 | if err != nil {
505 | t.Fatal("unexpected error was returned", err)
506 | }
507 |
508 | f := &frame{
509 | fin: true,
510 | opcode: OpcodeText,
511 | key: []byte{1, 1, 1, 1},
512 | payload: []byte("something"),
513 | }
514 |
515 | b, err := f.toBytes()
516 |
517 | if err != nil {
518 | t.Fatal("unexpected error returned", err)
519 | }
520 |
521 | s.CloseHandler = func(err error) {
522 | if e, k := err.(*CloseError); k {
523 | r := "expected payload to not be masked"
524 |
525 | if e.Code != CloseProtocolError {
526 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseProtocolError, e.Code)
527 | }
528 |
529 | if e.Reason != r {
530 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason)
531 | }
532 | } else {
533 | t.Errorf("expected error instance to be of type *CloseError")
534 | }
535 | done <- true
536 | }
537 |
538 | s.buf.Write(b)
539 | if err := s.buf.Flush(); err != nil {
540 | t.Error("unexpected error returned", err)
541 | }
542 |
543 | s.Listen()
544 | }
545 |
546 | s := httptest.NewServer(http.HandlerFunc(h))
547 | defer s.Close()
548 |
549 | d := &Dialer{}
550 | c, _, err := d.Dial(adaptURL(s.URL))
551 |
552 | if err != nil {
553 | t.Fatal("unexpected error returned", err)
554 | }
555 |
556 | defer c.TCPClose()
557 |
558 | c.ReadHandler = func(o int, p []byte) {
559 | t.Errorf("unexpected invocation of Read Handler")
560 | }
561 |
562 | go c.Listen()
563 |
564 | select {
565 | case <-done:
566 | {
567 |
568 | }
569 | case <-timeout.C:
570 | {
571 | t.Error("test case timed out")
572 | }
573 | }
574 | }
575 |
576 | func TestSocketClose(t *testing.T) {
577 | done := make(chan bool)
578 | timeout := time.NewTicker(time.Second * 2)
579 |
580 | h := func(w http.ResponseWriter, r *http.Request) {
581 | q := Request{}
582 | s, err := q.Upgrade(w, r)
583 |
584 | if err != nil {
585 | t.Fatal("unexpected error was returned", err)
586 | }
587 |
588 | s.Listen()
589 | }
590 |
591 | s := httptest.NewServer(http.HandlerFunc(h))
592 | defer s.Close()
593 |
594 | d := &Dialer{}
595 | c, _, err := d.Dial(adaptURL(s.URL))
596 |
597 | if err != nil {
598 | t.Fatal("unexpected error returned", err)
599 | }
600 |
601 | c.CloseHandler = func(err error) {
602 | if e, k := err.(*CloseError); k {
603 | if e.Code != CloseNormalClosure {
604 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseNormalClosure, e.Code)
605 | }
606 |
607 | if e.Reason != "" {
608 | t.Errorf(`expected Close Error Reason to be empty, but it is "%s"`, e.Reason)
609 | }
610 | } else {
611 | t.Errorf("expected error instance to be of type *CloseError")
612 | }
613 | done <- true
614 | }
615 |
616 | go c.Listen()
617 |
618 | c.Close()
619 |
620 | select {
621 | case <-done:
622 | {
623 | }
624 | case <-timeout.C:
625 | {
626 | t.Error("test case timed out")
627 | }
628 | }
629 | }
630 |
631 | func TestSocketReadEOFError(t *testing.T) {
632 | done := make(chan bool)
633 | timeout := time.NewTicker(time.Second * 2)
634 |
635 | h := func(w http.ResponseWriter, r *http.Request) {
636 | q := Request{}
637 | s, err := q.Upgrade(w, r)
638 |
639 | if err != nil {
640 | t.Fatal("unexpected error was returned", err)
641 | }
642 |
643 | s.TCPClose()
644 | }
645 |
646 | s := httptest.NewServer(http.HandlerFunc(h))
647 | defer s.Close()
648 |
649 | d := &Dialer{}
650 | c, _, err := d.Dial(adaptURL(s.URL))
651 |
652 | if err != nil {
653 | t.Fatal("unexpected error returned", err)
654 | }
655 |
656 | defer c.TCPClose()
657 |
658 | c.CloseHandler = func(err error) {
659 | if e, k := err.(*CloseError); k {
660 | r := "abnormal closure"
661 | if e.Code != CloseAbnormalClosure {
662 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseAbnormalClosure, e.Code)
663 | }
664 |
665 | if e.Reason != r {
666 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason)
667 | }
668 | } else {
669 | t.Errorf("expected error instance to be of type *CloseError")
670 | }
671 | done <- true
672 | }
673 |
674 | go c.Listen()
675 |
676 | select {
677 | case <-done:
678 | {
679 |
680 | }
681 | case <-timeout.C:
682 | {
683 | t.Error("test case timed out")
684 | }
685 | }
686 | }
687 |
688 | func TestSocketReadTimeoutError(t *testing.T) {
689 | done := make(chan bool)
690 | timeout := time.NewTicker(time.Second * 4)
691 |
692 | h := func(w http.ResponseWriter, r *http.Request) {
693 | q := Request{}
694 | s, err := q.Upgrade(w, r)
695 |
696 | if err != nil {
697 | t.Fatal("unexpected error was returned", err)
698 | }
699 |
700 | s.Listen()
701 | }
702 |
703 | s := httptest.NewServer(http.HandlerFunc(h))
704 | defer s.Close()
705 |
706 | d := &Dialer{}
707 | c, _, err := d.Dial(adaptURL(s.URL))
708 |
709 | if err != nil {
710 | t.Fatal("unexpected error returned", err)
711 | }
712 |
713 | defer c.TCPClose()
714 |
715 | c.CloseHandler = func(err error) {
716 | if e, k := err.(*CloseError); k {
717 | r := "abnormal closure"
718 | if e.Code != CloseAbnormalClosure {
719 | t.Errorf("expected Close Error Code to be '%d', but it is '%d'", CloseAbnormalClosure, e.Code)
720 | }
721 |
722 | if e.Reason != r {
723 | t.Errorf(`expected Close Error Reason to be "%s", but it is "%s"`, r, e.Reason)
724 | }
725 | } else {
726 | t.Errorf("expected error instance to be of type *CloseError")
727 | }
728 | done <- true
729 | }
730 |
731 | go c.Listen()
732 |
733 | c.SetReadDeadline(time.Now().Add(time.Second * 1))
734 |
735 | select {
736 | case <-done:
737 | {
738 |
739 | }
740 | case <-timeout.C:
741 | {
742 | t.Error("test case timed out")
743 | }
744 | }
745 | }
746 |
747 | func TestSocketWriteTimeoutErorr(t *testing.T) {
748 | done := make(chan bool)
749 | timeout := time.NewTicker(time.Second * 4)
750 |
751 | h := func(w http.ResponseWriter, r *http.Request) {
752 | q := Request{}
753 | s, err := q.Upgrade(w, r)
754 |
755 | if err != nil {
756 | t.Fatal("unexpected error was returned", err)
757 | }
758 |
759 | s.CloseHandler = func(err error) {
760 | done <- true
761 | }
762 |
763 | s.SetWriteDeadline(time.Now().Add(time.Second * 1))
764 |
765 | go s.Listen()
766 |
767 | time.Sleep(time.Second * 2)
768 |
769 | s.WriteMessage(OpcodeText, []byte("something"))
770 | }
771 |
772 | s := httptest.NewServer(http.HandlerFunc(h))
773 | defer s.Close()
774 |
775 | d := &Dialer{}
776 | c, _, err := d.Dial(adaptURL(s.URL))
777 |
778 | if err != nil {
779 | t.Fatal("unexpected error returned", err)
780 | }
781 |
782 | defer c.TCPClose()
783 |
784 | select {
785 | case <-done:
786 | {
787 |
788 | }
789 | case <-timeout.C:
790 | {
791 | t.Error("test case timed out")
792 | }
793 | }
794 | }
795 |
796 | func TestSocketWriteFromClient(t *testing.T) {
797 | payload := "expected payload"
798 |
799 | done := make(chan bool)
800 | timeout := time.NewTicker(time.Second * 2)
801 |
802 | h := func(w http.ResponseWriter, r *http.Request) {
803 | q := Request{}
804 | s, err := q.Upgrade(w, r)
805 |
806 | if err != nil {
807 | t.Fatal("unexpected error was returned", err)
808 | }
809 |
810 | s.ReadHandler = func(o int, p []byte) {
811 | if o != OpcodeText {
812 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeText, o)
813 | }
814 |
815 | if string(p) != payload {
816 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload)
817 | }
818 |
819 | done <- true
820 | }
821 |
822 | s.Listen()
823 | }
824 |
825 | s := httptest.NewServer(http.HandlerFunc(h))
826 | defer s.Close()
827 |
828 | d := &Dialer{}
829 | c, _, err := d.Dial(adaptURL(s.URL))
830 |
831 | if err != nil {
832 | t.Fatal("unexpected error was returned", err)
833 | }
834 |
835 | defer c.TCPClose()
836 |
837 | if err := c.WriteMessage(OpcodeText, []byte(payload)); err != nil {
838 | t.Fatal("unexpected error returned", err)
839 | }
840 |
841 | select {
842 | case <-done:
843 | {
844 |
845 | }
846 | case <-timeout.C:
847 | {
848 | t.Error("test case timed out")
849 | }
850 | }
851 | }
852 |
853 | func TestSocketWriteFromServer(t *testing.T) {
854 | payload := "expected payload"
855 |
856 | done := make(chan bool)
857 | timeout := time.NewTicker(time.Second * 2)
858 |
859 | h := func(w http.ResponseWriter, r *http.Request) {
860 | q := Request{}
861 | s, err := q.Upgrade(w, r)
862 |
863 | if err != nil {
864 | t.Fatal("unexpected error was returned", err)
865 | }
866 |
867 | if err := s.WriteMessage(OpcodeText, []byte(payload)); err != nil {
868 | t.Fatal("unexpected error was returned", err)
869 | }
870 | }
871 |
872 | s := httptest.NewServer(http.HandlerFunc(h))
873 | defer s.Close()
874 |
875 | d := &Dialer{}
876 | c, _, err := d.Dial(adaptURL(s.URL))
877 |
878 | if err != nil {
879 | t.Fatal("unexpected error returned", err)
880 | }
881 |
882 | defer c.TCPClose()
883 |
884 | c.ReadHandler = func(o int, p []byte) {
885 | if o != OpcodeText {
886 | t.Errorf("expected opcode to be '%d' but it is '%d'", OpcodeText, o)
887 | }
888 |
889 | if string(p) != payload {
890 | t.Errorf(`expected payload to be "%s" but it is "%s"`, p, payload)
891 | }
892 |
893 | done <- true
894 | }
895 |
896 | go c.Listen()
897 |
898 | select {
899 | case <-done:
900 | {
901 |
902 | }
903 | case <-timeout.C:
904 | {
905 | t.Error("test case timed out")
906 | }
907 | }
908 | }
909 |
910 | func TestSocketWriteWhenClosed(t *testing.T) {
911 | s := &Socket{
912 | writeMutex: &sync.Mutex{},
913 | }
914 | s.state = stateClosed
915 |
916 | if err := s.WriteMessage(1, []byte("test")); err != ErrSocketClosed {
917 | t.Errorf(`expected error "%s", but got "%v"`, ErrSocketClosed, err)
918 | }
919 | }
920 |
921 | func adaptURL(u string) string {
922 | return strings.Replace(u, "http://", "ws://", 1)
923 | }
924 |
--------------------------------------------------------------------------------
/utils.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "bufio"
5 | "crypto/sha1"
6 | "encoding/base64"
7 | "encoding/binary"
8 | "io"
9 | "math/rand"
10 | "strings"
11 | "time"
12 | )
13 |
14 | // wsAcceptSalt is the GUID used by the WebSocket protocol to generate the
15 | // value for the "Sec-Websocket-Accept" response HTTP Header field.
16 | const wsAcceptSalt string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
17 |
18 | // makeAcceptKey is used to generate the Accept Key which is then sent to the
19 | // client using the 'Sec-Websocket-Accept' Response Header Field. This is used
20 | // to prevent an attacker from ticking the server.
21 | //
22 | // Ref Spec: https://tools.ietf.org/html/rfc6455#section-1.3
23 | func makeAcceptKey(k string) string {
24 | h := sha1.New()
25 | io.WriteString(h, k+wsAcceptSalt)
26 | return base64.StdEncoding.EncodeToString(h.Sum(nil))
27 | }
28 |
29 | // readFromBuffer reads from the buffer (b) provided the number of specified
30 | // bytes (l).
31 | func readFromBuffer(b *bufio.Reader, l uint64) ([]byte, error) {
32 | p := make([]byte, l)
33 |
34 | // If the number of buffered bytes will accommodate the number of bytes
35 | // requested, read once and return the read bytes.
36 | if uint64(b.Buffered()) >= l {
37 | _, err := b.Read(p)
38 | return p, err
39 | }
40 |
41 | // If the user requires more bytes than there is buffered, the buffer will
42 | // be read from multiple times.
43 |
44 | // Total number of bytes read from buffer.
45 | n := 0
46 |
47 | for {
48 | // Temporary slice of bytes.
49 | t := make([]byte, l)
50 |
51 | // Read from buffer and put read bytes in temporary slice of bytes.
52 | i, err := b.Read(t)
53 |
54 | if err != nil {
55 | return nil, err
56 | }
57 |
58 | // Append bytes to the slice of bytes to be returned.
59 | p = append(p[:n], t[:i]...)
60 |
61 | // Increment the total number of bytes with the bytes read.
62 | n += i
63 |
64 | // If the total number of bytes is the same as the number of bytes
65 | // requested, stop read operation and read bytes.
66 | if uint64(n) == l {
67 | break
68 | }
69 | }
70 |
71 | return p, nil
72 | }
73 |
74 | // stringExists is a utility function used to check whether a slice of string
75 | // ('l') contains a particular value ('k'). If it does, its position will be
76 | // returned otherwise '-1' is returned.
77 | func stringExists(l []string, k string) int {
78 | for i, v := range l {
79 | if k == v {
80 | return i
81 | }
82 | }
83 |
84 | return -1
85 | }
86 |
87 | // headerToSlice is used to turn the values of a multi value HTTP Header field
88 | // to a slice of string.
89 | //
90 | // From RFC2616: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
91 | func headerToSlice(v string) []string {
92 | l := strings.Split(v, ",")
93 |
94 | for i, v := range l {
95 | l[i] = strings.Trim(v, " ")
96 | }
97 |
98 | return l
99 | }
100 |
101 | // randomByteSlice is used to generate a byte slice of random 32 bit integers.
102 | func randomByteSlice(i int) []byte {
103 | // Slice of bytes which will grow to be 16 bytes in length once the
104 | // operation is ready. This slice will then be used to generate the key to
105 | // be sent with the clients opening handshake using the Sec-Websocket-Key
106 | // Header.
107 | var b []byte
108 |
109 | // Set seed.
110 | rand.Seed(time.Now().UnixNano())
111 |
112 | // The challenge key must be 16 bytes in length.
113 | for l := 0; l < i; l++ {
114 | // Temp slice
115 | t := make([]byte, 4)
116 |
117 | // Generate a random 32bit number and store its binary value in 't'.
118 | binary.BigEndian.PutUint32(t, rand.Uint32())
119 |
120 | // Finally append the random generated number to 'b'.
121 | b = append(b, t...)
122 | }
123 |
124 | return b
125 | }
126 |
127 | // closeErrorExist returns whether the error number provided as an argument is
128 | // a valid error number or not.
129 | func closeErrorExist(i int) bool {
130 | switch i {
131 | case CloseNormalClosure, CloseGoingAway, CloseProtocolError, CloseUnsupportedData, CloseNoStatusReceived, CloseAbnormalClosure, CloseInvalidFramePayloadData, ClosePolicyViolation, CloseMessageTooBig, CloseMandatoryExtension, CloseInternalServerErr, CloseTLSHandshake:
132 | {
133 | return true
134 | }
135 | }
136 | return false
137 | }
138 |
--------------------------------------------------------------------------------
/utils_test.go:
--------------------------------------------------------------------------------
1 | package websocket
2 |
3 | import (
4 | "bufio"
5 | "math/rand"
6 | "strings"
7 | "testing"
8 | )
9 |
10 | func TestMakeAcceptKey(t *testing.T) {
11 | e := "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
12 | k := makeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==")
13 | if k != e {
14 | t.Errorf(`expected "%s" instead "%s" was returned.`, e, k)
15 | }
16 | }
17 |
18 | type payloadMock struct {
19 | p []byte
20 | }
21 |
22 | func (m *payloadMock) Read(p []byte) (int, error) {
23 | n := 0
24 |
25 | for i := range m.p {
26 | if i == len(p) {
27 | break
28 | }
29 |
30 | p[i] = m.p[i]
31 | n++
32 | }
33 |
34 | m.p = append(make([]byte, 0), m.p[n:]...)
35 |
36 | return n, nil
37 | }
38 |
39 | func newBuffer(d []byte) *bufio.Reader {
40 | p := &payloadMock{p: d}
41 | r := bufio.NewReader(p)
42 | return r
43 | }
44 |
45 | func TestReadFromBufferSingleRead(t *testing.T) {
46 | var c uint64 = 3
47 | p := []byte{120, 123, 54, 32, 102}
48 | b := newBuffer(p)
49 |
50 | n, err := readFromBuffer(b, c)
51 |
52 | if err != nil {
53 | t.Fatal("An unexpected error was returned while invoking readFromBuffer():", err)
54 | }
55 |
56 | if uint64(len(n)) != c {
57 | t.Errorf("Expected slice of bytes returned from readFromBuffer to be of the length '%d'. Instead it is '%d'.", c, len(n))
58 | }
59 |
60 | for i, v := range n {
61 | if v != p[i] {
62 | t.Fatalf("Expected slice of bytes to be '%v'. Instead it is '%v'.", p[:c], n)
63 | }
64 | }
65 | }
66 |
67 | func TestReadFromBufferMultiRead(t *testing.T) {
68 | // The slice to be read from the buffer must be greater than 4096. Since
69 | // this is the default size of a bufio buffer.
70 | // GO Ref: https://golang.org/src/bufio/bufio.go#L18
71 | p := make([]byte, 4100)
72 |
73 | for i := range p {
74 | rand.Seed(int64(i))
75 | p[i] = byte(rand.Intn(255))
76 | }
77 |
78 | b := newBuffer(p)
79 |
80 | readFromBuffer(b, 4090)
81 | n, err := readFromBuffer(b, 10)
82 |
83 | if err != nil {
84 | t.Error("Unexpected error was returned while invoking readFromBuffer:", err)
85 | }
86 |
87 | for i, v := range n {
88 | if v != p[i+4090] {
89 | t.Errorf("%v != %v", p[i+4090], v)
90 | }
91 | }
92 | }
93 |
94 | func TestStringExists(t *testing.T) {
95 | l := []string{"one", "two", "three"}
96 |
97 | type testCase struct {
98 | k string
99 | v int
100 | }
101 |
102 | testCases := []testCase{
103 | {k: "one", v: 0},
104 | {k: "four", v: -1},
105 | }
106 |
107 | for i, c := range testCases {
108 | r := stringExists(l, c.k)
109 |
110 | if r != c.v {
111 | t.Errorf(`Test Case %d: Expected stringExists("%s") to return '%d' instead returned '%d'`, i, c.k, c.v, r)
112 | }
113 | }
114 | }
115 |
116 | func TestHeaderToSlice(t *testing.T) {
117 | l := []string{" both ", " left", "right ", "none"}
118 |
119 | r := headerToSlice(strings.Join(l, ","))
120 |
121 | if len(l) != len(r) {
122 | t.Errorf("The length of the list of header value are not the same. '%d' != '%d'.", len(l), len(r))
123 | }
124 |
125 | if r[0] != "both" {
126 | t.Errorf(`Expected "both" instead got "%s".`, r[0])
127 | }
128 |
129 | if r[1] != "left" {
130 | t.Errorf(`Expected "left" instead got "%s".`, r[1])
131 | }
132 |
133 | if r[2] != "right" {
134 | t.Errorf(`Expected "right" instead got "%s".`, r[2])
135 | }
136 |
137 | if r[3] != "none" {
138 | t.Errorf(`Expected "none" instead got "%s".`, r[3])
139 | }
140 | }
141 |
142 | func TestRandomByteSlice(t *testing.T) {
143 | type testCase struct {
144 | l int
145 | }
146 |
147 | testCases := []testCase{
148 | {l: 2},
149 | {l: 6},
150 | }
151 |
152 | for i, c := range testCases {
153 | if b := randomByteSlice(c.l); len(b) != c.l*4 {
154 | t.Errorf("test case %d: expected slice of bytes to be '%d' in length, but it is '%d'", i, c.l*4, len(b))
155 | }
156 | }
157 | }
158 |
159 | func TestCloseErrorExist(t *testing.T) {
160 | type testCase struct {
161 | e int
162 | v bool
163 | }
164 |
165 | testCases := []testCase{
166 | // Should return false when opcode is invalid
167 | {e: 15, v: false},
168 | // Should return true when opcode is valid.
169 | {e: CloseNormalClosure, v: true},
170 | }
171 |
172 | for i, c := range testCases {
173 | if v := closeErrorExist(c.e); v != c.v {
174 | t.Errorf("test case %d: expected '%t' for '%d'", i, c.v, c.e)
175 | }
176 | }
177 | }
178 |
--------------------------------------------------------------------------------