├── LICENSE ├── README.md ├── _examples ├── ssh-docker │ ├── Dockerfile │ ├── README.md │ └── docker.go ├── ssh-forwardagent │ └── forwardagent.go ├── ssh-pty │ └── pty.go ├── ssh-publickey │ └── public_key.go ├── ssh-remoteforward │ └── portforward.go ├── ssh-sftpserver │ └── sftp.go ├── ssh-simple │ └── simple.go └── ssh-timeouts │ └── timeouts.go ├── agent.go ├── circle.yml ├── conn.go ├── context.go ├── context_test.go ├── doc.go ├── example_test.go ├── go.mod ├── go.sum ├── options.go ├── options_test.go ├── server.go ├── server_test.go ├── session.go ├── session_test.go ├── ssh.go ├── ssh_test.go ├── tcpip.go ├── tcpip_test.go ├── util.go └── wrap.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Glider Labs. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Glider Labs nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gliderlabs/ssh 2 | 3 | [![GoDoc](https://godoc.org/github.com/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) 4 | [![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) 6 | [![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) 7 | [![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) 8 | [![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) 9 | 10 | > The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member 11 | 12 | This Go package wraps the [crypto/ssh 13 | package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for 14 | building SSH servers. The goal of the API was to make it as simple as using 15 | [net/http](https://golang.org/pkg/net/http/), so the API is very similar: 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "github.com/gliderlabs/ssh" 22 | "io" 23 | "log" 24 | ) 25 | 26 | func main() { 27 | ssh.Handle(func(s ssh.Session) { 28 | io.WriteString(s, "Hello world\n") 29 | }) 30 | 31 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 32 | } 33 | 34 | ``` 35 | This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). 36 | 37 | ## Examples 38 | 39 | A bunch of great examples are in the `_examples` directory. 40 | 41 | ## Usage 42 | 43 | [See GoDoc reference.](https://godoc.org/github.com/gliderlabs/ssh) 44 | 45 | ## Contributing 46 | 47 | Pull requests are welcome! However, since this project is very much about API 48 | design, please submit API changes as issues to discuss before submitting PRs. 49 | 50 | Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. 51 | 52 | ## Roadmap 53 | 54 | * Non-session channel handlers 55 | * Cleanup callback API 56 | * 1.0 release 57 | * High-level client? 58 | 59 | ## Sponsors 60 | 61 | Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | ## License 95 | 96 | [BSD](LICENSE) 97 | -------------------------------------------------------------------------------- /_examples/ssh-docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM alpine 2 | RUN apk add -U jq 3 | ENTRYPOINT ["jq"] 4 | -------------------------------------------------------------------------------- /_examples/ssh-docker/README.md: -------------------------------------------------------------------------------- 1 | # SSH Docker Example 2 | Run docker containers over SSH. You can even pipe things into them too! 3 | 4 | # Installation / Prep 5 | We're going to build JQ as an SSH service using the Glider Labs SSH package. If you haven't installed GoLang and docker yet, see the doc's for help getting your environment setup. 6 | 7 | Install the Glider Labs SSH package 8 | `go get github.com/gliderlabs/ssh` 9 | 10 | Build the example docker container with 11 | `docker build --rm -t jq .` 12 | 13 | # Usage 14 | Run the SSH server with 15 | 16 | `go run docker.go` 17 | 18 | This SSH service uses the user name passed into the connection to key the desired docker image. So to use jq over our SSH server we would say: 19 | 20 | `ssh jq@localhost -p 2222` 21 | 22 | If we run this command we will see 23 | 24 | ``` 25 | jq - commandline JSON processor [version 1.5] 26 | Usage: jq [options] [file...] 27 | 28 | jq is a tool for processing JSON inputs, applying the 29 | given filter to its JSON text inputs and producing the 30 | filter's results as JSON on standard output. 31 | The simplest filter is ., which is the identity filter, 32 | copying jq's input to its output unmodified (except for 33 | formatting). 34 | For more advanced filters see the jq(1) manpage ("man jq") 35 | and/or https://stedolan.github.io/jq 36 | 37 | Some of the options include: 38 | -c compact instead of pretty-printed output; 39 | -n use `null` as the single input value; 40 | -e set the exit status code based on the output; 41 | -s read (slurp) all inputs into an array; apply filter to it; 42 | -r output raw strings, not JSON texts; 43 | -R read raw strings, not JSON texts; 44 | -C colorize JSON; 45 | -M monochrome (don't colorize JSON); 46 | -S sort keys of objects on output; 47 | --tab use tabs for indentation; 48 | --arg a v set variable $a to value ; 49 | --argjson a v set variable $a to JSON value ; 50 | --slurpfile a f set variable $a to an array of JSON texts read from ; 51 | See the manpage for more options. 52 | Connection to localhost closed. 53 | ``` 54 | 55 | JQ's help text! It's working! Now let's pipe some json into our SSH service and marvel at the awesomeness. 56 | 57 | `curl -s https://api.github.com/orgs/gliderlabs | ssh jq@localhost -p 2222 .` 58 | 59 | ```json 60 | { 61 | "login": "gliderlabs", 62 | "id": 8484931, 63 | "url": "https://api.github.com/orgs/gliderlabs", 64 | "repos_url": "https://api.github.com/orgs/gliderlabs/repos", 65 | "events_url": "https://api.github.com/orgs/gliderlabs/events", 66 | "hooks_url": "https://api.github.com/orgs/gliderlabs/hooks", 67 | "issues_url": "https://api.github.com/orgs/gliderlabs/issues", 68 | "members_url": "https://api.github.com/orgs/gliderlabs/members{/member}", 69 | "public_members_url": "https://api.github.com/orgs/gliderlabs/public_members{/member}", 70 | "avatar_url": "https://avatars3.githubusercontent.com/u/8484931?v=4", 71 | "description": "", 72 | "name": "Glider Labs", 73 | "company": null, 74 | "blog": "http://gliderlabs.com", 75 | "location": "Austin, TX", 76 | "email": "team@gliderlabs.com", 77 | "has_organization_projects": true, 78 | "has_repository_projects": true, 79 | "public_repos": 29, 80 | "public_gists": 0, 81 | "followers": 0, 82 | "following": 0, 83 | "html_url": "https://github.com/gliderlabs", 84 | "created_at": "2014-08-18T23:25:37Z", 85 | "updated_at": "2017-08-21T09:52:17Z", 86 | "type": "Organization" 87 | } 88 | ``` 89 | 90 | # Conclusion 91 | We built JQ as a service over SSH in Go using the Glider Labs SSH package. We showed how you can run docker containers through the service as well as how to pipe stuff into your SSH service. 92 | 93 | 94 | # Troubleshooting 95 | 96 | A new host key is generated each time we run the server so you'll probably see this error on the second run. 97 | ``` 98 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 99 | @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @ 100 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ 101 | IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! 102 | Someone could be eavesdropping on you right now (man-in-the-middle attack)! 103 | It is also possible that a host key has just been changed. 104 | The fingerprint for the RSA key sent by the remote host is 105 | SHA256:9dyZ8KrMCPGvDqECggoDFyz51gA6DdHa9zpfl/5Kgos. 106 | Please contact your system administrator. 107 | Add correct host key in /Users/murr/.ssh/known_hosts to get rid of this message. 108 | Offending RSA key in /Users/murr/.ssh/known_hosts:7 109 | RSA host key for [localhost]:2222 has changed and you have requested strict checking. 110 | Host key verification failed. 111 | ``` 112 | 113 | To bypass this error, regnerate your host key with 114 | `ssh-keygen -R "[localhost]:2222"` 115 | -------------------------------------------------------------------------------- /_examples/ssh-docker/docker.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "log" 8 | 9 | "github.com/docker/docker/api/types" 10 | "github.com/docker/docker/api/types/container" 11 | "github.com/docker/docker/client" 12 | "github.com/docker/docker/pkg/stdcopy" 13 | "github.com/gliderlabs/ssh" 14 | ) 15 | 16 | func main() { 17 | ssh.Handle(func(sess ssh.Session) { 18 | _, _, isTty := sess.Pty() 19 | cfg := &container.Config{ 20 | Image: sess.User(), 21 | Cmd: sess.Command(), 22 | Env: sess.Environ(), 23 | Tty: isTty, 24 | OpenStdin: true, 25 | AttachStderr: true, 26 | AttachStdin: true, 27 | AttachStdout: true, 28 | StdinOnce: true, 29 | Volumes: make(map[string]struct{}), 30 | } 31 | status, cleanup, err := dockerRun(cfg, sess) 32 | defer cleanup() 33 | if err != nil { 34 | fmt.Fprintln(sess, err) 35 | log.Println(err) 36 | } 37 | sess.Exit(int(status)) 38 | }) 39 | 40 | log.Println("starting ssh server on port 2222...") 41 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 42 | } 43 | 44 | func dockerRun(cfg *container.Config, sess ssh.Session) (status int64, cleanup func(), err error) { 45 | docker, err := client.NewEnvClient() 46 | if err != nil { 47 | panic(err) 48 | } 49 | status = 255 50 | cleanup = func() {} 51 | ctx := context.Background() 52 | res, err := docker.ContainerCreate(ctx, cfg, nil, nil, "") 53 | if err != nil { 54 | return 55 | } 56 | cleanup = func() { 57 | docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{}) 58 | } 59 | opts := types.ContainerAttachOptions{ 60 | Stdin: cfg.AttachStdin, 61 | Stdout: cfg.AttachStdout, 62 | Stderr: cfg.AttachStderr, 63 | Stream: true, 64 | } 65 | stream, err := docker.ContainerAttach(ctx, res.ID, opts) 66 | if err != nil { 67 | return 68 | } 69 | cleanup = func() { 70 | docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{}) 71 | stream.Close() 72 | } 73 | 74 | outputErr := make(chan error) 75 | 76 | go func() { 77 | var err error 78 | if cfg.Tty { 79 | _, err = io.Copy(sess, stream.Reader) 80 | } else { 81 | _, err = stdcopy.StdCopy(sess, sess.Stderr(), stream.Reader) 82 | } 83 | outputErr <- err 84 | }() 85 | 86 | go func() { 87 | defer stream.CloseWrite() 88 | io.Copy(stream.Conn, sess) 89 | }() 90 | 91 | err = docker.ContainerStart(ctx, res.ID, types.ContainerStartOptions{}) 92 | if err != nil { 93 | return 94 | } 95 | if cfg.Tty { 96 | _, winCh, _ := sess.Pty() 97 | go func() { 98 | for win := range winCh { 99 | err := docker.ContainerResize(ctx, res.ID, types.ResizeOptions{ 100 | Height: uint(win.Height), 101 | Width: uint(win.Width), 102 | }) 103 | if err != nil { 104 | log.Println(err) 105 | break 106 | } 107 | } 108 | }() 109 | } 110 | resultC, errC := docker.ContainerWait(ctx, res.ID, container.WaitConditionNotRunning) 111 | select { 112 | case err = <-errC: 113 | return 114 | case result := <-resultC: 115 | status = result.StatusCode 116 | } 117 | err = <-outputErr 118 | return 119 | } 120 | -------------------------------------------------------------------------------- /_examples/ssh-forwardagent/forwardagent.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os/exec" 7 | 8 | "github.com/gliderlabs/ssh" 9 | ) 10 | 11 | func main() { 12 | ssh.Handle(func(s ssh.Session) { 13 | cmd := exec.Command("ssh-add", "-l") 14 | if ssh.AgentRequested(s) { 15 | l, err := ssh.NewAgentListener() 16 | if err != nil { 17 | log.Fatal(err) 18 | } 19 | defer l.Close() 20 | go ssh.ForwardAgentConnections(l, s) 21 | cmd.Env = append(s.Environ(), fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) 22 | } else { 23 | cmd.Env = s.Environ() 24 | } 25 | cmd.Stdout = s 26 | cmd.Stderr = s.Stderr() 27 | if err := cmd.Run(); err != nil { 28 | log.Println(err) 29 | return 30 | } 31 | }) 32 | 33 | log.Println("starting ssh server on port 2222...") 34 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 35 | } 36 | -------------------------------------------------------------------------------- /_examples/ssh-pty/pty.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "os" 8 | "os/exec" 9 | "syscall" 10 | "unsafe" 11 | 12 | "github.com/gliderlabs/ssh" 13 | "github.com/creack/pty" 14 | ) 15 | 16 | func setWinsize(f *os.File, w, h int) { 17 | syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), 18 | uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) 19 | } 20 | 21 | func main() { 22 | ssh.Handle(func(s ssh.Session) { 23 | cmd := exec.Command("top") 24 | ptyReq, winCh, isPty := s.Pty() 25 | if isPty { 26 | cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) 27 | f, err := pty.Start(cmd) 28 | if err != nil { 29 | panic(err) 30 | } 31 | go func() { 32 | for win := range winCh { 33 | setWinsize(f, win.Width, win.Height) 34 | } 35 | }() 36 | go func() { 37 | io.Copy(f, s) // stdin 38 | }() 39 | io.Copy(s, f) // stdout 40 | cmd.Wait() 41 | } else { 42 | io.WriteString(s, "No PTY requested.\n") 43 | s.Exit(1) 44 | } 45 | }) 46 | 47 | log.Println("starting ssh server on port 2222...") 48 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 49 | } 50 | -------------------------------------------------------------------------------- /_examples/ssh-publickey/public_key.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | 8 | "github.com/gliderlabs/ssh" 9 | gossh "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | func main() { 13 | ssh.Handle(func(s ssh.Session) { 14 | authorizedKey := gossh.MarshalAuthorizedKey(s.PublicKey()) 15 | io.WriteString(s, fmt.Sprintf("public key used by %s:\n", s.User())) 16 | s.Write(authorizedKey) 17 | }) 18 | 19 | publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { 20 | return true // allow all keys, or use ssh.KeysEqual() to compare against known keys 21 | }) 22 | 23 | log.Println("starting ssh server on port 2222...") 24 | log.Fatal(ssh.ListenAndServe(":2222", nil, publicKeyOption)) 25 | } 26 | -------------------------------------------------------------------------------- /_examples/ssh-remoteforward/portforward.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "log" 6 | 7 | "github.com/gliderlabs/ssh" 8 | ) 9 | 10 | func main() { 11 | 12 | log.Println("starting ssh server on port 2222...") 13 | 14 | forwardHandler := &ssh.ForwardedTCPHandler{} 15 | 16 | server := ssh.Server{ 17 | LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool { 18 | log.Println("Accepted forward", dhost, dport) 19 | return true 20 | }), 21 | Addr: ":2222", 22 | Handler: ssh.Handler(func(s ssh.Session) { 23 | io.WriteString(s, "Remote forwarding available...\n") 24 | select {} 25 | }), 26 | ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, host string, port uint32) bool { 27 | log.Println("attempt to bind", host, port, "granted") 28 | return true 29 | }), 30 | RequestHandlers: map[string]ssh.RequestHandler{ 31 | "tcpip-forward": forwardHandler.HandleSSHRequest, 32 | "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, 33 | }, 34 | } 35 | 36 | log.Fatal(server.ListenAndServe()) 37 | } 38 | -------------------------------------------------------------------------------- /_examples/ssh-sftpserver/sftp.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | 8 | "github.com/gliderlabs/ssh" 9 | "github.com/pkg/sftp" 10 | ) 11 | 12 | // SftpHandler handler for SFTP subsystem 13 | func SftpHandler(sess ssh.Session) { 14 | debugStream := io.Discard 15 | serverOptions := []sftp.ServerOption{ 16 | sftp.WithDebug(debugStream), 17 | } 18 | server, err := sftp.NewServer( 19 | sess, 20 | serverOptions..., 21 | ) 22 | if err != nil { 23 | log.Printf("sftp server init error: %s\n", err) 24 | return 25 | } 26 | if err := server.Serve(); err == io.EOF { 27 | server.Close() 28 | fmt.Println("sftp client exited session.") 29 | } else if err != nil { 30 | fmt.Println("sftp server completed with error:", err) 31 | } 32 | } 33 | 34 | func main() { 35 | ssh_server := ssh.Server{ 36 | Addr: "127.0.0.1:2222", 37 | SubsystemHandlers: map[string]ssh.SubsystemHandler{ 38 | "sftp": SftpHandler, 39 | }, 40 | } 41 | log.Fatal(ssh_server.ListenAndServe()) 42 | } 43 | -------------------------------------------------------------------------------- /_examples/ssh-simple/simple.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | 8 | "github.com/gliderlabs/ssh" 9 | ) 10 | 11 | func main() { 12 | ssh.Handle(func(s ssh.Session) { 13 | io.WriteString(s, fmt.Sprintf("Hello %s\n", s.User())) 14 | }) 15 | 16 | log.Println("starting ssh server on port 2222...") 17 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 18 | } 19 | -------------------------------------------------------------------------------- /_examples/ssh-timeouts/timeouts.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "time" 6 | 7 | "github.com/gliderlabs/ssh" 8 | ) 9 | 10 | var ( 11 | DeadlineTimeout = 30 * time.Second 12 | IdleTimeout = 10 * time.Second 13 | ) 14 | 15 | func main() { 16 | ssh.Handle(func(s ssh.Session) { 17 | log.Println("new connection") 18 | i := 0 19 | for { 20 | i += 1 21 | log.Println("active seconds:", i) 22 | select { 23 | case <-time.After(time.Second): 24 | continue 25 | case <-s.Context().Done(): 26 | log.Println("connection closed") 27 | return 28 | } 29 | } 30 | }) 31 | 32 | log.Println("starting ssh server on port 2222...") 33 | log.Printf("connections will only last %s\n", DeadlineTimeout) 34 | log.Printf("and timeout after %s of no activity\n", IdleTimeout) 35 | server := &ssh.Server{ 36 | Addr: ":2222", 37 | MaxTimeout: DeadlineTimeout, 38 | IdleTimeout: IdleTimeout, 39 | } 40 | log.Fatal(server.ListenAndServe()) 41 | } 42 | -------------------------------------------------------------------------------- /agent.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "os" 7 | "path" 8 | "sync" 9 | 10 | gossh "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | const ( 14 | agentRequestType = "auth-agent-req@openssh.com" 15 | agentChannelType = "auth-agent@openssh.com" 16 | 17 | agentTempDir = "auth-agent" 18 | agentListenFile = "listener.sock" 19 | ) 20 | 21 | // contextKeyAgentRequest is an internal context key for storing if the 22 | // client requested agent forwarding 23 | var contextKeyAgentRequest = &contextKey{"auth-agent-req"} 24 | 25 | // SetAgentRequested sets up the session context so that AgentRequested 26 | // returns true. 27 | func SetAgentRequested(ctx Context) { 28 | ctx.SetValue(contextKeyAgentRequest, true) 29 | } 30 | 31 | // AgentRequested returns true if the client requested agent forwarding. 32 | func AgentRequested(sess Session) bool { 33 | return sess.Context().Value(contextKeyAgentRequest) == true 34 | } 35 | 36 | // NewAgentListener sets up a temporary Unix socket that can be communicated 37 | // to the session environment and used for forwarding connections. 38 | func NewAgentListener() (net.Listener, error) { 39 | dir, err := os.MkdirTemp("", agentTempDir) 40 | if err != nil { 41 | return nil, err 42 | } 43 | l, err := net.Listen("unix", path.Join(dir, agentListenFile)) 44 | if err != nil { 45 | return nil, err 46 | } 47 | return l, nil 48 | } 49 | 50 | // ForwardAgentConnections takes connections from a listener to proxy into the 51 | // session on the OpenSSH channel for agent connections. It blocks and services 52 | // connections until the listener stop accepting. 53 | func ForwardAgentConnections(l net.Listener, s Session) { 54 | sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) 55 | for { 56 | conn, err := l.Accept() 57 | if err != nil { 58 | return 59 | } 60 | go func(conn net.Conn) { 61 | defer conn.Close() 62 | channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) 63 | if err != nil { 64 | return 65 | } 66 | defer channel.Close() 67 | go gossh.DiscardRequests(reqs) 68 | var wg sync.WaitGroup 69 | wg.Add(2) 70 | go func() { 71 | io.Copy(conn, channel) 72 | conn.(*net.UnixConn).CloseWrite() 73 | wg.Done() 74 | }() 75 | go func() { 76 | io.Copy(channel, conn) 77 | channel.CloseWrite() 78 | wg.Done() 79 | }() 80 | wg.Wait() 81 | }(conn) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /circle.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build-go-latest: 4 | docker: 5 | - image: golang:latest 6 | working_directory: /go/src/github.com/gliderlabs/ssh 7 | steps: 8 | - checkout 9 | - run: go get 10 | - run: go test -v -race 11 | 12 | build-go-1.20: 13 | docker: 14 | - image: golang:1.20 15 | working_directory: /go/src/github.com/gliderlabs/ssh 16 | steps: 17 | - checkout 18 | - run: go get 19 | - run: go test -v -race 20 | 21 | workflows: 22 | version: 2 23 | build: 24 | jobs: 25 | - build-go-latest 26 | - build-go-1.20 27 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "time" 7 | ) 8 | 9 | type serverConn struct { 10 | net.Conn 11 | 12 | idleTimeout time.Duration 13 | handshakeDeadline time.Time 14 | maxDeadline time.Time 15 | closeCanceler context.CancelFunc 16 | } 17 | 18 | func (c *serverConn) Write(p []byte) (n int, err error) { 19 | if c.idleTimeout > 0 { 20 | c.updateDeadline() 21 | } 22 | n, err = c.Conn.Write(p) 23 | if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { 24 | c.closeCanceler() 25 | } 26 | return 27 | } 28 | 29 | func (c *serverConn) Read(b []byte) (n int, err error) { 30 | if c.idleTimeout > 0 { 31 | c.updateDeadline() 32 | } 33 | n, err = c.Conn.Read(b) 34 | if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { 35 | c.closeCanceler() 36 | } 37 | return 38 | } 39 | 40 | func (c *serverConn) Close() (err error) { 41 | err = c.Conn.Close() 42 | if c.closeCanceler != nil { 43 | c.closeCanceler() 44 | } 45 | return 46 | } 47 | 48 | func (c *serverConn) updateDeadline() { 49 | deadline := c.maxDeadline 50 | 51 | if !c.handshakeDeadline.IsZero() && (deadline.IsZero() || c.handshakeDeadline.Before(deadline)) { 52 | deadline = c.handshakeDeadline 53 | } 54 | 55 | if c.idleTimeout > 0 { 56 | idleDeadline := time.Now().Add(c.idleTimeout) 57 | if deadline.IsZero() || idleDeadline.Before(deadline) { 58 | deadline = idleDeadline 59 | } 60 | } 61 | 62 | c.Conn.SetDeadline(deadline) 63 | } 64 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "net" 7 | "sync" 8 | 9 | gossh "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | // contextKey is a value for use with context.WithValue. It's used as 13 | // a pointer so it fits in an interface{} without allocation. 14 | type contextKey struct { 15 | name string 16 | } 17 | 18 | var ( 19 | // ContextKeyUser is a context key for use with Contexts in this package. 20 | // The associated value will be of type string. 21 | ContextKeyUser = &contextKey{"user"} 22 | 23 | // ContextKeySessionID is a context key for use with Contexts in this package. 24 | // The associated value will be of type string. 25 | ContextKeySessionID = &contextKey{"session-id"} 26 | 27 | // ContextKeyPermissions is a context key for use with Contexts in this package. 28 | // The associated value will be of type *Permissions. 29 | ContextKeyPermissions = &contextKey{"permissions"} 30 | 31 | // ContextKeyClientVersion is a context key for use with Contexts in this package. 32 | // The associated value will be of type string. 33 | ContextKeyClientVersion = &contextKey{"client-version"} 34 | 35 | // ContextKeyServerVersion is a context key for use with Contexts in this package. 36 | // The associated value will be of type string. 37 | ContextKeyServerVersion = &contextKey{"server-version"} 38 | 39 | // ContextKeyLocalAddr is a context key for use with Contexts in this package. 40 | // The associated value will be of type net.Addr. 41 | ContextKeyLocalAddr = &contextKey{"local-addr"} 42 | 43 | // ContextKeyRemoteAddr is a context key for use with Contexts in this package. 44 | // The associated value will be of type net.Addr. 45 | ContextKeyRemoteAddr = &contextKey{"remote-addr"} 46 | 47 | // ContextKeyServer is a context key for use with Contexts in this package. 48 | // The associated value will be of type *Server. 49 | ContextKeyServer = &contextKey{"ssh-server"} 50 | 51 | // ContextKeyConn is a context key for use with Contexts in this package. 52 | // The associated value will be of type gossh.ServerConn. 53 | ContextKeyConn = &contextKey{"ssh-conn"} 54 | 55 | // ContextKeyPublicKey is a context key for use with Contexts in this package. 56 | // The associated value will be of type PublicKey. 57 | ContextKeyPublicKey = &contextKey{"public-key"} 58 | ) 59 | 60 | // Context is a package specific context interface. It exposes connection 61 | // metadata and allows new values to be easily written to it. It's used in 62 | // authentication handlers and callbacks, and its underlying context.Context is 63 | // exposed on Session in the session Handler. A connection-scoped lock is also 64 | // embedded in the context to make it easier to limit operations per-connection. 65 | type Context interface { 66 | context.Context 67 | sync.Locker 68 | 69 | // User returns the username used when establishing the SSH connection. 70 | User() string 71 | 72 | // SessionID returns the session hash. 73 | SessionID() string 74 | 75 | // ClientVersion returns the version reported by the client. 76 | ClientVersion() string 77 | 78 | // ServerVersion returns the version reported by the server. 79 | ServerVersion() string 80 | 81 | // RemoteAddr returns the remote address for this connection. 82 | RemoteAddr() net.Addr 83 | 84 | // LocalAddr returns the local address for this connection. 85 | LocalAddr() net.Addr 86 | 87 | // Permissions returns the Permissions object used for this connection. 88 | Permissions() *Permissions 89 | 90 | // SetValue allows you to easily write new values into the underlying context. 91 | SetValue(key, value interface{}) 92 | } 93 | 94 | type sshContext struct { 95 | context.Context 96 | *sync.Mutex 97 | 98 | values map[interface{}]interface{} 99 | valuesMu sync.Mutex 100 | } 101 | 102 | func newContext(srv *Server) (*sshContext, context.CancelFunc) { 103 | innerCtx, cancel := context.WithCancel(context.Background()) 104 | ctx := &sshContext{Context: innerCtx, Mutex: &sync.Mutex{}, values: make(map[interface{}]interface{})} 105 | ctx.SetValue(ContextKeyServer, srv) 106 | perms := &Permissions{&gossh.Permissions{}} 107 | ctx.SetValue(ContextKeyPermissions, perms) 108 | return ctx, cancel 109 | } 110 | 111 | // this is separate from newContext because we will get ConnMetadata 112 | // at different points so it needs to be applied separately 113 | func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { 114 | if ctx.Value(ContextKeySessionID) != nil { 115 | return 116 | } 117 | ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) 118 | ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) 119 | ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) 120 | ctx.SetValue(ContextKeyUser, conn.User()) 121 | ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) 122 | ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) 123 | } 124 | 125 | func (ctx *sshContext) Value(key interface{}) interface{} { 126 | ctx.valuesMu.Lock() 127 | defer ctx.valuesMu.Unlock() 128 | if v, ok := ctx.values[key]; ok { 129 | return v 130 | } 131 | return ctx.Context.Value(key) 132 | } 133 | 134 | func (ctx *sshContext) SetValue(key, value interface{}) { 135 | ctx.valuesMu.Lock() 136 | defer ctx.valuesMu.Unlock() 137 | ctx.values[key] = value 138 | } 139 | 140 | func (ctx *sshContext) User() string { 141 | return ctx.Value(ContextKeyUser).(string) 142 | } 143 | 144 | func (ctx *sshContext) SessionID() string { 145 | return ctx.Value(ContextKeySessionID).(string) 146 | } 147 | 148 | func (ctx *sshContext) ClientVersion() string { 149 | return ctx.Value(ContextKeyClientVersion).(string) 150 | } 151 | 152 | func (ctx *sshContext) ServerVersion() string { 153 | return ctx.Value(ContextKeyServerVersion).(string) 154 | } 155 | 156 | func (ctx *sshContext) RemoteAddr() net.Addr { 157 | if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { 158 | return addr 159 | } 160 | return nil 161 | } 162 | 163 | func (ctx *sshContext) LocalAddr() net.Addr { 164 | return ctx.Value(ContextKeyLocalAddr).(net.Addr) 165 | } 166 | 167 | func (ctx *sshContext) Permissions() *Permissions { 168 | return ctx.Value(ContextKeyPermissions).(*Permissions) 169 | } 170 | -------------------------------------------------------------------------------- /context_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestSetPermissions(t *testing.T) { 9 | t.Parallel() 10 | permsExt := map[string]string{ 11 | "foo": "bar", 12 | } 13 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 14 | Handler: func(s Session) { 15 | if _, ok := s.Permissions().Extensions["foo"]; !ok { 16 | t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) 17 | } 18 | }, 19 | }, nil, PasswordAuth(func(ctx Context, password string) bool { 20 | ctx.Permissions().Extensions = permsExt 21 | return true 22 | })) 23 | defer cleanup() 24 | if err := session.Run(""); err != nil { 25 | t.Fatal(err) 26 | } 27 | } 28 | 29 | func TestSetValue(t *testing.T) { 30 | t.Parallel() 31 | value := map[string]string{ 32 | "foo": "bar", 33 | } 34 | key := "testValue" 35 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 36 | Handler: func(s Session) { 37 | v := s.Context().Value(key).(map[string]string) 38 | if v["foo"] != value["foo"] { 39 | t.Fatalf("got %#v; want %#v", v, value) 40 | } 41 | }, 42 | }, nil, PasswordAuth(func(ctx Context, password string) bool { 43 | ctx.SetValue(key, value) 44 | return true 45 | })) 46 | defer cleanup() 47 | if err := session.Run(""); err != nil { 48 | t.Fatal(err) 49 | } 50 | } 51 | 52 | func TestSetValueConcurrency(t *testing.T) { 53 | ctx, cancel := newContext(nil) 54 | defer cancel() 55 | 56 | go func() { 57 | for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue 58 | _, _ = ctx.Deadline() 59 | _ = ctx.Err() 60 | _ = ctx.Value("foo") 61 | select { 62 | case <-ctx.Done(): 63 | break 64 | default: 65 | } 66 | } 67 | }() 68 | ctx.SetValue("bar", -1) // a context value which never changes 69 | now := time.Now() 70 | var cnt int64 71 | go func() { 72 | for time.Since(now) < 100*time.Millisecond { 73 | cnt++ 74 | ctx.SetValue("foo", cnt) // a context value which changes a lot 75 | } 76 | cancel() 77 | }() 78 | <-ctx.Done() 79 | if ctx.Value("foo") != cnt { 80 | t.Fatal("context.Value(foo) doesn't match latest SetValue") 81 | } 82 | if ctx.Value("bar") != -1 { 83 | t.Fatal("context.Value(bar) doesn't match latest SetValue") 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package ssh wraps the crypto/ssh package with a higher-level API for building 3 | SSH servers. The goal of the API was to make it as simple as using net/http, so 4 | the API is very similar. 5 | 6 | You should be able to build any SSH server using only this package, which wraps 7 | relevant types and some functions from crypto/ssh. However, you still need to 8 | use crypto/ssh for building SSH clients. 9 | 10 | ListenAndServe starts an SSH server with a given address, handler, and options. The 11 | handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: 12 | 13 | ssh.Handle(func(s ssh.Session) { 14 | io.WriteString(s, "Hello world\n") 15 | }) 16 | 17 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 18 | 19 | If you don't specify a host key, it will generate one every time. This is convenient 20 | except you'll have to deal with clients being confused that the host key is different. 21 | It's a better idea to generate or point to an existing key on your system: 22 | 23 | log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) 24 | 25 | Although all options have functional option helpers, another way to control the 26 | server's behavior is by creating a custom Server: 27 | 28 | s := &ssh.Server{ 29 | Addr: ":2222", 30 | Handler: sessionHandler, 31 | PublicKeyHandler: authHandler, 32 | } 33 | s.AddHostKey(hostKeySigner) 34 | 35 | log.Fatal(s.ListenAndServe()) 36 | 37 | This package automatically handles basic SSH requests like setting environment 38 | variables, requesting PTY, and changing window size. These requests are 39 | processed, responded to, and any relevant state is updated. This state is then 40 | exposed to you via the Session interface. 41 | 42 | The one big feature missing from the Session abstraction is signals. This was 43 | started, but not completed. Pull Requests welcome! 44 | */ 45 | package ssh 46 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package ssh_test 2 | 3 | import ( 4 | "io" 5 | "os" 6 | 7 | "github.com/gliderlabs/ssh" 8 | ) 9 | 10 | func ExampleListenAndServe() { 11 | ssh.ListenAndServe(":2222", func(s ssh.Session) { 12 | io.WriteString(s, "Hello world\n") 13 | }) 14 | } 15 | 16 | func ExamplePasswordAuth() { 17 | ssh.ListenAndServe(":2222", nil, 18 | ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { 19 | return pass == "secret" 20 | }), 21 | ) 22 | } 23 | 24 | func ExampleNoPty() { 25 | ssh.ListenAndServe(":2222", nil, ssh.NoPty()) 26 | } 27 | 28 | func ExamplePublicKeyAuth() { 29 | ssh.ListenAndServe(":2222", nil, 30 | ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { 31 | data, _ := os.ReadFile("/path/to/allowed/key.pub") 32 | allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data) 33 | return ssh.KeysEqual(key, allowed) 34 | }), 35 | ) 36 | } 37 | 38 | func ExampleHostKeyFile() { 39 | ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) 40 | } 41 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gliderlabs/ssh 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be 7 | golang.org/x/crypto v0.31.0 8 | ) 9 | 10 | require golang.org/x/sys v0.28.0 // indirect 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= 2 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= 3 | golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= 4 | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= 5 | golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= 6 | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 7 | golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= 8 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "os" 5 | 6 | gossh "golang.org/x/crypto/ssh" 7 | ) 8 | 9 | // PasswordAuth returns a functional option that sets PasswordHandler on the server. 10 | func PasswordAuth(fn PasswordHandler) Option { 11 | return func(srv *Server) error { 12 | srv.PasswordHandler = fn 13 | return nil 14 | } 15 | } 16 | 17 | // PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. 18 | func PublicKeyAuth(fn PublicKeyHandler) Option { 19 | return func(srv *Server) error { 20 | srv.PublicKeyHandler = fn 21 | return nil 22 | } 23 | } 24 | 25 | // HostKeyFile returns a functional option that adds HostSigners to the server 26 | // from a PEM file at filepath. 27 | func HostKeyFile(filepath string) Option { 28 | return func(srv *Server) error { 29 | pemBytes, err := os.ReadFile(filepath) 30 | if err != nil { 31 | return err 32 | } 33 | 34 | signer, err := gossh.ParsePrivateKey(pemBytes) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | srv.AddHostKey(signer) 40 | 41 | return nil 42 | } 43 | } 44 | 45 | func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { 46 | return func(srv *Server) error { 47 | srv.KeyboardInteractiveHandler = fn 48 | return nil 49 | } 50 | } 51 | 52 | // HostKeyPEM returns a functional option that adds HostSigners to the server 53 | // from a PEM file as bytes. 54 | func HostKeyPEM(bytes []byte) Option { 55 | return func(srv *Server) error { 56 | signer, err := gossh.ParsePrivateKey(bytes) 57 | if err != nil { 58 | return err 59 | } 60 | 61 | srv.AddHostKey(signer) 62 | 63 | return nil 64 | } 65 | } 66 | 67 | // NoPty returns a functional option that sets PtyCallback to return false, 68 | // denying PTY requests. 69 | func NoPty() Option { 70 | return func(srv *Server) error { 71 | srv.PtyCallback = func(ctx Context, pty Pty) bool { 72 | return false 73 | } 74 | return nil 75 | } 76 | } 77 | 78 | // WrapConn returns a functional option that sets ConnCallback on the server. 79 | func WrapConn(fn ConnCallback) Option { 80 | return func(srv *Server) error { 81 | srv.ConnCallback = fn 82 | return nil 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "net" 5 | "strings" 6 | "sync/atomic" 7 | "testing" 8 | 9 | gossh "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { 13 | for _, option := range options { 14 | if err := srv.SetOption(option); err != nil { 15 | t.Fatal(err) 16 | } 17 | } 18 | return newTestSession(t, srv, cfg) 19 | } 20 | 21 | func TestPasswordAuth(t *testing.T) { 22 | t.Parallel() 23 | testUser := "testuser" 24 | testPass := "testpass" 25 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 26 | Handler: func(s Session) { 27 | // noop 28 | }, 29 | }, &gossh.ClientConfig{ 30 | User: testUser, 31 | Auth: []gossh.AuthMethod{ 32 | gossh.Password(testPass), 33 | }, 34 | HostKeyCallback: gossh.InsecureIgnoreHostKey(), 35 | }, PasswordAuth(func(ctx Context, password string) bool { 36 | if ctx.User() != testUser { 37 | t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) 38 | } 39 | if password != testPass { 40 | t.Fatalf("user = %#v; want %#v", password, testPass) 41 | } 42 | return true 43 | })) 44 | defer cleanup() 45 | if err := session.Run(""); err != nil { 46 | t.Fatal(err) 47 | } 48 | } 49 | 50 | func TestPasswordAuthBadPass(t *testing.T) { 51 | t.Parallel() 52 | l := newLocalListener() 53 | srv := &Server{Handler: func(s Session) {}} 54 | srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { 55 | return false 56 | })) 57 | go srv.serveOnce(l) 58 | _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ 59 | User: "testuser", 60 | Auth: []gossh.AuthMethod{ 61 | gossh.Password("testpass"), 62 | }, 63 | HostKeyCallback: gossh.InsecureIgnoreHostKey(), 64 | }) 65 | if err != nil { 66 | if !strings.Contains(err.Error(), "unable to authenticate") { 67 | t.Fatal(err) 68 | } 69 | } 70 | } 71 | 72 | type wrappedConn struct { 73 | net.Conn 74 | written int32 75 | } 76 | 77 | func (c *wrappedConn) Write(p []byte) (n int, err error) { 78 | n, err = c.Conn.Write(p) 79 | atomic.AddInt32(&(c.written), int32(n)) 80 | return 81 | } 82 | 83 | func TestConnWrapping(t *testing.T) { 84 | t.Parallel() 85 | var wrapped *wrappedConn 86 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 87 | Handler: func(s Session) { 88 | // nothing 89 | }, 90 | }, &gossh.ClientConfig{ 91 | User: "testuser", 92 | Auth: []gossh.AuthMethod{ 93 | gossh.Password("testpass"), 94 | }, 95 | HostKeyCallback: gossh.InsecureIgnoreHostKey(), 96 | }, PasswordAuth(func(ctx Context, password string) bool { 97 | return true 98 | }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { 99 | wrapped = &wrappedConn{conn, 0} 100 | return wrapped 101 | })) 102 | defer cleanup() 103 | if err := session.Shell(); err != nil { 104 | t.Fatal(err) 105 | } 106 | if atomic.LoadInt32(&(wrapped.written)) == 0 { 107 | t.Fatal("wrapped conn not written to") 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "sync" 9 | "time" 10 | 11 | gossh "golang.org/x/crypto/ssh" 12 | ) 13 | 14 | // ErrServerClosed is returned by the Server's Serve, ListenAndServe, 15 | // and ListenAndServeTLS methods after a call to Shutdown or Close. 16 | var ErrServerClosed = errors.New("ssh: Server closed") 17 | 18 | type SubsystemHandler func(s Session) 19 | 20 | var DefaultSubsystemHandlers = map[string]SubsystemHandler{} 21 | 22 | type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) 23 | 24 | var DefaultRequestHandlers = map[string]RequestHandler{} 25 | 26 | type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) 27 | 28 | var DefaultChannelHandlers = map[string]ChannelHandler{ 29 | "session": DefaultSessionHandler, 30 | } 31 | 32 | // Server defines parameters for running an SSH server. The zero value for 33 | // Server is a valid configuration. When both PasswordHandler and 34 | // PublicKeyHandler are nil, no client authentication is performed. 35 | type Server struct { 36 | Addr string // TCP address to listen on, ":22" if empty 37 | Handler Handler // handler to invoke, ssh.DefaultHandler if nil 38 | HostSigners []Signer // private keys for the host key, must have at least one 39 | Version string // server version to be sent before the initial handshake 40 | Banner string // server banner 41 | 42 | BannerHandler BannerHandler // server banner handler, overrides Banner 43 | KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler 44 | PasswordHandler PasswordHandler // password authentication handler 45 | PublicKeyHandler PublicKeyHandler // public key authentication handler 46 | PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil 47 | ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling 48 | LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil 49 | ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil 50 | ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options 51 | SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions 52 | 53 | ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures 54 | 55 | HandshakeTimeout time.Duration // connection timeout until successful handshake, none if empty 56 | IdleTimeout time.Duration // connection timeout when no activity, none if empty 57 | MaxTimeout time.Duration // absolute connection timeout, none if empty 58 | 59 | // ChannelHandlers allow overriding the built-in session handlers or provide 60 | // extensions to the protocol, such as tcpip forwarding. By default only the 61 | // "session" handler is enabled. 62 | ChannelHandlers map[string]ChannelHandler 63 | 64 | // RequestHandlers allow overriding the server-level request handlers or 65 | // provide extensions to the protocol, such as tcpip forwarding. By default 66 | // no handlers are enabled. 67 | RequestHandlers map[string]RequestHandler 68 | 69 | // SubsystemHandlers are handlers which are similar to the usual SSH command 70 | // handlers, but handle named subsystems. 71 | SubsystemHandlers map[string]SubsystemHandler 72 | 73 | listenerWg sync.WaitGroup 74 | mu sync.RWMutex 75 | listeners map[net.Listener]struct{} 76 | conns map[*gossh.ServerConn]struct{} 77 | connWg sync.WaitGroup 78 | doneChan chan struct{} 79 | } 80 | 81 | func (srv *Server) ensureHostSigner() error { 82 | srv.mu.Lock() 83 | defer srv.mu.Unlock() 84 | 85 | if len(srv.HostSigners) == 0 { 86 | signer, err := generateSigner() 87 | if err != nil { 88 | return err 89 | } 90 | srv.HostSigners = append(srv.HostSigners, signer) 91 | } 92 | return nil 93 | } 94 | 95 | func (srv *Server) ensureHandlers() { 96 | srv.mu.Lock() 97 | defer srv.mu.Unlock() 98 | 99 | if srv.RequestHandlers == nil { 100 | srv.RequestHandlers = map[string]RequestHandler{} 101 | for k, v := range DefaultRequestHandlers { 102 | srv.RequestHandlers[k] = v 103 | } 104 | } 105 | if srv.ChannelHandlers == nil { 106 | srv.ChannelHandlers = map[string]ChannelHandler{} 107 | for k, v := range DefaultChannelHandlers { 108 | srv.ChannelHandlers[k] = v 109 | } 110 | } 111 | if srv.SubsystemHandlers == nil { 112 | srv.SubsystemHandlers = map[string]SubsystemHandler{} 113 | for k, v := range DefaultSubsystemHandlers { 114 | srv.SubsystemHandlers[k] = v 115 | } 116 | } 117 | } 118 | 119 | func (srv *Server) config(ctx Context) *gossh.ServerConfig { 120 | srv.mu.RLock() 121 | defer srv.mu.RUnlock() 122 | 123 | var config *gossh.ServerConfig 124 | if srv.ServerConfigCallback == nil { 125 | config = &gossh.ServerConfig{} 126 | } else { 127 | config = srv.ServerConfigCallback(ctx) 128 | } 129 | for _, signer := range srv.HostSigners { 130 | config.AddHostKey(signer) 131 | } 132 | if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { 133 | config.NoClientAuth = true 134 | } 135 | if srv.Version != "" { 136 | config.ServerVersion = "SSH-2.0-" + srv.Version 137 | } 138 | if srv.Banner != "" { 139 | config.BannerCallback = func(_ gossh.ConnMetadata) string { 140 | return srv.Banner 141 | } 142 | } 143 | if srv.BannerHandler != nil { 144 | config.BannerCallback = func(conn gossh.ConnMetadata) string { 145 | applyConnMetadata(ctx, conn) 146 | return srv.BannerHandler(ctx) 147 | } 148 | } 149 | if srv.PasswordHandler != nil { 150 | config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { 151 | applyConnMetadata(ctx, conn) 152 | if ok := srv.PasswordHandler(ctx, string(password)); !ok { 153 | return ctx.Permissions().Permissions, fmt.Errorf("permission denied") 154 | } 155 | return ctx.Permissions().Permissions, nil 156 | } 157 | } 158 | if srv.PublicKeyHandler != nil { 159 | config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { 160 | applyConnMetadata(ctx, conn) 161 | if ok := srv.PublicKeyHandler(ctx, key); !ok { 162 | return ctx.Permissions().Permissions, fmt.Errorf("permission denied") 163 | } 164 | ctx.SetValue(ContextKeyPublicKey, key) 165 | return ctx.Permissions().Permissions, nil 166 | } 167 | } 168 | if srv.KeyboardInteractiveHandler != nil { 169 | config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { 170 | applyConnMetadata(ctx, conn) 171 | if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { 172 | return ctx.Permissions().Permissions, fmt.Errorf("permission denied") 173 | } 174 | return ctx.Permissions().Permissions, nil 175 | } 176 | } 177 | return config 178 | } 179 | 180 | // Handle sets the Handler for the server. 181 | func (srv *Server) Handle(fn Handler) { 182 | srv.mu.Lock() 183 | defer srv.mu.Unlock() 184 | 185 | srv.Handler = fn 186 | } 187 | 188 | // Close immediately closes all active listeners and all active 189 | // connections. 190 | // 191 | // Close returns any error returned from closing the Server's 192 | // underlying Listener(s). 193 | func (srv *Server) Close() error { 194 | srv.mu.Lock() 195 | defer srv.mu.Unlock() 196 | 197 | srv.closeDoneChanLocked() 198 | err := srv.closeListenersLocked() 199 | for c := range srv.conns { 200 | c.Close() 201 | delete(srv.conns, c) 202 | } 203 | return err 204 | } 205 | 206 | // Shutdown gracefully shuts down the server without interrupting any 207 | // active connections. Shutdown works by first closing all open 208 | // listeners, and then waiting indefinitely for connections to close. 209 | // If the provided context expires before the shutdown is complete, 210 | // then the context's error is returned. 211 | func (srv *Server) Shutdown(ctx context.Context) error { 212 | srv.mu.Lock() 213 | lnerr := srv.closeListenersLocked() 214 | srv.closeDoneChanLocked() 215 | srv.mu.Unlock() 216 | 217 | finished := make(chan struct{}, 1) 218 | go func() { 219 | srv.listenerWg.Wait() 220 | srv.connWg.Wait() 221 | finished <- struct{}{} 222 | }() 223 | 224 | select { 225 | case <-ctx.Done(): 226 | return ctx.Err() 227 | case <-finished: 228 | return lnerr 229 | } 230 | } 231 | 232 | // Serve accepts incoming connections on the Listener l, creating a new 233 | // connection goroutine for each. The connection goroutines read requests and then 234 | // calls srv.Handler to handle sessions. 235 | // 236 | // Serve always returns a non-nil error. 237 | func (srv *Server) Serve(l net.Listener) error { 238 | srv.ensureHandlers() 239 | defer l.Close() 240 | if err := srv.ensureHostSigner(); err != nil { 241 | return err 242 | } 243 | if srv.Handler == nil { 244 | srv.Handler = DefaultHandler 245 | } 246 | var tempDelay time.Duration 247 | 248 | srv.trackListener(l, true) 249 | defer srv.trackListener(l, false) 250 | for { 251 | conn, e := l.Accept() 252 | if e != nil { 253 | select { 254 | case <-srv.getDoneChan(): 255 | return ErrServerClosed 256 | default: 257 | } 258 | if ne, ok := e.(net.Error); ok && ne.Temporary() { 259 | if tempDelay == 0 { 260 | tempDelay = 5 * time.Millisecond 261 | } else { 262 | tempDelay *= 2 263 | } 264 | if max := 1 * time.Second; tempDelay > max { 265 | tempDelay = max 266 | } 267 | time.Sleep(tempDelay) 268 | continue 269 | } 270 | return e 271 | } 272 | go srv.HandleConn(conn) 273 | } 274 | } 275 | 276 | func (srv *Server) HandleConn(newConn net.Conn) { 277 | ctx, cancel := newContext(srv) 278 | if srv.ConnCallback != nil { 279 | cbConn := srv.ConnCallback(ctx, newConn) 280 | if cbConn == nil { 281 | newConn.Close() 282 | return 283 | } 284 | newConn = cbConn 285 | } 286 | conn := &serverConn{ 287 | Conn: newConn, 288 | idleTimeout: srv.IdleTimeout, 289 | closeCanceler: cancel, 290 | } 291 | if srv.MaxTimeout > 0 { 292 | conn.maxDeadline = time.Now().Add(srv.MaxTimeout) 293 | } 294 | if srv.HandshakeTimeout > 0 { 295 | conn.handshakeDeadline = time.Now().Add(srv.HandshakeTimeout) 296 | } 297 | conn.updateDeadline() 298 | defer conn.Close() 299 | sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) 300 | if err != nil { 301 | if srv.ConnectionFailedCallback != nil { 302 | srv.ConnectionFailedCallback(conn, err) 303 | } 304 | return 305 | } 306 | conn.handshakeDeadline = time.Time{} 307 | conn.updateDeadline() 308 | srv.trackConn(sshConn, true) 309 | defer srv.trackConn(sshConn, false) 310 | 311 | ctx.SetValue(ContextKeyConn, sshConn) 312 | applyConnMetadata(ctx, sshConn) 313 | //go gossh.DiscardRequests(reqs) 314 | go srv.handleRequests(ctx, reqs) 315 | for ch := range chans { 316 | handler := srv.ChannelHandlers[ch.ChannelType()] 317 | if handler == nil { 318 | handler = srv.ChannelHandlers["default"] 319 | } 320 | if handler == nil { 321 | ch.Reject(gossh.UnknownChannelType, "unsupported channel type") 322 | continue 323 | } 324 | go handler(srv, sshConn, ch, ctx) 325 | } 326 | } 327 | 328 | func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { 329 | for req := range in { 330 | handler := srv.RequestHandlers[req.Type] 331 | if handler == nil { 332 | handler = srv.RequestHandlers["default"] 333 | } 334 | if handler == nil { 335 | req.Reply(false, nil) 336 | continue 337 | } 338 | /*reqCtx, cancel := context.WithCancel(ctx) 339 | defer cancel() */ 340 | ret, payload := handler(ctx, srv, req) 341 | req.Reply(ret, payload) 342 | } 343 | } 344 | 345 | // ListenAndServe listens on the TCP network address srv.Addr and then calls 346 | // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. 347 | // ListenAndServe always returns a non-nil error. 348 | func (srv *Server) ListenAndServe() error { 349 | addr := srv.Addr 350 | if addr == "" { 351 | addr = ":22" 352 | } 353 | ln, err := net.Listen("tcp", addr) 354 | if err != nil { 355 | return err 356 | } 357 | return srv.Serve(ln) 358 | } 359 | 360 | // AddHostKey adds a private key as a host key. If an existing host key exists 361 | // with the same algorithm, it is overwritten. Each server config must have at 362 | // least one host key. 363 | func (srv *Server) AddHostKey(key Signer) { 364 | srv.mu.Lock() 365 | defer srv.mu.Unlock() 366 | 367 | // these are later added via AddHostKey on ServerConfig, which performs the 368 | // check for one of every algorithm. 369 | 370 | // This check is based on the AddHostKey method from the x/crypto/ssh 371 | // library. This allows us to only keep one active key for each type on a 372 | // server at once. So, if you're dynamically updating keys at runtime, this 373 | // list will not keep growing. 374 | for i, k := range srv.HostSigners { 375 | if k.PublicKey().Type() == key.PublicKey().Type() { 376 | srv.HostSigners[i] = key 377 | return 378 | } 379 | } 380 | 381 | srv.HostSigners = append(srv.HostSigners, key) 382 | } 383 | 384 | // SetOption runs a functional option against the server. 385 | func (srv *Server) SetOption(option Option) error { 386 | // NOTE: there is a potential race here for any option that doesn't call an 387 | // internal method. We can't actually lock here because if something calls 388 | // (as an example) AddHostKey, it will deadlock. 389 | 390 | //srv.mu.Lock() 391 | //defer srv.mu.Unlock() 392 | 393 | return option(srv) 394 | } 395 | 396 | func (srv *Server) getDoneChan() <-chan struct{} { 397 | srv.mu.Lock() 398 | defer srv.mu.Unlock() 399 | 400 | return srv.getDoneChanLocked() 401 | } 402 | 403 | func (srv *Server) getDoneChanLocked() chan struct{} { 404 | if srv.doneChan == nil { 405 | srv.doneChan = make(chan struct{}) 406 | } 407 | return srv.doneChan 408 | } 409 | 410 | func (srv *Server) closeDoneChanLocked() { 411 | ch := srv.getDoneChanLocked() 412 | select { 413 | case <-ch: 414 | // Already closed. Don't close again. 415 | default: 416 | // Safe to close here. We're the only closer, guarded 417 | // by srv.mu. 418 | close(ch) 419 | } 420 | } 421 | 422 | func (srv *Server) closeListenersLocked() error { 423 | var err error 424 | for ln := range srv.listeners { 425 | if cerr := ln.Close(); cerr != nil && err == nil { 426 | err = cerr 427 | } 428 | delete(srv.listeners, ln) 429 | } 430 | return err 431 | } 432 | 433 | func (srv *Server) trackListener(ln net.Listener, add bool) { 434 | srv.mu.Lock() 435 | defer srv.mu.Unlock() 436 | 437 | if srv.listeners == nil { 438 | srv.listeners = make(map[net.Listener]struct{}) 439 | } 440 | if add { 441 | // If the *Server is being reused after a previous 442 | // Close or Shutdown, reset its doneChan: 443 | if len(srv.listeners) == 0 && len(srv.conns) == 0 { 444 | srv.doneChan = nil 445 | } 446 | srv.listeners[ln] = struct{}{} 447 | srv.listenerWg.Add(1) 448 | } else { 449 | delete(srv.listeners, ln) 450 | srv.listenerWg.Done() 451 | } 452 | } 453 | 454 | func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { 455 | srv.mu.Lock() 456 | defer srv.mu.Unlock() 457 | 458 | if srv.conns == nil { 459 | srv.conns = make(map[*gossh.ServerConn]struct{}) 460 | } 461 | if add { 462 | srv.conns[c] = struct{}{} 463 | srv.connWg.Add(1) 464 | } else { 465 | delete(srv.conns, c) 466 | srv.connWg.Done() 467 | } 468 | } 469 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | "net" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestAddHostKey(t *testing.T) { 13 | s := Server{} 14 | signer, err := generateSigner() 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | s.AddHostKey(signer) 19 | if len(s.HostSigners) != 1 { 20 | t.Fatal("Key was not properly added") 21 | } 22 | signer, err = generateSigner() 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | s.AddHostKey(signer) 27 | if len(s.HostSigners) != 1 { 28 | t.Fatal("Key was not properly replaced") 29 | } 30 | } 31 | 32 | func TestServerShutdown(t *testing.T) { 33 | l := newLocalListener() 34 | testBytes := []byte("Hello world\n") 35 | s := &Server{ 36 | Handler: func(s Session) { 37 | s.Write(testBytes) 38 | time.Sleep(50 * time.Millisecond) 39 | }, 40 | } 41 | go func() { 42 | err := s.Serve(l) 43 | if err != nil && err != ErrServerClosed { 44 | t.Fatal(err) 45 | } 46 | }() 47 | sessDone := make(chan struct{}) 48 | sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) 49 | go func() { 50 | defer cleanup() 51 | defer close(sessDone) 52 | var stdout bytes.Buffer 53 | sess.Stdout = &stdout 54 | if err := sess.Run(""); err != nil { 55 | t.Fatal(err) 56 | } 57 | if !bytes.Equal(stdout.Bytes(), testBytes) { 58 | t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) 59 | } 60 | }() 61 | 62 | srvDone := make(chan struct{}) 63 | go func() { 64 | defer close(srvDone) 65 | err := s.Shutdown(context.Background()) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | }() 70 | 71 | timeout := time.After(2 * time.Second) 72 | select { 73 | case <-timeout: 74 | t.Fatal("timeout") 75 | return 76 | case <-srvDone: 77 | // TODO: add timeout for sessDone 78 | <-sessDone 79 | return 80 | } 81 | } 82 | 83 | func TestServerClose(t *testing.T) { 84 | l := newLocalListener() 85 | s := &Server{ 86 | Handler: func(s Session) { 87 | time.Sleep(5 * time.Second) 88 | }, 89 | } 90 | go func() { 91 | err := s.Serve(l) 92 | if err != nil && err != ErrServerClosed { 93 | t.Fatal(err) 94 | } 95 | }() 96 | 97 | clientDoneChan := make(chan struct{}) 98 | closeDoneChan := make(chan struct{}) 99 | 100 | sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) 101 | go func() { 102 | defer cleanup() 103 | defer close(clientDoneChan) 104 | <-closeDoneChan 105 | if err := sess.Run(""); err != nil && err != io.EOF { 106 | t.Fatal(err) 107 | } 108 | }() 109 | 110 | go func() { 111 | err := s.Close() 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | close(closeDoneChan) 116 | }() 117 | 118 | timeout := time.After(100 * time.Millisecond) 119 | select { 120 | case <-timeout: 121 | t.Error("timeout") 122 | return 123 | case <-s.getDoneChan(): 124 | <-clientDoneChan 125 | return 126 | } 127 | } 128 | 129 | func TestServerHandshakeTimeout(t *testing.T) { 130 | l := newLocalListener() 131 | 132 | s := &Server{ 133 | HandshakeTimeout: time.Millisecond, 134 | } 135 | go func() { 136 | if err := s.Serve(l); err != nil { 137 | t.Error(err) 138 | } 139 | }() 140 | 141 | conn, err := net.Dial("tcp", l.Addr().String()) 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | defer conn.Close() 146 | 147 | ch := make(chan struct{}) 148 | go func() { 149 | defer close(ch) 150 | io.Copy(io.Discard, conn) 151 | }() 152 | 153 | select { 154 | case <-ch: 155 | return 156 | case <-time.After(time.Second): 157 | t.Fatal("client connection was not force-closed") 158 | return 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "sync" 9 | 10 | "github.com/anmitsu/go-shlex" 11 | gossh "golang.org/x/crypto/ssh" 12 | ) 13 | 14 | // Session provides access to information about an SSH session and methods 15 | // to read and write to the SSH channel with an embedded Channel interface from 16 | // crypto/ssh. 17 | // 18 | // When Command() returns an empty slice, the user requested a shell. Otherwise 19 | // the user is performing an exec with those command arguments. 20 | // 21 | // TODO: Signals 22 | type Session interface { 23 | gossh.Channel 24 | 25 | // User returns the username used when establishing the SSH connection. 26 | User() string 27 | 28 | // RemoteAddr returns the net.Addr of the client side of the connection. 29 | RemoteAddr() net.Addr 30 | 31 | // LocalAddr returns the net.Addr of the server side of the connection. 32 | LocalAddr() net.Addr 33 | 34 | // Environ returns a copy of strings representing the environment set by the 35 | // user for this session, in the form "key=value". 36 | Environ() []string 37 | 38 | // Exit sends an exit status and then closes the session. 39 | Exit(code int) error 40 | 41 | // Command returns a shell parsed slice of arguments that were provided by the 42 | // user. Shell parsing splits the command string according to POSIX shell rules, 43 | // which considers quoting not just whitespace. 44 | Command() []string 45 | 46 | // RawCommand returns the exact command that was provided by the user. 47 | RawCommand() string 48 | 49 | // Subsystem returns the subsystem requested by the user. 50 | Subsystem() string 51 | 52 | // PublicKey returns the PublicKey used to authenticate. If a public key was not 53 | // used it will return nil. 54 | PublicKey() PublicKey 55 | 56 | // Context returns the connection's context. The returned context is always 57 | // non-nil and holds the same data as the Context passed into auth 58 | // handlers and callbacks. 59 | // 60 | // The context is canceled when the client's connection closes or I/O 61 | // operation fails. 62 | Context() Context 63 | 64 | // Permissions returns a copy of the Permissions object that was available for 65 | // setup in the auth handlers via the Context. 66 | Permissions() Permissions 67 | 68 | // Pty returns PTY information, a channel of window size changes, and a boolean 69 | // of whether or not a PTY was accepted for this session. 70 | Pty() (Pty, <-chan Window, bool) 71 | 72 | // Signals registers a channel to receive signals sent from the client. The 73 | // channel must handle signal sends or it will block the SSH request loop. 74 | // Registering nil will unregister the channel from signal sends. During the 75 | // time no channel is registered signals are buffered up to a reasonable amount. 76 | // If there are buffered signals when a channel is registered, they will be 77 | // sent in order on the channel immediately after registering. 78 | Signals(c chan<- Signal) 79 | 80 | // Break regisers a channel to receive notifications of break requests sent 81 | // from the client. The channel must handle break requests, or it will block 82 | // the request handling loop. Registering nil will unregister the channel. 83 | // During the time that no channel is registered, breaks are ignored. 84 | Break(c chan<- bool) 85 | } 86 | 87 | // maxSigBufSize is how many signals will be buffered 88 | // when there is no signal channel specified 89 | const maxSigBufSize = 128 90 | 91 | func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { 92 | ch, reqs, err := newChan.Accept() 93 | if err != nil { 94 | // TODO: trigger event callback 95 | return 96 | } 97 | sess := &session{ 98 | Channel: ch, 99 | conn: conn, 100 | handler: srv.Handler, 101 | ptyCb: srv.PtyCallback, 102 | sessReqCb: srv.SessionRequestCallback, 103 | subsystemHandlers: srv.SubsystemHandlers, 104 | ctx: ctx, 105 | } 106 | sess.handleRequests(reqs) 107 | } 108 | 109 | type session struct { 110 | sync.Mutex 111 | gossh.Channel 112 | conn *gossh.ServerConn 113 | handler Handler 114 | subsystemHandlers map[string]SubsystemHandler 115 | handled bool 116 | exited bool 117 | pty *Pty 118 | winch chan Window 119 | env []string 120 | ptyCb PtyCallback 121 | sessReqCb SessionRequestCallback 122 | rawCmd string 123 | subsystem string 124 | ctx Context 125 | sigCh chan<- Signal 126 | sigBuf []Signal 127 | breakCh chan<- bool 128 | } 129 | 130 | func (sess *session) Write(p []byte) (n int, err error) { 131 | if sess.pty != nil { 132 | m := len(p) 133 | // normalize \n to \r\n when pty is accepted. 134 | // this is a hardcoded shortcut since we don't support terminal modes. 135 | p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) 136 | p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) 137 | n, err = sess.Channel.Write(p) 138 | if n > m { 139 | n = m 140 | } 141 | return 142 | } 143 | return sess.Channel.Write(p) 144 | } 145 | 146 | func (sess *session) PublicKey() PublicKey { 147 | sessionkey := sess.ctx.Value(ContextKeyPublicKey) 148 | if sessionkey == nil { 149 | return nil 150 | } 151 | return sessionkey.(PublicKey) 152 | } 153 | 154 | func (sess *session) Permissions() Permissions { 155 | // use context permissions because its properly 156 | // wrapped and easier to dereference 157 | perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) 158 | return *perms 159 | } 160 | 161 | func (sess *session) Context() Context { 162 | return sess.ctx 163 | } 164 | 165 | func (sess *session) Exit(code int) error { 166 | sess.Lock() 167 | defer sess.Unlock() 168 | if sess.exited { 169 | return errors.New("Session.Exit called multiple times") 170 | } 171 | sess.exited = true 172 | 173 | status := struct{ Status uint32 }{uint32(code)} 174 | _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) 175 | if err != nil { 176 | return err 177 | } 178 | return sess.Close() 179 | } 180 | 181 | func (sess *session) User() string { 182 | return sess.conn.User() 183 | } 184 | 185 | func (sess *session) RemoteAddr() net.Addr { 186 | return sess.conn.RemoteAddr() 187 | } 188 | 189 | func (sess *session) LocalAddr() net.Addr { 190 | return sess.conn.LocalAddr() 191 | } 192 | 193 | func (sess *session) Environ() []string { 194 | return append([]string(nil), sess.env...) 195 | } 196 | 197 | func (sess *session) RawCommand() string { 198 | return sess.rawCmd 199 | } 200 | 201 | func (sess *session) Command() []string { 202 | cmd, _ := shlex.Split(sess.rawCmd, true) 203 | return append([]string(nil), cmd...) 204 | } 205 | 206 | func (sess *session) Subsystem() string { 207 | return sess.subsystem 208 | } 209 | 210 | func (sess *session) Pty() (Pty, <-chan Window, bool) { 211 | if sess.pty != nil { 212 | return *sess.pty, sess.winch, true 213 | } 214 | return Pty{}, sess.winch, false 215 | } 216 | 217 | func (sess *session) Signals(c chan<- Signal) { 218 | sess.Lock() 219 | defer sess.Unlock() 220 | sess.sigCh = c 221 | if len(sess.sigBuf) > 0 { 222 | go func() { 223 | for _, sig := range sess.sigBuf { 224 | sess.sigCh <- sig 225 | } 226 | }() 227 | } 228 | } 229 | 230 | func (sess *session) Break(c chan<- bool) { 231 | sess.Lock() 232 | defer sess.Unlock() 233 | sess.breakCh = c 234 | } 235 | 236 | func (sess *session) handleRequests(reqs <-chan *gossh.Request) { 237 | for req := range reqs { 238 | switch req.Type { 239 | case "shell", "exec": 240 | if sess.handled { 241 | req.Reply(false, nil) 242 | continue 243 | } 244 | 245 | var payload = struct{ Value string }{} 246 | gossh.Unmarshal(req.Payload, &payload) 247 | sess.rawCmd = payload.Value 248 | 249 | // If there's a session policy callback, we need to confirm before 250 | // accepting the session. 251 | if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { 252 | sess.rawCmd = "" 253 | req.Reply(false, nil) 254 | continue 255 | } 256 | 257 | sess.handled = true 258 | req.Reply(true, nil) 259 | 260 | go func() { 261 | sess.handler(sess) 262 | sess.Exit(0) 263 | }() 264 | case "subsystem": 265 | if sess.handled { 266 | req.Reply(false, nil) 267 | continue 268 | } 269 | 270 | var payload = struct{ Value string }{} 271 | gossh.Unmarshal(req.Payload, &payload) 272 | sess.subsystem = payload.Value 273 | 274 | // If there's a session policy callback, we need to confirm before 275 | // accepting the session. 276 | if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { 277 | sess.rawCmd = "" 278 | req.Reply(false, nil) 279 | continue 280 | } 281 | 282 | handler := sess.subsystemHandlers[payload.Value] 283 | if handler == nil { 284 | handler = sess.subsystemHandlers["default"] 285 | } 286 | if handler == nil { 287 | req.Reply(false, nil) 288 | continue 289 | } 290 | 291 | sess.handled = true 292 | req.Reply(true, nil) 293 | 294 | go func() { 295 | handler(sess) 296 | sess.Exit(0) 297 | }() 298 | case "env": 299 | if sess.handled { 300 | req.Reply(false, nil) 301 | continue 302 | } 303 | var kv struct{ Key, Value string } 304 | gossh.Unmarshal(req.Payload, &kv) 305 | sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) 306 | req.Reply(true, nil) 307 | case "signal": 308 | var payload struct{ Signal string } 309 | gossh.Unmarshal(req.Payload, &payload) 310 | sess.Lock() 311 | if sess.sigCh != nil { 312 | sess.sigCh <- Signal(payload.Signal) 313 | } else { 314 | if len(sess.sigBuf) < maxSigBufSize { 315 | sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) 316 | } 317 | } 318 | sess.Unlock() 319 | case "pty-req": 320 | if sess.handled || sess.pty != nil { 321 | req.Reply(false, nil) 322 | continue 323 | } 324 | ptyReq, ok := parsePtyRequest(req.Payload) 325 | if !ok { 326 | req.Reply(false, nil) 327 | continue 328 | } 329 | if sess.ptyCb != nil { 330 | ok := sess.ptyCb(sess.ctx, ptyReq) 331 | if !ok { 332 | req.Reply(false, nil) 333 | continue 334 | } 335 | } 336 | sess.pty = &ptyReq 337 | sess.winch = make(chan Window, 1) 338 | sess.winch <- ptyReq.Window 339 | defer func() { 340 | // when reqs is closed 341 | close(sess.winch) 342 | }() 343 | req.Reply(ok, nil) 344 | case "window-change": 345 | if sess.pty == nil { 346 | req.Reply(false, nil) 347 | continue 348 | } 349 | win, ok := parseWinchRequest(req.Payload) 350 | if ok { 351 | sess.pty.Window = win 352 | sess.winch <- win 353 | } 354 | req.Reply(ok, nil) 355 | case agentRequestType: 356 | // TODO: option/callback to allow agent forwarding 357 | SetAgentRequested(sess.ctx) 358 | req.Reply(true, nil) 359 | case "break": 360 | ok := false 361 | sess.Lock() 362 | if sess.breakCh != nil { 363 | sess.breakCh <- true 364 | ok = true 365 | } 366 | req.Reply(ok, nil) 367 | sess.Unlock() 368 | default: 369 | // TODO: debug log 370 | req.Reply(false, nil) 371 | } 372 | } 373 | } 374 | -------------------------------------------------------------------------------- /session_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net" 8 | "testing" 9 | 10 | gossh "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | func (srv *Server) serveOnce(l net.Listener) error { 14 | srv.ensureHandlers() 15 | if err := srv.ensureHostSigner(); err != nil { 16 | return err 17 | } 18 | conn, e := l.Accept() 19 | if e != nil { 20 | return e 21 | } 22 | srv.ChannelHandlers = map[string]ChannelHandler{ 23 | "session": DefaultSessionHandler, 24 | "direct-tcpip": DirectTCPIPHandler, 25 | } 26 | srv.HandleConn(conn) 27 | return nil 28 | } 29 | 30 | func newLocalListener() net.Listener { 31 | l, err := net.Listen("tcp", "127.0.0.1:0") 32 | if err != nil { 33 | if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { 34 | panic(fmt.Sprintf("failed to listen on a port: %v", err)) 35 | } 36 | } 37 | return l 38 | } 39 | 40 | func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { 41 | if config == nil { 42 | config = &gossh.ClientConfig{ 43 | User: "testuser", 44 | Auth: []gossh.AuthMethod{ 45 | gossh.Password("testpass"), 46 | }, 47 | } 48 | } 49 | if config.HostKeyCallback == nil { 50 | config.HostKeyCallback = gossh.InsecureIgnoreHostKey() 51 | } 52 | client, err := gossh.Dial("tcp", addr, config) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | session, err := client.NewSession() 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | return session, client, func() { 61 | session.Close() 62 | client.Close() 63 | } 64 | } 65 | 66 | func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { 67 | l := newLocalListener() 68 | go srv.serveOnce(l) 69 | return newClientSession(t, l.Addr().String(), cfg) 70 | } 71 | 72 | func TestStdout(t *testing.T) { 73 | t.Parallel() 74 | testBytes := []byte("Hello world\n") 75 | session, _, cleanup := newTestSession(t, &Server{ 76 | Handler: func(s Session) { 77 | s.Write(testBytes) 78 | }, 79 | }, nil) 80 | defer cleanup() 81 | var stdout bytes.Buffer 82 | session.Stdout = &stdout 83 | if err := session.Run(""); err != nil { 84 | t.Fatal(err) 85 | } 86 | if !bytes.Equal(stdout.Bytes(), testBytes) { 87 | t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) 88 | } 89 | } 90 | 91 | func TestStderr(t *testing.T) { 92 | t.Parallel() 93 | testBytes := []byte("Hello world\n") 94 | session, _, cleanup := newTestSession(t, &Server{ 95 | Handler: func(s Session) { 96 | s.Stderr().Write(testBytes) 97 | }, 98 | }, nil) 99 | defer cleanup() 100 | var stderr bytes.Buffer 101 | session.Stderr = &stderr 102 | if err := session.Run(""); err != nil { 103 | t.Fatal(err) 104 | } 105 | if !bytes.Equal(stderr.Bytes(), testBytes) { 106 | t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) 107 | } 108 | } 109 | 110 | func TestStdin(t *testing.T) { 111 | t.Parallel() 112 | testBytes := []byte("Hello world\n") 113 | session, _, cleanup := newTestSession(t, &Server{ 114 | Handler: func(s Session) { 115 | io.Copy(s, s) // stdin back into stdout 116 | }, 117 | }, nil) 118 | defer cleanup() 119 | var stdout bytes.Buffer 120 | session.Stdout = &stdout 121 | session.Stdin = bytes.NewBuffer(testBytes) 122 | if err := session.Run(""); err != nil { 123 | t.Fatal(err) 124 | } 125 | if !bytes.Equal(stdout.Bytes(), testBytes) { 126 | t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) 127 | } 128 | } 129 | 130 | func TestUser(t *testing.T) { 131 | t.Parallel() 132 | testUser := []byte("progrium") 133 | session, _, cleanup := newTestSession(t, &Server{ 134 | Handler: func(s Session) { 135 | io.WriteString(s, s.User()) 136 | }, 137 | }, &gossh.ClientConfig{ 138 | User: string(testUser), 139 | }) 140 | defer cleanup() 141 | var stdout bytes.Buffer 142 | session.Stdout = &stdout 143 | if err := session.Run(""); err != nil { 144 | t.Fatal(err) 145 | } 146 | if !bytes.Equal(stdout.Bytes(), testUser) { 147 | t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) 148 | } 149 | } 150 | 151 | func TestDefaultExitStatusZero(t *testing.T) { 152 | t.Parallel() 153 | session, _, cleanup := newTestSession(t, &Server{ 154 | Handler: func(s Session) { 155 | // noop 156 | }, 157 | }, nil) 158 | defer cleanup() 159 | err := session.Run("") 160 | if err != nil { 161 | t.Fatalf("expected nil but got %v", err) 162 | } 163 | } 164 | 165 | func TestExplicitExitStatusZero(t *testing.T) { 166 | t.Parallel() 167 | session, _, cleanup := newTestSession(t, &Server{ 168 | Handler: func(s Session) { 169 | s.Exit(0) 170 | }, 171 | }, nil) 172 | defer cleanup() 173 | err := session.Run("") 174 | if err != nil { 175 | t.Fatalf("expected nil but got %v", err) 176 | } 177 | } 178 | 179 | func TestExitStatusNonZero(t *testing.T) { 180 | t.Parallel() 181 | session, _, cleanup := newTestSession(t, &Server{ 182 | Handler: func(s Session) { 183 | s.Exit(1) 184 | }, 185 | }, nil) 186 | defer cleanup() 187 | err := session.Run("") 188 | e, ok := err.(*gossh.ExitError) 189 | if !ok { 190 | t.Fatalf("expected ExitError but got %T", err) 191 | } 192 | if e.ExitStatus() != 1 { 193 | t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) 194 | } 195 | } 196 | 197 | func TestPty(t *testing.T) { 198 | t.Parallel() 199 | term := "xterm" 200 | winWidth := 40 201 | winHeight := 80 202 | done := make(chan bool) 203 | session, _, cleanup := newTestSession(t, &Server{ 204 | Handler: func(s Session) { 205 | ptyReq, _, isPty := s.Pty() 206 | if !isPty { 207 | t.Fatalf("expected pty but none requested") 208 | } 209 | if ptyReq.Term != term { 210 | t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) 211 | } 212 | if ptyReq.Window.Width != winWidth { 213 | t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) 214 | } 215 | if ptyReq.Window.Height != winHeight { 216 | t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) 217 | } 218 | close(done) 219 | }, 220 | }, nil) 221 | defer cleanup() 222 | if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { 223 | t.Fatalf("expected nil but got %v", err) 224 | } 225 | if err := session.Shell(); err != nil { 226 | t.Fatalf("expected nil but got %v", err) 227 | } 228 | <-done 229 | } 230 | 231 | func TestPtyResize(t *testing.T) { 232 | t.Parallel() 233 | winch0 := Window{40, 80} 234 | winch1 := Window{80, 160} 235 | winch2 := Window{20, 40} 236 | winches := make(chan Window) 237 | done := make(chan bool) 238 | session, _, cleanup := newTestSession(t, &Server{ 239 | Handler: func(s Session) { 240 | ptyReq, winCh, isPty := s.Pty() 241 | if !isPty { 242 | t.Fatalf("expected pty but none requested") 243 | } 244 | if ptyReq.Window != winch0 { 245 | t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) 246 | } 247 | for win := range winCh { 248 | winches <- win 249 | } 250 | close(done) 251 | }, 252 | }, nil) 253 | defer cleanup() 254 | // winch0 255 | if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { 256 | t.Fatalf("expected nil but got %v", err) 257 | } 258 | if err := session.Shell(); err != nil { 259 | t.Fatalf("expected nil but got %v", err) 260 | } 261 | gotWinch := <-winches 262 | if gotWinch != winch0 { 263 | t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) 264 | } 265 | // winch1 266 | winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} 267 | ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) 268 | if err == nil && !ok { 269 | t.Fatalf("unexpected error or bad reply on send request") 270 | } 271 | gotWinch = <-winches 272 | if gotWinch != winch1 { 273 | t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) 274 | } 275 | // winch2 276 | winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} 277 | ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) 278 | if err == nil && !ok { 279 | t.Fatalf("unexpected error or bad reply on send request") 280 | } 281 | gotWinch = <-winches 282 | if gotWinch != winch2 { 283 | t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) 284 | } 285 | session.Close() 286 | <-done 287 | } 288 | 289 | func TestSignals(t *testing.T) { 290 | t.Parallel() 291 | 292 | // errChan lets us get errors back from the session 293 | errChan := make(chan error, 5) 294 | 295 | // doneChan lets us specify that we should exit. 296 | doneChan := make(chan interface{}) 297 | 298 | session, _, cleanup := newTestSession(t, &Server{ 299 | Handler: func(s Session) { 300 | // We need to use a buffered channel here, otherwise it's possible for the 301 | // second call to Signal to get discarded. 302 | signals := make(chan Signal, 2) 303 | s.Signals(signals) 304 | 305 | select { 306 | case sig := <-signals: 307 | if sig != SIGINT { 308 | errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) 309 | return 310 | } 311 | case <-doneChan: 312 | errChan <- fmt.Errorf("Unexpected done") 313 | return 314 | } 315 | 316 | select { 317 | case sig := <-signals: 318 | if sig != SIGKILL { 319 | errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) 320 | return 321 | } 322 | case <-doneChan: 323 | errChan <- fmt.Errorf("Unexpected done") 324 | return 325 | } 326 | }, 327 | }, nil) 328 | defer cleanup() 329 | 330 | go func() { 331 | session.Signal(gossh.SIGINT) 332 | session.Signal(gossh.SIGKILL) 333 | }() 334 | 335 | go func() { 336 | errChan <- session.Run("") 337 | }() 338 | 339 | err := <-errChan 340 | close(doneChan) 341 | 342 | if err != nil { 343 | t.Fatalf("expected nil but got %v", err) 344 | } 345 | } 346 | 347 | func TestBreakWithChanRegistered(t *testing.T) { 348 | t.Parallel() 349 | 350 | // errChan lets us get errors back from the session 351 | errChan := make(chan error, 5) 352 | 353 | // doneChan lets us specify that we should exit. 354 | doneChan := make(chan interface{}) 355 | 356 | breakChan := make(chan bool) 357 | 358 | readyToReceiveBreak := make(chan bool) 359 | 360 | session, _, cleanup := newTestSession(t, &Server{ 361 | Handler: func(s Session) { 362 | s.Break(breakChan) // register a break channel with the session 363 | readyToReceiveBreak <- true 364 | 365 | select { 366 | case <-breakChan: 367 | io.WriteString(s, "break") 368 | case <-doneChan: 369 | errChan <- fmt.Errorf("Unexpected done") 370 | return 371 | } 372 | }, 373 | }, nil) 374 | defer cleanup() 375 | var stdout bytes.Buffer 376 | session.Stdout = &stdout 377 | go func() { 378 | errChan <- session.Run("") 379 | }() 380 | 381 | <-readyToReceiveBreak 382 | ok, err := session.SendRequest("break", true, nil) 383 | if err != nil { 384 | t.Fatalf("expected nil but got %v", err) 385 | } 386 | if ok != true { 387 | t.Fatalf("expected true but got %v", ok) 388 | } 389 | 390 | err = <-errChan 391 | close(doneChan) 392 | 393 | if err != nil { 394 | t.Fatalf("expected nil but got %v", err) 395 | } 396 | if !bytes.Equal(stdout.Bytes(), []byte("break")) { 397 | t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) 398 | } 399 | } 400 | 401 | func TestBreakWithoutChanRegistered(t *testing.T) { 402 | t.Parallel() 403 | 404 | // errChan lets us get errors back from the session 405 | errChan := make(chan error, 5) 406 | 407 | // doneChan lets us specify that we should exit. 408 | doneChan := make(chan interface{}) 409 | 410 | waitUntilAfterBreakSent := make(chan bool) 411 | 412 | session, _, cleanup := newTestSession(t, &Server{ 413 | Handler: func(s Session) { 414 | <-waitUntilAfterBreakSent 415 | }, 416 | }, nil) 417 | defer cleanup() 418 | var stdout bytes.Buffer 419 | session.Stdout = &stdout 420 | go func() { 421 | errChan <- session.Run("") 422 | }() 423 | 424 | ok, err := session.SendRequest("break", true, nil) 425 | if err != nil { 426 | t.Fatalf("expected nil but got %v", err) 427 | } 428 | if ok != false { 429 | t.Fatalf("expected false but got %v", ok) 430 | } 431 | waitUntilAfterBreakSent <- true 432 | 433 | err = <-errChan 434 | close(doneChan) 435 | if err != nil { 436 | t.Fatalf("expected nil but got %v", err) 437 | } 438 | } 439 | -------------------------------------------------------------------------------- /ssh.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "crypto/subtle" 5 | "net" 6 | 7 | gossh "golang.org/x/crypto/ssh" 8 | ) 9 | 10 | type Signal string 11 | 12 | // POSIX signals as listed in RFC 4254 Section 6.10. 13 | const ( 14 | SIGABRT Signal = "ABRT" 15 | SIGALRM Signal = "ALRM" 16 | SIGFPE Signal = "FPE" 17 | SIGHUP Signal = "HUP" 18 | SIGILL Signal = "ILL" 19 | SIGINT Signal = "INT" 20 | SIGKILL Signal = "KILL" 21 | SIGPIPE Signal = "PIPE" 22 | SIGQUIT Signal = "QUIT" 23 | SIGSEGV Signal = "SEGV" 24 | SIGTERM Signal = "TERM" 25 | SIGUSR1 Signal = "USR1" 26 | SIGUSR2 Signal = "USR2" 27 | ) 28 | 29 | // DefaultHandler is the default Handler used by Serve. 30 | var DefaultHandler Handler 31 | 32 | // Option is a functional option handler for Server. 33 | type Option func(*Server) error 34 | 35 | // Handler is a callback for handling established SSH sessions. 36 | type Handler func(Session) 37 | 38 | // BannerHandler is a callback for displaying the server banner. 39 | type BannerHandler func(ctx Context) string 40 | 41 | // PublicKeyHandler is a callback for performing public key authentication. 42 | type PublicKeyHandler func(ctx Context, key PublicKey) bool 43 | 44 | // PasswordHandler is a callback for performing password authentication. 45 | type PasswordHandler func(ctx Context, password string) bool 46 | 47 | // KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. 48 | type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool 49 | 50 | // PtyCallback is a hook for allowing PTY sessions. 51 | type PtyCallback func(ctx Context, pty Pty) bool 52 | 53 | // SessionRequestCallback is a callback for allowing or denying SSH sessions. 54 | type SessionRequestCallback func(sess Session, requestType string) bool 55 | 56 | // ConnCallback is a hook for new connections before handling. 57 | // It allows wrapping for timeouts and limiting by returning 58 | // the net.Conn that will be used as the underlying connection. 59 | type ConnCallback func(ctx Context, conn net.Conn) net.Conn 60 | 61 | // LocalPortForwardingCallback is a hook for allowing port forwarding 62 | type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool 63 | 64 | // ReversePortForwardingCallback is a hook for allowing reverse port forwarding 65 | type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool 66 | 67 | // ServerConfigCallback is a hook for creating custom default server configs 68 | type ServerConfigCallback func(ctx Context) *gossh.ServerConfig 69 | 70 | // ConnectionFailedCallback is a hook for reporting failed connections 71 | // Please note: the net.Conn is likely to be closed at this point 72 | type ConnectionFailedCallback func(conn net.Conn, err error) 73 | 74 | // Window represents the size of a PTY window. 75 | type Window struct { 76 | Width int 77 | Height int 78 | } 79 | 80 | // Pty represents a PTY request and configuration. 81 | type Pty struct { 82 | Term string 83 | Window Window 84 | // HELP WANTED: terminal modes! 85 | } 86 | 87 | // Serve accepts incoming SSH connections on the listener l, creating a new 88 | // connection goroutine for each. The connection goroutines read requests and 89 | // then calls handler to handle sessions. Handler is typically nil, in which 90 | // case the DefaultHandler is used. 91 | func Serve(l net.Listener, handler Handler, options ...Option) error { 92 | srv := &Server{Handler: handler} 93 | for _, option := range options { 94 | if err := srv.SetOption(option); err != nil { 95 | return err 96 | } 97 | } 98 | return srv.Serve(l) 99 | } 100 | 101 | // ListenAndServe listens on the TCP network address addr and then calls Serve 102 | // with handler to handle sessions on incoming connections. Handler is typically 103 | // nil, in which case the DefaultHandler is used. 104 | func ListenAndServe(addr string, handler Handler, options ...Option) error { 105 | srv := &Server{Addr: addr, Handler: handler} 106 | for _, option := range options { 107 | if err := srv.SetOption(option); err != nil { 108 | return err 109 | } 110 | } 111 | return srv.ListenAndServe() 112 | } 113 | 114 | // Handle registers the handler as the DefaultHandler. 115 | func Handle(handler Handler) { 116 | DefaultHandler = handler 117 | } 118 | 119 | // KeysEqual is constant time compare of the keys to avoid timing attacks. 120 | func KeysEqual(ak, bk PublicKey) bool { 121 | // avoid panic if one of the keys is nil, return false instead 122 | if ak == nil || bk == nil { 123 | return false 124 | } 125 | 126 | a := ak.Marshal() 127 | b := bk.Marshal() 128 | return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) 129 | } 130 | -------------------------------------------------------------------------------- /ssh_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestKeysEqual(t *testing.T) { 8 | defer func() { 9 | if r := recover(); r != nil { 10 | t.Errorf("The code did panic") 11 | } 12 | }() 13 | 14 | if KeysEqual(nil, nil) { 15 | t.Error("two nil keys should not return true") 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /tcpip.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net" 7 | "strconv" 8 | "sync" 9 | 10 | gossh "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | const ( 14 | forwardedTCPChannelType = "forwarded-tcpip" 15 | ) 16 | 17 | // direct-tcpip data struct as specified in RFC4254, Section 7.2 18 | type localForwardChannelData struct { 19 | DestAddr string 20 | DestPort uint32 21 | 22 | OriginAddr string 23 | OriginPort uint32 24 | } 25 | 26 | // DirectTCPIPHandler can be enabled by adding it to the server's 27 | // ChannelHandlers under direct-tcpip. 28 | func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { 29 | d := localForwardChannelData{} 30 | if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { 31 | newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) 32 | return 33 | } 34 | 35 | if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { 36 | newChan.Reject(gossh.Prohibited, "port forwarding is disabled") 37 | return 38 | } 39 | 40 | dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) 41 | 42 | var dialer net.Dialer 43 | dconn, err := dialer.DialContext(ctx, "tcp", dest) 44 | if err != nil { 45 | newChan.Reject(gossh.ConnectionFailed, err.Error()) 46 | return 47 | } 48 | 49 | ch, reqs, err := newChan.Accept() 50 | if err != nil { 51 | dconn.Close() 52 | return 53 | } 54 | go gossh.DiscardRequests(reqs) 55 | 56 | go func() { 57 | defer ch.Close() 58 | defer dconn.Close() 59 | io.Copy(ch, dconn) 60 | }() 61 | go func() { 62 | defer ch.Close() 63 | defer dconn.Close() 64 | io.Copy(dconn, ch) 65 | }() 66 | } 67 | 68 | type remoteForwardRequest struct { 69 | BindAddr string 70 | BindPort uint32 71 | } 72 | 73 | type remoteForwardSuccess struct { 74 | BindPort uint32 75 | } 76 | 77 | type remoteForwardCancelRequest struct { 78 | BindAddr string 79 | BindPort uint32 80 | } 81 | 82 | type remoteForwardChannelData struct { 83 | DestAddr string 84 | DestPort uint32 85 | OriginAddr string 86 | OriginPort uint32 87 | } 88 | 89 | // ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and 90 | // adding the HandleSSHRequest callback to the server's RequestHandlers under 91 | // tcpip-forward and cancel-tcpip-forward. 92 | type ForwardedTCPHandler struct { 93 | forwards map[string]net.Listener 94 | sync.Mutex 95 | } 96 | 97 | func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { 98 | h.Lock() 99 | if h.forwards == nil { 100 | h.forwards = make(map[string]net.Listener) 101 | } 102 | h.Unlock() 103 | conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) 104 | switch req.Type { 105 | case "tcpip-forward": 106 | var reqPayload remoteForwardRequest 107 | if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { 108 | // TODO: log parse failure 109 | return false, []byte{} 110 | } 111 | if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { 112 | return false, []byte("port forwarding is disabled") 113 | } 114 | addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) 115 | ln, err := net.Listen("tcp", addr) 116 | if err != nil { 117 | // TODO: log listen failure 118 | return false, []byte{} 119 | } 120 | _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) 121 | destPort, _ := strconv.Atoi(destPortStr) 122 | h.Lock() 123 | h.forwards[addr] = ln 124 | h.Unlock() 125 | go func() { 126 | <-ctx.Done() 127 | h.Lock() 128 | ln, ok := h.forwards[addr] 129 | h.Unlock() 130 | if ok { 131 | ln.Close() 132 | } 133 | }() 134 | go func() { 135 | for { 136 | c, err := ln.Accept() 137 | if err != nil { 138 | // TODO: log accept failure 139 | break 140 | } 141 | originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) 142 | originPort, _ := strconv.Atoi(orignPortStr) 143 | payload := gossh.Marshal(&remoteForwardChannelData{ 144 | DestAddr: reqPayload.BindAddr, 145 | DestPort: uint32(destPort), 146 | OriginAddr: originAddr, 147 | OriginPort: uint32(originPort), 148 | }) 149 | go func() { 150 | ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) 151 | if err != nil { 152 | // TODO: log failure to open channel 153 | log.Println(err) 154 | c.Close() 155 | return 156 | } 157 | go gossh.DiscardRequests(reqs) 158 | go func() { 159 | defer ch.Close() 160 | defer c.Close() 161 | io.Copy(ch, c) 162 | }() 163 | go func() { 164 | defer ch.Close() 165 | defer c.Close() 166 | io.Copy(c, ch) 167 | }() 168 | }() 169 | } 170 | h.Lock() 171 | delete(h.forwards, addr) 172 | h.Unlock() 173 | }() 174 | return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) 175 | 176 | case "cancel-tcpip-forward": 177 | var reqPayload remoteForwardCancelRequest 178 | if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { 179 | // TODO: log parse failure 180 | return false, []byte{} 181 | } 182 | addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) 183 | h.Lock() 184 | ln, ok := h.forwards[addr] 185 | h.Unlock() 186 | if ok { 187 | ln.Close() 188 | } 189 | return true, nil 190 | default: 191 | return false, nil 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /tcpip_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net" 7 | "strconv" 8 | "strings" 9 | "testing" 10 | 11 | gossh "golang.org/x/crypto/ssh" 12 | ) 13 | 14 | var sampleServerResponse = []byte("Hello world") 15 | 16 | func sampleSocketServer() net.Listener { 17 | l := newLocalListener() 18 | 19 | go func() { 20 | conn, err := l.Accept() 21 | if err != nil { 22 | return 23 | } 24 | conn.Write(sampleServerResponse) 25 | conn.Close() 26 | }() 27 | 28 | return l 29 | } 30 | 31 | func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { 32 | l := sampleSocketServer() 33 | 34 | _, client, cleanup := newTestSession(t, &Server{ 35 | Handler: func(s Session) {}, 36 | LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { 37 | addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) 38 | if addr != l.Addr().String() { 39 | panic("unexpected destinationHost: " + addr) 40 | } 41 | return forwardingEnabled 42 | }, 43 | }, nil) 44 | 45 | return l, client, func() { 46 | cleanup() 47 | l.Close() 48 | } 49 | } 50 | 51 | func TestLocalPortForwardingWorks(t *testing.T) { 52 | t.Parallel() 53 | 54 | l, client, cleanup := newTestSessionWithForwarding(t, true) 55 | defer cleanup() 56 | 57 | conn, err := client.Dial("tcp", l.Addr().String()) 58 | if err != nil { 59 | t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) 60 | } 61 | result, err := io.ReadAll(conn) 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | if !bytes.Equal(result, sampleServerResponse) { 66 | t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) 67 | } 68 | } 69 | 70 | func TestLocalPortForwardingRespectsCallback(t *testing.T) { 71 | t.Parallel() 72 | 73 | l, client, cleanup := newTestSessionWithForwarding(t, false) 74 | defer cleanup() 75 | 76 | _, err := client.Dial("tcp", l.Addr().String()) 77 | if err == nil { 78 | t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) 79 | } 80 | if !strings.Contains(err.Error(), "port forwarding is disabled") { 81 | t.Fatalf("Expected permission error but got %#v", err) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "encoding/binary" 7 | 8 | "golang.org/x/crypto/ssh" 9 | ) 10 | 11 | func generateSigner() (ssh.Signer, error) { 12 | key, err := rsa.GenerateKey(rand.Reader, 2048) 13 | if err != nil { 14 | return nil, err 15 | } 16 | return ssh.NewSignerFromKey(key) 17 | } 18 | 19 | func parsePtyRequest(s []byte) (pty Pty, ok bool) { 20 | term, s, ok := parseString(s) 21 | if !ok { 22 | return 23 | } 24 | width32, s, ok := parseUint32(s) 25 | if !ok { 26 | return 27 | } 28 | height32, _, ok := parseUint32(s) 29 | if !ok { 30 | return 31 | } 32 | pty = Pty{ 33 | Term: term, 34 | Window: Window{ 35 | Width: int(width32), 36 | Height: int(height32), 37 | }, 38 | } 39 | return 40 | } 41 | 42 | func parseWinchRequest(s []byte) (win Window, ok bool) { 43 | width32, s, ok := parseUint32(s) 44 | if width32 < 1 { 45 | ok = false 46 | } 47 | if !ok { 48 | return 49 | } 50 | height32, _, ok := parseUint32(s) 51 | if height32 < 1 { 52 | ok = false 53 | } 54 | if !ok { 55 | return 56 | } 57 | win = Window{ 58 | Width: int(width32), 59 | Height: int(height32), 60 | } 61 | return 62 | } 63 | 64 | func parseString(in []byte) (out string, rest []byte, ok bool) { 65 | if len(in) < 4 { 66 | return 67 | } 68 | length := binary.BigEndian.Uint32(in) 69 | if uint32(len(in)) < 4+length { 70 | return 71 | } 72 | out = string(in[4 : 4+length]) 73 | rest = in[4+length:] 74 | ok = true 75 | return 76 | } 77 | 78 | func parseUint32(in []byte) (uint32, []byte, bool) { 79 | if len(in) < 4 { 80 | return 0, nil, false 81 | } 82 | return binary.BigEndian.Uint32(in), in[4:], true 83 | } 84 | -------------------------------------------------------------------------------- /wrap.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import gossh "golang.org/x/crypto/ssh" 4 | 5 | // PublicKey is an abstraction of different types of public keys. 6 | type PublicKey interface { 7 | gossh.PublicKey 8 | } 9 | 10 | // The Permissions type holds fine-grained permissions that are specific to a 11 | // user or a specific authentication method for a user. Permissions, except for 12 | // "source-address", must be enforced in the server application layer, after 13 | // successful authentication. 14 | type Permissions struct { 15 | *gossh.Permissions 16 | } 17 | 18 | // A Signer can create signatures that verify against a public key. 19 | type Signer interface { 20 | gossh.Signer 21 | } 22 | 23 | // ParseAuthorizedKey parses a public key from an authorized_keys file used in 24 | // OpenSSH according to the sshd(8) manual page. 25 | func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { 26 | return gossh.ParseAuthorizedKey(in) 27 | } 28 | 29 | // ParsePublicKey parses an SSH public key formatted for use in 30 | // the SSH wire protocol according to RFC 4253, section 6.6. 31 | func ParsePublicKey(in []byte) (out PublicKey, err error) { 32 | return gossh.ParsePublicKey(in) 33 | } 34 | --------------------------------------------------------------------------------