├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── _examples └── main.go ├── doc.go ├── go.mod ├── go.sum ├── reverseproxy.go ├── reverseproxy_test.go ├── ssh_test.go └── test.Dockerfile /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Test and coverage 2 | on: push 3 | jobs: 4 | test: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - uses: actions/setup-go@v2 9 | with: 10 | go-version: 1.18 11 | - name: Run test 12 | run: make test 13 | - name: Upload coverage to Codecov 14 | run: bash <(curl -s https://codecov.io/bash) 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage.txt 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Charles Moog 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /usr/bin/env bash 2 | 3 | TEST_IMG_TAG := sshproxy-test-target 4 | TEST_SERVER_PORT := 2222 5 | TEST_CONTAINER_NAME := sshproxy-test-target 6 | TEST_USER := test 7 | TEST_PASSWORD := testpassword 8 | 9 | build/image/tests: 10 | docker build \ 11 | --tag $(TEST_IMG_TAG) \ 12 | --build-arg PORT=$(TEST_SERVER_PORT) \ 13 | --build-arg USER=$(TEST_USER) \ 14 | --build-arg PASSWORD=$(TEST_PASSWORD) \ 15 | --file test.Dockerfile \ 16 | . 17 | .PHONY: build/image/tests 18 | 19 | clean: 20 | docker kill $(TEST_CONTAINER_NAME) || true 21 | .PHONY: clean 22 | 23 | setup/tests: build/image/tests clean 24 | docker run \ 25 | --detach \ 26 | --rm \ 27 | --network host \ 28 | --name $(TEST_CONTAINER_NAME) \ 29 | $(TEST_IMG_TAG) 30 | .PHONY: setup/tests 31 | 32 | test: setup/tests 33 | go test . \ 34 | -count 20 \ 35 | -race \ 36 | -coverprofile coverage.txt \ 37 | -covermode atomic \ 38 | -ssh-addr localhost:$(TEST_SERVER_PORT) \ 39 | -ssh-user $(TEST_USER) \ 40 | -ssh-passwd $(TEST_PASSWORD) 41 | docker kill $(TEST_CONTAINER_NAME) 42 | .PHONY: test 43 | 44 | fmt: 45 | go fmt 46 | goimports -w -local=github.com/cmoog/sshproxy $(shell git ls-files '*.go') 47 | .PHONY: fmt 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sshproxy 2 | 3 | [![Documentation](https://godoc.org/github.com/cmoog/sshproxy?status.svg)](https://pkg.go.dev/github.com/cmoog/sshproxy) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/cmoog/sshproxy)](https://goreportcard.com/report/github.com/cmoog/sshproxy) 5 | [![codecov](https://codecov.io/gh/cmoog/sshproxy/branch/master/graph/badge.svg?token=IQ87G7H7OA)](https://codecov.io/gh/cmoog/sshproxy) 6 | 7 | Package sshproxy provides a slim SSH reverse proxy built 8 | atop the `golang.org/x/crypto/ssh` package. 9 | 10 | ```text 11 | go get github.com/cmoog/sshproxy 12 | ``` 13 | 14 | ## Authorization termination proxy 15 | 16 | `sshproxy.ReverseProxy` implements a single host reverse proxy 17 | for SSH servers and clients. Its API is modeled after the ergonomics 18 | of the [HTTP reverse proxy](https://pkg.go.dev/net/http/httputil#ReverseProxy) implementation 19 | from the standard library. 20 | 21 | It enables the proxy to perform authorization termination, 22 | whereby custom authorization logic of the single entrypoint can protect 23 | a set of SSH hosts hidden in a private network. 24 | 25 | For example, one could conceivably use OAuth as a basis for verifying 26 | identity and ownership of public keys. 27 | 28 | ## Example usage 29 | 30 | Consider the following bare-bones example with error handling omitted for brevity. 31 | 32 | ```go 33 | package main 34 | 35 | import ( 36 | "net" 37 | "golang.org/x/crypto/ssh" 38 | "github.com/cmoog/sshproxy" 39 | ) 40 | 41 | func main() { 42 | serverConfig := ssh.ServerConfig{ 43 | // TODO: add your custom public key authentication logic 44 | PublicKeyCallback: customPublicKeyAuthenticationLogic 45 | } 46 | serverConfig.AddHostKey(reverseProxyHostKey) 47 | 48 | listener, _ := net.Listen("tcp", reverseProxyEntrypoint) 49 | for { 50 | clientConnection, _ := listener.Accept() 51 | go func() { 52 | defer clientConnection.Close() 53 | sshConn, sshChannels, sshRequests, _ := ssh.NewServerConn(clientConnection, &serverConfig) 54 | 55 | // TODO: add your custom routing logic based the SSH `user` string, and/or the public key 56 | targetServer, targetServerConnectionConfig := customRoutingLogic(sshConn.User()) 57 | 58 | proxy := sshproxy.New(targetServer, targetServerConnectionConfig) 59 | _ = proxy.Serve(ctx, sshConn, sshChannels, sshRequests) 60 | }() 61 | } 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /_examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto/ecdsa" 6 | "crypto/elliptic" 7 | "crypto/rand" 8 | "crypto/x509" 9 | "encoding/pem" 10 | "fmt" 11 | "log" 12 | "net" 13 | "time" 14 | 15 | "golang.org/x/crypto/ssh" 16 | 17 | "github.com/cmoog/sshproxy" 18 | ) 19 | 20 | // The following example demonstrates a simple usage of sshproxy.ReverseProxy. 21 | // 22 | // Run this example on your local machine, with "username" and "password" 23 | // substituted properly. This will allow you to dial port 2222 and be reverse 24 | // proxied through to your OpenSSH server on port 22. 25 | // 26 | // Run this server in the backround, then dial 27 | // 28 | // $ ssh -p2222 localhost 29 | // 30 | 31 | const exampleUsername = "username" 32 | const examplePassword = "password" 33 | 34 | func main() { 35 | ctx, cancel := context.WithCancel(context.Background()) 36 | defer cancel() 37 | 38 | serverConfig := ssh.ServerConfig{ 39 | NoClientAuth: true, 40 | } 41 | signer, err := generateSigner() 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | serverConfig.AddHostKey(signer) 46 | 47 | l, err := net.Listen("tcp", "localhost:2222") 48 | if err != nil { 49 | log.Fatal(err) 50 | } 51 | for { 52 | conn, err := l.Accept() 53 | if err != nil { 54 | log.Fatal(err) 55 | } 56 | go func() { 57 | defer conn.Close() 58 | serverConn, serverChans, serverReqs, err := ssh.NewServerConn(conn, &serverConfig) 59 | if err != nil { 60 | log.Println(err) 61 | return 62 | } 63 | rp := sshproxy.New("localhost:22", &ssh.ClientConfig{ 64 | User: exampleUsername, 65 | Auth: []ssh.AuthMethod{ssh.Password(examplePassword)}, 66 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), 67 | Timeout: 3 * time.Second, 68 | }) 69 | err = rp.Serve(ctx, serverConn, serverChans, serverReqs) 70 | if err != nil { 71 | log.Println(err) 72 | return 73 | } 74 | }() 75 | } 76 | } 77 | 78 | func generateSigner() (ssh.Signer, error) { 79 | const blockType = "EC PRIVATE KEY" 80 | pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 81 | if err != nil { 82 | return nil, fmt.Errorf("generate rsa private key: %w", err) 83 | } 84 | 85 | byt, err := x509.MarshalECPrivateKey(pkey) 86 | if err != nil { 87 | return nil, fmt.Errorf("marshal private key: %w", err) 88 | } 89 | pb := pem.Block{ 90 | Type: blockType, 91 | Headers: nil, 92 | Bytes: byt, 93 | } 94 | p, err := ssh.ParsePrivateKey(pem.EncodeToMemory(&pb)) 95 | if err != nil { 96 | return nil, err 97 | } 98 | return p, nil 99 | } 100 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package sshproxy provides a slim SSH reverse proxy built 2 | // atop the `golang.org/x/crypto/ssh` package. 3 | package sshproxy // import "github.com/cmoog/sshproxy" 4 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cmoog/sshproxy 2 | 3 | go 1.18 4 | 5 | require golang.org/x/crypto v0.12.0 6 | 7 | require golang.org/x/sys v0.11.0 // indirect 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 2 | golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= 3 | golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= 4 | golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= 5 | golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= 6 | golang.org/x/crypto v0.0.0-20211202192323-5770296d904e h1:MUP6MR3rJ7Gk9LEia0LP2ytiH6MuCfs7qYz+47jGdD8= 7 | golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 8 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= 9 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 10 | golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= 11 | golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= 12 | golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= 13 | golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= 14 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 15 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 16 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 17 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 18 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 19 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 20 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 21 | golang.org/x/sys v0.0.0-20211204120058-94396e421777 h1:QAkhGVjOxMa+n4mlsAWeAU+BMZmimQAaNiMu+iUi94E= 22 | golang.org/x/sys v0.0.0-20211204120058-94396e421777/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 23 | golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= 24 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 25 | golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= 26 | golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 27 | golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= 28 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 29 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 30 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 31 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 32 | -------------------------------------------------------------------------------- /reverseproxy.go: -------------------------------------------------------------------------------- 1 | package sshproxy 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net" 10 | 11 | "golang.org/x/crypto/ssh" 12 | ) 13 | 14 | // ReverseProxy is an SSH Handler that takes an incoming request and sends it 15 | // to another server, proxying the response back to the client. 16 | type ReverseProxy struct { 17 | TargetAddress string 18 | TargetClientConfig *ssh.ClientConfig 19 | 20 | // ErrorLog specifies an optional logger for errors 21 | // that occur when attempting to proxy. 22 | // If nil, logging is done via the log package's standard logger. 23 | ErrorLog *log.Logger 24 | } 25 | 26 | // New constructs a new *ReverseProxy instance. 27 | func New(targetAddr string, clientConfig *ssh.ClientConfig) *ReverseProxy { 28 | return &ReverseProxy{ 29 | TargetAddress: targetAddr, 30 | TargetClientConfig: clientConfig, 31 | } 32 | } 33 | 34 | // Serve executes the reverse proxy between the specified target client and the server connection. 35 | func (r *ReverseProxy) Serve(ctx context.Context, serverConn *ssh.ServerConn, serverChans <-chan ssh.NewChannel, serverReqs <-chan *ssh.Request) error { 36 | ctx, cancel := context.WithCancel(ctx) 37 | defer cancel() 38 | 39 | var logger logger = defaultLogger{} 40 | if r.ErrorLog != nil { 41 | logger = r.ErrorLog 42 | } 43 | 44 | // TODO: do we need to make "network" an argument? 45 | targetConn, err := net.DialTimeout("tcp", r.TargetAddress, r.TargetClientConfig.Timeout) 46 | if err != nil { 47 | return fmt.Errorf("dial reverse proxy target: %w", err) 48 | } 49 | defer targetConn.Close() 50 | 51 | destConn, destChans, destReqs, err := ssh.NewClientConn(targetConn, r.TargetAddress, r.TargetClientConfig) 52 | if err != nil { 53 | return fmt.Errorf("new ssh client conn: %w", err) 54 | } 55 | 56 | shutdownErr := make(chan error, 1) 57 | go func() { 58 | shutdownErr <- serverConn.Conn.Wait() 59 | }() 60 | 61 | go processChannels(ctx, destConn, serverChans, logger) 62 | go processChannels(ctx, serverConn.Conn, destChans, logger) 63 | go processRequests(ctx, destConn, serverReqs, logger) 64 | go processRequests(ctx, serverConn.Conn, destReqs, logger) 65 | 66 | select { 67 | case <-ctx.Done(): 68 | return ctx.Err() 69 | case err := <-shutdownErr: 70 | return err 71 | } 72 | } 73 | 74 | type defaultLogger struct{} 75 | 76 | // wrap the default logger 77 | func (defaultLogger) Printf(format string, v ...any) { log.Printf(format, v...) } 78 | 79 | type logger interface { 80 | Printf(format string, v ...any) 81 | } 82 | 83 | // processChannels handles each ssh.NewChannel concurrently. 84 | func processChannels(ctx context.Context, destConn ssh.Conn, chans <-chan ssh.NewChannel, logger logger) { 85 | defer destConn.Close() 86 | for newCh := range chans { 87 | // reset the var scope for each goroutine 88 | newCh := newCh 89 | go func() { 90 | err := handleChannel(ctx, destConn, newCh, logger) 91 | if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, context.Canceled) { 92 | logger.Printf("sshproxy: ReverseProxy handle channel error: %v", err) 93 | } 94 | }() 95 | } 96 | } 97 | 98 | // processRequests handles each *ssh.Request in series. 99 | func processRequests(ctx context.Context, dest requestDest, requests <-chan *ssh.Request, logger logger) { 100 | for req := range requests { 101 | err := handleRequest(ctx, dest, req) 102 | if err != nil && !errors.Is(err, io.EOF) { 103 | logger.Printf("sshproxy: ReverseProxy handle request error: %v", err) 104 | } 105 | } 106 | } 107 | 108 | // handleChannel performs the bicopy between the destination SSH connection and a 109 | // new incoming channel. 110 | func handleChannel(ctx context.Context, destConn ssh.Conn, newChannel ssh.NewChannel, logger logger) error { 111 | destCh, destReqs, err := destConn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()) 112 | if err != nil { 113 | if openChanErr, ok := err.(*ssh.OpenChannelError); ok { 114 | _ = newChannel.Reject(openChanErr.Reason, openChanErr.Message) 115 | } else { 116 | _ = newChannel.Reject(ssh.ConnectionFailed, err.Error()) 117 | } 118 | return fmt.Errorf("open channel: %w", err) 119 | } 120 | defer destCh.Close() 121 | 122 | originCh, originRequests, err := newChannel.Accept() 123 | if err != nil { 124 | return fmt.Errorf("accept new channel: %w", err) 125 | } 126 | defer originCh.Close() 127 | 128 | destRequestsDone := make(chan struct{}) 129 | go func() { 130 | defer close(destRequestsDone) 131 | processRequests(ctx, channelRequestDest{originCh}, destReqs, logger) 132 | }() 133 | 134 | // This request channel does not get closed 135 | // by the client causing this function to hang if we wait on it. 136 | go processRequests(ctx, channelRequestDest{destCh}, originRequests, logger) 137 | 138 | if err := bicopy(ctx, originCh, destCh, logger); err != nil { 139 | return fmt.Errorf("channel bidirectional copy: %w", err) 140 | } 141 | 142 | select { 143 | case <-destRequestsDone: 144 | return nil 145 | case <-ctx.Done(): 146 | return ctx.Err() 147 | } 148 | } 149 | 150 | // bicopy copies data between the two channels, 151 | // but does not perform complete closure. 152 | // It will block until the context is cancelled or the `alpha` channel 153 | // has completed writing its data. Writes from the `beta` channel are not 154 | // waited on. 155 | func bicopy(ctx context.Context, alpha, beta ssh.Channel, logger logger) error { 156 | alphaWriteDone := make(chan struct{}) 157 | go func() { 158 | defer close(alphaWriteDone) 159 | copyChannels(alpha, beta, logger) 160 | }() 161 | go copyChannels(beta, alpha, logger) 162 | 163 | select { 164 | case <-alphaWriteDone: 165 | return nil 166 | case <-ctx.Done(): 167 | return ctx.Err() 168 | } 169 | } 170 | 171 | // copyChannels pipes data from the writer to the reader channel, calling 172 | // w.CloseWrite when writes have completed. This operation blocks until 173 | // both the stderr and primary copy streams exit. Non EOF errors are logged 174 | // to the given logger. 175 | func copyChannels(w, r ssh.Channel, logger logger) { 176 | defer func() { _ = w.CloseWrite() }() 177 | 178 | copyDone := make(chan struct{}) 179 | go func() { 180 | defer close(copyDone) 181 | _, err := io.Copy(w, r) 182 | if err != nil && !errors.Is(err, io.EOF) { 183 | logger.Printf("sshproxy: bicopy channel: %v", err) 184 | } 185 | }() 186 | _, err := io.Copy(w.Stderr(), r.Stderr()) 187 | if err != nil && !errors.Is(err, io.EOF) { 188 | logger.Printf("sshproxy: bicopy channel: %v", err) 189 | } 190 | <-copyDone 191 | } 192 | 193 | // channelRequestDest wraps the ssh.Channel type to conform with the standard 194 | // SendRequest function signiture. This allows for convenient code re-use in 195 | // piping channel-level requests as well as global, connection-level 196 | // requests. 197 | type channelRequestDest struct { 198 | ssh.Channel 199 | } 200 | 201 | func (c channelRequestDest) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { 202 | ok, err := c.Channel.SendRequest(name, wantReply, payload) 203 | return ok, nil, err 204 | } 205 | 206 | // requestDest defines a resource capable of receiving requests, (global or channel). 207 | type requestDest interface { 208 | SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) 209 | } 210 | 211 | func handleRequest(ctx context.Context, dest requestDest, request *ssh.Request) error { 212 | ok, payload, err := dest.SendRequest(request.Type, request.WantReply, request.Payload) 213 | if err != nil { 214 | if request.WantReply { 215 | if err := request.Reply(ok, payload); err != nil { 216 | return fmt.Errorf("reply after send failure: %w", err) 217 | } 218 | } 219 | return fmt.Errorf("send request: %w", err) 220 | } 221 | 222 | if request.WantReply { 223 | if err := request.Reply(ok, payload); err != nil { 224 | return fmt.Errorf("reply: %w", err) 225 | } 226 | } 227 | return nil 228 | } 229 | -------------------------------------------------------------------------------- /reverseproxy_test.go: -------------------------------------------------------------------------------- 1 | package sshproxy 2 | 3 | import ( 4 | "context" 5 | "crypto/ecdsa" 6 | "crypto/elliptic" 7 | "crypto/rand" 8 | "crypto/x509" 9 | "encoding/pem" 10 | "flag" 11 | "fmt" 12 | "log" 13 | "net" 14 | "sync" 15 | "testing" 16 | "time" 17 | 18 | "golang.org/x/crypto/ssh" 19 | ) 20 | 21 | var ( 22 | addr = flag.String("ssh-addr", "", "specify a target address to dial SSH") 23 | user = flag.String("ssh-user", "", "specify the SSH user with which to dial") 24 | password = flag.String("ssh-passwd", "", "specify the password with which to dial") 25 | ) 26 | 27 | func Test_reverseProxy(t *testing.T) { 28 | t.Parallel() 29 | if *addr == "" || *user == "" || *password == "" { 30 | t.Fatalf("-ssh-addr, -ssh-user, and -ssh-passwd are all required flags") 31 | } 32 | ctx, cancel := context.WithCancel(context.Background()) 33 | defer cancel() 34 | 35 | left, right, err := tcpPipeWithDialer(net.Dial, net.Listen) 36 | if err != nil { 37 | t.Fatalf("new net pipe: %v", err) 38 | } 39 | 40 | clientConfig := &ssh.ClientConfig{ 41 | User: *user, 42 | Auth: []ssh.AuthMethod{ssh.Password(*password)}, 43 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), 44 | Timeout: 3 * time.Second, 45 | } 46 | 47 | var wg sync.WaitGroup 48 | wg.Add(1) 49 | go func() { 50 | defer wg.Done() 51 | defer left.Close() 52 | clientSSHConn, clientChans, clientReqs, err := ssh.NewClientConn(left, "localhost", clientConfig) 53 | if err != nil { 54 | t.Errorf("new client conn: %v", err) 55 | } 56 | client := ssh.NewClient(clientSSHConn, clientChans, clientReqs) 57 | testSSHClient(t, client) 58 | }() 59 | 60 | serverConfig := &ssh.ServerConfig{ 61 | NoClientAuth: true, 62 | } 63 | signer, err := generateSigner() 64 | if err != nil { 65 | t.Fatalf("generate signer: %v", err) 66 | } 67 | serverConfig.AddHostKey(signer) 68 | 69 | serverConn, serverChans, serverReqs, err := ssh.NewServerConn(right, serverConfig) 70 | if err != nil { 71 | t.Fatalf("accept server conn: %v", err) 72 | } 73 | 74 | proxy := New(*addr, clientConfig) 75 | proxy.ErrorLog = log.Default() 76 | err = proxy.Serve(ctx, serverConn, serverChans, serverReqs) 77 | if err == nil { 78 | t.Fatalf("expected error from reverse proxy, got: %v", err) 79 | } 80 | wg.Wait() 81 | } 82 | 83 | func Test_dialFailure(t *testing.T) { 84 | ctx, cancel := context.WithCancel(context.Background()) 85 | defer cancel() 86 | 87 | clientConfig := &ssh.ClientConfig{ 88 | User: *user, 89 | Auth: []ssh.AuthMethod{ssh.Password(*password)}, 90 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), 91 | Timeout: 3 * time.Second, 92 | } 93 | 94 | proxy := New("/tmp/sshproxy-null.sock", clientConfig) 95 | err := proxy.Serve(ctx, nil, nil, nil) 96 | if err == nil { 97 | t.Fatalf("expected error from reverse proxy, got: %v", err) 98 | } 99 | if err.Error() != "dial reverse proxy target: dial tcp: address /tmp/sshproxy-null.sock: missing port in address" { 100 | t.Fatalf("unexpected error, got: %v", err) 101 | } 102 | } 103 | 104 | func Test_serverConnFailure(t *testing.T) { 105 | ctx, cancel := context.WithCancel(context.Background()) 106 | defer cancel() 107 | 108 | clientConfig := &ssh.ClientConfig{ 109 | User: *user, 110 | Auth: []ssh.AuthMethod{ssh.Password(*password)}, 111 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), 112 | Timeout: 3 * time.Second, 113 | } 114 | 115 | listener, err := net.Listen("tcp", "127.0.0.1:0") 116 | if err != nil { 117 | t.Fatalf("unexpected error, got: %v", err) 118 | } 119 | 120 | go func() { 121 | conn, err := listener.Accept() 122 | if err != nil { 123 | t.Errorf("unexpected error, got: %v", err) 124 | return 125 | } 126 | err = conn.Close() 127 | if err != nil { 128 | t.Errorf("unexpected error, got: %v", err) 129 | } 130 | }() 131 | 132 | proxy := New(listener.Addr().String(), clientConfig) 133 | err = proxy.Serve(ctx, nil, nil, nil) 134 | if err == nil { 135 | t.Fatalf("expected error from reverse proxy, got: %v", err) 136 | } 137 | } 138 | 139 | func generateSigner() (ssh.Signer, error) { 140 | const blockType = "EC PRIVATE KEY" 141 | pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 142 | if err != nil { 143 | return nil, fmt.Errorf("generate rsa private key: %w", err) 144 | } 145 | 146 | byt, err := x509.MarshalECPrivateKey(pkey) 147 | if err != nil { 148 | return nil, fmt.Errorf("marshal private key: %w", err) 149 | } 150 | pb := pem.Block{ 151 | Type: blockType, 152 | Headers: nil, 153 | Bytes: byt, 154 | } 155 | p, err := ssh.ParsePrivateKey(pem.EncodeToMemory(&pb)) 156 | if err != nil { 157 | return nil, err 158 | } 159 | return p, nil 160 | } 161 | -------------------------------------------------------------------------------- /ssh_test.go: -------------------------------------------------------------------------------- 1 | package sshproxy 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "math/rand" 10 | "net" 11 | "os" 12 | "path/filepath" 13 | "strconv" 14 | "strings" 15 | "sync" 16 | "testing" 17 | 18 | "golang.org/x/crypto/ssh" 19 | ) 20 | 21 | func testSSHClient(t *testing.T, client *ssh.Client) { 22 | t.Run("session_exec", func(t *testing.T) { testSessionExec(t, client) }) 23 | t.Run("session_exec_stderr", func(t *testing.T) { testSessionPipes(t, client) }) 24 | t.Run("session_stdin", func(t *testing.T) { testStdin(t, client) }) 25 | t.Run("session_exit_code", func(t *testing.T) { testExitCode(t, client) }) 26 | t.Run("environment_variables", func(t *testing.T) { testEnvironmentVar(t, client) }) 27 | t.Run("tcp_forward_local", func(t *testing.T) { testTCPLocal(t, client) }) 28 | t.Run("tcp_forward_remote", func(t *testing.T) { testTCPRemote(t, client) }) 29 | t.Run("unix_forward", func(t *testing.T) { t.Skip(); testUnixForward(t, client) }) 30 | t.Run("invalid_request", func(t *testing.T) { testRequestError(t, client) }) 31 | t.Run("channel_error", func(t *testing.T) { testChannelError(t, client) }) 32 | t.Run("x11_request", func(t *testing.T) { testX11Forwarding(t, client) }) 33 | } 34 | 35 | func testTCPLocal(t *testing.T, client *ssh.Client) { 36 | left, right, err := tcpPipeWithDialer(client.Dial, net.Listen) 37 | if err != nil { 38 | t.Fatalf("new net pipe: %v", err) 39 | } 40 | testConnPipe(t, left, right) 41 | } 42 | 43 | func testConnPipe(t *testing.T, a, b net.Conn) { 44 | mockdata := strconv.Itoa(rand.Int()) + "\n" 45 | 46 | _, err := a.Write([]byte(mockdata)) 47 | if err != nil { 48 | t.Fatalf("write mock data: %v", err) 49 | } 50 | 51 | err = a.Close() 52 | if err != nil { 53 | t.Fatalf("close conn: %v", err) 54 | } 55 | 56 | content, err := bufio.NewReader(b).ReadString('\n') 57 | if err != nil { 58 | t.Fatalf("read from net conn: %v", err) 59 | } 60 | if content != mockdata { 61 | t.Fatalf("unexpected data, expected (%s), got (%s)", mockdata, content) 62 | } 63 | } 64 | 65 | func testUnixForward(t *testing.T, client *ssh.Client) { 66 | left, right, err := unixSocketPipe(t, client.Dial, net.Listen) 67 | if err != nil { 68 | t.Fatalf("new unix socket pipe: %v", err) 69 | } 70 | testConnPipe(t, left, right) 71 | } 72 | 73 | func testX11Forwarding(t *testing.T, client *ssh.Client) { 74 | _, _, err := client.SendRequest("x11-req", true, nil) 75 | if err != nil { 76 | t.Fatalf("new x11 forward request: %v", err) 77 | } 78 | } 79 | 80 | func testTCPRemote(t *testing.T, client *ssh.Client) { 81 | left, right, err := tcpPipeWithDialer(net.Dial, client.Listen) 82 | if err != nil { 83 | t.Fatalf("new remote tcp conn pipe: %v", err) 84 | } 85 | testConnPipe(t, left, right) 86 | } 87 | 88 | func testSessionExec(t *testing.T, client *ssh.Client) { 89 | session, err := client.NewSession() 90 | if err != nil { 91 | t.Fatalf("new ssh session: %v", err) 92 | } 93 | defer session.Close() 94 | 95 | output, err := session.CombinedOutput("echo 123") 96 | if err != nil { 97 | t.Fatalf("execute command: %v", err) 98 | } 99 | if string(output) != "123\n" { 100 | t.Fatalf("unexpected output, expected (%s), got (%s)", "123\n", string(output)) 101 | } 102 | 103 | session, err = client.NewSession() 104 | if err != nil { 105 | t.Fatalf("new ssh session: %v", err) 106 | } 107 | defer session.Close() 108 | 109 | // create a pipe to simulate a hanging stdin 110 | // never write or close the write end 111 | r, _ := io.Pipe() 112 | session.Stdin = r 113 | 114 | output, err = session.CombinedOutput("echo 123") 115 | if err != nil { 116 | t.Fatalf("new ssh session: %v", err) 117 | } 118 | if string(output) != "123\n" { 119 | t.Fatalf("unexpected command output, expected (%s), got (%s)", "123\n", string(output)) 120 | } 121 | } 122 | 123 | func testSessionPipes(t *testing.T, client *ssh.Client) { 124 | session, err := client.NewSession() 125 | if err != nil { 126 | t.Fatalf("new ssh session: %v", err) 127 | } 128 | defer session.Close() 129 | 130 | stderrPipe, err := session.StderrPipe() 131 | if err != nil { 132 | t.Fatalf("new stderr pipe: %v", err) 133 | } 134 | 135 | stdoutPipe, err := session.StdoutPipe() 136 | if err != nil { 137 | t.Fatalf("new stdout pipe: %v", err) 138 | } 139 | 140 | var wg sync.WaitGroup 141 | stderr, stdout := bytes.NewBuffer(nil), bytes.NewBuffer(nil) 142 | wg.Add(2) 143 | go func() { 144 | defer wg.Done() 145 | _, _ = io.Copy(stderr, stderrPipe) 146 | }() 147 | go func() { 148 | defer wg.Done() 149 | _, _ = io.Copy(stdout, stdoutPipe) 150 | }() 151 | 152 | err = session.Run(">&2 echo error") 153 | if err != nil { 154 | t.Fatalf("run command: %v", err) 155 | } 156 | 157 | wg.Wait() 158 | session.Close() 159 | } 160 | 161 | func testStdin(t *testing.T, client *ssh.Client) { 162 | session, err := client.NewSession() 163 | if err != nil { 164 | t.Fatalf("new ssh session: %v", err) 165 | } 166 | defer session.Close() 167 | 168 | session.Stdin = strings.NewReader("testing\n") 169 | 170 | output, err := session.CombinedOutput("cat") 171 | if err != nil { 172 | t.Fatalf("execute command: %v", err) 173 | } 174 | if string(output) != "testing\n" { 175 | t.Fatalf("unexpected output, expected (%s), got (%s)", "testing", string(output)) 176 | } 177 | } 178 | 179 | func testExitCode(t *testing.T, client *ssh.Client) { 180 | session, err := client.NewSession() 181 | if err != nil { 182 | t.Fatalf("new ssh session: %v", err) 183 | } 184 | defer session.Close() 185 | 186 | err = session.Run("exit 123") 187 | if err == nil { 188 | t.Fatalf("expected error, got: %v", err) 189 | } 190 | 191 | var exitErr *ssh.ExitError 192 | ok := errors.As(err, &exitErr) 193 | if !ok { 194 | t.Fatalf("unknown error type, expected ssh.ExitError: %v", err) 195 | } 196 | if exitErr.ExitStatus() != 123 { 197 | t.Fatalf("unexpected exit status, expected %d, got %d", 123, exitErr.ExitStatus()) 198 | } 199 | } 200 | 201 | func testEnvironmentVar(t *testing.T, client *ssh.Client) { 202 | session, err := client.NewSession() 203 | if err != nil { 204 | t.Fatalf("new ssh session: %v", err) 205 | } 206 | defer session.Close() 207 | 208 | setEnvs := map[string]string{ 209 | "NEW_ENV": "TEST_VALUE", 210 | "TESTING": "with space", 211 | } 212 | 213 | for k, v := range setEnvs { 214 | err := session.Setenv(k, v) 215 | if err != nil { 216 | t.Fatalf("set environment variable: %v", err) 217 | } 218 | } 219 | 220 | output, err := session.CombinedOutput("env") 221 | if err != nil { 222 | t.Fatalf("run comamnd: %v", err) 223 | } 224 | 225 | env := string(output) 226 | for k, v := range setEnvs { 227 | contains := strings.Contains(env, fmt.Sprintf("%s=%s", k, v)) 228 | if !contains { 229 | t.Fatalf("environment var not found: %v", err) 230 | } 231 | } 232 | } 233 | 234 | func testChannelError(t *testing.T, client *ssh.Client) { 235 | var openChErr *ssh.OpenChannelError 236 | _, _, err := client.OpenChannel("invalid", []byte{}) 237 | if err == nil { 238 | t.Fatalf("expected error from open invalid channel, got %v", err) 239 | } 240 | if !errors.As(err, &openChErr) { 241 | t.Fatalf("expected *ssh.OpenChannelError, got %T: %v", err, err) 242 | } 243 | if openChErr.Reason != ssh.ConnectionFailed { 244 | t.Fatalf("expected ssh.ConnectionFailed, got: %s", openChErr.Reason.String()) 245 | } 246 | } 247 | 248 | func testRequestError(t *testing.T, client *ssh.Client) { 249 | ok, resp, err := client.SendRequest("invalid", true, nil) 250 | if err != nil { 251 | t.Fatalf("unexpected error: %v", err) 252 | } 253 | if len(resp) != 0 { 254 | t.Fatalf("expected request response to be empty, got %v", resp) 255 | } 256 | if ok { 257 | t.Fatalf("expected false from \"ok\"") 258 | } 259 | } 260 | 261 | type listener func(net string, addr string) (net.Listener, error) 262 | type dialer func(net, addr string) (net.Conn, error) 263 | 264 | func tcpPipeWithDialer(dial dialer, listen listener) (net.Conn, net.Conn, error) { 265 | // may need to use "[::1]:0" for ipv6 266 | return netPipeWithDialer(dial, listen, "tcp", "127.0.0.1:0") 267 | } 268 | 269 | func unixSocketPipe(t *testing.T, dial dialer, listen listener) (net.Conn, net.Conn, error) { 270 | socket := filepath.Join("/tmp", "sshproxy-unix-test-"+strconv.Itoa(rand.Int())+".sock") 271 | cleanup := func() { _ = os.Remove(socket) } 272 | cleanup() 273 | t.Cleanup(cleanup) 274 | return netPipeWithDialer(dial, listen, "unix", socket) 275 | } 276 | 277 | func netPipeWithDialer(dial dialer, listen listener, net, addr string) (net.Conn, net.Conn, error) { 278 | listener, err := listen(net, addr) 279 | if err != nil { 280 | return nil, nil, err 281 | } 282 | defer listener.Close() 283 | 284 | c1, err := dial(net, listener.Addr().String()) 285 | if err != nil { 286 | return nil, nil, err 287 | } 288 | 289 | c2, err := listener.Accept() 290 | if err != nil { 291 | c1.Close() 292 | return nil, nil, err 293 | } 294 | 295 | return c1, c2, nil 296 | } 297 | -------------------------------------------------------------------------------- /test.Dockerfile: -------------------------------------------------------------------------------- 1 | # This Dockerfile specifies an image for use as a testing target, whereby 2 | # its OpenSSH server is accessed through sshproxy.ReverseProxy and verified 3 | # to respond as expected. 4 | FROM ubuntu:20.04 5 | 6 | RUN apt-get update && DEBIAN_FRONTEND="noninteractive" apt-get install -y \ 7 | bash \ 8 | openssh-server \ 9 | && rm -rf /var/lib/apt/lists/* 10 | 11 | ARG USER=testuser 12 | ARG PASSWORD=testpassword 13 | ARG PORT=2222 14 | 15 | RUN ssh-keygen -A && mkdir -p /run/sshd 16 | 17 | RUN useradd --create-home --shell /bin/bash ${USER} 18 | RUN echo ${USER}:${PASSWORD} | chpasswd 19 | 20 | RUN echo "Port ${PORT}" >> /etc/ssh/sshd_config 21 | RUN echo "AcceptEnv *" >> /etc/ssh/sshd_config 22 | 23 | CMD [ "sh", "-c", "/usr/sbin/sshd && sleep infinity" ] 24 | --------------------------------------------------------------------------------