├── LICENSE ├── README.md ├── testdata ├── gencerts.go ├── root_cert.crt ├── root_key.crt ├── server_cert.crt └── server_key.crt ├── wsutil.go └── wsutil_test.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Yhat, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 18 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 20 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 21 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 23 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wsutil 2 | 3 | Like `net/http/httputil` but for WebSockets. 4 | 5 | [![GoDoc](https://godoc.org/github.com/yhat/wsutil?status.svg)](https://godoc.org/github.com/yhat/wsutil) 6 | 7 | ## A Reverse Proxy Example 8 | 9 | ```go 10 | package main 11 | 12 | import ( 13 | "fmt" 14 | "io" 15 | "log" 16 | "net/http" 17 | "net/url" 18 | "time" 19 | 20 | "github.com/yhat/wsutil" 21 | "golang.org/x/net/websocket" 22 | ) 23 | 24 | func main() { 25 | backend := ":8001" 26 | proxy := ":8002" 27 | 28 | // an webscket echo server 29 | backendHandler := websocket.Handler(func(ws *websocket.Conn) { 30 | io.Copy(ws, ws) 31 | }) 32 | 33 | // make a proxy pointing at that backend url 34 | backendURL := &url.URL{Scheme: "ws://", Host: backend} 35 | p := wsutil.NewSingleHostReverseProxy(backendURL) 36 | 37 | // run both servers and give them a second to start up 38 | go http.ListenAndServe(backend, backendHandler) 39 | go http.ListenAndServe(proxy, p) 40 | time.Sleep(1 * time.Second) 41 | 42 | // connect to the proxy 43 | origin := "http://localhost/" 44 | ws, err := websocket.Dial("ws://"+proxy, "", origin) 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | 49 | // send a message along the websocket 50 | msg := []byte("isn't yhat awesome?") 51 | if _, err := ws.Write(msg); err != nil { 52 | log.Fatal(err) 53 | } 54 | 55 | // read the response from the proxy 56 | resp := make([]byte, 4096) 57 | if n, err := ws.Read(resp); err != nil { 58 | log.Fatal(err) 59 | } else { 60 | fmt.Printf("%s\n", resp[:n]) 61 | } 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /testdata/gencerts.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "crypto/x509/pkix" 8 | "encoding/pem" 9 | "errors" 10 | "fmt" 11 | "io/ioutil" 12 | "log" 13 | "math/big" 14 | "net" 15 | "time" 16 | ) 17 | 18 | type CertInfo struct { 19 | IsCA bool 20 | KeyUsage x509.KeyUsage 21 | ExtKeyUsage []x509.ExtKeyUsage 22 | IPAddresses []net.IP 23 | } 24 | 25 | // Generate a new certificate. If the certificate is self signed, parent and 26 | // parentKey should be nil. 27 | func NewCert(parent *x509.Certificate, parentKey *rsa.PrivateKey, info CertInfo) (cert, key []byte, err error) { 28 | priv, err := rsa.GenerateKey(rand.Reader, 2048) 29 | if err != nil { 30 | return nil, nil, errors.New("failed to generate key pair: " + err.Error()) 31 | } 32 | 33 | // this certs will be used forever 34 | notBefore := time.Now() 35 | notAfter := notBefore.Add(time.Hour * 876581) 36 | 37 | serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 38 | serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 39 | if err != nil { 40 | return nil, nil, errors.New("failed to generate serial number: " + err.Error()) 41 | } 42 | 43 | tmpl := x509.Certificate{ 44 | SerialNumber: serialNumber, 45 | Subject: pkix.Name{Organization: []string{"Yhat, Inc."}}, 46 | NotBefore: notBefore, 47 | NotAfter: notAfter, 48 | KeyUsage: info.KeyUsage, 49 | ExtKeyUsage: info.ExtKeyUsage, 50 | IsCA: info.IsCA, 51 | IPAddresses: info.IPAddresses, 52 | BasicConstraintsValid: true, 53 | } 54 | 55 | pub := priv.Public() 56 | var p *x509.Certificate 57 | var signingKey *rsa.PrivateKey 58 | if parent == nil { 59 | p = &tmpl 60 | signingKey = priv 61 | } else { 62 | p = parent 63 | signingKey = parentKey 64 | } 65 | 66 | certBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, p, pub, signingKey) 67 | if err != nil { 68 | return nil, nil, errors.New("failed to create certificate: " + err.Error()) 69 | } 70 | 71 | certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) 72 | 73 | block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)} 74 | keyPEM := pem.EncodeToMemory(block) 75 | 76 | return certPEM, keyPEM, nil 77 | } 78 | 79 | func DecodeKey(keyPEM []byte) (*rsa.PrivateKey, error) { 80 | block, _ := pem.Decode(keyPEM) 81 | if block == nil { 82 | return nil, errors.New("no private key found") 83 | } 84 | return x509.ParsePKCS1PrivateKey(block.Bytes) 85 | } 86 | 87 | func DecodeCert(certPEM []byte) (*x509.Certificate, error) { 88 | block, _ := pem.Decode(certPEM) 89 | if block == nil { 90 | return nil, errors.New("no cert found") 91 | } 92 | return x509.ParseCertificate(block.Bytes) 93 | } 94 | 95 | func genCerts() error { 96 | // create a root certificate 97 | rootInfo := CertInfo{ 98 | IsCA: true, 99 | KeyUsage: x509.KeyUsageCertSign, 100 | } 101 | 102 | rootCertPEM, rootKeyPEM, err := NewCert(nil, nil, rootInfo) 103 | if err != nil { 104 | return fmt.Errorf("failed to create cert: ", err) 105 | } 106 | 107 | rootKey, err := DecodeKey(rootKeyPEM) 108 | if err != nil { 109 | return fmt.Errorf("failed to parse private key: ", err) 110 | } 111 | rootCert, err := DecodeCert(rootCertPEM) 112 | if err != nil { 113 | return fmt.Errorf("failed to parse root cert: ", err) 114 | } 115 | 116 | // create the server certificate 117 | serverInfo := CertInfo{ 118 | IsCA: false, 119 | KeyUsage: x509.KeyUsageDigitalSignature, 120 | ExtKeyUsage: []x509.ExtKeyUsage{ 121 | x509.ExtKeyUsageServerAuth, 122 | }, 123 | IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, 124 | } 125 | 126 | serverCertPEM, serverKeyPEM, err := NewCert(rootCert, rootKey, serverInfo) 127 | if err != nil { 128 | return fmt.Errorf("failed to create cert: ", err) 129 | } 130 | 131 | certs := []struct { 132 | Name string 133 | Data []byte 134 | }{ 135 | {"root_cert.crt", rootCertPEM}, 136 | {"root_key.crt", rootKeyPEM}, 137 | {"server_cert.crt", serverCertPEM}, 138 | {"server_key.crt", serverKeyPEM}, 139 | } 140 | for _, cert := range certs { 141 | err = ioutil.WriteFile(cert.Name, cert.Data, 0600) 142 | if err != nil { 143 | return fmt.Errorf("error creating file %s %v", cert.Name, err) 144 | } 145 | } 146 | 147 | return nil 148 | } 149 | 150 | func main() { 151 | if err := genCerts(); err != nil { 152 | log.Fatal(err) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /testdata/root_cert.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC1TCCAb+gAwIBAgIQf5s3ykf8jM23qWBewTr9VDALBgkqhkiG9w0BAQswFTET 3 | MBEGA1UEChMKWWhhdCwgSW5jLjAgFw0xNTA2MTAxNTE2MzBaGA8yMTE1MDYxMDIw 4 | MTYzMFowFTETMBEGA1UEChMKWWhhdCwgSW5jLjCCASIwDQYJKoZIhvcNAQEBBQAD 5 | ggEPADCCAQoCggEBAJtKvre4kAVG7q5naD5Tov+z+yuvH/cBGAcT/SakeKBgACs/ 6 | pWdnuBfhjgLCbmWxS5+pHuNjj5mhbqVe4eIcUPg4kFmsisZr4ZQ3jEdL6TSxLYui 7 | XhxgxxZ2n47H9k3D5TkykU0TJ9txbLIMGG8Pz8VeDg1pRq9FVGYAIj3jLOU7z3Ee 8 | v4aeBMvOJe9enhde1pCfe2G/T+eksm4b2T7/HG+74te14dN+/Ik68wV/ZuaJRBCF 9 | iGws13YTxrgFLguxOUTyHJPyZz196i4mk55q0u4yqGO0hbrMx9XDz2JfjTjyBHZQ 10 | yx28ge5DbtXmEJemTIu7K/IXK5ccaBeO6RNQlbUCAwEAAaMjMCEwDgYDVR0PAQH/ 11 | BAQDAgAEMA8GA1UdEwEB/wQFMAMBAf8wCwYJKoZIhvcNAQELA4IBAQB3cbKUWXuN 12 | 8aggZ5wIGodKxDQuTE/2Kmg17tkvakQYHUIS2QCksOprK4uC4y8u6j3a3dsZhRpa 13 | vqtw0XWvYaiVZC0tphFDOj6hflLTh+l4yfLcPJ444zTaCW9U58CB+0TegFFBD348 14 | CxoVYLjbAA4Flsw+77LFuOFJoyWetzEnBtMXozUKWzYggMTO/ef0AGrBOffa2dWm 15 | FOb4uxjg8Yji1V5FTRObahWEKrYYMSKUbthpXuJecmVUjh+daHO8iQRFmOs+pU2K 16 | lJhtPxndm7bhrjh6vO5VMlNWy8UgfYLRS8uPVJhRvNQ2WPolU1cVyFEmXSiKLv7G 17 | pufHrroLKknf 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /testdata/root_key.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEoQIBAAKCAQEAm0q+t7iQBUburmdoPlOi/7P7K68f9wEYBxP9JqR4oGAAKz+l 3 | Z2e4F+GOAsJuZbFLn6ke42OPmaFupV7h4hxQ+DiQWayKxmvhlDeMR0vpNLEti6Je 4 | HGDHFnafjsf2TcPlOTKRTRMn23FssgwYbw/PxV4ODWlGr0VUZgAiPeMs5TvPcR6/ 5 | hp4Ey84l716eF17WkJ97Yb9P56SybhvZPv8cb7vi17Xh0378iTrzBX9m5olEEIWI 6 | bCzXdhPGuAUuC7E5RPIck/JnPX3qLiaTnmrS7jKoY7SFuszH1cPPYl+NOPIEdlDL 7 | HbyB7kNu1eYQl6ZMi7sr8hcrlxxoF47pE1CVtQIDAQABAoIBAQCHmBQ/RyA5+12d 8 | Wx1ijpTcwMKnuhonCwV36LQ6cZICNtDu8nvydlYARCpDrpHGqbBmkL0vv1c7xgFe 9 | 1CJX1HG+y2T/Z/EkoD4vtPD/sADIyYwDSJr0HGy56IoZmfAupsYX01eb/QXoTnmi 10 | XX1YPG4m36FVhH77u4WeKq+7aglhfiPJQ6x4KQ6a3EoeMx6eM1DuiBMVnBWOnGkl 11 | GWzj1nDghXMyRzKC3VILX/fT3AhDeO0vW2XM+/KBNmNt15nklCUVY7dvFBOZZpax 12 | 15A+LsYuvq2yERZ1c2W/7hNnYu5Vwo+3qBm5A/3t/ZXotK+cg2tcQhcPp8Y9lwQq 13 | TMMIDRBhAoGBAMf2Ctr0Q8hkzCo7eghvFuT7cyenKGvq2YXOKvtsWgH9jjNpMvY8 14 | co/q+yjg74I5g+j5pOsrVAWpW/Dg0cbuqq2xLOVMp9oNY7dM0mE/l7JVUXsZggV0 15 | 3JXQkT6Wz3k9xDGpH3F3sK5RLW26vwvWx+P5hgy+OM7wZECSKQIc8yJDAoGBAMbP 16 | +NisKuRzaeSJNXJT2IAD76g8y068Cs7+V1uZJxGKBYNgXS1TVg50JvAGxKnysncT 17 | QpT7+xJymNI0/ZzixQH/K40o9kqz8Jx9/GWF8loMx2y0dfYZA7jSwOJevECJHWhR 18 | nRHHaXSxHLcaCAYSfYcZyc+irVcUdhRnXeW5WBSnAoGAecd9uvVyZpQEx1+rpYFK 19 | dzAwZKDn/DluOpBiGvdVJcsvFF5oPBB6UO3yAmZjV3MBxBxt9Q9RP5VyOhQhjj8C 20 | UYAK8Kcrvp/S7+poYfOhxmkxk3/ocLxILzFzk6OzPYqBdyEh1i/nuXIU8bP+8A3h 21 | dRUdL5uV62n6FF0vfmr1JBMCf2ETK26golDqCcqNNIueZRgc0+hRxvOq3Zw0lHMl 22 | VO47dnWvl8+J4XstO9X3eA+DcaCyxs/4OZ/IVNZPCYaRM6DF9331gkz3j4TZ/2OT 23 | A8L0emuZsf94N1kHjyb5GvJoAPPu5cLIm7VexaaiD0jnvmM5NFEuHXVniEBuOGrz 24 | FQ8CgYA3BpxAxDQeMWriZMDQuQ5Fa3vK89yUWOtW0WSEsmy9tRQT6oy4tLkgfGSc 25 | UiHAbFPnqXflYegViMMXT7PS/+Cj3C6gla2Md0vyi29mcucK/eBwAl1nkTLOITUX 26 | vIx0zukEsfEBj3+nHwBYnzLAiWQ52rCgoSgn1Ye8mPNa4Gcj9g== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /testdata/server_cert.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC+DCCAeKgAwIBAgIQT3i9yIPOEk8MT+E0QzRQ5DALBgkqhkiG9w0BAQswFTET 3 | MBEGA1UEChMKWWhhdCwgSW5jLjAgFw0xNTA2MTAxNTE2MzJaGA8yMTE1MDYxMDIw 4 | MTYzMlowFTETMBEGA1UEChMKWWhhdCwgSW5jLjCCASIwDQYJKoZIhvcNAQEBBQAD 5 | ggEPADCCAQoCggEBAMHgUCKR5lt88LBCVHdbKUxCVrQsc2SbYFJiaafb/t4lxkEk 6 | PDFRbSxbLU//bOrHeix4NQjLmSWDLBkX3q97nNpLiQwZYFZDAnihHw2VGgjG3vye 7 | PVzcQMI4U4QozsItIyTszsc9oUKExTdM9iYA2EGXn6C9PJWGUJjQeBmk6OnXVsM+ 8 | B9+3czl1Xf78ZV0Lb6MExiCZWSRYCdizHx0M1IBczrxfm8Kg3uOl3jcFpYdhtIfw 9 | nBvm7y+qBLxF8a5ZQH/ztPC5FXyYHGre7QvM4nztnBmFfOLPi9yiCmb2B5GIbLa6 10 | IkcaZf0kSQJL2y5c5S2a4LsF8E11DvCt7njmEwcCAwEAAaNGMEQwDgYDVR0PAQH/ 11 | BAQDAgCAMBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwDwYDVR0R 12 | BAgwBocEfwAAATALBgkqhkiG9w0BAQsDggEBACE0rJ5tihu+D9DOSp7Z/ue/9VuF 13 | oaTVrPL9ZyJXYW9KSGhTlmMPC3tvgRevam4Gi7Arc9im4x1vg35/du8fY/tvw5sW 14 | 2yI8V3GmOLsf7brCGuMMA253BhPvmxCPQX2J9Zv/Ue+W8JeS7ciZz/TO+OLENdws 15 | 7hOCTLKozuydPCRWuoN8yOfW8WWJJ/Jw22GLKn4sUHKGOkdDOpCAN7mqsw/hpGbN 16 | 0w7LnohcLbmNz6jWLy+yJTOW0h4ny35R+NDAscNgdMuCuasyfdRT1BLP5bVdyv/r 17 | 9X4Y6w9wmrh7H5wN+Xy35EcfV+wZYsYfEViDJ1scOn7MSsqhTE9xY3lxn7w= 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /testdata/server_key.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEAweBQIpHmW3zwsEJUd1spTEJWtCxzZJtgUmJpp9v+3iXGQSQ8 3 | MVFtLFstT/9s6sd6LHg1CMuZJYMsGRfer3uc2kuJDBlgVkMCeKEfDZUaCMbe/J49 4 | XNxAwjhThCjOwi0jJOzOxz2hQoTFN0z2JgDYQZefoL08lYZQmNB4GaTo6ddWwz4H 5 | 37dzOXVd/vxlXQtvowTGIJlZJFgJ2LMfHQzUgFzOvF+bwqDe46XeNwWlh2G0h/Cc 6 | G+bvL6oEvEXxrllAf/O08LkVfJgcat7tC8zifO2cGYV84s+L3KIKZvYHkYhstroi 7 | Rxpl/SRJAkvbLlzlLZrguwXwTXUO8K3ueOYTBwIDAQABAoIBAF1NysMYXDhjZKIv 8 | Cd93K+TbeEa1rcMZU73SRu7V3U6j90maWq2Rdm0GZEQ/tPc4kP+dykg1U7rN6gcZ 9 | ib4CS3ZzK7166PYawbG0OPGcdC8NJnuE3Bs4lKHj4a2UxgyFFIjMvqb8bgNJSnBW 10 | xY98kJbglZ4R9HoxbdBdY69TwhPnCg1ap3vfIm8dlYg4rzxVCs3jHKoDiwip/tIK 11 | JErcfdlorlNud3gJMAvsUMTt8ptzizcDbsDXbMTHv4ErFvq4/X0kZdQ6iCvnek3i 12 | ttlhYGqC5KJuRCA0A1fh+e453+yb9WBTg+qpRwJqGk6R7MrlMfdjeC6ovDCQUW2y 13 | s0HUiYECgYEAwhu8OR2FpguWXl8XujmiD9TJs294wVcOOZ1ZvdK0GDtZkBKBnJem 14 | YPgudOnK/PMaxolB2Xzy70KDcjNrUe3I1keyjxZE0nsT1J4gMpnhL0WQPTom5WQs 15 | HxzUG68HX7nKEoLMV+nXAcR5x+DvH5MRqIZ4jmvQTNFyFHMGq+wjwdECgYEA/7Gh 16 | hVSyDwku/cJjKchfafnypsu/ku4i49dopyaP266S6YCgxgGT/mn8fJv4qE7g1Stm 17 | eZne4h7Jv225Jk/fVvtt1JWGPrumtbBDcjz/hcvXMU+dp2IhqLdR5RtQfir06avB 18 | mfBT50q9r7bvOlFJP9zB5hzxkLVHXqJfRgT8JVcCgYAUSdXYmm9XtapX3tSaEGAS 19 | C4mxiZsziifgecPhhV5xkfKAjo6hkXBAfnBMpAsleTt5OOt7EgZKX8dhbmJvQ81U 20 | KFZPgmJbJaYi+Qwgfdj9meXDoIpkO54o+lhpNFgu9zpZyPYW0kg41RJtg+M6h6K2 21 | 3KdJK5ewD8w+uu8dlSb/oQKBgQD8sJA+kvAROfMtpvCW90WMFw691feyfhMO9e+f 22 | 2NUstn7LsmmwpRibwiRbBU0dEC7TnDt+ixkggGrC+u2SNjcy3/GvuEFeN9bOEa7l 23 | 8/BWSpeVTOgx5iH7eYe+klrfBRba4vnGZyKUHmINiA0tpe1s5n7dKdd1OiGZHYBo 24 | Uz7YqwKBgAt8bqOtxftVSI7C71h1xdc9okGxamiGj5l06REgq2NrNZ3qtEKTrSgV 25 | zkfmhrZ3z6DNXbBKWAZVrD0McdWviRDnS+vtXtMhSvUXVWIvZLdCNLkNlej2PR1G 26 | N65Cb2PvlOZrWZreAmSqRtOz2vYM7PfK6My3nr3FkvzDVBjYCYWM 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /wsutil.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "crypto/tls" 5 | "io" 6 | "log" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | ) 12 | 13 | // ReverseProxy is a WebSocket reverse proxy. It will not work with a regular 14 | // HTTP request, so it is the caller's responsiblity to ensure the incoming 15 | // request is a WebSocket request. 16 | type ReverseProxy struct { 17 | // Director must be a function which modifies 18 | // the request into a new request to be sent 19 | // using Transport. Its response is then copied 20 | // back to the original client unmodified. 21 | Director func(*http.Request) 22 | 23 | // Dial specifies the dial function for dialing the proxied 24 | // server over tcp. 25 | // If Dial is nil, net.Dial is used. 26 | Dial func(network, addr string) (net.Conn, error) 27 | 28 | // TLSClientConfig specifies the TLS configuration to use for 'wss'. 29 | // If nil, the default configuration is used. 30 | TLSClientConfig *tls.Config 31 | 32 | // ErrorLog specifies an optional logger for errors 33 | // that occur when attempting to proxy the request. 34 | // If nil, logging goes to os.Stderr via the log package's 35 | // standard logger. 36 | ErrorLog *log.Logger 37 | } 38 | 39 | // stolen from net/http/httputil. singleJoiningSlash ensures that the route 40 | // '/a/' joined with '/b' becomes '/a/b'. 41 | func singleJoiningSlash(a, b string) string { 42 | aslash := strings.HasSuffix(a, "/") 43 | bslash := strings.HasPrefix(b, "/") 44 | switch { 45 | case aslash && bslash: 46 | return a + b[1:] 47 | case !aslash && !bslash: 48 | return a + "/" + b 49 | } 50 | return a + b 51 | } 52 | 53 | // NewSingleHostReverseProxy returns a new websocket ReverseProxy. The path 54 | // rewrites follow the same rules as the httputil.ReverseProxy. If the target 55 | // url has the path '/foo' and the incoming request '/bar', the request path 56 | // will be updated to '/foo/bar' before forwarding. 57 | // Scheme should specify if 'ws' or 'wss' should be used. 58 | func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { 59 | targetQuery := target.RawQuery 60 | director := func(req *http.Request) { 61 | req.URL.Scheme = target.Scheme 62 | req.URL.Host = target.Host 63 | req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) 64 | if targetQuery == "" || req.URL.RawQuery == "" { 65 | req.URL.RawQuery = targetQuery + req.URL.RawQuery 66 | } else { 67 | req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery 68 | } 69 | } 70 | return &ReverseProxy{Director: director} 71 | } 72 | 73 | // Function to implement the http.Handler interface. 74 | func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 75 | logFunc := log.Printf 76 | if p.ErrorLog != nil { 77 | logFunc = p.ErrorLog.Printf 78 | } 79 | 80 | if !IsWebSocketRequest(r) { 81 | http.Error(w, "Cannot handle non-WebSocket requests", 500) 82 | logFunc("Received a request that was not a WebSocket request") 83 | return 84 | } 85 | 86 | outreq := new(http.Request) 87 | // shallow copying 88 | *outreq = *r 89 | p.Director(outreq) 90 | host := outreq.URL.Host 91 | 92 | if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { 93 | // If we aren't the first proxy retain prior 94 | // X-Forwarded-For information as a comma+space 95 | // separated list and fold multiple headers into one. 96 | if prior, ok := outreq.Header["X-Forwarded-For"]; ok { 97 | clientIP = strings.Join(prior, ", ") + ", " + clientIP 98 | } 99 | outreq.Header.Set("X-Forwarded-For", clientIP) 100 | } 101 | 102 | dial := p.Dial 103 | if dial == nil { 104 | dial = net.Dial 105 | } 106 | 107 | // if host does not specify a port, use the default http port 108 | if !strings.Contains(host, ":") { 109 | if outreq.URL.Scheme == "wss" { 110 | host = host + ":443" 111 | } else { 112 | host = host + ":80" 113 | } 114 | } 115 | 116 | if outreq.URL.Scheme == "wss" { 117 | var tlsConfig *tls.Config 118 | if p.TLSClientConfig == nil { 119 | tlsConfig = &tls.Config{} 120 | } else { 121 | tlsConfig = p.TLSClientConfig 122 | } 123 | dial = func(network, address string) (net.Conn, error) { 124 | return tls.Dial("tcp", host, tlsConfig) 125 | } 126 | } 127 | 128 | d, err := dial("tcp", host) 129 | if err != nil { 130 | http.Error(w, "Error forwarding request.", 500) 131 | logFunc("Error dialing websocket backend %s: %v", outreq.URL, err) 132 | return 133 | } 134 | // All request generated by the http package implement this interface. 135 | hj, ok := w.(http.Hijacker) 136 | if !ok { 137 | http.Error(w, "Not a hijacker?", 500) 138 | return 139 | } 140 | // Hijack() tells the http package not to do anything else with the connection. 141 | // After, it bcomes this functions job to manage it. `nc` is of type *net.Conn. 142 | nc, _, err := hj.Hijack() 143 | if err != nil { 144 | logFunc("Hijack error: %v", err) 145 | return 146 | } 147 | defer nc.Close() // must close the underlying net connection after hijacking 148 | defer d.Close() 149 | 150 | // write the modified incoming request to the dialed connection 151 | err = outreq.Write(d) 152 | if err != nil { 153 | logFunc("Error copying request to target: %v", err) 154 | return 155 | } 156 | errc := make(chan error, 2) 157 | cp := func(dst io.Writer, src io.Reader) { 158 | _, err := io.Copy(dst, src) 159 | errc <- err 160 | } 161 | go cp(d, nc) 162 | go cp(nc, d) 163 | <-errc 164 | } 165 | 166 | // IsWebSocketRequest returns a boolean indicating whether the request has the 167 | // headers of a WebSocket handshake request. 168 | func IsWebSocketRequest(r *http.Request) bool { 169 | contains := func(key, val string) bool { 170 | vv := strings.Split(r.Header.Get(key), ",") 171 | for _, v := range vv { 172 | if val == strings.ToLower(strings.TrimSpace(v)) { 173 | return true 174 | } 175 | } 176 | return false 177 | } 178 | if !contains("Connection", "upgrade") { 179 | return false 180 | } 181 | if !contains("Upgrade", "websocket") { 182 | return false 183 | } 184 | return true 185 | } 186 | -------------------------------------------------------------------------------- /wsutil_test.go: -------------------------------------------------------------------------------- 1 | package wsutil 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "crypto/tls" 7 | "crypto/x509" 8 | "fmt" 9 | "io" 10 | "io/ioutil" 11 | "log" 12 | "net/http" 13 | "net/http/httptest" 14 | "net/url" 15 | "testing" 16 | "time" 17 | 18 | "golang.org/x/net/websocket" 19 | ) 20 | 21 | var devnull = log.New(ioutil.Discard, "", 0) 22 | 23 | func EchoWSHandler(ws *websocket.Conn) { 24 | io.Copy(ws, ws) 25 | } 26 | 27 | // Helper function to send an WS request to a given path. urlStr is assumed to 28 | // be the url from a httptest.Server 29 | func SendWSRequest(urlStr, data string, t *testing.T) (string, error) { 30 | if data == "" { 31 | return "", fmt.Errorf("cannot send no data to a websocket") 32 | } 33 | u, err := url.Parse(urlStr) 34 | if err != nil { 35 | return "", err 36 | } 37 | u.Scheme = "ws" 38 | origin := "http://localhost/" 39 | errc := make(chan error) 40 | wsc := make(chan *websocket.Conn) 41 | go func() { 42 | ws, err := websocket.Dial(u.String(), "", origin) 43 | if err != nil { 44 | errc <- err 45 | return 46 | } 47 | wsc <- ws 48 | }() 49 | var ws *websocket.Conn 50 | select { 51 | case err := <-errc: 52 | return "", err 53 | case ws = <-wsc: 54 | case <-time.After(time.Second * 2): 55 | return "", fmt.Errorf("websocket dial timed out") 56 | } 57 | defer ws.Close() 58 | msgc := make(chan string) 59 | go func() { 60 | if _, err := ws.Write([]byte(data)); err != nil { 61 | errc <- err 62 | return 63 | } 64 | var msg = make([]byte, 512) 65 | var n int 66 | if n, err = ws.Read(msg); err != nil { 67 | errc <- err 68 | return 69 | } 70 | msgc <- string(msg[:n]) 71 | }() 72 | select { 73 | case err := <-errc: 74 | return "", err 75 | case msg := <-msgc: 76 | t.Logf("response from ws: '%s'", msg) 77 | return msg, nil 78 | case <-time.After(time.Second * 2): 79 | return "", fmt.Errorf("websocket request timed out") 80 | } 81 | } 82 | 83 | func TestWebSocketProxy(t *testing.T) { 84 | go func() { 85 | time.Sleep(5 * time.Second) 86 | panic("hi") 87 | }() 88 | echoServer := http.NewServeMux() 89 | echoServer.Handle("/echo/ws", websocket.Handler(EchoWSHandler)) 90 | // make sure that the proxy preserves url queries 91 | queryAssert := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 92 | if r.URL.Query().Get("foo") != "bar" { 93 | t.Errorf("request is missing url query") 94 | } 95 | echoServer.ServeHTTP(w, r) 96 | }) 97 | backend := httptest.NewServer(queryAssert) 98 | defer backend.Close() 99 | backendURL, err := url.Parse(backend.URL) 100 | if err != nil { 101 | t.Fatal(err) 102 | } 103 | backendURL.Path = "/echo" 104 | proxy := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) 105 | defer proxy.Close() 106 | 107 | for _, data := range []string{"eric is so cool", "some data", "else"} { 108 | resp, err := SendWSRequest(proxy.URL+"/ws?foo=bar", data, t) 109 | if err != nil { 110 | t.Error(err) 111 | continue 112 | } 113 | if resp != data { 114 | t.Errorf("expected '%s' from server, got '%s'", data, resp) 115 | } 116 | } 117 | } 118 | 119 | func TestReverseProxy(t *testing.T) { 120 | h := http.NewServeMux() 121 | h.Handle("/ws", websocket.Handler(func(ws *websocket.Conn) { 122 | ws.Write([]byte("wssuccess")) 123 | ws.Close() 124 | })) 125 | h.HandleFunc("/http", func(w http.ResponseWriter, r *http.Request) { 126 | w.Write([]byte("httpsuccess")) 127 | }) 128 | isWSHandler := func(w http.ResponseWriter, r *http.Request) { 129 | isWS := IsWebSocketRequest(r) 130 | if isWS && (r.URL.Path != "/ws") { 131 | t.Errorf("detected ws and got path %s", r.URL.Path) 132 | } else if !isWS && (r.URL.Path != "/http") { 133 | t.Errorf("detected http and got path %s", r.URL.Path) 134 | } 135 | h.ServeHTTP(w, r) 136 | } 137 | n := httptest.NewServer(http.HandlerFunc(isWSHandler)) 138 | defer n.Close() 139 | errc := make(chan error) 140 | go func() { 141 | resp, err := http.Get(n.URL + "/http") 142 | if err != nil { 143 | errc <- fmt.Errorf("could not GET url: %v", err) 144 | return 145 | } 146 | defer resp.Body.Close() 147 | data, err := ioutil.ReadAll(resp.Body) 148 | if err != nil { 149 | errc <- fmt.Errorf("could not read from body") 150 | return 151 | } 152 | t.Logf("response from http request: %s", data) 153 | if string(data) != "httpsuccess" { 154 | errc <- fmt.Errorf("expected 'httpsuccess' got '%s'", string(data)) 155 | return 156 | } 157 | errc <- nil 158 | }() 159 | select { 160 | case err := <-errc: 161 | if err != nil { 162 | t.Error(err) 163 | } 164 | case <-time.After(4 * time.Second): 165 | t.Error("http request timed out") 166 | } 167 | go func() { 168 | t.Logf("making request to server") 169 | wsResp, err := SendWSRequest(n.URL+"/ws", "a lot of data", t) 170 | if err != nil { 171 | errc <- fmt.Errorf("could not connect to ws server: %v", err) 172 | return 173 | } 174 | t.Logf("got response from server: %s", wsResp) 175 | if wsResp != "wssuccess" { 176 | errc <- fmt.Errorf("expected 'wssuccess' got '%s'", wsResp) 177 | return 178 | } 179 | errc <- nil 180 | }() 181 | t.Logf("waiting for response from websocket") 182 | select { 183 | case err := <-errc: 184 | if err != nil { 185 | t.Error(err) 186 | } 187 | return 188 | case <-time.After(4 * time.Second): 189 | t.Error("websocket request timed out") 190 | return 191 | } 192 | } 193 | 194 | // HTTP requests should always create errors 195 | func TestHTTPReq(t *testing.T) { 196 | backendHF := func(w http.ResponseWriter, r *http.Request) { 197 | t.Error("non-websocket request was made through proxy") 198 | } 199 | backend := httptest.NewServer(http.HandlerFunc(backendHF)) 200 | defer backend.Close() 201 | 202 | u, err := url.Parse(backend.URL + "/") 203 | if err != nil { 204 | t.Error(err) 205 | return 206 | } 207 | u.Scheme = "ws" 208 | 209 | proxy := NewSingleHostReverseProxy(u) 210 | proxy.ErrorLog = devnull 211 | proxyServer := httptest.NewServer(proxy) 212 | defer proxyServer.Close() 213 | 214 | _, err = http.Get(proxyServer.URL + "/") 215 | if err != nil { 216 | // the websocket proxy should return with a 500 to an http request, not an error 217 | t.Error(err) 218 | return 219 | } 220 | } 221 | 222 | func TestTLS(t *testing.T) { 223 | cert, err := tls.LoadX509KeyPair("testdata/server_cert.crt", "testdata/server_key.crt") 224 | if err != nil { 225 | t.Fatal(err) 226 | } 227 | rootCert, err := ioutil.ReadFile("testdata/root_cert.crt") 228 | if err != nil { 229 | t.Fatal(err) 230 | } 231 | certPool := x509.NewCertPool() 232 | if !certPool.AppendCertsFromPEM(rootCert) { 233 | t.Fatal("no root certificate detected") 234 | } 235 | 236 | randData := make([]byte, 2048) 237 | if _, err = io.ReadFull(rand.Reader, randData); err != nil { 238 | t.Fatal(err) 239 | } 240 | 241 | backendHandler := func(ws *websocket.Conn) { 242 | _, err := ws.Write(randData) 243 | if err != nil { 244 | t.Errorf("error writing to websocket: %v", err) 245 | } 246 | ws.Close() 247 | } 248 | backend := httptest.NewUnstartedServer(websocket.Handler(backendHandler)) 249 | backend.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} 250 | backend.StartTLS() 251 | defer backend.Close() 252 | 253 | u, err := url.Parse(backend.URL + "/") 254 | if err != nil { 255 | t.Error(err) 256 | return 257 | } 258 | u.Scheme = "wss" 259 | 260 | proxy := NewSingleHostReverseProxy(u) 261 | proxy.TLSClientConfig = &tls.Config{RootCAs: certPool} 262 | proxyServer := httptest.NewServer(proxy) 263 | defer proxyServer.Close() 264 | 265 | proxyURL, err := url.Parse(proxyServer.URL + "/") 266 | if err != nil { 267 | t.Error(err) 268 | return 269 | } 270 | proxyURL.Scheme = "ws" 271 | 272 | ws, err := websocket.Dial(proxyURL.String(), "", "http://localhost/") 273 | if err != nil { 274 | t.Errorf("could not dial proxy %s: %v", proxyURL.String(), err) 275 | return 276 | } 277 | readData := make([]byte, len(randData)) 278 | defer ws.Close() 279 | if _, err := io.ReadFull(ws, readData); err != nil { 280 | t.Errorf("error reading from ws: %v", err) 281 | return 282 | } 283 | if bytes.Compare(randData, readData) != 0 { 284 | t.Error("data send and data read was different") 285 | } 286 | } 287 | --------------------------------------------------------------------------------