├── LICENSE.md ├── README.md ├── go.mod ├── go.sum ├── main.go ├── misc └── sshtunnel ├── server_test.go ├── tunnel.go └── tunnel_test.go /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright © 2016, The Go Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright notice, 9 | this list of conditions and the following disclaimer in the documentation and/or 10 | other materials provided with the distribution. 11 | * Neither the copyright holder nor the names of its contributors may be used to 12 | endorse or promote products derived from this software without specific prior 13 | written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSH tunnel proxy daemon # 2 | 3 | ## Introduction ## 4 | 5 | This repository contains a simple implementation of a SSH proxy daemon used to 6 | securely tunnel TCP connections in forward and reverse proxy mode. 7 | This tool provides equivalent functionality to using the `ssh` command's 8 | `-L` and `-R` flags. 9 | 10 | Consider using [github.com/dsnet/udptunnel](https://github.com/dsnet/udptunnel) 11 | if running behind a NAT that drops long-running TCP connections, but allows 12 | UDP traffic to reliably pass through. 13 | 14 | ## Usage ## 15 | 16 | Build the daemon: 17 | 18 | ```go get -u github.com/dsnet/sshtunnel``` 19 | 20 | Create a configuration file: 21 | 22 | ```javascript 23 | { 24 | "KeyFiles": ["/path/to/key.priv"], 25 | "KnownHostFiles": ["/path/to/known_hosts"], 26 | "Tunnels": [{ 27 | // Forward tunnel (locally binded socket proxies to remote target). 28 | "Tunnel": "bind_address:port -> dial_address:port", 29 | "Server": "user@host", 30 | }, { 31 | // Reverse tunnel (remotely binded socket proxies to local target). 32 | "Tunnel": "dial_address:port <- bind_address:port", 33 | "Server": "user@host", 34 | }], 35 | } 36 | ``` 37 | 38 | The above configuration is equivalent to running the following: 39 | 40 | ```bash 41 | ssh $USER@$HOST -i /path/to/key.priv -L $BIND_ADDRESS:$BIND_PORT:$DIAL_ADDRESS:$DIAL_PORT 42 | ssh $USER@$HOST -i /path/to/key.priv -R $BIND_ADDRESS:$BIND_PORT:$DIAL_ADDRESS:$DIAL_PORT 43 | ``` 44 | 45 | Start the daemon (assuming `$GOPATH/bin` is in your `$PATH`): 46 | 47 | ```sshtunnel /path/to/config.json``` 48 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/dsnet/sshtunnel 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a 7 | golang.org/x/crypto v0.27.0 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 2 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 3 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 4 | github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= 5 | github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= 6 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 7 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 8 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 9 | golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= 10 | golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= 11 | golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= 12 | golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= 13 | golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= 14 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 15 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 16 | golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 17 | golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 18 | golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 19 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 20 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 21 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 22 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 23 | golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= 24 | golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= 25 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 26 | golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= 27 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 28 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 29 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 30 | golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= 31 | golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 32 | golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 33 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 34 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 35 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 36 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 37 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 38 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 39 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 40 | golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 41 | golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 42 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 43 | golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 44 | golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= 45 | golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 46 | golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= 47 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 48 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 49 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= 50 | golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= 51 | golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= 52 | golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= 53 | golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= 54 | golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= 55 | golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= 56 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 57 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 58 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 59 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 60 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 61 | golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= 62 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 63 | golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 64 | golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 65 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 66 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 67 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 68 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= 69 | golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= 70 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= 71 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 72 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016, The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE.md file. 4 | 5 | // sshtunnel is daemon for setting up forward and reverse SSH tunnels. 6 | // 7 | // The daemon is started by executing sshtunnel with the path to a JSON 8 | // configuration file. The configuration takes the following form: 9 | // 10 | // { 11 | // "KeyFiles": ["/path/to/key.priv"], 12 | // "KnownHostFiles": ["/path/to/known_hosts"], 13 | // "Tunnels": [{ 14 | // // Forward tunnel (locally binded socket proxies to remote target). 15 | // "Tunnel": "bind_address:port -> dial_address:port", 16 | // "Server": "user@host:port", 17 | // }, { 18 | // // Reverse tunnel (remotely binded socket proxies to local target). 19 | // "Tunnel": "dial_address:port <- bind_address:port", 20 | // "Server": "user@host:port", 21 | // }], 22 | // } 23 | // 24 | // See the TunnelConfig struct for more details. 25 | package main 26 | 27 | import ( 28 | "bytes" 29 | "context" 30 | "crypto/sha256" 31 | "encoding/json" 32 | "fmt" 33 | "io" 34 | "log" 35 | "net" 36 | "os" 37 | "os/signal" 38 | "os/user" 39 | "path" 40 | "strings" 41 | "sync" 42 | "syscall" 43 | "time" 44 | 45 | "github.com/tailscale/hujson" 46 | "golang.org/x/crypto/ssh" 47 | "golang.org/x/crypto/ssh/knownhosts" 48 | ) 49 | 50 | // Version of the sshtunnel binary. May be set by linker when building. 51 | var version string 52 | 53 | type TunnelConfig struct { 54 | // Log configures how log lines are produced by zsync. 55 | Log struct { 56 | // File is where the daemon will direct its output log. 57 | // If the path is empty, then the log outputs to os.Stderr. 58 | File string `json:",omitempty"` 59 | 60 | // ExcludeTimestamp specifies that a timestamp should not be logged. 61 | // This is useful if another mechanism (e.g., systemd) records timestamps. 62 | ExcludeTimestamp bool `json:",omitempty"` 63 | } 64 | 65 | // KeyFiles is a list of SSH private key files. 66 | KeyFiles []string 67 | 68 | // KnownHostFiles is a list of key database files for host public keys 69 | // in the OpenSSH known_hosts file format. 70 | KnownHostFiles []string 71 | 72 | // KeepAlive sets the keep alive settings for each SSH connection. 73 | // It is recommended that these values match the AliveInterval and 74 | // AliveCountMax parameters on the remote OpenSSH server. 75 | // If unset, then the default is an interval of 30s with 2 max counts. 76 | KeepAlive *KeepAliveConfig `json:",omitempty"` 77 | 78 | // Tunnels is a list of tunnels to establish. 79 | // The same set of SSH keys will be used to authenticate the 80 | // SSH connection for each server. 81 | Tunnels []struct { 82 | // Tunnel is a pair of host:port endpoints that can be configured 83 | // to either operate as a forward tunnel or a reverse tunnel. 84 | // 85 | // The syntax of a forward tunnel is: 86 | // "bind_address:port -> dial_address:port" 87 | // 88 | // A forward tunnel opens a listening TCP socket on the 89 | // local side (at bind_address:port) and proxies all traffic to a 90 | // socket on the remote side (at dial_address:port). 91 | // 92 | // The syntax of a reverse tunnel is: 93 | // "dial_address:port <- bind_address:port" 94 | // 95 | // A reverse tunnel opens a listening TCP socket on the 96 | // remote side (at bind_address:port) and proxies all traffic to a 97 | // socket on the local side (at dial_address:port). 98 | Tunnel string 99 | 100 | // Server is a remote SSH host. It has the following syntax: 101 | // "user@host:port" 102 | // 103 | // If the user is missing, then it defaults to the current process user. 104 | // If the port is missing, then it defaults to 22. 105 | Server string 106 | 107 | // KeepAlive is a tunnel-specific setting of the global KeepAlive. 108 | // If unspecified, it uses the global KeepAlive settings. 109 | KeepAlive *KeepAliveConfig `json:",omitempty"` 110 | } 111 | } 112 | 113 | type KeepAliveConfig struct { 114 | // Interval is the amount of time in seconds to wait before the 115 | // tunnel client will send a keep-alive message to ensure some minimum 116 | // traffic on the SSH connection. 117 | Interval uint 118 | 119 | // CountMax is the maximum number of consecutive failed responses to 120 | // keep-alive messages the client is willing to tolerate before considering 121 | // the SSH connection as dead. 122 | CountMax uint 123 | } 124 | 125 | func loadConfig(conf string) (tunns []tunnel, logger *log.Logger, closer func() error) { 126 | var logBuf bytes.Buffer 127 | logger = log.New(io.MultiWriter(os.Stderr, &logBuf), "", log.Ldate|log.Ltime|log.Lshortfile) 128 | 129 | var hash string 130 | if b, _ := os.ReadFile(os.Args[0]); len(b) > 0 { 131 | hash = fmt.Sprintf("%x", sha256.Sum256(b)) 132 | } 133 | 134 | // Load configuration file. 135 | var config TunnelConfig 136 | c, err := os.ReadFile(conf) 137 | if err != nil { 138 | logger.Fatalf("unable to read config: %v", err) 139 | } 140 | c, _ = hujson.Standardize(c) 141 | if c, err = hujson.Format(c); err != nil { 142 | logger.Fatalf("unable to parse config: %v", err) 143 | } 144 | if err := json.Unmarshal(c, &config); err != nil { 145 | logger.Fatalf("unable to decode config: %v", err) 146 | } 147 | for _, t := range config.Tunnels { 148 | if config.KeepAlive == nil && t.KeepAlive == nil { 149 | config.KeepAlive = &KeepAliveConfig{Interval: 30, CountMax: 2} 150 | break 151 | } 152 | } 153 | if config.Log.ExcludeTimestamp { 154 | logger.SetFlags(log.Lshortfile) 155 | } 156 | 157 | // Print the configuration. 158 | var b bytes.Buffer 159 | enc := json.NewEncoder(&b) 160 | enc.SetEscapeHTML(false) 161 | enc.SetIndent("", "\t") 162 | enc.Encode(struct { 163 | TunnelConfig 164 | BinaryVersion string `json:",omitempty"` 165 | BinarySHA256 string `json:",omitempty"` 166 | }{config, version, hash}) 167 | logger.Printf("loaded config:\n%s", b.String()) 168 | 169 | // Setup the log output. 170 | if config.Log.File == "" { 171 | logger.SetOutput(os.Stderr) 172 | closer = func() error { return nil } 173 | } else { 174 | f, err := os.OpenFile(config.Log.File, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0664) 175 | if err != nil { 176 | logger.Fatalf("error opening log file: %v", err) 177 | } 178 | f.Write(logBuf.Bytes()) // Write log output prior to this point 179 | logger.Printf("suppress stderr logging (redirected to %s)", f.Name()) 180 | logger.SetOutput(f) 181 | closer = f.Close 182 | } 183 | 184 | // Parse all of the private keys. 185 | var keys []ssh.Signer 186 | if len(config.KeyFiles) == 0 { 187 | logger.Fatal("no private keys specified") 188 | } 189 | for _, kf := range config.KeyFiles { 190 | b, err := os.ReadFile(kf) 191 | if err != nil { 192 | logger.Fatalf("private key error: %v", err) 193 | } 194 | k, err := ssh.ParsePrivateKey(b) 195 | if err != nil { 196 | logger.Fatalf("private key error: %v", err) 197 | } 198 | keys = append(keys, k) 199 | } 200 | auth := []ssh.AuthMethod{ssh.PublicKeys(keys...)} 201 | 202 | // Parse all of the host public keys. 203 | if len(config.KnownHostFiles) == 0 { 204 | logger.Fatal("no host public keys specified") 205 | } 206 | hostKeys, err := knownhosts.New(config.KnownHostFiles...) 207 | if err != nil { 208 | logger.Fatalf("public key error: %v", err) 209 | } 210 | 211 | // Parse all of the tunnels. 212 | for _, t := range config.Tunnels { 213 | var tunn tunnel 214 | tt := strings.Fields(t.Tunnel) 215 | if len(tt) != 3 { 216 | logger.Fatalf("invalid tunnel syntax: %s", t.Tunnel) 217 | } 218 | 219 | // Parse for the tunnel endpoints. 220 | switch tt[1] { 221 | case "->": 222 | tunn.bindAddr, tunn.mode, tunn.dialAddr = tt[0], '>', tt[2] 223 | case "<-": 224 | tunn.dialAddr, tunn.mode, tunn.bindAddr = tt[0], '<', tt[2] 225 | default: 226 | logger.Fatalf("invalid tunnel syntax: %s", t.Tunnel) 227 | } 228 | for _, addr := range []string{tunn.bindAddr, tunn.dialAddr} { 229 | if _, _, err := net.SplitHostPort(addr); err != nil { 230 | logger.Fatalf("invalid endpoint: %s", addr) 231 | } 232 | } 233 | 234 | // Parse for the SSH target host. 235 | tunn.hostAddr = t.Server 236 | if i := strings.IndexByte(t.Server, '@'); i >= 0 { 237 | tunn.user = t.Server[:i] 238 | tunn.hostAddr = t.Server[i+1:] 239 | } 240 | if _, _, err := net.SplitHostPort(tunn.hostAddr); err != nil { 241 | tunn.hostAddr = net.JoinHostPort(tunn.hostAddr, "22") 242 | } 243 | if _, _, err := net.SplitHostPort(tunn.hostAddr); err != nil { 244 | logger.Fatalf("invalid server: %s", t.Server) 245 | } 246 | 247 | // Parse for the SSH user. 248 | if tunn.user == "" { 249 | u, err := user.Current() 250 | if err != nil { 251 | logger.Fatalf("unexpected error: %v", err) 252 | } 253 | tunn.user = u.Username 254 | } 255 | 256 | if t.KeepAlive == nil { 257 | tunn.keepAlive = *config.KeepAlive 258 | } else { 259 | tunn.keepAlive = *t.KeepAlive 260 | } 261 | tunn.retryInterval = 30 * time.Second 262 | tunn.auth = auth 263 | tunn.hostKeys = hostKeys 264 | tunn.log = logger 265 | tunns = append(tunns, tunn) 266 | } 267 | 268 | return tunns, logger, closer 269 | } 270 | 271 | func main() { 272 | if len(os.Args) != 2 { 273 | fmt.Fprintf(os.Stderr, "Usage:\n") 274 | fmt.Fprintf(os.Stderr, "\t%s CONFIG_PATH\n", os.Args[0]) 275 | os.Exit(1) 276 | } 277 | tunns, logger, closer := loadConfig(os.Args[1]) 278 | defer closer() 279 | 280 | // Setup signal handler to initiate shutdown. 281 | ctx, cancel := context.WithCancel(context.Background()) 282 | go func() { 283 | sigc := make(chan os.Signal, 1) 284 | signal.Notify(sigc, syscall.SIGINT, syscall.SIGTERM) 285 | logger.Printf("received %v - initiating shutdown", <-sigc) 286 | cancel() 287 | }() 288 | 289 | // Start a bridge for each tunnel. 290 | var wg sync.WaitGroup 291 | logger.Printf("%s starting", path.Base(os.Args[0])) 292 | defer logger.Printf("%s shutdown", path.Base(os.Args[0])) 293 | for _, t := range tunns { 294 | wg.Add(1) 295 | go t.bindTunnel(ctx, &wg) 296 | } 297 | wg.Wait() 298 | } 299 | -------------------------------------------------------------------------------- /misc/sshtunnel: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ### BEGIN INIT INFO 4 | # Provides: sshtunnel 5 | # Required-Start: $all 6 | # Required-Stop: $all 7 | # Default-Start: 2 3 4 5 8 | # Default-Stop: 0 1 6 9 | # Short-Description: Starts a SSH reverse proxy tunnel 10 | # Description: Starts sshtunnel using start-stop-daemon 11 | ### END INIT INFO 12 | 13 | NAME=sshtunnel 14 | DESC=sshtunnel 15 | PID_FILE=/var/run/$NAME.pid 16 | DAEMON_USER=tunnel 17 | DAEMON_PATH=/usr/local/sshtunnel 18 | DAEMON_BINARY=sshtunnel 19 | DAEMON_CONFIG=config.json 20 | 21 | # Pre-check. 22 | if [ ! -f $DAEMON_PATH/$DAEMON_BINARY ]; then 23 | echo "$DAEMON_PATH/$DAEMON_BINARY does not exist!" 24 | exit 1 25 | fi 26 | set -e 27 | . /lib/lsb/init-functions 28 | 29 | running() { 30 | [ ! -f "$PID_FILE" ] && return 1 31 | [ ! -d /proc/$(cat $PID_FILE) ] && return 1 32 | return 0 33 | } 34 | 35 | start_server() { 36 | echo "Starting $DESC:" 37 | start-stop-daemon --start --quiet --pidfile $PID_FILE --make-pidfile --background --chuid $DAEMON_USER --chdir $DAEMON_PATH --startas $DAEMON_BINARY -- $DAEMON_CONFIG && RET_CODE=0 || RET_CODE=$? 38 | if [ $RET_CODE -eq 0 ]; then 39 | sleep 0.5 40 | if running; then 41 | log_success_msg "$NAME started" 42 | else 43 | log_failure_msg "$NAME not started" 44 | rm -f $PID_FILE 45 | return 1 46 | fi 47 | elif [ $RET_CODE -eq 1 ]; then 48 | log_warning_msg "$NAME already started" 49 | else 50 | log_failure_msg "$NAME not started" 51 | fi 52 | return $RET_CODE 53 | } 54 | 55 | stop_server() { 56 | echo "Stopping $DESC:" 57 | start-stop-daemon --stop --quiet --pidfile $PID_FILE && RET_CODE=0 || RET_CODE=$? 58 | if [ $RET_CODE -eq 0 ]; then 59 | log_success_msg "$NAME stopped" 60 | elif [ $RET_CODE -eq 1 ]; then 61 | log_warning_msg "$NAME already stopped" 62 | else 63 | log_failure_msg "$NAME not stopped" 64 | fi 65 | running || rm -f $PID_FILE 66 | return $RET_CODE 67 | } 68 | 69 | status_server() { 70 | echo "Checking $DESC status:" 71 | if running; then 72 | log_success_msg "$NAME is running" 73 | else 74 | log_warning_msg "$NAME is not running" 75 | fi 76 | return 0 77 | } 78 | 79 | case "$1" in 80 | start) 81 | start_server 82 | ;; 83 | stop) 84 | stop_server 85 | ;; 86 | restart) 87 | stop_server || true 88 | sleep 0.5 89 | start_server 90 | ;; 91 | status) 92 | status_server 93 | ;; 94 | *) 95 | echo "Usage: $NAME {start|stop|restart|status}" >&2 96 | exit 1 97 | ;; 98 | esac 99 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017, The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE.md file. 4 | 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "encoding/binary" 11 | "fmt" 12 | "io" 13 | "net" 14 | "reflect" 15 | "strconv" 16 | "sync" 17 | "testing" 18 | 19 | "golang.org/x/crypto/ssh" 20 | ) 21 | 22 | // runServer starts an SSH server capable of handling forward and reverse 23 | // TCP tunnels. This function blocks for the entire duration that the 24 | // server is running and can be stopped by canceling the context. 25 | // 26 | // The server listens on the provided Listener and will present to clients 27 | // a certificate from serverKey and will only accept users that match 28 | // the provided clientKeys. Only users of the name "user%d" are allowed where 29 | // the ID number is the index for the specified client key provided. 30 | func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) { 31 | wg := new(sync.WaitGroup) 32 | defer wg.Wait() 33 | 34 | // Generate SSH server configuration. 35 | conf := ssh.ServerConfig{ 36 | PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { 37 | var uid int 38 | _, err := fmt.Sscanf(c.User(), "user%d", &uid) 39 | if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) { 40 | return nil, fmt.Errorf("unknown public key for %q", c.User()) 41 | } 42 | return nil, nil 43 | }, 44 | } 45 | conf.AddHostKey(serverKey) 46 | 47 | // Handle every SSH client connection. 48 | for { 49 | tcpCn, err := ln.Accept() 50 | if err != nil { 51 | if !isDone(ctx) { 52 | t.Errorf("accept error: %v", err) 53 | } 54 | return 55 | } 56 | wg.Add(1) 57 | go handleServerConn(t, ctx, wg, tcpCn, &conf) 58 | } 59 | } 60 | 61 | // handleServerConn handles a single SSH connection. 62 | func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) { 63 | defer wg.Done() 64 | go closeWhenDone(ctx, tcpCn) 65 | defer tcpCn.Close() 66 | 67 | sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf) 68 | if err != nil { 69 | t.Errorf("new connection error: %v", err) 70 | return 71 | } 72 | go closeWhenDone(ctx, sshCn) 73 | defer sshCn.Close() 74 | 75 | wg.Add(1) 76 | go handleServerChannels(t, ctx, wg, sshCn, chans) 77 | 78 | wg.Add(1) 79 | go handleServerRequests(t, ctx, wg, sshCn, reqs) 80 | 81 | if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) { 82 | t.Errorf("connection error: %v", err) 83 | } 84 | } 85 | 86 | // handleServerChannels handles new channels on a SSH connection. 87 | // The client initiates a new channel when forwarding a TCP dial. 88 | func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) { 89 | defer wg.Done() 90 | for nc := range chans { 91 | if nc.ChannelType() != "direct-tcpip" { 92 | nc.Reject(ssh.UnknownChannelType, "not implemented") 93 | continue 94 | } 95 | var args struct { 96 | DstHost string 97 | DstPort uint32 98 | SrcHost string 99 | SrcPort uint32 100 | } 101 | if !unmarshalData(nc.ExtraData(), &args) { 102 | nc.Reject(ssh.Prohibited, "invalid request") 103 | continue 104 | } 105 | 106 | // Open a connection for both sides. 107 | cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort)))) 108 | if err != nil { 109 | nc.Reject(ssh.ConnectionFailed, err.Error()) 110 | continue 111 | } 112 | ch, reqs, err := nc.Accept() 113 | if err != nil { 114 | t.Errorf("accept channel error: %v", err) 115 | cn.Close() 116 | continue 117 | } 118 | go ssh.DiscardRequests(reqs) 119 | 120 | wg.Add(1) 121 | go bidirCopyAndClose(t, ctx, wg, cn, ch) 122 | } 123 | } 124 | 125 | // handleServerRequests handles new requests on a SSH connection. 126 | // The client initiates a new request for binding a local TCP socket. 127 | func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) { 128 | defer wg.Done() 129 | for r := range reqs { 130 | if !r.WantReply { 131 | continue 132 | } 133 | if r.Type != "tcpip-forward" { 134 | r.Reply(false, nil) 135 | continue 136 | } 137 | var args struct { 138 | Host string 139 | Port uint32 140 | } 141 | if !unmarshalData(r.Payload, &args) { 142 | r.Reply(false, nil) 143 | continue 144 | } 145 | ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port)))) 146 | if err != nil { 147 | r.Reply(false, nil) 148 | continue 149 | } 150 | 151 | var resp struct{ Port uint32 } 152 | _, resp.Port = splitHostPort(ln.Addr().String()) 153 | if err := r.Reply(true, marshalData(resp)); err != nil { 154 | t.Errorf("request reply error: %v", err) 155 | ln.Close() 156 | continue 157 | } 158 | 159 | wg.Add(1) 160 | go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host) 161 | 162 | } 163 | } 164 | 165 | // handleLocalListener handles every new connection on the provided socket. 166 | // All local connections will be forwarded to the client via a new channel. 167 | func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) { 168 | defer wg.Done() 169 | go closeWhenDone(ctx, ln) 170 | defer ln.Close() 171 | 172 | for { 173 | // Open a connection for both sides. 174 | cn, err := ln.Accept() 175 | if err != nil { 176 | if !isDone(ctx) { 177 | t.Errorf("accept error: %v", err) 178 | } 179 | return 180 | } 181 | var args struct { 182 | DstHost string 183 | DstPort uint32 184 | SrcHost string 185 | SrcPort uint32 186 | } 187 | args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String()) 188 | args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String()) 189 | args.DstHost = host // This must match on client side! 190 | ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args)) 191 | if err != nil { 192 | t.Errorf("open channel error: %v", err) 193 | cn.Close() 194 | continue 195 | } 196 | go ssh.DiscardRequests(reqs) 197 | 198 | wg.Add(1) 199 | go bidirCopyAndClose(t, ctx, wg, cn, ch) 200 | } 201 | } 202 | 203 | // bidirCopyAndClose performs a bi-directional copy on both connections 204 | // until either side closes the connection or the context is canceled. 205 | // This will close both connections before returning. 206 | func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) { 207 | defer wg.Done() 208 | go closeWhenDone(ctx, c1) 209 | go closeWhenDone(ctx, c2) 210 | defer c1.Close() 211 | defer c2.Close() 212 | 213 | errc := make(chan error, 2) 214 | go func() { 215 | _, err := io.Copy(c1, c2) 216 | errc <- err 217 | }() 218 | go func() { 219 | _, err := io.Copy(c2, c1) 220 | errc <- err 221 | }() 222 | if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) { 223 | t.Errorf("copy error: %v", err) 224 | } 225 | } 226 | 227 | // unmarshalData parses b into s, where s is a pointer to a struct. 228 | // Only unexported fields of type uint32 or string are allowed. 229 | func unmarshalData(b []byte, s interface{}) bool { 230 | v := reflect.ValueOf(s) 231 | if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { 232 | panic("destination must be pointer to struct") 233 | } 234 | v = v.Elem() 235 | for i := 0; i < v.NumField(); i++ { 236 | switch v.Type().Field(i).Type.Kind() { 237 | case reflect.Uint32: 238 | if len(b) < 4 { 239 | return false 240 | } 241 | v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b))) 242 | b = b[4:] 243 | case reflect.String: 244 | if len(b) < 4 { 245 | return false 246 | } 247 | n := binary.BigEndian.Uint32(b) 248 | b = b[4:] 249 | if uint64(len(b)) < uint64(n) { 250 | return false 251 | } 252 | v.Field(i).Set(reflect.ValueOf(string(b[:n]))) 253 | b = b[n:] 254 | default: 255 | panic("invalid field type: " + v.Type().Field(i).Type.String()) 256 | } 257 | } 258 | return len(b) == 0 259 | } 260 | 261 | // marshalData serializes s into b, where s is a struct (or a pointer to one). 262 | // Only unexported fields of type uint32 or string are allowed. 263 | func marshalData(s interface{}) (b []byte) { 264 | v := reflect.ValueOf(s) 265 | if v.IsValid() && v.Kind() == reflect.Ptr { 266 | v = v.Elem() 267 | } 268 | if !v.IsValid() || v.Kind() != reflect.Struct { 269 | panic("source must be a struct") 270 | } 271 | var arr32 [4]byte 272 | for i := 0; i < v.NumField(); i++ { 273 | switch v.Type().Field(i).Type.Kind() { 274 | case reflect.Uint32: 275 | binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint())) 276 | b = append(b, arr32[:]...) 277 | case reflect.String: 278 | binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len())) 279 | b = append(b, arr32[:]...) 280 | b = append(b, v.Field(i).String()...) 281 | default: 282 | panic("invalid field type: " + v.Type().Field(i).Type.String()) 283 | } 284 | } 285 | return b 286 | 287 | } 288 | 289 | func splitHostPort(s string) (string, uint32) { 290 | host, port, _ := net.SplitHostPort(s) 291 | p, _ := strconv.Atoi(port) 292 | return host, uint32(p) 293 | } 294 | 295 | func closeWhenDone(ctx context.Context, c io.Closer) { 296 | <-ctx.Done() 297 | c.Close() 298 | } 299 | 300 | func isDone(ctx context.Context) bool { 301 | select { 302 | case <-ctx.Done(): 303 | return true 304 | default: 305 | return false 306 | } 307 | } 308 | -------------------------------------------------------------------------------- /tunnel.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017, The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE.md file. 4 | 5 | package main 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | "io" 11 | "net" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | 16 | "golang.org/x/crypto/ssh" 17 | ) 18 | 19 | type logger interface { 20 | Printf(string, ...interface{}) 21 | } 22 | 23 | type tunnel struct { 24 | auth []ssh.AuthMethod 25 | hostKeys ssh.HostKeyCallback 26 | mode byte // '>' for forward, '<' for reverse 27 | user string 28 | hostAddr string 29 | bindAddr string 30 | dialAddr string 31 | 32 | retryInterval time.Duration 33 | keepAlive KeepAliveConfig 34 | 35 | log logger 36 | } 37 | 38 | func (t tunnel) String() string { 39 | var left, right string 40 | mode := "" 41 | switch t.mode { 42 | case '>': 43 | left, mode, right = t.bindAddr, "->", t.dialAddr 44 | case '<': 45 | left, mode, right = t.dialAddr, "<-", t.bindAddr 46 | } 47 | return fmt.Sprintf("%s@%s | %s %s %s", t.user, t.hostAddr, left, mode, right) 48 | } 49 | 50 | func (t tunnel) bindTunnel(ctx context.Context, wg *sync.WaitGroup) { 51 | defer wg.Done() 52 | 53 | for { 54 | var once sync.Once // Only print errors once per session 55 | func() { 56 | // Connect to the server host via SSH. 57 | cl, err := ssh.Dial("tcp", t.hostAddr, &ssh.ClientConfig{ 58 | User: t.user, 59 | Auth: t.auth, 60 | HostKeyCallback: t.hostKeys, 61 | Timeout: 5 * time.Second, 62 | }) 63 | if err != nil { 64 | once.Do(func() { t.log.Printf("(%v) SSH dial error: %v", t, err) }) 65 | return 66 | } 67 | wg.Add(1) 68 | go t.keepAliveMonitor(&once, wg, cl) 69 | defer cl.Close() 70 | 71 | // Attempt to bind to the inbound socket. 72 | var ln net.Listener 73 | switch t.mode { 74 | case '>': 75 | ln, err = net.Listen("tcp", t.bindAddr) 76 | case '<': 77 | ln, err = cl.Listen("tcp", t.bindAddr) 78 | } 79 | if err != nil { 80 | once.Do(func() { t.log.Printf("(%v) bind error: %v", t, err) }) 81 | return 82 | } 83 | 84 | // The socket is binded. Make sure we close it eventually. 85 | bindCtx, cancel := context.WithCancel(ctx) 86 | defer cancel() 87 | go func() { 88 | cl.Wait() 89 | cancel() 90 | }() 91 | go func() { 92 | <-bindCtx.Done() 93 | once.Do(func() {}) // Suppress future errors 94 | ln.Close() 95 | }() 96 | 97 | t.log.Printf("(%v) binded tunnel", t) 98 | defer t.log.Printf("(%v) collapsed tunnel", t) 99 | 100 | // Accept all incoming connections. 101 | for { 102 | cn1, err := ln.Accept() 103 | if err != nil { 104 | once.Do(func() { t.log.Printf("(%v) accept error: %v", t, err) }) 105 | return 106 | } 107 | wg.Add(1) 108 | go t.dialTunnel(bindCtx, wg, cl, cn1) 109 | } 110 | }() 111 | 112 | select { 113 | case <-ctx.Done(): 114 | return 115 | case <-time.After(t.retryInterval): 116 | t.log.Printf("(%v) retrying...", t) 117 | } 118 | } 119 | } 120 | 121 | func (t tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) { 122 | defer wg.Done() 123 | 124 | // The inbound connection is established. Make sure we close it eventually. 125 | connCtx, cancel := context.WithCancel(ctx) 126 | defer cancel() 127 | go func() { 128 | <-connCtx.Done() 129 | cn1.Close() 130 | }() 131 | 132 | // Establish the outbound connection. 133 | var cn2 net.Conn 134 | var err error 135 | switch t.mode { 136 | case '>': 137 | cn2, err = client.Dial("tcp", t.dialAddr) 138 | case '<': 139 | cn2, err = net.Dial("tcp", t.dialAddr) 140 | } 141 | if err != nil { 142 | t.log.Printf("(%v) dial error: %v", t, err) 143 | return 144 | } 145 | 146 | go func() { 147 | <-connCtx.Done() 148 | cn2.Close() 149 | }() 150 | 151 | t.log.Printf("(%v) connection established", t) 152 | defer t.log.Printf("(%v) connection closed", t) 153 | 154 | // Copy bytes from one connection to the other until one side closes. 155 | var once sync.Once 156 | var wg2 sync.WaitGroup 157 | wg2.Add(2) 158 | go func() { 159 | defer wg2.Done() 160 | defer cancel() 161 | if _, err := io.Copy(cn1, cn2); err != nil { 162 | once.Do(func() { t.log.Printf("(%v) connection error: %v", t, err) }) 163 | } 164 | once.Do(func() {}) // Suppress future errors 165 | }() 166 | go func() { 167 | defer wg2.Done() 168 | defer cancel() 169 | if _, err := io.Copy(cn2, cn1); err != nil { 170 | once.Do(func() { t.log.Printf("(%v) connection error: %v", t, err) }) 171 | } 172 | once.Do(func() {}) // Suppress future errors 173 | }() 174 | wg2.Wait() 175 | } 176 | 177 | // keepAliveMonitor periodically sends messages to invoke a response. 178 | // If the server does not respond after some period of time, 179 | // assume that the underlying net.Conn abruptly died. 180 | func (t tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) { 181 | defer wg.Done() 182 | if t.keepAlive.Interval == 0 || t.keepAlive.CountMax == 0 { 183 | return 184 | } 185 | 186 | // Detect when the SSH connection is closed. 187 | wait := make(chan error, 1) 188 | wg.Add(1) 189 | go func() { 190 | defer wg.Done() 191 | wait <- client.Wait() 192 | }() 193 | 194 | // Repeatedly check if the remote server is still alive. 195 | var aliveCount int32 196 | ticker := time.NewTicker(time.Duration(t.keepAlive.Interval) * time.Second) 197 | defer ticker.Stop() 198 | for { 199 | select { 200 | case err := <-wait: 201 | if err != nil && err != io.EOF { 202 | once.Do(func() { t.log.Printf("(%v) SSH error: %v", t, err) }) 203 | } 204 | return 205 | case <-ticker.C: 206 | if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.keepAlive.CountMax) { 207 | once.Do(func() { t.log.Printf("(%v) SSH keep-alive termination", t) }) 208 | client.Close() 209 | return 210 | } 211 | } 212 | 213 | wg.Add(1) 214 | go func() { 215 | defer wg.Done() 216 | _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) 217 | if err == nil { 218 | atomic.StoreInt32(&aliveCount, 0) 219 | } 220 | }() 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /tunnel_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2017, The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE.md file. 4 | 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "crypto/md5" 11 | "crypto/rsa" 12 | "encoding/binary" 13 | "io" 14 | "math/rand" 15 | "net" 16 | "sync" 17 | "testing" 18 | "time" 19 | 20 | "golang.org/x/crypto/ssh" 21 | ) 22 | 23 | type testLogger struct { 24 | *testing.T // Already has Fatalf method 25 | } 26 | 27 | func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) } 28 | 29 | func TestTunnel(t *testing.T) { 30 | rootWG := new(sync.WaitGroup) 31 | defer rootWG.Wait() 32 | rootCtx, cancelAll := context.WithCancel(context.Background()) 33 | defer cancelAll() 34 | 35 | // Open all of the TCP sockets needed for the test. 36 | tcpLn0 := openListener(t) // Start of the chain 37 | tcpLn1 := openListener(t) // Mid-point of the chain 38 | tcpLn2 := openListener(t) // End of the chain 39 | srvLn0 := openListener(t) // Socket for SSH server in reverse mode 40 | srvLn1 := openListener(t) // Socket for SSH server in forward mode 41 | 42 | tcpLn0.Close() // To be later binded by the reverse tunnel 43 | tcpLn1.Close() // To be later binded by the forward tunnel 44 | go closeWhenDone(rootCtx, tcpLn2) 45 | go closeWhenDone(rootCtx, srvLn0) 46 | go closeWhenDone(rootCtx, srvLn1) 47 | 48 | // Generate keys for both the servers and clients. 49 | clientPriv0, clientPub0 := generateKeys(t) 50 | clientPriv1, clientPub1 := generateKeys(t) 51 | serverPriv0, serverPub0 := generateKeys(t) 52 | serverPriv1, serverPub1 := generateKeys(t) 53 | 54 | // Start the SSH servers. 55 | rootWG.Add(2) 56 | go func() { 57 | defer rootWG.Done() 58 | runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1) 59 | }() 60 | go func() { 61 | defer rootWG.Done() 62 | runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1) 63 | }() 64 | 65 | wg := new(sync.WaitGroup) 66 | defer wg.Wait() 67 | ctx, cancel := context.WithCancel(context.Background()) 68 | defer cancel() 69 | 70 | // Create the tunnel configurations. 71 | tn0 := tunnel{ 72 | auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)}, 73 | hostKeys: ssh.FixedHostKey(serverPub0), 74 | mode: '<', // Reverse tunnel 75 | user: "user0", 76 | hostAddr: srvLn0.Addr().String(), 77 | bindAddr: tcpLn0.Addr().String(), 78 | dialAddr: tcpLn1.Addr().String(), 79 | log: testLogger{t}, 80 | } 81 | tn1 := tunnel{ 82 | auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)}, 83 | hostKeys: ssh.FixedHostKey(serverPub1), 84 | mode: '>', // Forward tunnel 85 | user: "user1", 86 | hostAddr: srvLn1.Addr().String(), 87 | bindAddr: tcpLn1.Addr().String(), 88 | dialAddr: tcpLn2.Addr().String(), 89 | log: testLogger{t}, 90 | } 91 | 92 | // Start the SSH client tunnels. 93 | wg.Add(2) 94 | go tn0.bindTunnel(ctx, wg) 95 | go tn1.bindTunnel(ctx, wg) 96 | 97 | t.Log("test started") 98 | done := make(chan bool, 10) 99 | 100 | // Start all the transmitters. 101 | for i := 0; i < cap(done); i++ { 102 | i := i 103 | go func() { 104 | for { 105 | rnd := rand.New(rand.NewSource(int64(i))) 106 | hash := md5.New() 107 | size := uint32((1 << 10) + rnd.Intn(1<<20)) 108 | buf4 := make([]byte, 4) 109 | binary.LittleEndian.PutUint32(buf4, size) 110 | 111 | cnStart, err := net.Dial("tcp", tcpLn0.Addr().String()) 112 | if err != nil { 113 | time.Sleep(10 * time.Millisecond) 114 | continue 115 | } 116 | defer cnStart.Close() 117 | if _, err := cnStart.Write(buf4); err != nil { 118 | t.Errorf("write size error: %v", err) 119 | break 120 | } 121 | r := io.LimitReader(rnd, int64(size)) 122 | w := io.MultiWriter(cnStart, hash) 123 | if _, err := io.Copy(w, r); err != nil { 124 | t.Errorf("copy error: %v", err) 125 | break 126 | } 127 | if _, err := cnStart.Write(hash.Sum(nil)); err != nil { 128 | t.Errorf("write hash error: %v", err) 129 | break 130 | } 131 | if err := cnStart.Close(); err != nil { 132 | t.Errorf("close error: %v", err) 133 | break 134 | } 135 | break 136 | } 137 | }() 138 | } 139 | 140 | // Start all the receivers. 141 | for i := 0; i < cap(done); i++ { 142 | go func() { 143 | for { 144 | hash := md5.New() 145 | buf4 := make([]byte, 4) 146 | 147 | cnEnd, err := tcpLn2.Accept() 148 | if err != nil { 149 | time.Sleep(10 * time.Millisecond) 150 | continue 151 | } 152 | defer cnEnd.Close() 153 | 154 | if _, err := io.ReadFull(cnEnd, buf4); err != nil { 155 | t.Errorf("read size error: %v", err) 156 | break 157 | } 158 | size := binary.LittleEndian.Uint32(buf4) 159 | r := io.LimitReader(cnEnd, int64(size)) 160 | if _, err := io.Copy(hash, r); err != nil { 161 | t.Errorf("copy error: %v", err) 162 | break 163 | } 164 | wantHash, err := io.ReadAll(cnEnd) 165 | if err != nil { 166 | t.Errorf("read hash error: %v", err) 167 | break 168 | } 169 | if err := cnEnd.Close(); err != nil { 170 | t.Errorf("close error: %v", err) 171 | break 172 | } 173 | 174 | if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) { 175 | t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash) 176 | } 177 | break 178 | } 179 | done <- true 180 | }() 181 | } 182 | 183 | for i := 0; i < cap(done); i++ { 184 | select { 185 | case <-done: 186 | case <-time.After(10 * time.Second): 187 | t.Errorf("timed out: %d remaining", cap(done)-i) 188 | return 189 | } 190 | } 191 | t.Log("test complete") 192 | } 193 | 194 | // generateKeys generates a random pair of SSH private and public keys. 195 | func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) { 196 | rnd := rand.New(rand.NewSource(time.Now().Unix())) 197 | rsaKey, err := rsa.GenerateKey(rnd, 1024) 198 | if err != nil { 199 | t.Fatalf("unable to generate RSA key pair: %v", err) 200 | } 201 | priv, err = ssh.NewSignerFromKey(rsaKey) 202 | if err != nil { 203 | t.Fatalf("unable to generate signer: %v", err) 204 | } 205 | pub, err = ssh.NewPublicKey(&rsaKey.PublicKey) 206 | if err != nil { 207 | t.Fatalf("unable to generate public key: %v", err) 208 | } 209 | return priv, pub 210 | } 211 | 212 | func openListener(t *testing.T) net.Listener { 213 | ln, err := net.Listen("tcp", ":0") 214 | if err != nil { 215 | t.Fatalf("listen error: %v", err) 216 | } 217 | return ln 218 | } 219 | --------------------------------------------------------------------------------