├── util.go ├── json.go ├── README.md ├── client.go ├── server.go └── conn.go /util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "crypto/rand" 9 | "crypto/sha1" 10 | "encoding/base64" 11 | "io" 12 | "strings" 13 | ) 14 | 15 | // tokenListContainsValue returns true if the 1#token header with the given 16 | // name contains token. 17 | func tokenListContainsValue(name string, value string) bool { 18 | for _, s := range strings.Split(name, ",") { 19 | if strings.EqualFold(value, strings.TrimSpace(s)) { 20 | return true 21 | } 22 | } 23 | return false 24 | } 25 | 26 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 27 | 28 | func computeAcceptKey(challengeKey string) string { 29 | h := sha1.New() 30 | h.Write([]byte(challengeKey)) 31 | h.Write(keyGUID) 32 | return base64.StdEncoding.EncodeToString(h.Sum(nil)) 33 | } 34 | 35 | func generateChallengeKey() (string, error) { 36 | p := make([]byte, 16) 37 | if _, err := io.ReadFull(rand.Reader, p); err != nil { 38 | return "", err 39 | } 40 | return base64.StdEncoding.EncodeToString(p), nil 41 | } 42 | -------------------------------------------------------------------------------- /json.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "encoding/json" 9 | "io" 10 | ) 11 | 12 | // WriteJSON is deprecated, use c.WriteJSON instead. 13 | func WriteJSON(c *Conn, v interface{}) error { 14 | return c.WriteJSON(v) 15 | } 16 | 17 | // WriteJSON writes the JSON encoding of v to the connection. 18 | // 19 | // See the documentation for encoding/json Marshal for details about the 20 | // conversion of Go values to JSON. 21 | func (c *Conn) WriteJSON(v interface{}) error { 22 | w, err := c.NextWriter(TextMessage) 23 | if err != nil { 24 | return err 25 | } 26 | err1 := json.NewEncoder(w).Encode(v) 27 | err2 := w.Close() 28 | if err1 != nil { 29 | return err1 30 | } 31 | return err2 32 | } 33 | 34 | // ReadJSON is deprecated, use c.ReadJSON instead. 35 | func ReadJSON(c *Conn, v interface{}) error { 36 | return c.ReadJSON(v) 37 | } 38 | 39 | // ReadJSON reads the next JSON-encoded message from the connection and stores 40 | // it in the value pointed to by v. 41 | // 42 | // See the documentation for the encoding/json Unmarshal function for details 43 | // about the conversion of JSON to a Go value. 44 | func (c *Conn) ReadJSON(v interface{}) error { 45 | _, r, err := c.NextReader() 46 | if err != nil { 47 | return err 48 | } 49 | err = json.NewDecoder(r).Decode(v) 50 | if err == io.EOF { 51 | // One value is expected in the message. 52 | err = io.ErrUnexpectedEOF 53 | } 54 | return err 55 | } 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Websockets 2 | 3 | The package was converted to work with fasthttp, is fork of https://github.com/gorilla/websocket. 4 | 5 | **WebSocket is a protocol providing full-duplex communication channels over a single TCP connection**. The WebSocket protocol was standardized by the IETF as RFC 6455 in 2011, and the WebSocket API in Web IDL is being standardized by the W3C. 6 | 7 | WebSocket is designed to be implemented in web browsers and web servers, but it can be used by any client or server application. The WebSocket Protocol is an independent TCP-based protocol. Its only relationship to HTTP is that its handshake is interpreted by HTTP servers as an Upgrade request. The WebSocket protocol makes more interaction between a browser and a website possible, **facilitating the real-time data transfer from and to the server**. 8 | 9 | [Read more about Websockets](https://en.wikipedia.org/wiki/WebSocket) 10 | 11 | ----- 12 | 13 | How to use 14 | 15 | ```go 16 | import ( 17 | "github.com/fasthttp-contrib/websocket" 18 | "github.com/vayala/fasthttp" 19 | ) 20 | 21 | func chat(c *websocket.Conn) { 22 | // defer c.Close() 23 | // mt, message, err := c.ReadMessage() 24 | // c.WriteMessage(mt, message) 25 | } 26 | 27 | var upgrader = websocket.New(chat) // use default options 28 | //var upgrader = websocket.Custom(chat, 1024, 1024) // customized options, read and write buffer sizes (int). Default: 4096 29 | // var upgrader = websocket.New(chat).DontCheckOrigin() // it's useful when you have the websocket server on a different machine 30 | 31 | func myChatHandler(ctx *fasthttp.RequestCtx) { 32 | err := upgrader.Upgrade(ctx)// returns only error, executes the handler you defined on the websocket.New before (the 'chat' function) 33 | } 34 | 35 | func main() { 36 | fasthttp.ListenAndServe(":8080", myChatHandler) 37 | } 38 | 39 | ``` 40 | 41 | If you want to see more examples just go [here](https://github.com/gorilla/websocket/tree/master/examples) and make the conversions as you see in 'How to use' before. 42 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "bytes" 10 | "crypto/tls" 11 | "encoding/base64" 12 | "errors" 13 | "io" 14 | "io/ioutil" 15 | "net" 16 | "net/http" 17 | "net/url" 18 | "strings" 19 | "time" 20 | ) 21 | 22 | // ErrBadHandshake is returned when the server response to opening handshake is 23 | // invalid. 24 | var ErrBadHandshake = errors.New("websocket: bad handshake") 25 | 26 | // NewClient creates a new client connection using the given net connection. 27 | // The URL u specifies the host and request URI. Use requestHeader to specify 28 | // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies 29 | // (Cookie). Use the response.Header to get the selected subprotocol 30 | // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 31 | // 32 | // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 33 | // non-nil *http.Response so that callers can handle redirects, authentication, 34 | // etc. 35 | // 36 | // Deprecated: Use Dialer instead. 37 | func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { 38 | d := Dialer{ 39 | ReadBufferSize: readBufSize, 40 | WriteBufferSize: writeBufSize, 41 | NetDial: func(net, addr string) (net.Conn, error) { 42 | return netConn, nil 43 | }, 44 | } 45 | return d.Dial(u.String(), requestHeader) 46 | } 47 | 48 | // A Dialer contains options for connecting to WebSocket server. 49 | type Dialer struct { 50 | // NetDial specifies the dial function for creating TCP connections. If 51 | // NetDial is nil, net.Dial is used. 52 | NetDial func(network, addr string) (net.Conn, error) 53 | 54 | // Proxy specifies a function to return a proxy for a given 55 | // Request. If the function returns a non-nil error, the 56 | // request is aborted with the provided error. 57 | // If Proxy is nil or returns a nil *URL, no proxy is used. 58 | Proxy func(*http.Request) (*url.URL, error) 59 | 60 | // TLSClientConfig specifies the TLS configuration to use with tls.Client. 61 | // If nil, the default configuration is used. 62 | TLSClientConfig *tls.Config 63 | 64 | // HandshakeTimeout specifies the duration for the handshake to complete. 65 | HandshakeTimeout time.Duration 66 | 67 | // Input and output buffer sizes. If the buffer size is zero, then a 68 | // default value of 4096 is used. 69 | ReadBufferSize, WriteBufferSize int 70 | 71 | // Subprotocols specifies the client's requested subprotocols. 72 | Subprotocols []string 73 | } 74 | 75 | var errMalformedURL = errors.New("malformed ws or wss URL") 76 | 77 | // parseURL parses the URL. 78 | // 79 | // This function is a replacement for the standard library url.Parse function. 80 | // In Go 1.4 and earlier, url.Parse loses information from the path. 81 | func parseURL(s string) (*url.URL, error) { 82 | // From the RFC: 83 | // 84 | // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] 85 | // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] 86 | 87 | var u url.URL 88 | switch { 89 | case strings.HasPrefix(s, "ws://"): 90 | u.Scheme = "ws" 91 | s = s[len("ws://"):] 92 | case strings.HasPrefix(s, "wss://"): 93 | u.Scheme = "wss" 94 | s = s[len("wss://"):] 95 | default: 96 | return nil, errMalformedURL 97 | } 98 | 99 | if i := strings.Index(s, "?"); i >= 0 { 100 | u.RawQuery = s[i+1:] 101 | s = s[:i] 102 | } 103 | 104 | if i := strings.Index(s, "/"); i >= 0 { 105 | u.Opaque = s[i:] 106 | s = s[:i] 107 | } else { 108 | u.Opaque = "/" 109 | } 110 | 111 | u.Host = s 112 | 113 | if strings.Contains(u.Host, "@") { 114 | // Don't bother parsing user information because user information is 115 | // not allowed in websocket URIs. 116 | return nil, errMalformedURL 117 | } 118 | 119 | return &u, nil 120 | } 121 | 122 | func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { 123 | hostPort = u.Host 124 | hostNoPort = u.Host 125 | if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { 126 | hostNoPort = hostNoPort[:i] 127 | } else { 128 | switch u.Scheme { 129 | case "wss": 130 | hostPort += ":443" 131 | case "https": 132 | hostPort += ":443" 133 | default: 134 | hostPort += ":80" 135 | } 136 | } 137 | return hostPort, hostNoPort 138 | } 139 | 140 | // DefaultDialer is a dialer with all fields set to the default zero values. 141 | var DefaultDialer = &Dialer{ 142 | Proxy: http.ProxyFromEnvironment, 143 | } 144 | 145 | // Dial creates a new client connection. Use requestHeader to specify the 146 | // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). 147 | // Use the response.Header to get the selected subprotocol 148 | // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 149 | // 150 | // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 151 | // non-nil *http.Response so that callers can handle redirects, authentication, 152 | // etcetera. The response body may not contain the entire response and does not 153 | // need to be closed by the application. 154 | func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 155 | 156 | if d == nil { 157 | d = &Dialer{ 158 | Proxy: http.ProxyFromEnvironment, 159 | } 160 | } 161 | 162 | challengeKey, err := generateChallengeKey() 163 | if err != nil { 164 | return nil, nil, err 165 | } 166 | 167 | u, err := parseURL(urlStr) 168 | if err != nil { 169 | return nil, nil, err 170 | } 171 | 172 | switch u.Scheme { 173 | case "ws": 174 | u.Scheme = "http" 175 | case "wss": 176 | u.Scheme = "https" 177 | default: 178 | return nil, nil, errMalformedURL 179 | } 180 | 181 | if u.User != nil { 182 | // User name and password are not allowed in websocket URIs. 183 | return nil, nil, errMalformedURL 184 | } 185 | 186 | req := &http.Request{ 187 | Method: "GET", 188 | URL: u, 189 | Proto: "HTTP/1.1", 190 | ProtoMajor: 1, 191 | ProtoMinor: 1, 192 | Header: make(http.Header), 193 | Host: u.Host, 194 | } 195 | 196 | // Set the request headers using the capitalization for names and values in 197 | // RFC examples. Although the capitalization shouldn't matter, there are 198 | // servers that depend on it. The Header.Set method is not used because the 199 | // method canonicalizes the header names. 200 | req.Header["Upgrade"] = []string{"websocket"} 201 | req.Header["Connection"] = []string{"Upgrade"} 202 | req.Header["Sec-WebSocket-Key"] = []string{challengeKey} 203 | req.Header["Sec-WebSocket-Version"] = []string{"13"} 204 | if len(d.Subprotocols) > 0 { 205 | req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} 206 | } 207 | for k, vs := range requestHeader { 208 | switch { 209 | case k == "Host": 210 | if len(vs) > 0 { 211 | req.Host = vs[0] 212 | } 213 | case k == "Upgrade" || 214 | k == "Connection" || 215 | k == "Sec-Websocket-Key" || 216 | k == "Sec-Websocket-Version" || 217 | (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): 218 | return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) 219 | default: 220 | req.Header[k] = vs 221 | } 222 | } 223 | 224 | hostPort, hostNoPort := hostPortNoPort(u) 225 | 226 | var proxyURL *url.URL 227 | // Check wether the proxy method has been configured 228 | if d.Proxy != nil { 229 | proxyURL, err = d.Proxy(req) 230 | } 231 | if err != nil { 232 | return nil, nil, err 233 | } 234 | 235 | var targetHostPort string 236 | if proxyURL != nil { 237 | targetHostPort, _ = hostPortNoPort(proxyURL) 238 | } else { 239 | targetHostPort = hostPort 240 | } 241 | 242 | var deadline time.Time 243 | if d.HandshakeTimeout != 0 { 244 | deadline = time.Now().Add(d.HandshakeTimeout) 245 | } 246 | 247 | netDial := d.NetDial 248 | if netDial == nil { 249 | netDialer := &net.Dialer{Deadline: deadline} 250 | netDial = netDialer.Dial 251 | } 252 | 253 | netConn, err := netDial("tcp", targetHostPort) 254 | if err != nil { 255 | return nil, nil, err 256 | } 257 | 258 | defer func() { 259 | if netConn != nil { 260 | netConn.Close() 261 | } 262 | }() 263 | 264 | if err := netConn.SetDeadline(deadline); err != nil { 265 | return nil, nil, err 266 | } 267 | 268 | if proxyURL != nil { 269 | connectHeader := make(http.Header) 270 | if user := proxyURL.User; user != nil { 271 | proxyUser := user.Username() 272 | if proxyPassword, passwordSet := user.Password(); passwordSet { 273 | credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) 274 | connectHeader.Set("Proxy-Authorization", "Basic "+credential) 275 | } 276 | } 277 | connectReq := &http.Request{ 278 | Method: "CONNECT", 279 | URL: &url.URL{Opaque: hostPort}, 280 | Host: hostPort, 281 | Header: connectHeader, 282 | } 283 | 284 | connectReq.Write(netConn) 285 | 286 | // Read response. 287 | // Okay to use and discard buffered reader here, because 288 | // TLS server will not speak until spoken to. 289 | br := bufio.NewReader(netConn) 290 | resp, err := http.ReadResponse(br, connectReq) 291 | if err != nil { 292 | return nil, nil, err 293 | } 294 | if resp.StatusCode != 200 { 295 | f := strings.SplitN(resp.Status, " ", 2) 296 | return nil, nil, errors.New(f[1]) 297 | } 298 | } 299 | 300 | if u.Scheme == "https" { 301 | cfg := d.TLSClientConfig 302 | if cfg == nil { 303 | cfg = &tls.Config{ServerName: hostNoPort} 304 | } else if cfg.ServerName == "" { 305 | shallowCopy := *cfg 306 | cfg = &shallowCopy 307 | cfg.ServerName = hostNoPort 308 | } 309 | tlsConn := tls.Client(netConn, cfg) 310 | netConn = tlsConn 311 | if err := tlsConn.Handshake(); err != nil { 312 | return nil, nil, err 313 | } 314 | if !cfg.InsecureSkipVerify { 315 | if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { 316 | return nil, nil, err 317 | } 318 | } 319 | } 320 | 321 | conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) 322 | 323 | if err := req.Write(netConn); err != nil { 324 | return nil, nil, err 325 | } 326 | 327 | resp, err := http.ReadResponse(conn.br, req) 328 | if err != nil { 329 | return nil, nil, err 330 | } 331 | if resp.StatusCode != 101 || 332 | !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || 333 | !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || 334 | resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { 335 | // Before closing the network connection on return from this 336 | // function, slurp up some of the response to aid application 337 | // debugging. 338 | buf := make([]byte, 1024) 339 | n, _ := io.ReadFull(resp.Body, buf) 340 | resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) 341 | return nil, resp, ErrBadHandshake 342 | } 343 | 344 | resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) 345 | conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") 346 | 347 | netConn.SetDeadline(time.Time{}) 348 | netConn = nil // to avoid close in defer. 349 | return conn, resp, nil 350 | } 351 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Copyright 2016 Gerasimos Maropoulos. All rights reserved. 3 | 4 | package websocket 5 | 6 | import ( 7 | "net" 8 | "net/url" 9 | "strings" 10 | "time" 11 | 12 | "github.com/valyala/fasthttp" 13 | ) 14 | 15 | // HandshakeError describes an error with the handshake from the peer. 16 | type HandshakeError struct { 17 | message string 18 | } 19 | 20 | func (e HandshakeError) Error() string { return e.message } 21 | 22 | // Upgrader specifies parameters for upgrading an HTTP connection to a 23 | // WebSocket connection. 24 | type Upgrader struct { 25 | // HandshakeTimeout specifies the duration for the handshake to complete. 26 | HandshakeTimeout time.Duration 27 | 28 | // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer 29 | // size is zero, then a default value of 4096 is used. The I/O buffer sizes 30 | // do not limit the size of the messages that can be sent or received. 31 | ReadBufferSize, WriteBufferSize int 32 | 33 | // Subprotocols specifies the server's supported protocols in order of 34 | // preference. If this field is set, then the Upgrade method negotiates a 35 | // subprotocol by selecting the first match in this list with a protocol 36 | // requested by the client. 37 | Subprotocols []string 38 | 39 | // Error specifies the function for generating HTTP error responses. 40 | Error func(ctx *fasthttp.RequestCtx, status int, reason error) 41 | 42 | // CheckOrigin returns true if the request Origin header is acceptable. If 43 | // CheckOrigin is nil, the host in the Origin header must not be set or 44 | // must match the host of the request. 45 | CheckOrigin func(ctx *fasthttp.RequestCtx) bool 46 | 47 | //Receiver it's the receiver handler, acceps a *websocket.Conn 48 | Receiver func(*Conn) 49 | } 50 | 51 | // DontCheckOrigin set Upgrader.CheckOrigin to a function which always returns true 52 | // returns itself 53 | func (u *Upgrader) DontCheckOrigin() *Upgrader { 54 | u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool { 55 | return true 56 | } 57 | return u 58 | } 59 | 60 | func (u *Upgrader) returnError(ctx *fasthttp.RequestCtx, status int, reason string) error { 61 | err := HandshakeError{reason} 62 | if u.Error != nil { 63 | u.Error(ctx, status, err) 64 | } else { 65 | ctx.SetStatusCode(status) 66 | ctx.SetBodyString(reason) 67 | } 68 | return err 69 | } 70 | 71 | // checkSameOrigin returns true if the origin is not set or is equal to the request host. 72 | func checkSameOrigin(ctx *fasthttp.RequestCtx) bool { 73 | origin := string(ctx.Request.Header.Peek("Origin")) 74 | if len(origin) == 0 { 75 | return true 76 | } 77 | u, err := url.Parse(origin) 78 | if err != nil { 79 | return false 80 | } 81 | return u.Host == string(ctx.Host()) 82 | } 83 | 84 | func (u *Upgrader) selectSubprotocol(ctx *fasthttp.RequestCtx) string { 85 | responseHeader := ctx.Response.Header 86 | if u.Subprotocols != nil { 87 | clientProtocols := Subprotocols(ctx) 88 | for _, serverProtocol := range u.Subprotocols { 89 | for _, clientProtocol := range clientProtocols { 90 | if clientProtocol == serverProtocol { 91 | return clientProtocol 92 | } 93 | } 94 | } 95 | } else if responseHeader.Len() > 0 { 96 | return string(responseHeader.Peek("Sec-Websocket-Protocol")) 97 | } 98 | return "" 99 | } 100 | 101 | func (u *Upgrader) getSubprotocol(ctx *fasthttp.RequestCtx) (subprotocol string) { 102 | //first of all check if we have already that setted 103 | if h := string(ctx.Response.Header.Peek("Sec-Websocket-Protocol")); h != "" { 104 | subprotocol = h 105 | return 106 | } 107 | 108 | header := string(ctx.Request.Header.Peek("Sec-Websocket-Protocol")) 109 | if len(header) > 0 { 110 | protocols := strings.Split(header, ",") 111 | for i := range protocols { 112 | protocols[i] = strings.TrimSpace(protocols[i]) 113 | } 114 | 115 | if len(protocols) > 0 { 116 | subprotocol = checkSubprotocols(protocols, u.Subprotocols) 117 | if subprotocol != "" { 118 | ctx.Response.Header.Set("Sec-Websocket-Protocol", subprotocol) 119 | } 120 | } 121 | } 122 | 123 | return 124 | } 125 | 126 | func checkSubprotocols(reqProtocols []string, resProtocols []string) string { 127 | for _, resProtocol := range resProtocols { 128 | for _, reqProtocol := range reqProtocols { 129 | if reqProtocol == resProtocol { 130 | return reqProtocol 131 | } 132 | } 133 | } 134 | 135 | return "" 136 | } 137 | 138 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 139 | // 140 | // The responseHeader is included in the response to the client's upgrade 141 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the 142 | // application negotiated subprotocol (Sec-Websocket-Protocol). 143 | // 144 | // If the upgrade fails, then Upgrade replies to the client with an HTTP error 145 | // response. 146 | func (u *Upgrader) Upgrade(ctx *fasthttp.RequestCtx) error { 147 | if !ctx.IsGet() { 148 | return u.returnError(ctx, fasthttp.StatusMethodNotAllowed, "websocket: method not GET") 149 | } 150 | if string(ctx.Request.Header.Peek("Sec-Websocket-Version")) != "13" { 151 | return u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: version != 13") 152 | } 153 | 154 | if !ctx.Request.Header.ConnectionUpgrade() { 155 | return u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'") 156 | } 157 | 158 | if !tokenListContainsValue(string(ctx.Request.Header.Peek("Upgrade")), "websocket") { 159 | return u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'") 160 | } 161 | 162 | checkOrigin := u.CheckOrigin 163 | if checkOrigin == nil { 164 | checkOrigin = checkSameOrigin 165 | } 166 | if !checkOrigin(ctx) { 167 | return u.returnError(ctx, fasthttp.StatusForbidden, "websocket: origin not allowed") 168 | } 169 | 170 | challengeKey := string(ctx.Request.Header.Peek("Sec-Websocket-Key")) 171 | if challengeKey == "" { 172 | return u.returnError(ctx, fasthttp.StatusBadRequest, "websocket: key missing or blank") 173 | } 174 | 175 | //set the headers 176 | ctx.SetStatusCode(fasthttp.StatusSwitchingProtocols) 177 | ctx.Response.Header.Set("Upgrade", "websocket") 178 | ctx.Response.Header.Set("Connection", "Upgrade") 179 | ctx.Response.Header.Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) 180 | 181 | subprotocol := u.selectSubprotocol(ctx) 182 | h := &fasthttp.RequestHeader{} 183 | //copy request headers in order to have access inside the Conn after 184 | ctx.Request.Header.CopyTo(h) 185 | /* 186 | 187 | var ( 188 | netConn net.Conn 189 | br *bufio.Reader 190 | err error 191 | ) 192 | 193 | h, ok := w.(fasthttp.Hijacker) 194 | if !ok { 195 | return u.returnError(ctx, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") 196 | } 197 | var rw *bufio.ReadWriter 198 | netConn, rw, err = h.Hijack() 199 | if err != nil { 200 | return u.returnError(ctx, http.StatusInternalServerError, err.Error()) 201 | } 202 | br = rw.Reader 203 | 204 | if br.Buffered() > 0 { 205 | netConn.Close() 206 | return nil, errors.New("websocket: client sent data before handshake is complete") 207 | } 208 | c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) 209 | c.subprotocol = subprotocol 210 | 211 | p := c.writeBuf[:0] 212 | p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) 213 | p = append(p, computeAcceptKey(challengeKey)...) 214 | p = append(p, "\r\n"...) 215 | if c.subprotocol != "" { 216 | p = append(p, "Sec-Websocket-Protocol: "...) 217 | p = append(p, c.subprotocol...) 218 | p = append(p, "\r\n"...) 219 | } 220 | for k, vs := range responseHeader { 221 | if k == protocolHeader { 222 | continue 223 | } 224 | for _, v := range vs { 225 | p = append(p, k...) 226 | p = append(p, ": "...) 227 | for i := 0; i < len(v); i++ { 228 | b := v[i] 229 | if b <= 31 { 230 | // prevent response splitting. 231 | b = ' ' 232 | } 233 | p = append(p, b) 234 | } 235 | p = append(p, "\r\n"...) 236 | } 237 | } 238 | p = append(p, "\r\n"...) 239 | 240 | // Clear deadlines set by HTTP server. 241 | netConn.SetDeadline(time.Time{}) 242 | 243 | if u.HandshakeTimeout > 0 { 244 | netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) 245 | } 246 | if _, err = netConn.Write(p); err != nil { 247 | netConn.Close() 248 | return nil, err 249 | } 250 | if u.HandshakeTimeout > 0 { 251 | netConn.SetWriteDeadline(time.Time{}) 252 | } 253 | */ 254 | ctx.Hijack(func(conn net.Conn) { 255 | c := newConn(conn, true, u.ReadBufferSize, u.WriteBufferSize) 256 | c.SetHeaders(h) 257 | c.subprotocol = subprotocol 258 | u.Receiver(c) 259 | 260 | }) 261 | 262 | return nil 263 | } 264 | 265 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 266 | // 267 | // If the endpoint supports subprotocols, then the application is responsible 268 | // for negotiating the protocol used on the connection. Use the Subprotocols() 269 | // function to get the subprotocols requested by the client. Use the 270 | // Sec-Websocket-Protocol response header to specify the subprotocol selected 271 | // by the application. 272 | // 273 | // The responseHeader is included in the response to the client's upgrade 274 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the 275 | // negotiated subprotocol (Sec-Websocket-Protocol). 276 | // 277 | // The connection buffers IO to the underlying network connection. The 278 | // readBufSize and writeBufSize parameters specify the size of the buffers to 279 | // use. Messages can be larger than the buffers. 280 | // 281 | // If the request is not a valid WebSocket handshake, then Upgrade returns an 282 | // error of type HandshakeError. Applications should handle this error by 283 | // replying to the client with an HTTP error response. 284 | func Upgrade(ctx *fasthttp.RequestCtx, receiverHandler func(*Conn), readBufSize, writeBufSize int) error { 285 | u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize, Receiver: receiverHandler} 286 | u.Error = func(ctx *fasthttp.RequestCtx, status int, reason error) { 287 | // don't return errors to maintain backwards compatibility 288 | } 289 | u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool { 290 | // allow all connections by default 291 | return true 292 | } 293 | return u.Upgrade(ctx) 294 | } 295 | 296 | // Custom returns an Upgrader with customized options (readBufSize,writeBuf size int) 297 | // accepts 3 parameters 298 | // first parameter is the receiver, think it like a handler which accepts a *websocket.Conn (func *websocket.Conn) 299 | // second parameter is the readBufSize (int) 300 | // third parameter is the writeBufSize (int) 301 | func Custom(receiverHandler func(*Conn), readBufSize, writeBufSize int) Upgrader { 302 | u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize, Receiver: receiverHandler} 303 | u.Error = func(ctx *fasthttp.RequestCtx, status int, reason error) { 304 | // don't return errors to maintain backwards compatibility 305 | } 306 | u.CheckOrigin = func(ctx *fasthttp.RequestCtx) bool { 307 | // allow all connections by default 308 | return true 309 | } 310 | return u 311 | } 312 | 313 | // New returns an Upgrader with the default options 314 | // accepts one parameter 315 | // the receiver, think it like a handler which accepts a *websocket.Conn (func *websocket.Conn) 316 | func New(receiverHandler func(*Conn)) Upgrader { 317 | return Custom(receiverHandler, 4096, 4096) 318 | } 319 | 320 | // Subprotocols returns the subprotocols requested by the client in the 321 | // Sec-Websocket-Protocol header. 322 | func Subprotocols(ctx *fasthttp.RequestCtx) []string { 323 | 324 | h := strings.TrimSpace(string(ctx.Request.Header.Peek("Sec-Websocket-Protocol"))) 325 | if h == "" { 326 | return nil 327 | } 328 | protocols := strings.Split(h, ",") 329 | for i := range protocols { 330 | protocols[i] = strings.TrimSpace(protocols[i]) 331 | } 332 | return protocols 333 | } 334 | 335 | // IsWebSocketUpgrade returns true if the client requested upgrade to the 336 | // WebSocket protocol. 337 | func IsWebSocketUpgrade(ctx *fasthttp.RequestCtx) bool { 338 | return tokenListContainsValue(string(ctx.Request.Header.Peek("Connection")), "upgrade") && 339 | tokenListContainsValue(string(ctx.Request.Header.Peek("Upgrade")), "websocket") 340 | } 341 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package websocket 6 | 7 | import ( 8 | "bufio" 9 | "encoding/binary" 10 | "errors" 11 | "io" 12 | "io/ioutil" 13 | "math/rand" 14 | "net" 15 | "strconv" 16 | "time" 17 | 18 | "github.com/valyala/fasthttp" 19 | ) 20 | 21 | const ( 22 | maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask 23 | maxControlFramePayloadSize = 125 24 | finalBit = 1 << 7 25 | maskBit = 1 << 7 26 | writeWait = time.Second 27 | 28 | defaultReadBufferSize = 4096 29 | defaultWriteBufferSize = 4096 30 | 31 | continuationFrame = 0 32 | noFrame = -1 33 | ) 34 | 35 | // Close codes defined in RFC 6455, section 11.7. 36 | const ( 37 | CloseNormalClosure = 1000 38 | CloseGoingAway = 1001 39 | CloseProtocolError = 1002 40 | CloseUnsupportedData = 1003 41 | CloseNoStatusReceived = 1005 42 | CloseAbnormalClosure = 1006 43 | CloseInvalidFramePayloadData = 1007 44 | ClosePolicyViolation = 1008 45 | CloseMessageTooBig = 1009 46 | CloseMandatoryExtension = 1010 47 | CloseInternalServerErr = 1011 48 | CloseTLSHandshake = 1015 49 | ) 50 | 51 | // The message types are defined in RFC 6455, section 11.8. 52 | const ( 53 | // TextMessage denotes a text data message. The text message payload is 54 | // interpreted as UTF-8 encoded text data. 55 | TextMessage = 1 56 | 57 | // BinaryMessage denotes a binary data message. 58 | BinaryMessage = 2 59 | 60 | // CloseMessage denotes a close control message. The optional message 61 | // payload contains a numeric code and text. Use the FormatCloseMessage 62 | // function to format a close message payload. 63 | CloseMessage = 8 64 | 65 | // PingMessage denotes a ping control message. The optional message payload 66 | // is UTF-8 encoded text. 67 | PingMessage = 9 68 | 69 | // PongMessage denotes a ping control message. The optional message payload 70 | // is UTF-8 encoded text. 71 | PongMessage = 10 72 | ) 73 | 74 | // ErrCloseSent is returned when the application writes a message to the 75 | // connection after sending a close message. 76 | var ErrCloseSent = errors.New("websocket: close sent") 77 | 78 | // ErrReadLimit is returned when reading a message that is larger than the 79 | // read limit set for the connection. 80 | var ErrReadLimit = errors.New("websocket: read limit exceeded") 81 | 82 | // netError satisfies the net Error interface. 83 | type netError struct { 84 | msg string 85 | temporary bool 86 | timeout bool 87 | } 88 | 89 | func (e *netError) Error() string { return e.msg } 90 | func (e *netError) Temporary() bool { return e.temporary } 91 | func (e *netError) Timeout() bool { return e.timeout } 92 | 93 | // CloseError represents close frame. 94 | type CloseError struct { 95 | 96 | // Code is defined in RFC 6455, section 11.7. 97 | Code int 98 | 99 | // Text is the optional text payload. 100 | Text string 101 | } 102 | 103 | func (e *CloseError) Error() string { 104 | s := []byte("websocket: close ") 105 | s = strconv.AppendInt(s, int64(e.Code), 10) 106 | switch e.Code { 107 | case CloseNormalClosure: 108 | s = append(s, " (normal)"...) 109 | case CloseGoingAway: 110 | s = append(s, " (going away)"...) 111 | case CloseProtocolError: 112 | s = append(s, " (protocol error)"...) 113 | case CloseUnsupportedData: 114 | s = append(s, " (unsupported data)"...) 115 | case CloseNoStatusReceived: 116 | s = append(s, " (no status)"...) 117 | case CloseAbnormalClosure: 118 | s = append(s, " (abnormal closure)"...) 119 | case CloseInvalidFramePayloadData: 120 | s = append(s, " (invalid payload data)"...) 121 | case ClosePolicyViolation: 122 | s = append(s, " (policy violation)"...) 123 | case CloseMessageTooBig: 124 | s = append(s, " (message too big)"...) 125 | case CloseMandatoryExtension: 126 | s = append(s, " (mandatory extension missing)"...) 127 | case CloseInternalServerErr: 128 | s = append(s, " (internal server error)"...) 129 | case CloseTLSHandshake: 130 | s = append(s, " (TLS handshake error)"...) 131 | } 132 | if e.Text != "" { 133 | s = append(s, ": "...) 134 | s = append(s, e.Text...) 135 | } 136 | return string(s) 137 | } 138 | 139 | // IsCloseError returns boolean indicating whether the error is a *CloseError 140 | // with one of the specified codes. 141 | func IsCloseError(err error, codes ...int) bool { 142 | if e, ok := err.(*CloseError); ok { 143 | for _, code := range codes { 144 | if e.Code == code { 145 | return true 146 | } 147 | } 148 | } 149 | return false 150 | } 151 | 152 | // IsUnexpectedCloseError returns boolean indicating whether the error is a 153 | // *CloseError with a code not in the list of expected codes. 154 | func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { 155 | if e, ok := err.(*CloseError); ok { 156 | for _, code := range expectedCodes { 157 | if e.Code == code { 158 | return false 159 | } 160 | } 161 | return true 162 | } 163 | return false 164 | } 165 | 166 | var ( 167 | errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} 168 | errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} 169 | errBadWriteOpCode = errors.New("websocket: bad write message type") 170 | errWriteClosed = errors.New("websocket: write closed") 171 | errInvalidControlFrame = errors.New("websocket: invalid control frame") 172 | ) 173 | 174 | func hideTempErr(err error) error { 175 | if e, ok := err.(net.Error); ok && e.Temporary() { 176 | err = &netError{msg: e.Error(), timeout: e.Timeout()} 177 | } 178 | return err 179 | } 180 | 181 | func isControl(frameType int) bool { 182 | return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage 183 | } 184 | 185 | func isData(frameType int) bool { 186 | return frameType == TextMessage || frameType == BinaryMessage 187 | } 188 | 189 | func maskBytes(key [4]byte, pos int, b []byte) int { 190 | for i := range b { 191 | b[i] ^= key[pos&3] 192 | pos++ 193 | } 194 | return pos & 3 195 | } 196 | 197 | func newMaskKey() [4]byte { 198 | n := rand.Uint32() 199 | return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} 200 | } 201 | 202 | // Conn represents a WebSocket connection. 203 | type Conn struct { 204 | conn net.Conn 205 | isServer bool 206 | subprotocol string 207 | 208 | // Write fields 209 | mu chan bool // used as mutex to protect write to conn and closeSent 210 | closeSent bool // true if close message was sent 211 | 212 | // Message writer fields. 213 | writeErr error 214 | writeBuf []byte // frame is constructed in this buffer. 215 | writePos int // end of data in writeBuf. 216 | writeFrameType int // type of the current frame. 217 | writeSeq int // incremented to invalidate message writers. 218 | writeDeadline time.Time 219 | isWriting bool // for best-effort concurrent write detection 220 | 221 | // Read fields 222 | readErr error 223 | br *bufio.Reader 224 | readRemaining int64 // bytes remaining in current frame. 225 | readFinal bool // true the current message has more frames. 226 | readSeq int // incremented to invalidate message readers. 227 | readLength int64 // Message size. 228 | readLimit int64 // Maximum message size. 229 | readMaskPos int 230 | readMaskKey [4]byte 231 | handlePong func(string) error 232 | handlePing func(string) error 233 | readErrCount int 234 | 235 | //request headers 236 | headers *fasthttp.RequestHeader 237 | } 238 | 239 | func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { 240 | mu := make(chan bool, 1) 241 | mu <- true 242 | 243 | if readBufferSize == 0 { 244 | readBufferSize = defaultReadBufferSize 245 | } 246 | if writeBufferSize == 0 { 247 | writeBufferSize = defaultWriteBufferSize 248 | } 249 | 250 | c := &Conn{ 251 | isServer: isServer, 252 | br: bufio.NewReaderSize(conn, readBufferSize), 253 | conn: conn, 254 | mu: mu, 255 | readFinal: true, 256 | writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), 257 | writeFrameType: noFrame, 258 | writePos: maxFrameHeaderSize, 259 | } 260 | c.SetPingHandler(nil) 261 | c.SetPongHandler(nil) 262 | return c 263 | } 264 | 265 | // Subprotocol returns the negotiated protocol for the connection. 266 | func (c *Conn) Subprotocol() string { 267 | return c.subprotocol 268 | } 269 | 270 | // Close closes the underlying network connection without sending or waiting for a close frame. 271 | func (c *Conn) Close() error { 272 | return c.conn.Close() 273 | } 274 | 275 | // LocalAddr returns the local network address. 276 | func (c *Conn) LocalAddr() net.Addr { 277 | return c.conn.LocalAddr() 278 | } 279 | 280 | // RemoteAddr returns the remote network address. 281 | func (c *Conn) RemoteAddr() net.Addr { 282 | return c.conn.RemoteAddr() 283 | } 284 | 285 | // Write methods 286 | 287 | func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { 288 | <-c.mu 289 | defer func() { c.mu <- true }() 290 | 291 | if c.closeSent { 292 | return ErrCloseSent 293 | } else if frameType == CloseMessage { 294 | c.closeSent = true 295 | } 296 | 297 | c.conn.SetWriteDeadline(deadline) 298 | for _, buf := range bufs { 299 | if len(buf) > 0 { 300 | n, err := c.conn.Write(buf) 301 | if n != len(buf) { 302 | // Close on partial write. 303 | c.conn.Close() 304 | } 305 | if err != nil { 306 | return err 307 | } 308 | } 309 | } 310 | return nil 311 | } 312 | 313 | // WriteControl writes a control message with the given deadline. The allowed 314 | // message types are CloseMessage, PingMessage and PongMessage. 315 | func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { 316 | if !isControl(messageType) { 317 | return errBadWriteOpCode 318 | } 319 | if len(data) > maxControlFramePayloadSize { 320 | return errInvalidControlFrame 321 | } 322 | 323 | b0 := byte(messageType) | finalBit 324 | b1 := byte(len(data)) 325 | if !c.isServer { 326 | b1 |= maskBit 327 | } 328 | 329 | buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) 330 | buf = append(buf, b0, b1) 331 | 332 | if c.isServer { 333 | buf = append(buf, data...) 334 | } else { 335 | key := newMaskKey() 336 | buf = append(buf, key[:]...) 337 | buf = append(buf, data...) 338 | maskBytes(key, 0, buf[6:]) 339 | } 340 | 341 | d := time.Hour * 1000 342 | if !deadline.IsZero() { 343 | d = deadline.Sub(time.Now()) 344 | if d < 0 { 345 | return errWriteTimeout 346 | } 347 | } 348 | 349 | timer := time.NewTimer(d) 350 | select { 351 | case <-c.mu: 352 | timer.Stop() 353 | case <-timer.C: 354 | return errWriteTimeout 355 | } 356 | defer func() { c.mu <- true }() 357 | 358 | if c.closeSent { 359 | return ErrCloseSent 360 | } else if messageType == CloseMessage { 361 | c.closeSent = true 362 | } 363 | 364 | c.conn.SetWriteDeadline(deadline) 365 | n, err := c.conn.Write(buf) 366 | if n != 0 && n != len(buf) { 367 | c.conn.Close() 368 | } 369 | return hideTempErr(err) 370 | } 371 | 372 | // NextWriter returns a writer for the next message to send. The writer's 373 | // Close method flushes the complete message to the network. 374 | // 375 | // There can be at most one open writer on a connection. NextWriter closes the 376 | // previous writer if the application has not already done so. 377 | func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { 378 | if c.writeErr != nil { 379 | return nil, c.writeErr 380 | } 381 | 382 | if c.writeFrameType != noFrame { 383 | if err := c.flushFrame(true, nil); err != nil { 384 | return nil, err 385 | } 386 | } 387 | 388 | if !isControl(messageType) && !isData(messageType) { 389 | return nil, errBadWriteOpCode 390 | } 391 | 392 | c.writeFrameType = messageType 393 | return messageWriter{c, c.writeSeq}, nil 394 | } 395 | 396 | func (c *Conn) flushFrame(final bool, extra []byte) error { 397 | length := c.writePos - maxFrameHeaderSize + len(extra) 398 | 399 | // Check for invalid control frames. 400 | if isControl(c.writeFrameType) && 401 | (!final || length > maxControlFramePayloadSize) { 402 | c.writeSeq++ 403 | c.writeFrameType = noFrame 404 | c.writePos = maxFrameHeaderSize 405 | return errInvalidControlFrame 406 | } 407 | 408 | b0 := byte(c.writeFrameType) 409 | if final { 410 | b0 |= finalBit 411 | } 412 | b1 := byte(0) 413 | if !c.isServer { 414 | b1 |= maskBit 415 | } 416 | 417 | // Assume that the frame starts at beginning of c.writeBuf. 418 | framePos := 0 419 | if c.isServer { 420 | // Adjust up if mask not included in the header. 421 | framePos = 4 422 | } 423 | 424 | switch { 425 | case length >= 65536: 426 | c.writeBuf[framePos] = b0 427 | c.writeBuf[framePos+1] = b1 | 127 428 | binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) 429 | case length > 125: 430 | framePos += 6 431 | c.writeBuf[framePos] = b0 432 | c.writeBuf[framePos+1] = b1 | 126 433 | binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) 434 | default: 435 | framePos += 8 436 | c.writeBuf[framePos] = b0 437 | c.writeBuf[framePos+1] = b1 | byte(length) 438 | } 439 | 440 | if !c.isServer { 441 | key := newMaskKey() 442 | copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) 443 | maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) 444 | if len(extra) > 0 { 445 | c.writeErr = errors.New("websocket: internal error, extra used in client mode") 446 | return c.writeErr 447 | } 448 | } 449 | 450 | // Write the buffers to the connection with best-effort detection of 451 | // concurrent writes. See the concurrency section in the package 452 | // documentation for more info. 453 | 454 | if c.isWriting { 455 | panic("concurrent write to websocket connection") 456 | } 457 | c.isWriting = true 458 | 459 | c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) 460 | 461 | if !c.isWriting { 462 | panic("concurrent write to websocket connection") 463 | } 464 | c.isWriting = false 465 | 466 | // Setup for next frame. 467 | c.writePos = maxFrameHeaderSize 468 | c.writeFrameType = continuationFrame 469 | if final { 470 | c.writeSeq++ 471 | c.writeFrameType = noFrame 472 | } 473 | return c.writeErr 474 | } 475 | 476 | type messageWriter struct { 477 | c *Conn 478 | seq int 479 | } 480 | 481 | func (w messageWriter) err() error { 482 | c := w.c 483 | if c.writeSeq != w.seq { 484 | return errWriteClosed 485 | } 486 | if c.writeErr != nil { 487 | return c.writeErr 488 | } 489 | return nil 490 | } 491 | 492 | func (w messageWriter) ncopy(max int) (int, error) { 493 | n := len(w.c.writeBuf) - w.c.writePos 494 | if n <= 0 { 495 | if err := w.c.flushFrame(false, nil); err != nil { 496 | return 0, err 497 | } 498 | n = len(w.c.writeBuf) - w.c.writePos 499 | } 500 | if n > max { 501 | n = max 502 | } 503 | return n, nil 504 | } 505 | 506 | func (w messageWriter) write(final bool, p []byte) (int, error) { 507 | if err := w.err(); err != nil { 508 | return 0, err 509 | } 510 | 511 | if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { 512 | // Don't buffer large messages. 513 | err := w.c.flushFrame(final, p) 514 | if err != nil { 515 | return 0, err 516 | } 517 | return len(p), nil 518 | } 519 | 520 | nn := len(p) 521 | for len(p) > 0 { 522 | n, err := w.ncopy(len(p)) 523 | if err != nil { 524 | return 0, err 525 | } 526 | copy(w.c.writeBuf[w.c.writePos:], p[:n]) 527 | w.c.writePos += n 528 | p = p[n:] 529 | } 530 | return nn, nil 531 | } 532 | 533 | func (w messageWriter) Write(p []byte) (int, error) { 534 | return w.write(false, p) 535 | } 536 | 537 | func (w messageWriter) WriteString(p string) (int, error) { 538 | if err := w.err(); err != nil { 539 | return 0, err 540 | } 541 | 542 | nn := len(p) 543 | for len(p) > 0 { 544 | n, err := w.ncopy(len(p)) 545 | if err != nil { 546 | return 0, err 547 | } 548 | copy(w.c.writeBuf[w.c.writePos:], p[:n]) 549 | w.c.writePos += n 550 | p = p[n:] 551 | } 552 | return nn, nil 553 | } 554 | 555 | func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { 556 | if err := w.err(); err != nil { 557 | return 0, err 558 | } 559 | for { 560 | if w.c.writePos == len(w.c.writeBuf) { 561 | err = w.c.flushFrame(false, nil) 562 | if err != nil { 563 | break 564 | } 565 | } 566 | var n int 567 | n, err = r.Read(w.c.writeBuf[w.c.writePos:]) 568 | w.c.writePos += n 569 | nn += int64(n) 570 | if err != nil { 571 | if err == io.EOF { 572 | err = nil 573 | } 574 | break 575 | } 576 | } 577 | return nn, err 578 | } 579 | 580 | func (w messageWriter) Close() error { 581 | if err := w.err(); err != nil { 582 | return err 583 | } 584 | return w.c.flushFrame(true, nil) 585 | } 586 | 587 | // WriteMessage is a helper method for getting a writer using NextWriter, 588 | // writing the message and closing the writer. 589 | func (c *Conn) WriteMessage(messageType int, data []byte) error { 590 | wr, err := c.NextWriter(messageType) 591 | if err != nil { 592 | return err 593 | } 594 | w := wr.(messageWriter) 595 | if _, err := w.write(true, data); err != nil { 596 | return err 597 | } 598 | if c.writeSeq == w.seq { 599 | if err := c.flushFrame(true, nil); err != nil { 600 | return err 601 | } 602 | } 603 | return nil 604 | } 605 | 606 | // SetWriteDeadline sets the write deadline on the underlying network 607 | // connection. After a write has timed out, the websocket state is corrupt and 608 | // all future writes will return an error. A zero value for t means writes will 609 | // not time out. 610 | func (c *Conn) SetWriteDeadline(t time.Time) error { 611 | c.writeDeadline = t 612 | return nil 613 | } 614 | 615 | // Read methods 616 | 617 | // readFull is like io.ReadFull except that io.EOF is never returned. 618 | func (c *Conn) readFull(p []byte) (err error) { 619 | var n int 620 | for n < len(p) && err == nil { 621 | var nn int 622 | nn, err = c.br.Read(p[n:]) 623 | n += nn 624 | } 625 | if n == len(p) { 626 | err = nil 627 | } else if err == io.EOF { 628 | err = errUnexpectedEOF 629 | } 630 | return 631 | } 632 | 633 | func (c *Conn) advanceFrame() (int, error) { 634 | 635 | // 1. Skip remainder of previous frame. 636 | 637 | if c.readRemaining > 0 { 638 | if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { 639 | return noFrame, err 640 | } 641 | } 642 | 643 | // 2. Read and parse first two bytes of frame header. 644 | 645 | var b [8]byte 646 | if err := c.readFull(b[:2]); err != nil { 647 | return noFrame, err 648 | } 649 | 650 | final := b[0]&finalBit != 0 651 | frameType := int(b[0] & 0xf) 652 | reserved := int((b[0] >> 4) & 0x7) 653 | mask := b[1]&maskBit != 0 654 | c.readRemaining = int64(b[1] & 0x7f) 655 | 656 | if reserved != 0 { 657 | return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) 658 | } 659 | 660 | switch frameType { 661 | case CloseMessage, PingMessage, PongMessage: 662 | if c.readRemaining > maxControlFramePayloadSize { 663 | return noFrame, c.handleProtocolError("control frame length > 125") 664 | } 665 | if !final { 666 | return noFrame, c.handleProtocolError("control frame not final") 667 | } 668 | case TextMessage, BinaryMessage: 669 | if !c.readFinal { 670 | return noFrame, c.handleProtocolError("message start before final message frame") 671 | } 672 | c.readFinal = final 673 | case continuationFrame: 674 | if c.readFinal { 675 | return noFrame, c.handleProtocolError("continuation after final message frame") 676 | } 677 | c.readFinal = final 678 | default: 679 | return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) 680 | } 681 | 682 | // 3. Read and parse frame length. 683 | 684 | switch c.readRemaining { 685 | case 126: 686 | if err := c.readFull(b[:2]); err != nil { 687 | return noFrame, err 688 | } 689 | c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) 690 | case 127: 691 | if err := c.readFull(b[:8]); err != nil { 692 | return noFrame, err 693 | } 694 | c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) 695 | } 696 | 697 | // 4. Handle frame masking. 698 | 699 | if mask != c.isServer { 700 | return noFrame, c.handleProtocolError("incorrect mask flag") 701 | } 702 | 703 | if mask { 704 | c.readMaskPos = 0 705 | if err := c.readFull(c.readMaskKey[:]); err != nil { 706 | return noFrame, err 707 | } 708 | } 709 | 710 | // 5. For text and binary messages, enforce read limit and return. 711 | 712 | if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { 713 | 714 | c.readLength += c.readRemaining 715 | if c.readLimit > 0 && c.readLength > c.readLimit { 716 | c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) 717 | return noFrame, ErrReadLimit 718 | } 719 | 720 | return frameType, nil 721 | } 722 | 723 | // 6. Read control frame payload. 724 | 725 | var payload []byte 726 | if c.readRemaining > 0 { 727 | payload = make([]byte, c.readRemaining) 728 | c.readRemaining = 0 729 | if err := c.readFull(payload); err != nil { 730 | return noFrame, err 731 | } 732 | if c.isServer { 733 | maskBytes(c.readMaskKey, 0, payload) 734 | } 735 | } 736 | 737 | // 7. Process control frame payload. 738 | 739 | switch frameType { 740 | case PongMessage: 741 | if err := c.handlePong(string(payload)); err != nil { 742 | return noFrame, err 743 | } 744 | case PingMessage: 745 | if err := c.handlePing(string(payload)); err != nil { 746 | return noFrame, err 747 | } 748 | case CloseMessage: 749 | echoMessage := []byte{} 750 | closeCode := CloseNoStatusReceived 751 | closeText := "" 752 | if len(payload) >= 2 { 753 | echoMessage = payload[:2] 754 | closeCode = int(binary.BigEndian.Uint16(payload)) 755 | closeText = string(payload[2:]) 756 | } 757 | c.WriteControl(CloseMessage, echoMessage, time.Now().Add(writeWait)) 758 | return noFrame, &CloseError{Code: closeCode, Text: closeText} 759 | } 760 | 761 | return frameType, nil 762 | } 763 | 764 | func (c *Conn) handleProtocolError(message string) error { 765 | c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) 766 | return errors.New("websocket: " + message) 767 | } 768 | 769 | // NextReader returns the next data message received from the peer. The 770 | // returned messageType is either TextMessage or BinaryMessage. 771 | // 772 | // There can be at most one open reader on a connection. NextReader discards 773 | // the previous message if the application has not already consumed it. 774 | // 775 | // Applications must break out of the application's read loop when this method 776 | // returns a non-nil error value. Errors returned from this method are 777 | // permanent. Once this method returns a non-nil error, all subsequent calls to 778 | // this method return the same error. 779 | func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { 780 | 781 | c.readSeq++ 782 | c.readLength = 0 783 | 784 | for c.readErr == nil { 785 | frameType, err := c.advanceFrame() 786 | if err != nil { 787 | c.readErr = hideTempErr(err) 788 | break 789 | } 790 | if frameType == TextMessage || frameType == BinaryMessage { 791 | return frameType, messageReader{c, c.readSeq}, nil 792 | } 793 | } 794 | 795 | // Applications that do handle the error returned from this method spin in 796 | // tight loop on connection failure. To help application developers detect 797 | // this error, panic on repeated reads to the failed connection. 798 | c.readErrCount++ 799 | if c.readErrCount >= 1000 { 800 | panic("repeated read on failed websocket connection") 801 | } 802 | 803 | return noFrame, nil, c.readErr 804 | } 805 | 806 | type messageReader struct { 807 | c *Conn 808 | seq int 809 | } 810 | 811 | func (r messageReader) Read(b []byte) (int, error) { 812 | 813 | if r.seq != r.c.readSeq { 814 | return 0, io.EOF 815 | } 816 | 817 | for r.c.readErr == nil { 818 | 819 | if r.c.readRemaining > 0 { 820 | if int64(len(b)) > r.c.readRemaining { 821 | b = b[:r.c.readRemaining] 822 | } 823 | n, err := r.c.br.Read(b) 824 | r.c.readErr = hideTempErr(err) 825 | if r.c.isServer { 826 | r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) 827 | } 828 | r.c.readRemaining -= int64(n) 829 | return n, r.c.readErr 830 | } 831 | 832 | if r.c.readFinal { 833 | r.c.readSeq++ 834 | return 0, io.EOF 835 | } 836 | 837 | frameType, err := r.c.advanceFrame() 838 | switch { 839 | case err != nil: 840 | r.c.readErr = hideTempErr(err) 841 | case frameType == TextMessage || frameType == BinaryMessage: 842 | r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") 843 | } 844 | } 845 | 846 | err := r.c.readErr 847 | if err == io.EOF && r.seq == r.c.readSeq { 848 | err = errUnexpectedEOF 849 | } 850 | return 0, err 851 | } 852 | 853 | // ReadMessage is a helper method for getting a reader using NextReader and 854 | // reading from that reader to a buffer. 855 | func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { 856 | var r io.Reader 857 | messageType, r, err = c.NextReader() 858 | if err != nil { 859 | return messageType, nil, err 860 | } 861 | p, err = ioutil.ReadAll(r) 862 | return messageType, p, err 863 | } 864 | 865 | // SetReadDeadline sets the read deadline on the underlying network connection. 866 | // After a read has timed out, the websocket connection state is corrupt and 867 | // all future reads will return an error. A zero value for t means reads will 868 | // not time out. 869 | func (c *Conn) SetReadDeadline(t time.Time) error { 870 | return c.conn.SetReadDeadline(t) 871 | } 872 | 873 | // SetReadLimit sets the maximum size for a message read from the peer. If a 874 | // message exceeds the limit, the connection sends a close frame to the peer 875 | // and returns ErrReadLimit to the application. 876 | func (c *Conn) SetReadLimit(limit int64) { 877 | c.readLimit = limit 878 | } 879 | 880 | // SetPingHandler sets the handler for ping messages received from the peer. 881 | // The appData argument to h is the PING frame application data. The default 882 | // ping handler sends a pong to the peer. 883 | func (c *Conn) SetPingHandler(h func(appData string) error) { 884 | if h == nil { 885 | h = func(message string) error { 886 | err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) 887 | if err == ErrCloseSent { 888 | return nil 889 | } else if e, ok := err.(net.Error); ok && e.Temporary() { 890 | return nil 891 | } 892 | return err 893 | } 894 | } 895 | c.handlePing = h 896 | } 897 | 898 | // SetPongHandler sets the handler for pong messages received from the peer. 899 | // The appData argument to h is the PONG frame application data. The default 900 | // pong handler does nothing. 901 | func (c *Conn) SetPongHandler(h func(appData string) error) { 902 | if h == nil { 903 | h = func(string) error { return nil } 904 | } 905 | c.handlePong = h 906 | } 907 | 908 | // UnderlyingConn returns the internal net.Conn. This can be used to further 909 | // modifications to connection specific flags. 910 | func (c *Conn) UnderlyingConn() net.Conn { 911 | return c.conn 912 | } 913 | 914 | // FormatCloseMessage formats closeCode and text as a WebSocket close message. 915 | func FormatCloseMessage(closeCode int, text string) []byte { 916 | buf := make([]byte, 2+len(text)) 917 | binary.BigEndian.PutUint16(buf, uint16(closeCode)) 918 | copy(buf[2:], text) 919 | return buf 920 | } 921 | 922 | // SetHeaders sets request headers 923 | func (c *Conn) SetHeaders(h *fasthttp.RequestHeader) { 924 | c.headers = h 925 | } 926 | 927 | // Header returns header by key 928 | func (c *Conn) Header(key string) (value string) { 929 | return string(c.headers.Peek(key)) 930 | } 931 | 932 | // Headers returns the RequestHeader struct 933 | func (c *Conn) Headers() *fasthttp.RequestHeader { 934 | return c.headers 935 | } 936 | --------------------------------------------------------------------------------