├── .travis.yml ├── LICENSE.md ├── README.md ├── websocketproxy.go └── websocketproxy_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 1.8 3 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Koding, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WebsocketProxy [![GoDoc](https://godoc.org/github.com/koding/websocketproxy?status.svg)](https://godoc.org/github.com/koding/websocketproxy) [![Build Status](https://travis-ci.org/koding/websocketproxy.svg)](https://travis-ci.org/koding/websocketproxy) 2 | 3 | WebsocketProxy is an http.Handler interface build on top of 4 | [gorilla/websocket](https://github.com/gorilla/websocket) that you can plug 5 | into your existing Go webserver to provide WebSocket reverse proxy. 6 | 7 | ## Install 8 | 9 | ```bash 10 | go get github.com/koding/websocketproxy 11 | ``` 12 | 13 | ## Example 14 | 15 | Below is a simple server that proxies to the given backend URL 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "flag" 22 | "net/http" 23 | "net/url" 24 | 25 | "github.com/koding/websocketproxy" 26 | ) 27 | 28 | var ( 29 | flagBackend = flag.String("backend", "", "Backend URL for proxying") 30 | ) 31 | 32 | func main() { 33 | u, err := url.Parse(*flagBackend) 34 | if err != nil { 35 | log.Fatalln(err) 36 | } 37 | 38 | err = http.ListenAndServe(":80", websocketproxy.NewProxy(u)) 39 | if err != nil { 40 | log.Fatalln(err) 41 | } 42 | } 43 | ``` 44 | 45 | Save it as `proxy.go` and run as: 46 | 47 | ```bash 48 | go run proxy.go -backend ws://example.com:3000 49 | ``` 50 | 51 | Now all incoming WebSocket requests coming to this server will be proxied to 52 | `ws://example.com:3000` 53 | 54 | 55 | -------------------------------------------------------------------------------- /websocketproxy.go: -------------------------------------------------------------------------------- 1 | // Package websocketproxy is a reverse proxy for WebSocket connections. 2 | package websocketproxy 3 | 4 | import ( 5 | "fmt" 6 | "io" 7 | "log" 8 | "net" 9 | "net/http" 10 | "net/url" 11 | "strings" 12 | 13 | "github.com/gorilla/websocket" 14 | ) 15 | 16 | var ( 17 | // DefaultUpgrader specifies the parameters for upgrading an HTTP 18 | // connection to a WebSocket connection. 19 | DefaultUpgrader = &websocket.Upgrader{ 20 | ReadBufferSize: 1024, 21 | WriteBufferSize: 1024, 22 | } 23 | 24 | // DefaultDialer is a dialer with all fields set to the default zero values. 25 | DefaultDialer = websocket.DefaultDialer 26 | ) 27 | 28 | // WebsocketProxy is an HTTP Handler that takes an incoming WebSocket 29 | // connection and proxies it to another server. 30 | type WebsocketProxy struct { 31 | // Director, if non-nil, is a function that may copy additional request 32 | // headers from the incoming WebSocket connection into the output headers 33 | // which will be forwarded to another server. 34 | Director func(incoming *http.Request, out http.Header) 35 | 36 | // Backend returns the backend URL which the proxy uses to reverse proxy 37 | // the incoming WebSocket connection. Request is the initial incoming and 38 | // unmodified request. 39 | Backend func(*http.Request) *url.URL 40 | 41 | // Upgrader specifies the parameters for upgrading a incoming HTTP 42 | // connection to a WebSocket connection. If nil, DefaultUpgrader is used. 43 | Upgrader *websocket.Upgrader 44 | 45 | // Dialer contains options for connecting to the backend WebSocket server. 46 | // If nil, DefaultDialer is used. 47 | Dialer *websocket.Dialer 48 | } 49 | 50 | // ProxyHandler returns a new http.Handler interface that reverse proxies the 51 | // request to the given target. 52 | func ProxyHandler(target *url.URL) http.Handler { return NewProxy(target) } 53 | 54 | // NewProxy returns a new Websocket reverse proxy that rewrites the 55 | // URL's to the scheme, host and base path provider in target. 56 | func NewProxy(target *url.URL) *WebsocketProxy { 57 | backend := func(r *http.Request) *url.URL { 58 | // Shallow copy 59 | u := *target 60 | u.Fragment = r.URL.Fragment 61 | u.Path = r.URL.Path 62 | u.RawQuery = r.URL.RawQuery 63 | return &u 64 | } 65 | return &WebsocketProxy{Backend: backend} 66 | } 67 | 68 | // ServeHTTP implements the http.Handler that proxies WebSocket connections. 69 | func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 70 | if w.Backend == nil { 71 | log.Println("websocketproxy: backend function is not defined") 72 | http.Error(rw, "internal server error (code: 1)", http.StatusInternalServerError) 73 | return 74 | } 75 | 76 | backendURL := w.Backend(req) 77 | if backendURL == nil { 78 | log.Println("websocketproxy: backend URL is nil") 79 | http.Error(rw, "internal server error (code: 2)", http.StatusInternalServerError) 80 | return 81 | } 82 | 83 | dialer := w.Dialer 84 | if w.Dialer == nil { 85 | dialer = DefaultDialer 86 | } 87 | 88 | // Pass headers from the incoming request to the dialer to forward them to 89 | // the final destinations. 90 | requestHeader := http.Header{} 91 | if origin := req.Header.Get("Origin"); origin != "" { 92 | requestHeader.Add("Origin", origin) 93 | } 94 | for _, prot := range req.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { 95 | requestHeader.Add("Sec-WebSocket-Protocol", prot) 96 | } 97 | for _, cookie := range req.Header[http.CanonicalHeaderKey("Cookie")] { 98 | requestHeader.Add("Cookie", cookie) 99 | } 100 | if req.Host != "" { 101 | requestHeader.Set("Host", req.Host) 102 | } 103 | 104 | // Pass X-Forwarded-For headers too, code below is a part of 105 | // httputil.ReverseProxy. See http://en.wikipedia.org/wiki/X-Forwarded-For 106 | // for more information 107 | // TODO: use RFC7239 http://tools.ietf.org/html/rfc7239 108 | if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { 109 | // If we aren't the first proxy retain prior 110 | // X-Forwarded-For information as a comma+space 111 | // separated list and fold multiple headers into one. 112 | if prior, ok := req.Header["X-Forwarded-For"]; ok { 113 | clientIP = strings.Join(prior, ", ") + ", " + clientIP 114 | } 115 | requestHeader.Set("X-Forwarded-For", clientIP) 116 | } 117 | 118 | // Set the originating protocol of the incoming HTTP request. The SSL might 119 | // be terminated on our site and because we doing proxy adding this would 120 | // be helpful for applications on the backend. 121 | requestHeader.Set("X-Forwarded-Proto", "http") 122 | if req.TLS != nil { 123 | requestHeader.Set("X-Forwarded-Proto", "https") 124 | } 125 | 126 | // Enable the director to copy any additional headers it desires for 127 | // forwarding to the remote server. 128 | if w.Director != nil { 129 | w.Director(req, requestHeader) 130 | } 131 | 132 | // Connect to the backend URL, also pass the headers we get from the requst 133 | // together with the Forwarded headers we prepared above. 134 | // TODO: support multiplexing on the same backend connection instead of 135 | // opening a new TCP connection time for each request. This should be 136 | // optional: 137 | // http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-01 138 | connBackend, resp, err := dialer.Dial(backendURL.String(), requestHeader) 139 | if err != nil { 140 | log.Printf("websocketproxy: couldn't dial to remote backend url %s", err) 141 | if resp != nil { 142 | // If the WebSocket handshake fails, ErrBadHandshake is returned 143 | // along with a non-nil *http.Response so that callers can handle 144 | // redirects, authentication, etcetera. 145 | if err := copyResponse(rw, resp); err != nil { 146 | log.Printf("websocketproxy: couldn't write response after failed remote backend handshake: %s", err) 147 | } 148 | } else { 149 | http.Error(rw, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) 150 | } 151 | return 152 | } 153 | defer connBackend.Close() 154 | 155 | upgrader := w.Upgrader 156 | if w.Upgrader == nil { 157 | upgrader = DefaultUpgrader 158 | } 159 | 160 | // Only pass those headers to the upgrader. 161 | upgradeHeader := http.Header{} 162 | if hdr := resp.Header.Get("Sec-Websocket-Protocol"); hdr != "" { 163 | upgradeHeader.Set("Sec-Websocket-Protocol", hdr) 164 | } 165 | if hdr := resp.Header.Get("Set-Cookie"); hdr != "" { 166 | upgradeHeader.Set("Set-Cookie", hdr) 167 | } 168 | 169 | // Now upgrade the existing incoming request to a WebSocket connection. 170 | // Also pass the header that we gathered from the Dial handshake. 171 | connPub, err := upgrader.Upgrade(rw, req, upgradeHeader) 172 | if err != nil { 173 | log.Printf("websocketproxy: couldn't upgrade %s", err) 174 | return 175 | } 176 | defer connPub.Close() 177 | 178 | errClient := make(chan error, 1) 179 | errBackend := make(chan error, 1) 180 | replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { 181 | for { 182 | msgType, msg, err := src.ReadMessage() 183 | if err != nil { 184 | m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) 185 | if e, ok := err.(*websocket.CloseError); ok { 186 | if e.Code != websocket.CloseNoStatusReceived { 187 | m = websocket.FormatCloseMessage(e.Code, e.Text) 188 | } 189 | } 190 | errc <- err 191 | dst.WriteMessage(websocket.CloseMessage, m) 192 | break 193 | } 194 | err = dst.WriteMessage(msgType, msg) 195 | if err != nil { 196 | errc <- err 197 | break 198 | } 199 | } 200 | } 201 | 202 | go replicateWebsocketConn(connPub, connBackend, errClient) 203 | go replicateWebsocketConn(connBackend, connPub, errBackend) 204 | 205 | var message string 206 | select { 207 | case err = <-errClient: 208 | message = "websocketproxy: Error when copying from backend to client: %v" 209 | case err = <-errBackend: 210 | message = "websocketproxy: Error when copying from client to backend: %v" 211 | 212 | } 213 | if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { 214 | log.Printf(message, err) 215 | } 216 | } 217 | 218 | func copyHeader(dst, src http.Header) { 219 | for k, vv := range src { 220 | for _, v := range vv { 221 | dst.Add(k, v) 222 | } 223 | } 224 | } 225 | 226 | func copyResponse(rw http.ResponseWriter, resp *http.Response) error { 227 | copyHeader(rw.Header(), resp.Header) 228 | rw.WriteHeader(resp.StatusCode) 229 | defer resp.Body.Close() 230 | 231 | _, err := io.Copy(rw, resp.Body) 232 | return err 233 | } 234 | -------------------------------------------------------------------------------- /websocketproxy_test.go: -------------------------------------------------------------------------------- 1 | package websocketproxy 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "net/url" 7 | "testing" 8 | "time" 9 | 10 | "github.com/gorilla/websocket" 11 | ) 12 | 13 | var ( 14 | serverURL = "ws://127.0.0.1:7777" 15 | backendURL = "ws://127.0.0.1:8888" 16 | ) 17 | 18 | func TestProxy(t *testing.T) { 19 | // websocket proxy 20 | supportedSubProtocols := []string{"test-protocol"} 21 | upgrader := &websocket.Upgrader{ 22 | ReadBufferSize: 4096, 23 | WriteBufferSize: 4096, 24 | CheckOrigin: func(r *http.Request) bool { 25 | return true 26 | }, 27 | Subprotocols: supportedSubProtocols, 28 | } 29 | 30 | u, _ := url.Parse(backendURL) 31 | proxy := NewProxy(u) 32 | proxy.Upgrader = upgrader 33 | 34 | mux := http.NewServeMux() 35 | mux.Handle("/proxy", proxy) 36 | go func() { 37 | if err := http.ListenAndServe(":7777", mux); err != nil { 38 | t.Fatal("ListenAndServe: ", err) 39 | } 40 | }() 41 | 42 | time.Sleep(time.Millisecond * 100) 43 | 44 | // backend echo server 45 | go func() { 46 | mux2 := http.NewServeMux() 47 | mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 48 | // Don't upgrade if original host header isn't preserved 49 | if r.Host != "127.0.0.1:7777" { 50 | log.Printf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", r.Host) 51 | return 52 | } 53 | 54 | conn, err := upgrader.Upgrade(w, r, nil) 55 | if err != nil { 56 | log.Println(err) 57 | return 58 | } 59 | 60 | messageType, p, err := conn.ReadMessage() 61 | if err != nil { 62 | return 63 | } 64 | 65 | if err = conn.WriteMessage(messageType, p); err != nil { 66 | return 67 | } 68 | }) 69 | 70 | err := http.ListenAndServe(":8888", mux2) 71 | if err != nil { 72 | t.Fatal("ListenAndServe: ", err) 73 | } 74 | }() 75 | 76 | time.Sleep(time.Millisecond * 100) 77 | 78 | // let's us define two subprotocols, only one is supported by the server 79 | clientSubProtocols := []string{"test-protocol", "test-notsupported"} 80 | h := http.Header{} 81 | for _, subprot := range clientSubProtocols { 82 | h.Add("Sec-WebSocket-Protocol", subprot) 83 | } 84 | 85 | // frontend server, dial now our proxy, which will reverse proxy our 86 | // message to the backend websocket server. 87 | conn, resp, err := websocket.DefaultDialer.Dial(serverURL+"/proxy", h) 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | 92 | // check if the server really accepted only the first one 93 | in := func(desired string) bool { 94 | for _, prot := range resp.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { 95 | if desired == prot { 96 | return true 97 | } 98 | } 99 | return false 100 | } 101 | 102 | if !in("test-protocol") { 103 | t.Error("test-protocol should be available") 104 | } 105 | 106 | if in("test-notsupported") { 107 | t.Error("test-notsupported should be not recevied from the server.") 108 | } 109 | 110 | // now write a message and send it to the backend server (which goes trough 111 | // proxy..) 112 | msg := "hello kite" 113 | err = conn.WriteMessage(websocket.TextMessage, []byte(msg)) 114 | if err != nil { 115 | t.Error(err) 116 | } 117 | 118 | messageType, p, err := conn.ReadMessage() 119 | if err != nil { 120 | t.Error(err) 121 | } 122 | 123 | if messageType != websocket.TextMessage { 124 | t.Error("incoming message type is not Text") 125 | } 126 | 127 | if msg != string(p) { 128 | t.Errorf("expecting: %s, got: %s", msg, string(p)) 129 | } 130 | } 131 | --------------------------------------------------------------------------------