├── .github └── workflows │ └── govulncheck.yml ├── LICENSE ├── README.md ├── auth.go ├── endpoint.go ├── example └── example.go ├── forward.go ├── go.mod ├── go.sum ├── sshtun.go └── sshtun_test.go /.github/workflows/govulncheck.yml: -------------------------------------------------------------------------------- 1 | name: Govulncheck 2 | 3 | on: 4 | push: 5 | pull_request: 6 | schedule: 7 | - cron: '00 2 * * *' 8 | 9 | jobs: 10 | govulncheck: 11 | name: Run govulncheck 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - id: govulncheck 16 | uses: golang/govulncheck-action@v1 17 | with: 18 | go-version-file: go.mod 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Roger Zaragoza 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sshtun 2 | 3 | [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://godoc.org/github.com/rgzr/sshtun) 4 | 5 | sshtun is a Go package that provides a SSH tunnel with port forwarding supporting: 6 | 7 | * TCP and unix socket connections 8 | * Password authentication 9 | * Un/encrypted key file authentication 10 | * `ssh-agent` based authentication 11 | * Both local and remote port forwarding 12 | 13 | By default it reads the default linux ssh private key locations and fallbacks to using `ssh-agent`, but a specific authentication method can be set. 14 | 15 | The default locations are `~/.ssh/id_rsa`, `~/.ssh/id_dsa`, `~/.ssh/id_ecdsa`, `~/.ssh/id_ecdsa_sk`, `~/.ssh/id_ed25519` and `~/.ssh/id_ed25519_sk`. 16 | 17 | 18 | ## Installation 19 | 20 | `go get github.com/rgzr/sshtun` 21 | 22 | ## Example 23 | 24 | ```go 25 | package main 26 | 27 | import ( 28 | "context" 29 | "log" 30 | "time" 31 | 32 | "github.com/rgzr/sshtun" 33 | ) 34 | 35 | func main() { 36 | // We want to connect to port 8080 on our machine to acces port 80 on my.super.host.com 37 | sshTun := sshtun.New(8080, "my.super.host.com", 80) 38 | 39 | // We print each tunneled state to see the connections status 40 | sshTun.SetTunneledConnState(func(tun *sshtun.SSHTun, state *sshtun.TunneledConnState) { 41 | log.Printf("%+v", state) 42 | }) 43 | 44 | // We set a callback to know when the tunnel is ready 45 | sshTun.SetConnState(func(tun *sshtun.SSHTun, state sshtun.ConnState) { 46 | switch state { 47 | case sshtun.StateStarting: 48 | log.Printf("STATE is Starting") 49 | case sshtun.StateStarted: 50 | log.Printf("STATE is Started") 51 | case sshtun.StateStopped: 52 | log.Printf("STATE is Stopped") 53 | } 54 | }) 55 | 56 | // We start the tunnel (and restart it every time it is stopped) 57 | go func() { 58 | for { 59 | if err := sshTun.Start(context.Background()); err != nil { 60 | log.Printf("SSH tunnel error: %v", err) 61 | time.Sleep(time.Second) // don't flood if there's a start error :) 62 | } 63 | } 64 | }() 65 | 66 | // We stop the tunnel every 20 seconds (just to see what happens) 67 | for { 68 | time.Sleep(time.Second * time.Duration(20)) 69 | log.Println("Lets stop the SSH tunnel...") 70 | sshTun.Stop() 71 | } 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /auth.go: -------------------------------------------------------------------------------- 1 | package sshtun 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "os" 8 | "os/user" 9 | 10 | "golang.org/x/crypto/ssh" 11 | "golang.org/x/crypto/ssh/agent" 12 | ) 13 | 14 | var defaultSSHKeys = []string{"id_rsa", "id_dsa", "id_ecdsa", "id_ecdsa_sk", "id_ed25519", "id_ed25519_sk"} 15 | 16 | // AuthType is the type of authentication to use for SSH. 17 | type AuthType int 18 | 19 | const ( 20 | // AuthTypeKeyFile uses the keys from a SSH key file read from the system. 21 | AuthTypeKeyFile AuthType = iota 22 | // AuthTypeEncryptedKeyFile uses the keys from an encrypted SSH key file read from the system. 23 | AuthTypeEncryptedKeyFile 24 | // AuthTypeKeyReader uses the keys from a SSH key reader. 25 | AuthTypeKeyReader 26 | // AuthTypeEncryptedKeyReader uses the keys from an encrypted SSH key reader. 27 | AuthTypeEncryptedKeyReader 28 | // AuthTypePassword uses a password directly. 29 | AuthTypePassword 30 | // AuthTypeSSHAgent will use registered users in the ssh-agent. 31 | AuthTypeSSHAgent 32 | // AuthTypeAuto tries to get the authentication method automatically. See SSHTun.Start for details on 33 | // this. 34 | AuthTypeAuto 35 | ) 36 | 37 | func (tun *SSHTun) getSSHAuthMethod() (ssh.AuthMethod, error) { 38 | switch tun.authType { 39 | case AuthTypeKeyFile: 40 | return tun.getSSHAuthMethodForKeyFile(false) 41 | case AuthTypeEncryptedKeyFile: 42 | return tun.getSSHAuthMethodForKeyFile(true) 43 | case AuthTypeKeyReader: 44 | return tun.getSSHAuthMethodForKeyReader(false) 45 | case AuthTypeEncryptedKeyReader: 46 | return tun.getSSHAuthMethodForKeyReader(true) 47 | case AuthTypePassword: 48 | return ssh.Password(tun.authPassword), nil 49 | case AuthTypeSSHAgent: 50 | return tun.getSSHAuthMethodForSSHAgent() 51 | case AuthTypeAuto: 52 | method, errFile := tun.getSSHAuthMethodForKeyFile(false) 53 | if errFile == nil { 54 | return method, nil 55 | } 56 | method, errAgent := tun.getSSHAuthMethodForSSHAgent() 57 | if errAgent == nil { 58 | return method, nil 59 | } 60 | return nil, fmt.Errorf("auto auth failed (file based: %v) (ssh-agent: %v)", errFile, errAgent) 61 | default: 62 | return nil, fmt.Errorf("unknown auth type: %d", tun.authType) 63 | } 64 | } 65 | 66 | func (tun *SSHTun) getSSHAuthMethodForKeyFile(encrypted bool) (ssh.AuthMethod, error) { 67 | if tun.authKeyFile != "" { 68 | return tun.readPrivateKey(tun.authKeyFile, encrypted) 69 | } 70 | 71 | homeDir := "/root" 72 | usr, _ := user.Current() 73 | if usr != nil { 74 | homeDir = usr.HomeDir 75 | } 76 | 77 | for _, keyName := range defaultSSHKeys { 78 | keyFile := fmt.Sprintf("%s/.ssh/%s", homeDir, keyName) 79 | authMethod, err := tun.readPrivateKey(keyFile, encrypted) 80 | if err == nil { 81 | return authMethod, nil 82 | } 83 | } 84 | 85 | return nil, fmt.Errorf("could not read any default SSH key (%v)", defaultSSHKeys) 86 | } 87 | 88 | func (tun *SSHTun) readPrivateKey(keyFile string, encrypted bool) (ssh.AuthMethod, error) { 89 | buf, err := os.ReadFile(keyFile) 90 | if err != nil { 91 | return nil, fmt.Errorf("reading SSH key file %s: %w", keyFile, err) 92 | } 93 | 94 | key, err := tun.parsePrivateKey(buf, encrypted) 95 | if err != nil { 96 | return nil, fmt.Errorf("parsing SSH key file %s: %w", keyFile, err) 97 | } 98 | 99 | return key, nil 100 | } 101 | 102 | func (tun *SSHTun) getSSHAuthMethodForKeyReader(encrypted bool) (ssh.AuthMethod, error) { 103 | buf, err := io.ReadAll(tun.authKeyReader) 104 | if err != nil { 105 | return nil, fmt.Errorf("reading from SSH key reader: %w", err) 106 | } 107 | key, err := tun.parsePrivateKey(buf, encrypted) 108 | if err != nil { 109 | return nil, fmt.Errorf("reading from SSH key reader: %w", err) 110 | } 111 | return key, nil 112 | } 113 | 114 | func (tun *SSHTun) parsePrivateKey(buf []byte, encrypted bool) (ssh.AuthMethod, error) { 115 | var key ssh.Signer 116 | var err error 117 | if encrypted { 118 | key, err = ssh.ParsePrivateKeyWithPassphrase(buf, []byte(tun.authPassword)) 119 | if err != nil { 120 | return nil, fmt.Errorf("parsing encrypted key: %w", err) 121 | } 122 | } else { 123 | key, err = ssh.ParsePrivateKey(buf) 124 | if err != nil { 125 | return nil, fmt.Errorf("error parsing key: %w", err) 126 | } 127 | } 128 | return ssh.PublicKeys(key), nil 129 | } 130 | 131 | func (tun *SSHTun) getSSHAuthMethodForSSHAgent() (ssh.AuthMethod, error) { 132 | conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) 133 | if err != nil { 134 | return nil, fmt.Errorf("opening unix socket: %w", err) 135 | } 136 | 137 | agentClient := agent.NewClient(conn) 138 | 139 | signers, err := agentClient.Signers() 140 | if err != nil { 141 | return nil, fmt.Errorf("getting ssh-agent signers: %w", err) 142 | } 143 | 144 | if len(signers) == 0 { 145 | return nil, fmt.Errorf("no signers from ssh-agent (use 'ssh-add' to add keys to agent)") 146 | } 147 | 148 | return ssh.PublicKeys(signers...), nil 149 | } 150 | -------------------------------------------------------------------------------- /endpoint.go: -------------------------------------------------------------------------------- 1 | package sshtun 2 | 3 | import "fmt" 4 | 5 | const ( 6 | endpointTypeUnixSocket = "unix" 7 | endpointTypeTCP = "tcp" 8 | ) 9 | 10 | type Endpoint struct { 11 | host string 12 | port int 13 | unixSocket string 14 | } 15 | 16 | func (e *Endpoint) String() string { 17 | if e.unixSocket != "" { 18 | return e.unixSocket 19 | } 20 | return fmt.Sprintf("%s:%d", e.host, e.port) 21 | } 22 | 23 | func (e *Endpoint) Type() string { 24 | if e.unixSocket != "" { 25 | return endpointTypeUnixSocket 26 | } 27 | return endpointTypeTCP 28 | } 29 | 30 | func NewTCPEndpoint(host string, port int) *Endpoint { 31 | return &Endpoint{ 32 | host: host, 33 | port: port, 34 | } 35 | } 36 | 37 | func NewUnixEndpoint(socket string) *Endpoint { 38 | return &Endpoint{ 39 | unixSocket: socket, 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /example/example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "time" 7 | 8 | "github.com/rgzr/sshtun" 9 | ) 10 | 11 | func main() { 12 | // We want to connect to port 8080 on our machine to acces port 80 on my.super.host.com 13 | sshTun := sshtun.New(8080, "my.super.host.com", 80) 14 | 15 | // We print each tunneled state to see the connections status 16 | sshTun.SetTunneledConnState(func(tun *sshtun.SSHTun, state *sshtun.TunneledConnState) { 17 | log.Printf("%+v", state) 18 | }) 19 | 20 | // We set a callback to know when the tunnel is ready 21 | sshTun.SetConnState(func(tun *sshtun.SSHTun, state sshtun.ConnState) { 22 | switch state { 23 | case sshtun.StateStarting: 24 | log.Printf("STATE is Starting") 25 | case sshtun.StateStarted: 26 | log.Printf("STATE is Started") 27 | case sshtun.StateStopped: 28 | log.Printf("STATE is Stopped") 29 | } 30 | }) 31 | 32 | // We start the tunnel (and restart it every time it is stopped) 33 | go func() { 34 | for { 35 | if err := sshTun.Start(context.Background()); err != nil { 36 | log.Printf("SSH tunnel error: %v", err) 37 | time.Sleep(time.Second) // don't flood if there's a start error :) 38 | } 39 | } 40 | }() 41 | 42 | // We stop the tunnel every 20 seconds (just to see what happens) 43 | for { 44 | time.Sleep(time.Second * time.Duration(20)) 45 | log.Println("Lets stop the SSH tunnel...") 46 | sshTun.Stop() 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /forward.go: -------------------------------------------------------------------------------- 1 | package sshtun 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net" 8 | 9 | "golang.org/x/sync/errgroup" 10 | ) 11 | 12 | // TunneledConnState represents the state of the final connections made through the tunnel. 13 | type TunneledConnState struct { 14 | // From is the address initating the connection. 15 | From string 16 | // Info holds a message with info on the state of the connection (useful for debug purposes). 17 | Info string 18 | // Error holds an error on the connection or nil if the connection is successful. 19 | Error error 20 | // Ready indicates if the connection is established. 21 | Ready bool 22 | // Closed indicates if the connection is closed. 23 | Closed bool 24 | } 25 | 26 | func (s *TunneledConnState) String() string { 27 | out := fmt.Sprintf("[%s] ", s.From) 28 | if s.Info != "" { 29 | out += s.Info 30 | } 31 | if s.Error != nil { 32 | out += fmt.Sprintf("Error: %v", s.Error) 33 | } 34 | return out 35 | } 36 | 37 | func (tun *SSHTun) forward(fromConn net.Conn) { 38 | from := fromConn.RemoteAddr().String() 39 | 40 | tun.tunneledState(&TunneledConnState{ 41 | From: from, 42 | Info: fmt.Sprintf("accepted %s connection", tun.fromEndpoint().Type()), 43 | }) 44 | 45 | var toConn net.Conn 46 | var err error 47 | 48 | dialFunc := tun.sshClient.Dial 49 | if tun.forwardType == Remote { 50 | dialFunc = net.Dial 51 | } 52 | 53 | toConn, err = dialFunc(tun.toEndpoint().Type(), tun.toEndpoint().String()) 54 | if err != nil { 55 | tun.tunneledState(&TunneledConnState{ 56 | From: from, 57 | Error: fmt.Errorf("%s dial %s to %s failed: %w", tun.forwardToName(), 58 | tun.toEndpoint().Type(), tun.toEndpoint().String(), err), 59 | }) 60 | 61 | fromConn.Close() 62 | return 63 | } 64 | 65 | connStr := fmt.Sprintf("%s -(%s)> %s <(ssh)> %s -(%s)> %s", from, tun.fromEndpoint().Type(), 66 | tun.fromEndpoint().String(), tun.server.String(), tun.toEndpoint().Type(), tun.toEndpoint().String()) 67 | 68 | tun.tunneledState(&TunneledConnState{ 69 | From: from, 70 | Info: fmt.Sprintf("connection established: %s", connStr), 71 | Ready: true, 72 | Closed: false, 73 | }) 74 | 75 | connCtx, connCancel := context.WithCancel(tun.ctx) 76 | errGroup := &errgroup.Group{} 77 | 78 | errGroup.Go(func() error { 79 | defer connCancel() 80 | _, err = io.Copy(toConn, fromConn) 81 | if err != nil { 82 | return fmt.Errorf("failed copying bytes from %s to %s: %w", tun.forwardToName(), tun.forwardFromName(), err) 83 | } 84 | return nil 85 | }) 86 | 87 | errGroup.Go(func() error { 88 | defer connCancel() 89 | _, err = io.Copy(fromConn, toConn) 90 | if err != nil { 91 | return fmt.Errorf("failed copying bytes from %s to %s: %w", tun.forwardFromName(), tun.forwardToName(), err) 92 | } 93 | return nil 94 | }) 95 | 96 | <-connCtx.Done() 97 | 98 | fromConn.Close() 99 | toConn.Close() 100 | 101 | err = errGroup.Wait() 102 | 103 | select { 104 | case <-tun.ctx.Done(): 105 | default: 106 | if err != nil { 107 | tun.tunneledState(&TunneledConnState{ 108 | From: from, 109 | Error: err, 110 | Closed: true, 111 | }) 112 | } 113 | } 114 | 115 | tun.tunneledState(&TunneledConnState{ 116 | From: from, 117 | Info: fmt.Sprintf("connection closed: %s", connStr), 118 | Closed: true, 119 | }) 120 | } 121 | 122 | func (tun *SSHTun) tunneledState(state *TunneledConnState) { 123 | if tun.tunneledConnState != nil { 124 | tun.tunneledConnState(tun, state) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/rgzr/sshtun 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/avast/retry-go v3.0.0+incompatible 7 | github.com/gliderlabs/ssh v0.3.7 8 | github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 9 | github.com/stretchr/testify v1.9.0 10 | golang.org/x/crypto v0.35.0 11 | golang.org/x/sync v0.11.0 12 | ) 13 | 14 | require ( 15 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | golang.org/x/sys v0.30.0 // indirect 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= 2 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= 3 | github.com/avast/retry-go v3.0.0+incompatible h1:4SOWQ7Qs+oroOTQOYnAHqelpCO0biHSxpiH9JdtuBj0= 4 | github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevBhOOCWBLXXy3hyiqqBrY= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/gliderlabs/ssh v0.3.7 h1:iV3Bqi942d9huXnzEF2Mt+CY9gLu8DNM4Obd+8bODRE= 8 | github.com/gliderlabs/ssh v0.3.7/go.mod h1:zpHEXBstFnQYtGnB8k8kQLol82umzn/2/snG7alWVD8= 9 | github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= 10 | github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= 11 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 12 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 13 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 14 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 15 | golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= 16 | golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= 17 | golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 18 | golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 19 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 20 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 21 | golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= 22 | golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= 23 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 24 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 25 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 26 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 27 | -------------------------------------------------------------------------------- /sshtun.go: -------------------------------------------------------------------------------- 1 | // Package sshtun provides a SSH tunnel with port forwarding. 2 | package sshtun 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "io" 8 | "net" 9 | "sync" 10 | "time" 11 | 12 | "golang.org/x/crypto/ssh" 13 | "golang.org/x/sync/errgroup" 14 | ) 15 | 16 | // SSHTun represents a SSH tunnel 17 | type SSHTun struct { 18 | mutex *sync.Mutex 19 | ctx context.Context 20 | cancel context.CancelFunc 21 | started bool 22 | user string 23 | authType AuthType 24 | authKeyFile string 25 | authKeyReader io.Reader 26 | authPassword string 27 | server *Endpoint 28 | local *Endpoint 29 | remote *Endpoint 30 | forwardType ForwardType 31 | timeout time.Duration 32 | connState func(*SSHTun, ConnState) 33 | tunneledConnState func(*SSHTun, *TunneledConnState) 34 | active int 35 | sshClient *ssh.Client 36 | sshConfig *ssh.ClientConfig 37 | sshConfigKeyExchanges []string 38 | sshConfigCiphers []string 39 | sshConfigMACs []string 40 | } 41 | 42 | // ForwardType is the type of port forwarding. 43 | // Local: forward from localhost. 44 | // Remote: forward from remote - reverse port forward. 45 | type ForwardType int 46 | 47 | const ( 48 | Local ForwardType = iota 49 | Remote 50 | ) 51 | 52 | // ConnState represents the state of the SSH tunnel. It's returned to an optional function provided to SetConnState. 53 | type ConnState int 54 | 55 | const ( 56 | // StateStopped represents a stopped tunnel. A call to Start will make the state to transition to StateStarting. 57 | StateStopped ConnState = iota 58 | 59 | // StateStarting represents a tunnel initializing and preparing to listen for connections. 60 | // A successful initialization will make the state to transition to StateStarted, otherwise it will transition to StateStopped. 61 | StateStarting 62 | 63 | // StateStarted represents a tunnel ready to accept connections. 64 | // A call to stop or an error will make the state to transition to StateStopped. 65 | StateStarted 66 | ) 67 | 68 | // New creates a new SSH tunnel to the specified server redirecting a port on local localhost to a port on remote localhost. 69 | // By default the SSH connection is made to port 22 as root and using automatic detection of the authentication 70 | // method (see Start for details on this). 71 | // Calling SetPassword will change the authentication to password based. 72 | // Calling SetKeyFile will change the authentication to keyfile based.. 73 | // The SSH user and port can be changed with SetUser and SetPort. 74 | // The local and remote hosts can be changed to something different than localhost with SetLocalEndpoint and SetRemoteEndpoint. 75 | // The forward type can be changed with SetForwardType. 76 | // The states of the tunnel can be received throgh a callback function with SetConnState. 77 | // The states of the tunneled connections can be received through a callback function with SetTunneledConnState. 78 | func New(localPort int, server string, remotePort int) *SSHTun { 79 | sshTun := defaultSSHTun(server) 80 | sshTun.local = NewTCPEndpoint("localhost", localPort) 81 | sshTun.remote = NewTCPEndpoint("localhost", remotePort) 82 | return sshTun 83 | } 84 | 85 | // NewRemote does the same as New but for a remote port forward. 86 | func NewRemote(localPort int, server string, remotePort int) *SSHTun { 87 | sshTun := New(localPort, server, remotePort) 88 | sshTun.forwardType = Remote 89 | return sshTun 90 | } 91 | 92 | // NewUnix does the same as New but using unix sockets. 93 | func NewUnix(localUnixSocket string, server string, remoteUnixSocket string) *SSHTun { 94 | sshTun := defaultSSHTun(server) 95 | sshTun.local = NewUnixEndpoint(localUnixSocket) 96 | sshTun.remote = NewUnixEndpoint(remoteUnixSocket) 97 | return sshTun 98 | } 99 | 100 | // NewUnixRemote does the same as NewRemote but using unix sockets. 101 | func NewUnixRemote(localUnixSocket string, server string, remoteUnixSocket string) *SSHTun { 102 | sshTun := NewUnix(localUnixSocket, server, remoteUnixSocket) 103 | sshTun.forwardType = Remote 104 | return sshTun 105 | } 106 | 107 | func defaultSSHTun(server string) *SSHTun { 108 | return &SSHTun{ 109 | mutex: &sync.Mutex{}, 110 | server: NewTCPEndpoint(server, 22), 111 | user: "root", 112 | authType: AuthTypeAuto, 113 | timeout: time.Second * 15, 114 | forwardType: Local, 115 | } 116 | } 117 | 118 | // SetPort changes the port where the SSH connection will be made. 119 | func (tun *SSHTun) SetPort(port int) { 120 | tun.server.port = port 121 | } 122 | 123 | // Set KeyExchanges 124 | // supported, forbidden and preferred algos are in https://pkg.go.dev/golang.org/x/crypto/ssh#Config 125 | func (tun *SSHTun) SetKeyExchanges(keyExchanges []string) { 126 | tun.sshConfigKeyExchanges = keyExchanges 127 | } 128 | 129 | // Set ssh Ciphers 130 | // preferred and supported ciphers are in https://pkg.go.dev/golang.org/x/crypto/ssh#Config 131 | func (tun *SSHTun) SetCiphers(ciphers []string) { 132 | tun.sshConfigCiphers = ciphers 133 | } 134 | 135 | // Set MACs 136 | // supported MACs are in https://pkg.go.dev/golang.org/x/crypto/ssh#Config 137 | func (tun *SSHTun) SetMACs(MACs []string) { 138 | tun.sshConfigMACs = MACs 139 | } 140 | 141 | // SetUser changes the user used to make the SSH connection. 142 | func (tun *SSHTun) SetUser(user string) { 143 | tun.user = user 144 | } 145 | 146 | // SetKeyFile changes the authentication to key-based and uses the specified file. 147 | // Leaving the file empty defaults to the default linux private key locations: `~/.ssh/id_rsa`, `~/.ssh/id_dsa`, 148 | // `~/.ssh/id_ecdsa`, `~/.ssh/id_ecdsa_sk`, `~/.ssh/id_ed25519` and `~/.ssh/id_ed25519_sk`. 149 | func (tun *SSHTun) SetKeyFile(file string) { 150 | tun.authType = AuthTypeKeyFile 151 | tun.authKeyFile = file 152 | } 153 | 154 | // SetEncryptedKeyFile changes the authentication to encrypted key-based and uses the specified file and password. 155 | // Leaving the file empty defaults to the default linux private key locations: `~/.ssh/id_rsa`, `~/.ssh/id_dsa`, 156 | // `~/.ssh/id_ecdsa`, `~/.ssh/id_ecdsa_sk`, `~/.ssh/id_ed25519` and `~/.ssh/id_ed25519_sk`. 157 | func (tun *SSHTun) SetEncryptedKeyFile(file string, password string) { 158 | tun.authType = AuthTypeEncryptedKeyFile 159 | tun.authKeyFile = file 160 | tun.authPassword = password 161 | } 162 | 163 | // SetKeyReader changes the authentication to key-based and uses the specified reader. 164 | func (tun *SSHTun) SetKeyReader(reader io.Reader) { 165 | tun.authType = AuthTypeKeyReader 166 | tun.authKeyReader = reader 167 | } 168 | 169 | // SetEncryptedKeyReader changes the authentication to encrypted key-based and uses the specified reader and password. 170 | func (tun *SSHTun) SetEncryptedKeyReader(reader io.Reader, password string) { 171 | tun.authType = AuthTypeEncryptedKeyReader 172 | tun.authKeyReader = reader 173 | tun.authPassword = password 174 | } 175 | 176 | // SetForwardType changes the forward type. 177 | func (tun *SSHTun) SetForwardType(forwardType ForwardType) { 178 | tun.forwardType = forwardType 179 | } 180 | 181 | // SetSSHAgent changes the authentication to ssh-agent. 182 | func (tun *SSHTun) SetSSHAgent() { 183 | tun.authType = AuthTypeSSHAgent 184 | } 185 | 186 | // SetPassword changes the authentication to password-based and uses the specified password. 187 | func (tun *SSHTun) SetPassword(password string) { 188 | tun.authType = AuthTypePassword 189 | tun.authPassword = password 190 | } 191 | 192 | // SetLocalHost sets the local host to redirect (defaults to localhost). 193 | func (tun *SSHTun) SetLocalHost(host string) { 194 | tun.local.host = host 195 | } 196 | 197 | // SetRemoteHost sets the remote host to redirect (defaults to localhost). 198 | func (tun *SSHTun) SetRemoteHost(host string) { 199 | tun.remote.host = host 200 | } 201 | 202 | // SetLocalEndpoint sets the local endpoint to redirect. 203 | func (tun *SSHTun) SetLocalEndpoint(endpoint *Endpoint) { 204 | tun.local = endpoint 205 | } 206 | 207 | // SetRemoteEndpoint sets the remote endpoint to redirect. 208 | func (tun *SSHTun) SetRemoteEndpoint(endpoint *Endpoint) { 209 | tun.remote = endpoint 210 | } 211 | 212 | // SetTimeout sets the connection timeouts (defaults to 15 seconds). 213 | func (tun *SSHTun) SetTimeout(timeout time.Duration) { 214 | tun.timeout = timeout 215 | } 216 | 217 | // SetConnState specifies an optional callback function that is called when a SSH tunnel changes state. 218 | // See the ConnState type and associated constants for details. 219 | func (tun *SSHTun) SetConnState(connStateFun func(*SSHTun, ConnState)) { 220 | tun.connState = connStateFun 221 | } 222 | 223 | // SetTunneledConnState specifies an optional callback function that is called when the underlying tunneled 224 | // connections change state. 225 | func (tun *SSHTun) SetTunneledConnState(tunneledConnStateFun func(*SSHTun, *TunneledConnState)) { 226 | tun.tunneledConnState = tunneledConnStateFun 227 | } 228 | 229 | // Start starts the SSH tunnel. It can be stopped by calling `Stop` or cancelling its context. 230 | // This call will block until the tunnel is stopped either calling those methods or by an error. 231 | // Note on SSH authentication: in case the tunnel's authType is set to AuthTypeAuto the following will happen: 232 | // The default key files will be used, if that doesn't succeed it will try to use the SSH agent. 233 | // If that fails the whole authentication fails. 234 | // That means if you want to use password or encrypted key file authentication, you have to specify that explicitly. 235 | func (tun *SSHTun) Start(ctx context.Context) error { 236 | tun.mutex.Lock() 237 | if tun.started { 238 | tun.mutex.Unlock() 239 | return fmt.Errorf("already started") 240 | } 241 | tun.started = true 242 | tun.ctx, tun.cancel = context.WithCancel(ctx) 243 | tun.mutex.Unlock() 244 | 245 | if tun.connState != nil { 246 | tun.connState(tun, StateStarting) 247 | } 248 | 249 | config, err := tun.initSSHConfig() 250 | if err != nil { 251 | return tun.stop(fmt.Errorf("ssh config failed: %w", err)) 252 | } 253 | tun.sshConfig = config 254 | 255 | listenConfig := net.ListenConfig{} 256 | var listener net.Listener 257 | 258 | if tun.forwardType == Local { 259 | listener, err = listenConfig.Listen(tun.ctx, tun.local.Type(), tun.local.String()) 260 | if err != nil { 261 | return tun.stop(fmt.Errorf("local listen %s on %s failed: %w", tun.local.Type(), tun.local.String(), err)) 262 | } 263 | } else if tun.forwardType == Remote { 264 | sshClient, err := ssh.Dial(tun.server.Type(), tun.server.String(), tun.sshConfig) 265 | if err != nil { 266 | return tun.stop(fmt.Errorf("ssh dial %s to %s failed: %w", tun.server.Type(), tun.server.String(), err)) 267 | } 268 | listener, err = sshClient.Listen(tun.remote.Type(), tun.remote.String()) 269 | if err != nil { 270 | return tun.stop(fmt.Errorf("remote listen %s on %s failed: %w", tun.remote.Type(), tun.remote.String(), err)) 271 | } 272 | } 273 | 274 | errChan := make(chan error) 275 | go func() { 276 | errChan <- tun.listen(listener) 277 | }() 278 | 279 | if tun.connState != nil { 280 | tun.connState(tun, StateStarted) 281 | } 282 | 283 | return tun.stop(<-errChan) 284 | } 285 | 286 | // Stop closes all connections and makes Start exit gracefuly. 287 | func (tun *SSHTun) Stop() { 288 | tun.mutex.Lock() 289 | defer tun.mutex.Unlock() 290 | 291 | if tun.started { 292 | tun.cancel() 293 | } 294 | } 295 | 296 | func (tun *SSHTun) initSSHConfig() (*ssh.ClientConfig, error) { 297 | config := &ssh.ClientConfig{ 298 | Config: ssh.Config{ 299 | KeyExchanges: tun.sshConfigKeyExchanges, 300 | Ciphers: tun.sshConfigCiphers, 301 | MACs: tun.sshConfigMACs, 302 | }, 303 | User: tun.user, 304 | HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { 305 | return nil 306 | }, 307 | Timeout: tun.timeout, 308 | } 309 | 310 | authMethod, err := tun.getSSHAuthMethod() 311 | if err != nil { 312 | return nil, err 313 | } 314 | 315 | config.Auth = []ssh.AuthMethod{authMethod} 316 | 317 | return config, nil 318 | } 319 | 320 | func (tun *SSHTun) stop(err error) error { 321 | tun.mutex.Lock() 322 | tun.started = false 323 | tun.mutex.Unlock() 324 | if tun.connState != nil { 325 | tun.connState(tun, StateStopped) 326 | } 327 | return err 328 | } 329 | 330 | func (tun *SSHTun) fromEndpoint() *Endpoint { 331 | if tun.forwardType == Remote { 332 | return tun.remote 333 | } 334 | 335 | return tun.local 336 | } 337 | 338 | func (tun *SSHTun) toEndpoint() *Endpoint { 339 | if tun.forwardType == Remote { 340 | return tun.local 341 | } 342 | 343 | return tun.remote 344 | } 345 | 346 | func (tun *SSHTun) forwardFromName() string { 347 | if tun.forwardType == Remote { 348 | return "remote" 349 | } 350 | 351 | return "local" 352 | } 353 | 354 | func (tun *SSHTun) forwardToName() string { 355 | if tun.forwardType == Remote { 356 | return "local" 357 | } 358 | 359 | return "remote" 360 | } 361 | 362 | func (tun *SSHTun) listen(listener net.Listener) error { 363 | 364 | errGroup, groupCtx := errgroup.WithContext(tun.ctx) 365 | errGroup.Go(func() error { 366 | for { 367 | conn, err := listener.Accept() 368 | if err != nil { 369 | return fmt.Errorf("%s accept %s on %s failed: %w", tun.forwardFromName(), 370 | tun.fromEndpoint().Type(), tun.fromEndpoint().String(), err) 371 | } 372 | errGroup.Go(func() error { 373 | return tun.handle(conn) 374 | }) 375 | } 376 | }) 377 | 378 | <-groupCtx.Done() 379 | 380 | listener.Close() 381 | 382 | err := errGroup.Wait() 383 | 384 | select { 385 | case <-tun.ctx.Done(): 386 | default: 387 | return err 388 | } 389 | 390 | return nil 391 | } 392 | 393 | func (tun *SSHTun) handle(conn net.Conn) error { 394 | err := tun.addConn() 395 | if err != nil { 396 | return err 397 | } 398 | 399 | tun.forward(conn) 400 | tun.removeConn() 401 | 402 | return nil 403 | } 404 | 405 | func (tun *SSHTun) addConn() error { 406 | tun.mutex.Lock() 407 | defer tun.mutex.Unlock() 408 | 409 | if tun.forwardType == Local && tun.active == 0 { 410 | sshClient, err := ssh.Dial(tun.server.Type(), tun.server.String(), tun.sshConfig) 411 | if err != nil { 412 | return fmt.Errorf("ssh dial %s to %s failed: %w", tun.server.Type(), tun.server.String(), err) 413 | } 414 | tun.sshClient = sshClient 415 | } 416 | 417 | tun.active += 1 418 | 419 | return nil 420 | } 421 | 422 | func (tun *SSHTun) removeConn() { 423 | tun.mutex.Lock() 424 | defer tun.mutex.Unlock() 425 | 426 | tun.active -= 1 427 | 428 | if tun.forwardType == Local && tun.active == 0 { 429 | tun.sshClient.Close() 430 | tun.sshClient = nil 431 | } 432 | } 433 | -------------------------------------------------------------------------------- /sshtun_test.go: -------------------------------------------------------------------------------- 1 | package sshtun 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | 14 | "github.com/avast/retry-go" 15 | "github.com/gliderlabs/ssh" 16 | "github.com/phayes/freeport" 17 | "github.com/stretchr/testify/require" 18 | "golang.org/x/sync/errgroup" 19 | ) 20 | 21 | type testServers struct { 22 | localPort int 23 | sshPort int 24 | remotePort int 25 | 26 | sshServer *ssh.Server 27 | pingPongServer *net.TCPListener 28 | sshTun *SSHTun 29 | 30 | pingPongConnections atomic.Int32 31 | } 32 | 33 | func newTestServers(localPort, sshPort, remotePort int, forwardType ForwardType) *testServers { 34 | sshServer := &ssh.Server{ 35 | Addr: fmt.Sprintf(":%d", sshPort), 36 | } 37 | 38 | if forwardType == Local { 39 | sshServer.LocalPortForwardingCallback = ssh.LocalPortForwardingCallback( 40 | func(ctx ssh.Context, dhost string, dport uint32) bool { 41 | return true 42 | }) 43 | sshServer.ChannelHandlers = map[string]ssh.ChannelHandler{ 44 | "direct-tcpip": ssh.DirectTCPIPHandler, 45 | } 46 | } else if forwardType == Remote { 47 | sshServer.ReversePortForwardingCallback = ssh.ReversePortForwardingCallback( 48 | func(ctx ssh.Context, bindHost string, bindPort uint32) bool { 49 | return true 50 | }) 51 | forwarder := &ssh.ForwardedTCPHandler{} 52 | sshServer.RequestHandlers = map[string]ssh.RequestHandler{ 53 | "tcpip-forward": forwarder.HandleSSHRequest, 54 | } 55 | } 56 | 57 | sshTun := New(localPort, "localhost", remotePort) 58 | sshTun.SetForwardType(forwardType) 59 | sshTun.SetPort(sshPort) 60 | 61 | return &testServers{ 62 | localPort: localPort, 63 | sshPort: sshPort, 64 | remotePort: remotePort, 65 | sshServer: sshServer, 66 | sshTun: sshTun, 67 | } 68 | } 69 | 70 | func (s *testServers) start(ctx context.Context) error { 71 | errGroup, groupCtx := errgroup.WithContext(ctx) 72 | 73 | errGroup.Go(func() error { 74 | return s.serveSSH(groupCtx) 75 | }) 76 | 77 | errGroup.Go(func() error { 78 | return s.servePingPong(groupCtx) 79 | }) 80 | 81 | errGroup.Go(func() error { 82 | return s.sshTun.Start(groupCtx) 83 | }) 84 | 85 | return errGroup.Wait() 86 | } 87 | 88 | func (s *testServers) serveSSH(ctx context.Context) error { 89 | errCh := make(chan error) 90 | 91 | go func() { 92 | err := s.sshServer.ListenAndServe() 93 | if err == ssh.ErrServerClosed { 94 | err = nil 95 | } 96 | 97 | errCh <- err 98 | }() 99 | 100 | <-ctx.Done() 101 | 102 | s.sshServer.Close() 103 | 104 | return <-errCh 105 | } 106 | 107 | func (s *testServers) servePingPong(ctx context.Context) error { 108 | errCh := make(chan error) 109 | 110 | listener, err := net.ListenTCP("tcp", &net.TCPAddr{ 111 | Port: s.sshTun.toEndpoint().port, 112 | }) 113 | if err != nil { 114 | return err 115 | } 116 | 117 | s.pingPongServer = listener 118 | 119 | go func() { 120 | errCh <- s.handlePingPongConnections(ctx) 121 | }() 122 | 123 | <-ctx.Done() 124 | 125 | s.pingPongServer.Close() 126 | 127 | return <-errCh 128 | } 129 | 130 | func (s *testServers) handlePingPongConnections(ctx context.Context) error { 131 | for i := 0; ; i++ { 132 | conn, err := s.pingPongServer.AcceptTCP() 133 | if err != nil { 134 | if ctx.Err() != nil { 135 | return nil 136 | } 137 | 138 | return err 139 | } 140 | 141 | go func(connID int) { 142 | handleErr := s.handlePingPongConnection(conn) 143 | if handleErr != nil { 144 | log.Printf("conn %d: %v", connID, handleErr) 145 | } 146 | }(i) 147 | } 148 | } 149 | 150 | func (s *testServers) handlePingPongConnection(conn *net.TCPConn) error { 151 | s.pingPongConnections.Add(1) 152 | defer s.pingPongConnections.Add(-1) 153 | 154 | for { 155 | recv := make([]byte, 4) 156 | readBytes, err := io.ReadAtLeast(conn, recv, 4) 157 | if err != nil { 158 | if err == io.EOF { 159 | return nil 160 | } 161 | 162 | return err 163 | } 164 | 165 | if readBytes != 4 { 166 | return errors.New("not read 4 bytes") 167 | } 168 | 169 | if string(recv) != "ping" { 170 | return errors.New("not received ping") 171 | } 172 | 173 | writeBytes, err := io.WriteString(conn, "pong") 174 | if err != nil { 175 | if errors.Is(err, net.ErrClosed) { 176 | return nil 177 | } 178 | 179 | return err 180 | } 181 | 182 | if writeBytes != 4 { 183 | return errors.New("not write 4 bytes") 184 | } 185 | } 186 | } 187 | 188 | type pingPongClient struct { 189 | servers *testServers 190 | conn *net.TCPConn 191 | pings int 192 | } 193 | 194 | func (s *testServers) connectPingPong() (*pingPongClient, error) { 195 | conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{ 196 | Port: s.sshTun.fromEndpoint().port, 197 | }) 198 | 199 | if err != nil { 200 | return nil, err 201 | } 202 | 203 | return &pingPongClient{ 204 | servers: s, 205 | conn: conn, 206 | }, nil 207 | } 208 | 209 | func (c *pingPongClient) ping() error { 210 | sentBytes, err := io.WriteString(c.conn, "ping") 211 | if err != nil { 212 | return err 213 | } 214 | 215 | if sentBytes != 4 { 216 | return errors.New("not sent 4 bytes") 217 | } 218 | 219 | recv := make([]byte, 4) 220 | recvBytes, err := io.ReadAtLeast(c.conn, recv, 4) 221 | if err != nil { 222 | return err 223 | } 224 | 225 | if recvBytes != 4 { 226 | return errors.New("not received 4 bytes") 227 | } 228 | 229 | if string(recv) != "pong" { 230 | return errors.New("not received pong") 231 | } 232 | 233 | c.pings++ 234 | 235 | return nil 236 | } 237 | 238 | func (c *pingPongClient) close() error { 239 | return c.conn.Close() 240 | } 241 | 242 | func runTestServers(t *testing.T, forwardType ForwardType) (*testServers, chan error, context.CancelFunc) { 243 | t.Helper() 244 | 245 | sshPort, err := freeport.GetFreePort() 246 | require.NoError(t, err) 247 | 248 | localPort, err := freeport.GetFreePort() 249 | require.NoError(t, err) 250 | 251 | remotePort, err := freeport.GetFreePort() 252 | require.NoError(t, err) 253 | 254 | testServers := newTestServers(localPort, sshPort, remotePort, forwardType) 255 | 256 | testServers.sshTun.SetConnState(func(tun *SSHTun, connState ConnState) { 257 | switch connState { 258 | case StateStarting: 259 | t.Log("ConnState: starting") 260 | case StateStarted: 261 | t.Log("ConnState: started") 262 | case StateStopped: 263 | t.Log("ConnState: stopped") 264 | default: 265 | t.Log("ConnState: unexpected") 266 | } 267 | }) 268 | 269 | testServers.sshTun.SetTunneledConnState(func(tun *SSHTun, tunneledConnState *TunneledConnState) { 270 | t.Logf("TunneledConnState: %+v\n", tunneledConnState) 271 | }) 272 | 273 | ctx, cancel := context.WithCancel(context.Background()) 274 | 275 | errCh := make(chan error) 276 | go func() { 277 | errServers := testServers.start(ctx) 278 | if errServers != nil { 279 | log.Println(errServers.Error()) 280 | } 281 | 282 | errCh <- errServers 283 | }() 284 | 285 | return testServers, errCh, cancel 286 | } 287 | 288 | func pingPongConnect(t *testing.T, testServers *testServers) *pingPongClient { 289 | t.Helper() 290 | 291 | var client *pingPongClient 292 | var err error 293 | 294 | err = retry.Do(func() error { 295 | client, err = testServers.connectPingPong() 296 | return err 297 | }, retry.Attempts(5), retry.Delay(500*time.Millisecond)) 298 | 299 | require.NoError(t, err) 300 | 301 | return client 302 | } 303 | 304 | func TestOneConnection(t *testing.T) { 305 | testServers, errCh, cancel := runTestServers(t, Local) 306 | 307 | client := pingPongConnect(t, testServers) 308 | 309 | err := client.ping() 310 | require.NoError(t, err) 311 | 312 | require.Equal(t, 1, client.pings) 313 | 314 | err = client.close() 315 | require.NoError(t, err) 316 | 317 | cancel() 318 | err = <-errCh 319 | 320 | require.NoError(t, err) 321 | } 322 | 323 | func TestMultipleConnections(t *testing.T) { 324 | testServers, errCh, cancel := runTestServers(t, Local) 325 | 326 | client1 := pingPongConnect(t, testServers) 327 | 328 | client2 := pingPongConnect(t, testServers) 329 | 330 | err := client1.ping() 331 | require.NoError(t, err) 332 | 333 | err = client2.ping() 334 | require.NoError(t, err) 335 | 336 | err = client1.ping() 337 | require.NoError(t, err) 338 | 339 | require.Equal(t, 2, client1.pings) 340 | require.Equal(t, 1, client2.pings) 341 | 342 | err = client1.close() 343 | require.NoError(t, err) 344 | 345 | err = client2.ping() 346 | require.NoError(t, err) 347 | 348 | require.Equal(t, 2, client2.pings) 349 | 350 | err = client2.close() 351 | require.NoError(t, err) 352 | 353 | cancel() 354 | err = <-errCh 355 | 356 | require.NoError(t, err) 357 | } 358 | 359 | func checkTunConnections(t *testing.T, testServers *testServers, connections int) error { 360 | t.Helper() 361 | 362 | testServers.sshTun.mutex.Lock() 363 | defer testServers.sshTun.mutex.Unlock() 364 | 365 | if connections != testServers.sshTun.active { 366 | return fmt.Errorf("there are %d active connections instead of %d expected", testServers.sshTun.active, connections) 367 | } 368 | 369 | if testServers.sshTun.forwardType == Local { 370 | if connections == 0 && testServers.sshTun.sshClient != nil { 371 | return fmt.Errorf("ssh client should be nil") 372 | } 373 | 374 | if connections != 0 && testServers.sshTun.sshClient == nil { 375 | return fmt.Errorf("ssh client should not be nil") 376 | } 377 | } 378 | 379 | return nil 380 | } 381 | 382 | func TestReconnectTunnel(t *testing.T) { 383 | testServers, errCh, cancel := runTestServers(t, Local) 384 | 385 | require.NoError(t, checkTunConnections(t, testServers, 0)) 386 | 387 | client := pingPongConnect(t, testServers) 388 | 389 | err := client.ping() 390 | require.NoError(t, err) 391 | 392 | require.NoError(t, checkTunConnections(t, testServers, 1)) 393 | 394 | err = client.close() 395 | require.NoError(t, err) 396 | 397 | err = retry.Do(func() error { 398 | return checkTunConnections(t, testServers, 0) 399 | }, retry.Attempts(5), retry.Delay(500*time.Millisecond)) 400 | 401 | require.NoError(t, err) 402 | 403 | client = pingPongConnect(t, testServers) 404 | 405 | err = client.ping() 406 | require.NoError(t, err) 407 | 408 | require.NoError(t, checkTunConnections(t, testServers, 1)) 409 | 410 | err = client.close() 411 | require.NoError(t, err) 412 | 413 | cancel() 414 | err = <-errCh 415 | 416 | require.NoError(t, err) 417 | } 418 | 419 | func TestOneRemoteConnection(t *testing.T) { 420 | testServers, errCh, cancel := runTestServers(t, Remote) 421 | 422 | client := pingPongConnect(t, testServers) 423 | 424 | err := client.ping() 425 | require.NoError(t, err) 426 | 427 | require.Equal(t, 1, client.pings) 428 | 429 | err = client.close() 430 | require.NoError(t, err) 431 | 432 | cancel() 433 | err = <-errCh 434 | 435 | require.NoError(t, err) 436 | } 437 | 438 | func TestMultipleRemoteConnections(t *testing.T) { 439 | testServers, errCh, cancel := runTestServers(t, Remote) 440 | 441 | client1 := pingPongConnect(t, testServers) 442 | 443 | client2 := pingPongConnect(t, testServers) 444 | 445 | err := client1.ping() 446 | require.NoError(t, err) 447 | 448 | err = client2.ping() 449 | require.NoError(t, err) 450 | 451 | err = client1.ping() 452 | require.NoError(t, err) 453 | 454 | require.Equal(t, 2, client1.pings) 455 | require.Equal(t, 1, client2.pings) 456 | 457 | err = client1.close() 458 | require.NoError(t, err) 459 | 460 | err = client2.ping() 461 | require.NoError(t, err) 462 | 463 | require.Equal(t, 2, client2.pings) 464 | 465 | err = client2.close() 466 | require.NoError(t, err) 467 | 468 | cancel() 469 | err = <-errCh 470 | 471 | require.NoError(t, err) 472 | } 473 | 474 | func TestReconnectRemoteTunnel(t *testing.T) { 475 | testServers, errCh, cancel := runTestServers(t, Remote) 476 | 477 | require.NoError(t, checkTunConnections(t, testServers, 0)) 478 | 479 | client := pingPongConnect(t, testServers) 480 | 481 | err := client.ping() 482 | require.NoError(t, err) 483 | 484 | require.NoError(t, checkTunConnections(t, testServers, 1)) 485 | 486 | err = client.close() 487 | require.NoError(t, err) 488 | 489 | err = retry.Do(func() error { 490 | return checkTunConnections(t, testServers, 0) 491 | }, retry.Attempts(5), retry.Delay(500*time.Millisecond)) 492 | 493 | require.NoError(t, err) 494 | 495 | client = pingPongConnect(t, testServers) 496 | 497 | err = client.ping() 498 | require.NoError(t, err) 499 | 500 | require.NoError(t, checkTunConnections(t, testServers, 1)) 501 | 502 | err = client.close() 503 | require.NoError(t, err) 504 | 505 | cancel() 506 | err = <-errCh 507 | 508 | require.NoError(t, err) 509 | } 510 | --------------------------------------------------------------------------------