├── LICENSE
├── README.md
├── _examples
├── ssh-docker
│ ├── Dockerfile
│ ├── README.md
│ └── docker.go
├── ssh-forwardagent
│ └── forwardagent.go
├── ssh-pty
│ └── pty.go
├── ssh-publickey
│ └── public_key.go
├── ssh-remoteforward
│ └── portforward.go
├── ssh-sftpserver
│ └── sftp.go
├── ssh-simple
│ └── simple.go
└── ssh-timeouts
│ └── timeouts.go
├── agent.go
├── circle.yml
├── conn.go
├── context.go
├── context_test.go
├── doc.go
├── example_test.go
├── go.mod
├── go.sum
├── options.go
├── options_test.go
├── server.go
├── server_test.go
├── session.go
├── session_test.go
├── ssh.go
├── ssh_test.go
├── tcpip.go
├── tcpip_test.go
├── util.go
└── wrap.go
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2016 Glider Labs. All rights reserved.
2 |
3 | Redistribution and use in source and binary forms, with or without
4 | modification, are permitted provided that the following conditions are
5 | met:
6 |
7 | * Redistributions of source code must retain the above copyright
8 | notice, this list of conditions and the following disclaimer.
9 | * Redistributions in binary form must reproduce the above
10 | copyright notice, this list of conditions and the following disclaimer
11 | in the documentation and/or other materials provided with the
12 | distribution.
13 | * Neither the name of Glider Labs nor the names of its
14 | contributors may be used to endorse or promote products derived from
15 | this software without specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # gliderlabs/ssh
2 |
3 | [](https://godoc.org/github.com/gliderlabs/ssh)
4 | [](https://circleci.com/gh/gliderlabs/ssh)
5 | [](https://goreportcard.com/report/github.com/gliderlabs/ssh)
6 | [](#sponsors)
7 | [](http://slack.gliderlabs.com)
8 | [](https://app.convertkit.com/landing_pages/243312)
9 |
10 | > The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member
11 |
12 | This Go package wraps the [crypto/ssh
13 | package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for
14 | building SSH servers. The goal of the API was to make it as simple as using
15 | [net/http](https://golang.org/pkg/net/http/), so the API is very similar:
16 |
17 | ```go
18 | package main
19 |
20 | import (
21 | "github.com/gliderlabs/ssh"
22 | "io"
23 | "log"
24 | )
25 |
26 | func main() {
27 | ssh.Handle(func(s ssh.Session) {
28 | io.WriteString(s, "Hello world\n")
29 | })
30 |
31 | log.Fatal(ssh.ListenAndServe(":2222", nil))
32 | }
33 |
34 | ```
35 | This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)).
36 |
37 | ## Examples
38 |
39 | A bunch of great examples are in the `_examples` directory.
40 |
41 | ## Usage
42 |
43 | [See GoDoc reference.](https://godoc.org/github.com/gliderlabs/ssh)
44 |
45 | ## Contributing
46 |
47 | Pull requests are welcome! However, since this project is very much about API
48 | design, please submit API changes as issues to discuss before submitting PRs.
49 |
50 | Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well.
51 |
52 | ## Roadmap
53 |
54 | * Non-session channel handlers
55 | * Cleanup callback API
56 | * 1.0 release
57 | * High-level client?
58 |
59 | ## Sponsors
60 |
61 | Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)]
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 | ## License
95 |
96 | [BSD](LICENSE)
97 |
--------------------------------------------------------------------------------
/_examples/ssh-docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM alpine
2 | RUN apk add -U jq
3 | ENTRYPOINT ["jq"]
4 |
--------------------------------------------------------------------------------
/_examples/ssh-docker/README.md:
--------------------------------------------------------------------------------
1 | # SSH Docker Example
2 | Run docker containers over SSH. You can even pipe things into them too!
3 |
4 | # Installation / Prep
5 | We're going to build JQ as an SSH service using the Glider Labs SSH package. If you haven't installed GoLang and docker yet, see the doc's for help getting your environment setup.
6 |
7 | Install the Glider Labs SSH package
8 | `go get github.com/gliderlabs/ssh`
9 |
10 | Build the example docker container with
11 | `docker build --rm -t jq .`
12 |
13 | # Usage
14 | Run the SSH server with
15 |
16 | `go run docker.go`
17 |
18 | This SSH service uses the user name passed into the connection to key the desired docker image. So to use jq over our SSH server we would say:
19 |
20 | `ssh jq@localhost -p 2222`
21 |
22 | If we run this command we will see
23 |
24 | ```
25 | jq - commandline JSON processor [version 1.5]
26 | Usage: jq [options] [file...]
27 |
28 | jq is a tool for processing JSON inputs, applying the
29 | given filter to its JSON text inputs and producing the
30 | filter's results as JSON on standard output.
31 | The simplest filter is ., which is the identity filter,
32 | copying jq's input to its output unmodified (except for
33 | formatting).
34 | For more advanced filters see the jq(1) manpage ("man jq")
35 | and/or https://stedolan.github.io/jq
36 |
37 | Some of the options include:
38 | -c compact instead of pretty-printed output;
39 | -n use `null` as the single input value;
40 | -e set the exit status code based on the output;
41 | -s read (slurp) all inputs into an array; apply filter to it;
42 | -r output raw strings, not JSON texts;
43 | -R read raw strings, not JSON texts;
44 | -C colorize JSON;
45 | -M monochrome (don't colorize JSON);
46 | -S sort keys of objects on output;
47 | --tab use tabs for indentation;
48 | --arg a v set variable $a to value ;
49 | --argjson a v set variable $a to JSON value ;
50 | --slurpfile a f set variable $a to an array of JSON texts read from ;
51 | See the manpage for more options.
52 | Connection to localhost closed.
53 | ```
54 |
55 | JQ's help text! It's working! Now let's pipe some json into our SSH service and marvel at the awesomeness.
56 |
57 | `curl -s https://api.github.com/orgs/gliderlabs | ssh jq@localhost -p 2222 .`
58 |
59 | ```json
60 | {
61 | "login": "gliderlabs",
62 | "id": 8484931,
63 | "url": "https://api.github.com/orgs/gliderlabs",
64 | "repos_url": "https://api.github.com/orgs/gliderlabs/repos",
65 | "events_url": "https://api.github.com/orgs/gliderlabs/events",
66 | "hooks_url": "https://api.github.com/orgs/gliderlabs/hooks",
67 | "issues_url": "https://api.github.com/orgs/gliderlabs/issues",
68 | "members_url": "https://api.github.com/orgs/gliderlabs/members{/member}",
69 | "public_members_url": "https://api.github.com/orgs/gliderlabs/public_members{/member}",
70 | "avatar_url": "https://avatars3.githubusercontent.com/u/8484931?v=4",
71 | "description": "",
72 | "name": "Glider Labs",
73 | "company": null,
74 | "blog": "http://gliderlabs.com",
75 | "location": "Austin, TX",
76 | "email": "team@gliderlabs.com",
77 | "has_organization_projects": true,
78 | "has_repository_projects": true,
79 | "public_repos": 29,
80 | "public_gists": 0,
81 | "followers": 0,
82 | "following": 0,
83 | "html_url": "https://github.com/gliderlabs",
84 | "created_at": "2014-08-18T23:25:37Z",
85 | "updated_at": "2017-08-21T09:52:17Z",
86 | "type": "Organization"
87 | }
88 | ```
89 |
90 | # Conclusion
91 | We built JQ as a service over SSH in Go using the Glider Labs SSH package. We showed how you can run docker containers through the service as well as how to pipe stuff into your SSH service.
92 |
93 |
94 | # Troubleshooting
95 |
96 | A new host key is generated each time we run the server so you'll probably see this error on the second run.
97 | ```
98 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
99 | @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @
100 | @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
101 | IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!
102 | Someone could be eavesdropping on you right now (man-in-the-middle attack)!
103 | It is also possible that a host key has just been changed.
104 | The fingerprint for the RSA key sent by the remote host is
105 | SHA256:9dyZ8KrMCPGvDqECggoDFyz51gA6DdHa9zpfl/5Kgos.
106 | Please contact your system administrator.
107 | Add correct host key in /Users/murr/.ssh/known_hosts to get rid of this message.
108 | Offending RSA key in /Users/murr/.ssh/known_hosts:7
109 | RSA host key for [localhost]:2222 has changed and you have requested strict checking.
110 | Host key verification failed.
111 | ```
112 |
113 | To bypass this error, regnerate your host key with
114 | `ssh-keygen -R "[localhost]:2222"`
115 |
--------------------------------------------------------------------------------
/_examples/ssh-docker/docker.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "io"
7 | "log"
8 |
9 | "github.com/docker/docker/api/types"
10 | "github.com/docker/docker/api/types/container"
11 | "github.com/docker/docker/client"
12 | "github.com/docker/docker/pkg/stdcopy"
13 | "github.com/gliderlabs/ssh"
14 | )
15 |
16 | func main() {
17 | ssh.Handle(func(sess ssh.Session) {
18 | _, _, isTty := sess.Pty()
19 | cfg := &container.Config{
20 | Image: sess.User(),
21 | Cmd: sess.Command(),
22 | Env: sess.Environ(),
23 | Tty: isTty,
24 | OpenStdin: true,
25 | AttachStderr: true,
26 | AttachStdin: true,
27 | AttachStdout: true,
28 | StdinOnce: true,
29 | Volumes: make(map[string]struct{}),
30 | }
31 | status, cleanup, err := dockerRun(cfg, sess)
32 | defer cleanup()
33 | if err != nil {
34 | fmt.Fprintln(sess, err)
35 | log.Println(err)
36 | }
37 | sess.Exit(int(status))
38 | })
39 |
40 | log.Println("starting ssh server on port 2222...")
41 | log.Fatal(ssh.ListenAndServe(":2222", nil))
42 | }
43 |
44 | func dockerRun(cfg *container.Config, sess ssh.Session) (status int64, cleanup func(), err error) {
45 | docker, err := client.NewEnvClient()
46 | if err != nil {
47 | panic(err)
48 | }
49 | status = 255
50 | cleanup = func() {}
51 | ctx := context.Background()
52 | res, err := docker.ContainerCreate(ctx, cfg, nil, nil, "")
53 | if err != nil {
54 | return
55 | }
56 | cleanup = func() {
57 | docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{})
58 | }
59 | opts := types.ContainerAttachOptions{
60 | Stdin: cfg.AttachStdin,
61 | Stdout: cfg.AttachStdout,
62 | Stderr: cfg.AttachStderr,
63 | Stream: true,
64 | }
65 | stream, err := docker.ContainerAttach(ctx, res.ID, opts)
66 | if err != nil {
67 | return
68 | }
69 | cleanup = func() {
70 | docker.ContainerRemove(ctx, res.ID, types.ContainerRemoveOptions{})
71 | stream.Close()
72 | }
73 |
74 | outputErr := make(chan error)
75 |
76 | go func() {
77 | var err error
78 | if cfg.Tty {
79 | _, err = io.Copy(sess, stream.Reader)
80 | } else {
81 | _, err = stdcopy.StdCopy(sess, sess.Stderr(), stream.Reader)
82 | }
83 | outputErr <- err
84 | }()
85 |
86 | go func() {
87 | defer stream.CloseWrite()
88 | io.Copy(stream.Conn, sess)
89 | }()
90 |
91 | err = docker.ContainerStart(ctx, res.ID, types.ContainerStartOptions{})
92 | if err != nil {
93 | return
94 | }
95 | if cfg.Tty {
96 | _, winCh, _ := sess.Pty()
97 | go func() {
98 | for win := range winCh {
99 | err := docker.ContainerResize(ctx, res.ID, types.ResizeOptions{
100 | Height: uint(win.Height),
101 | Width: uint(win.Width),
102 | })
103 | if err != nil {
104 | log.Println(err)
105 | break
106 | }
107 | }
108 | }()
109 | }
110 | resultC, errC := docker.ContainerWait(ctx, res.ID, container.WaitConditionNotRunning)
111 | select {
112 | case err = <-errC:
113 | return
114 | case result := <-resultC:
115 | status = result.StatusCode
116 | }
117 | err = <-outputErr
118 | return
119 | }
120 |
--------------------------------------------------------------------------------
/_examples/ssh-forwardagent/forwardagent.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "os/exec"
7 |
8 | "github.com/gliderlabs/ssh"
9 | )
10 |
11 | func main() {
12 | ssh.Handle(func(s ssh.Session) {
13 | cmd := exec.Command("ssh-add", "-l")
14 | if ssh.AgentRequested(s) {
15 | l, err := ssh.NewAgentListener()
16 | if err != nil {
17 | log.Fatal(err)
18 | }
19 | defer l.Close()
20 | go ssh.ForwardAgentConnections(l, s)
21 | cmd.Env = append(s.Environ(), fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
22 | } else {
23 | cmd.Env = s.Environ()
24 | }
25 | cmd.Stdout = s
26 | cmd.Stderr = s.Stderr()
27 | if err := cmd.Run(); err != nil {
28 | log.Println(err)
29 | return
30 | }
31 | })
32 |
33 | log.Println("starting ssh server on port 2222...")
34 | log.Fatal(ssh.ListenAndServe(":2222", nil))
35 | }
36 |
--------------------------------------------------------------------------------
/_examples/ssh-pty/pty.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "log"
7 | "os"
8 | "os/exec"
9 | "syscall"
10 | "unsafe"
11 |
12 | "github.com/gliderlabs/ssh"
13 | "github.com/creack/pty"
14 | )
15 |
16 | func setWinsize(f *os.File, w, h int) {
17 | syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
18 | uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
19 | }
20 |
21 | func main() {
22 | ssh.Handle(func(s ssh.Session) {
23 | cmd := exec.Command("top")
24 | ptyReq, winCh, isPty := s.Pty()
25 | if isPty {
26 | cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
27 | f, err := pty.Start(cmd)
28 | if err != nil {
29 | panic(err)
30 | }
31 | go func() {
32 | for win := range winCh {
33 | setWinsize(f, win.Width, win.Height)
34 | }
35 | }()
36 | go func() {
37 | io.Copy(f, s) // stdin
38 | }()
39 | io.Copy(s, f) // stdout
40 | cmd.Wait()
41 | } else {
42 | io.WriteString(s, "No PTY requested.\n")
43 | s.Exit(1)
44 | }
45 | })
46 |
47 | log.Println("starting ssh server on port 2222...")
48 | log.Fatal(ssh.ListenAndServe(":2222", nil))
49 | }
50 |
--------------------------------------------------------------------------------
/_examples/ssh-publickey/public_key.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "log"
7 |
8 | "github.com/gliderlabs/ssh"
9 | gossh "golang.org/x/crypto/ssh"
10 | )
11 |
12 | func main() {
13 | ssh.Handle(func(s ssh.Session) {
14 | authorizedKey := gossh.MarshalAuthorizedKey(s.PublicKey())
15 | io.WriteString(s, fmt.Sprintf("public key used by %s:\n", s.User()))
16 | s.Write(authorizedKey)
17 | })
18 |
19 | publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
20 | return true // allow all keys, or use ssh.KeysEqual() to compare against known keys
21 | })
22 |
23 | log.Println("starting ssh server on port 2222...")
24 | log.Fatal(ssh.ListenAndServe(":2222", nil, publicKeyOption))
25 | }
26 |
--------------------------------------------------------------------------------
/_examples/ssh-remoteforward/portforward.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "io"
5 | "log"
6 |
7 | "github.com/gliderlabs/ssh"
8 | )
9 |
10 | func main() {
11 |
12 | log.Println("starting ssh server on port 2222...")
13 |
14 | forwardHandler := &ssh.ForwardedTCPHandler{}
15 |
16 | server := ssh.Server{
17 | LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool {
18 | log.Println("Accepted forward", dhost, dport)
19 | return true
20 | }),
21 | Addr: ":2222",
22 | Handler: ssh.Handler(func(s ssh.Session) {
23 | io.WriteString(s, "Remote forwarding available...\n")
24 | select {}
25 | }),
26 | ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, host string, port uint32) bool {
27 | log.Println("attempt to bind", host, port, "granted")
28 | return true
29 | }),
30 | RequestHandlers: map[string]ssh.RequestHandler{
31 | "tcpip-forward": forwardHandler.HandleSSHRequest,
32 | "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
33 | },
34 | }
35 |
36 | log.Fatal(server.ListenAndServe())
37 | }
38 |
--------------------------------------------------------------------------------
/_examples/ssh-sftpserver/sftp.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "log"
7 |
8 | "github.com/gliderlabs/ssh"
9 | "github.com/pkg/sftp"
10 | )
11 |
12 | // SftpHandler handler for SFTP subsystem
13 | func SftpHandler(sess ssh.Session) {
14 | debugStream := io.Discard
15 | serverOptions := []sftp.ServerOption{
16 | sftp.WithDebug(debugStream),
17 | }
18 | server, err := sftp.NewServer(
19 | sess,
20 | serverOptions...,
21 | )
22 | if err != nil {
23 | log.Printf("sftp server init error: %s\n", err)
24 | return
25 | }
26 | if err := server.Serve(); err == io.EOF {
27 | server.Close()
28 | fmt.Println("sftp client exited session.")
29 | } else if err != nil {
30 | fmt.Println("sftp server completed with error:", err)
31 | }
32 | }
33 |
34 | func main() {
35 | ssh_server := ssh.Server{
36 | Addr: "127.0.0.1:2222",
37 | SubsystemHandlers: map[string]ssh.SubsystemHandler{
38 | "sftp": SftpHandler,
39 | },
40 | }
41 | log.Fatal(ssh_server.ListenAndServe())
42 | }
43 |
--------------------------------------------------------------------------------
/_examples/ssh-simple/simple.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "log"
7 |
8 | "github.com/gliderlabs/ssh"
9 | )
10 |
11 | func main() {
12 | ssh.Handle(func(s ssh.Session) {
13 | io.WriteString(s, fmt.Sprintf("Hello %s\n", s.User()))
14 | })
15 |
16 | log.Println("starting ssh server on port 2222...")
17 | log.Fatal(ssh.ListenAndServe(":2222", nil))
18 | }
19 |
--------------------------------------------------------------------------------
/_examples/ssh-timeouts/timeouts.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 | "time"
6 |
7 | "github.com/gliderlabs/ssh"
8 | )
9 |
10 | var (
11 | DeadlineTimeout = 30 * time.Second
12 | IdleTimeout = 10 * time.Second
13 | )
14 |
15 | func main() {
16 | ssh.Handle(func(s ssh.Session) {
17 | log.Println("new connection")
18 | i := 0
19 | for {
20 | i += 1
21 | log.Println("active seconds:", i)
22 | select {
23 | case <-time.After(time.Second):
24 | continue
25 | case <-s.Context().Done():
26 | log.Println("connection closed")
27 | return
28 | }
29 | }
30 | })
31 |
32 | log.Println("starting ssh server on port 2222...")
33 | log.Printf("connections will only last %s\n", DeadlineTimeout)
34 | log.Printf("and timeout after %s of no activity\n", IdleTimeout)
35 | server := &ssh.Server{
36 | Addr: ":2222",
37 | MaxTimeout: DeadlineTimeout,
38 | IdleTimeout: IdleTimeout,
39 | }
40 | log.Fatal(server.ListenAndServe())
41 | }
42 |
--------------------------------------------------------------------------------
/agent.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "io"
5 | "net"
6 | "os"
7 | "path"
8 | "sync"
9 |
10 | gossh "golang.org/x/crypto/ssh"
11 | )
12 |
13 | const (
14 | agentRequestType = "auth-agent-req@openssh.com"
15 | agentChannelType = "auth-agent@openssh.com"
16 |
17 | agentTempDir = "auth-agent"
18 | agentListenFile = "listener.sock"
19 | )
20 |
21 | // contextKeyAgentRequest is an internal context key for storing if the
22 | // client requested agent forwarding
23 | var contextKeyAgentRequest = &contextKey{"auth-agent-req"}
24 |
25 | // SetAgentRequested sets up the session context so that AgentRequested
26 | // returns true.
27 | func SetAgentRequested(ctx Context) {
28 | ctx.SetValue(contextKeyAgentRequest, true)
29 | }
30 |
31 | // AgentRequested returns true if the client requested agent forwarding.
32 | func AgentRequested(sess Session) bool {
33 | return sess.Context().Value(contextKeyAgentRequest) == true
34 | }
35 |
36 | // NewAgentListener sets up a temporary Unix socket that can be communicated
37 | // to the session environment and used for forwarding connections.
38 | func NewAgentListener() (net.Listener, error) {
39 | dir, err := os.MkdirTemp("", agentTempDir)
40 | if err != nil {
41 | return nil, err
42 | }
43 | l, err := net.Listen("unix", path.Join(dir, agentListenFile))
44 | if err != nil {
45 | return nil, err
46 | }
47 | return l, nil
48 | }
49 |
50 | // ForwardAgentConnections takes connections from a listener to proxy into the
51 | // session on the OpenSSH channel for agent connections. It blocks and services
52 | // connections until the listener stop accepting.
53 | func ForwardAgentConnections(l net.Listener, s Session) {
54 | sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn)
55 | for {
56 | conn, err := l.Accept()
57 | if err != nil {
58 | return
59 | }
60 | go func(conn net.Conn) {
61 | defer conn.Close()
62 | channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil)
63 | if err != nil {
64 | return
65 | }
66 | defer channel.Close()
67 | go gossh.DiscardRequests(reqs)
68 | var wg sync.WaitGroup
69 | wg.Add(2)
70 | go func() {
71 | io.Copy(conn, channel)
72 | conn.(*net.UnixConn).CloseWrite()
73 | wg.Done()
74 | }()
75 | go func() {
76 | io.Copy(channel, conn)
77 | channel.CloseWrite()
78 | wg.Done()
79 | }()
80 | wg.Wait()
81 | }(conn)
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/circle.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | jobs:
3 | build-go-latest:
4 | docker:
5 | - image: golang:latest
6 | working_directory: /go/src/github.com/gliderlabs/ssh
7 | steps:
8 | - checkout
9 | - run: go get
10 | - run: go test -v -race
11 |
12 | build-go-1.20:
13 | docker:
14 | - image: golang:1.20
15 | working_directory: /go/src/github.com/gliderlabs/ssh
16 | steps:
17 | - checkout
18 | - run: go get
19 | - run: go test -v -race
20 |
21 | workflows:
22 | version: 2
23 | build:
24 | jobs:
25 | - build-go-latest
26 | - build-go-1.20
27 |
--------------------------------------------------------------------------------
/conn.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "context"
5 | "net"
6 | "time"
7 | )
8 |
9 | type serverConn struct {
10 | net.Conn
11 |
12 | idleTimeout time.Duration
13 | handshakeDeadline time.Time
14 | maxDeadline time.Time
15 | closeCanceler context.CancelFunc
16 | }
17 |
18 | func (c *serverConn) Write(p []byte) (n int, err error) {
19 | if c.idleTimeout > 0 {
20 | c.updateDeadline()
21 | }
22 | n, err = c.Conn.Write(p)
23 | if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
24 | c.closeCanceler()
25 | }
26 | return
27 | }
28 |
29 | func (c *serverConn) Read(b []byte) (n int, err error) {
30 | if c.idleTimeout > 0 {
31 | c.updateDeadline()
32 | }
33 | n, err = c.Conn.Read(b)
34 | if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
35 | c.closeCanceler()
36 | }
37 | return
38 | }
39 |
40 | func (c *serverConn) Close() (err error) {
41 | err = c.Conn.Close()
42 | if c.closeCanceler != nil {
43 | c.closeCanceler()
44 | }
45 | return
46 | }
47 |
48 | func (c *serverConn) updateDeadline() {
49 | deadline := c.maxDeadline
50 |
51 | if !c.handshakeDeadline.IsZero() && (deadline.IsZero() || c.handshakeDeadline.Before(deadline)) {
52 | deadline = c.handshakeDeadline
53 | }
54 |
55 | if c.idleTimeout > 0 {
56 | idleDeadline := time.Now().Add(c.idleTimeout)
57 | if deadline.IsZero() || idleDeadline.Before(deadline) {
58 | deadline = idleDeadline
59 | }
60 | }
61 |
62 | c.Conn.SetDeadline(deadline)
63 | }
64 |
--------------------------------------------------------------------------------
/context.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "context"
5 | "encoding/hex"
6 | "net"
7 | "sync"
8 |
9 | gossh "golang.org/x/crypto/ssh"
10 | )
11 |
12 | // contextKey is a value for use with context.WithValue. It's used as
13 | // a pointer so it fits in an interface{} without allocation.
14 | type contextKey struct {
15 | name string
16 | }
17 |
18 | var (
19 | // ContextKeyUser is a context key for use with Contexts in this package.
20 | // The associated value will be of type string.
21 | ContextKeyUser = &contextKey{"user"}
22 |
23 | // ContextKeySessionID is a context key for use with Contexts in this package.
24 | // The associated value will be of type string.
25 | ContextKeySessionID = &contextKey{"session-id"}
26 |
27 | // ContextKeyPermissions is a context key for use with Contexts in this package.
28 | // The associated value will be of type *Permissions.
29 | ContextKeyPermissions = &contextKey{"permissions"}
30 |
31 | // ContextKeyClientVersion is a context key for use with Contexts in this package.
32 | // The associated value will be of type string.
33 | ContextKeyClientVersion = &contextKey{"client-version"}
34 |
35 | // ContextKeyServerVersion is a context key for use with Contexts in this package.
36 | // The associated value will be of type string.
37 | ContextKeyServerVersion = &contextKey{"server-version"}
38 |
39 | // ContextKeyLocalAddr is a context key for use with Contexts in this package.
40 | // The associated value will be of type net.Addr.
41 | ContextKeyLocalAddr = &contextKey{"local-addr"}
42 |
43 | // ContextKeyRemoteAddr is a context key for use with Contexts in this package.
44 | // The associated value will be of type net.Addr.
45 | ContextKeyRemoteAddr = &contextKey{"remote-addr"}
46 |
47 | // ContextKeyServer is a context key for use with Contexts in this package.
48 | // The associated value will be of type *Server.
49 | ContextKeyServer = &contextKey{"ssh-server"}
50 |
51 | // ContextKeyConn is a context key for use with Contexts in this package.
52 | // The associated value will be of type gossh.ServerConn.
53 | ContextKeyConn = &contextKey{"ssh-conn"}
54 |
55 | // ContextKeyPublicKey is a context key for use with Contexts in this package.
56 | // The associated value will be of type PublicKey.
57 | ContextKeyPublicKey = &contextKey{"public-key"}
58 | )
59 |
60 | // Context is a package specific context interface. It exposes connection
61 | // metadata and allows new values to be easily written to it. It's used in
62 | // authentication handlers and callbacks, and its underlying context.Context is
63 | // exposed on Session in the session Handler. A connection-scoped lock is also
64 | // embedded in the context to make it easier to limit operations per-connection.
65 | type Context interface {
66 | context.Context
67 | sync.Locker
68 |
69 | // User returns the username used when establishing the SSH connection.
70 | User() string
71 |
72 | // SessionID returns the session hash.
73 | SessionID() string
74 |
75 | // ClientVersion returns the version reported by the client.
76 | ClientVersion() string
77 |
78 | // ServerVersion returns the version reported by the server.
79 | ServerVersion() string
80 |
81 | // RemoteAddr returns the remote address for this connection.
82 | RemoteAddr() net.Addr
83 |
84 | // LocalAddr returns the local address for this connection.
85 | LocalAddr() net.Addr
86 |
87 | // Permissions returns the Permissions object used for this connection.
88 | Permissions() *Permissions
89 |
90 | // SetValue allows you to easily write new values into the underlying context.
91 | SetValue(key, value interface{})
92 | }
93 |
94 | type sshContext struct {
95 | context.Context
96 | *sync.Mutex
97 |
98 | values map[interface{}]interface{}
99 | valuesMu sync.Mutex
100 | }
101 |
102 | func newContext(srv *Server) (*sshContext, context.CancelFunc) {
103 | innerCtx, cancel := context.WithCancel(context.Background())
104 | ctx := &sshContext{Context: innerCtx, Mutex: &sync.Mutex{}, values: make(map[interface{}]interface{})}
105 | ctx.SetValue(ContextKeyServer, srv)
106 | perms := &Permissions{&gossh.Permissions{}}
107 | ctx.SetValue(ContextKeyPermissions, perms)
108 | return ctx, cancel
109 | }
110 |
111 | // this is separate from newContext because we will get ConnMetadata
112 | // at different points so it needs to be applied separately
113 | func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
114 | if ctx.Value(ContextKeySessionID) != nil {
115 | return
116 | }
117 | ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID()))
118 | ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion()))
119 | ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion()))
120 | ctx.SetValue(ContextKeyUser, conn.User())
121 | ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr())
122 | ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
123 | }
124 |
125 | func (ctx *sshContext) Value(key interface{}) interface{} {
126 | ctx.valuesMu.Lock()
127 | defer ctx.valuesMu.Unlock()
128 | if v, ok := ctx.values[key]; ok {
129 | return v
130 | }
131 | return ctx.Context.Value(key)
132 | }
133 |
134 | func (ctx *sshContext) SetValue(key, value interface{}) {
135 | ctx.valuesMu.Lock()
136 | defer ctx.valuesMu.Unlock()
137 | ctx.values[key] = value
138 | }
139 |
140 | func (ctx *sshContext) User() string {
141 | return ctx.Value(ContextKeyUser).(string)
142 | }
143 |
144 | func (ctx *sshContext) SessionID() string {
145 | return ctx.Value(ContextKeySessionID).(string)
146 | }
147 |
148 | func (ctx *sshContext) ClientVersion() string {
149 | return ctx.Value(ContextKeyClientVersion).(string)
150 | }
151 |
152 | func (ctx *sshContext) ServerVersion() string {
153 | return ctx.Value(ContextKeyServerVersion).(string)
154 | }
155 |
156 | func (ctx *sshContext) RemoteAddr() net.Addr {
157 | if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok {
158 | return addr
159 | }
160 | return nil
161 | }
162 |
163 | func (ctx *sshContext) LocalAddr() net.Addr {
164 | return ctx.Value(ContextKeyLocalAddr).(net.Addr)
165 | }
166 |
167 | func (ctx *sshContext) Permissions() *Permissions {
168 | return ctx.Value(ContextKeyPermissions).(*Permissions)
169 | }
170 |
--------------------------------------------------------------------------------
/context_test.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "testing"
5 | "time"
6 | )
7 |
8 | func TestSetPermissions(t *testing.T) {
9 | t.Parallel()
10 | permsExt := map[string]string{
11 | "foo": "bar",
12 | }
13 | session, _, cleanup := newTestSessionWithOptions(t, &Server{
14 | Handler: func(s Session) {
15 | if _, ok := s.Permissions().Extensions["foo"]; !ok {
16 | t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
17 | }
18 | },
19 | }, nil, PasswordAuth(func(ctx Context, password string) bool {
20 | ctx.Permissions().Extensions = permsExt
21 | return true
22 | }))
23 | defer cleanup()
24 | if err := session.Run(""); err != nil {
25 | t.Fatal(err)
26 | }
27 | }
28 |
29 | func TestSetValue(t *testing.T) {
30 | t.Parallel()
31 | value := map[string]string{
32 | "foo": "bar",
33 | }
34 | key := "testValue"
35 | session, _, cleanup := newTestSessionWithOptions(t, &Server{
36 | Handler: func(s Session) {
37 | v := s.Context().Value(key).(map[string]string)
38 | if v["foo"] != value["foo"] {
39 | t.Fatalf("got %#v; want %#v", v, value)
40 | }
41 | },
42 | }, nil, PasswordAuth(func(ctx Context, password string) bool {
43 | ctx.SetValue(key, value)
44 | return true
45 | }))
46 | defer cleanup()
47 | if err := session.Run(""); err != nil {
48 | t.Fatal(err)
49 | }
50 | }
51 |
52 | func TestSetValueConcurrency(t *testing.T) {
53 | ctx, cancel := newContext(nil)
54 | defer cancel()
55 |
56 | go func() {
57 | for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue
58 | _, _ = ctx.Deadline()
59 | _ = ctx.Err()
60 | _ = ctx.Value("foo")
61 | select {
62 | case <-ctx.Done():
63 | break
64 | default:
65 | }
66 | }
67 | }()
68 | ctx.SetValue("bar", -1) // a context value which never changes
69 | now := time.Now()
70 | var cnt int64
71 | go func() {
72 | for time.Since(now) < 100*time.Millisecond {
73 | cnt++
74 | ctx.SetValue("foo", cnt) // a context value which changes a lot
75 | }
76 | cancel()
77 | }()
78 | <-ctx.Done()
79 | if ctx.Value("foo") != cnt {
80 | t.Fatal("context.Value(foo) doesn't match latest SetValue")
81 | }
82 | if ctx.Value("bar") != -1 {
83 | t.Fatal("context.Value(bar) doesn't match latest SetValue")
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package ssh wraps the crypto/ssh package with a higher-level API for building
3 | SSH servers. The goal of the API was to make it as simple as using net/http, so
4 | the API is very similar.
5 |
6 | You should be able to build any SSH server using only this package, which wraps
7 | relevant types and some functions from crypto/ssh. However, you still need to
8 | use crypto/ssh for building SSH clients.
9 |
10 | ListenAndServe starts an SSH server with a given address, handler, and options. The
11 | handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler:
12 |
13 | ssh.Handle(func(s ssh.Session) {
14 | io.WriteString(s, "Hello world\n")
15 | })
16 |
17 | log.Fatal(ssh.ListenAndServe(":2222", nil))
18 |
19 | If you don't specify a host key, it will generate one every time. This is convenient
20 | except you'll have to deal with clients being confused that the host key is different.
21 | It's a better idea to generate or point to an existing key on your system:
22 |
23 | log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa")))
24 |
25 | Although all options have functional option helpers, another way to control the
26 | server's behavior is by creating a custom Server:
27 |
28 | s := &ssh.Server{
29 | Addr: ":2222",
30 | Handler: sessionHandler,
31 | PublicKeyHandler: authHandler,
32 | }
33 | s.AddHostKey(hostKeySigner)
34 |
35 | log.Fatal(s.ListenAndServe())
36 |
37 | This package automatically handles basic SSH requests like setting environment
38 | variables, requesting PTY, and changing window size. These requests are
39 | processed, responded to, and any relevant state is updated. This state is then
40 | exposed to you via the Session interface.
41 |
42 | The one big feature missing from the Session abstraction is signals. This was
43 | started, but not completed. Pull Requests welcome!
44 | */
45 | package ssh
46 |
--------------------------------------------------------------------------------
/example_test.go:
--------------------------------------------------------------------------------
1 | package ssh_test
2 |
3 | import (
4 | "io"
5 | "os"
6 |
7 | "github.com/gliderlabs/ssh"
8 | )
9 |
10 | func ExampleListenAndServe() {
11 | ssh.ListenAndServe(":2222", func(s ssh.Session) {
12 | io.WriteString(s, "Hello world\n")
13 | })
14 | }
15 |
16 | func ExamplePasswordAuth() {
17 | ssh.ListenAndServe(":2222", nil,
18 | ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
19 | return pass == "secret"
20 | }),
21 | )
22 | }
23 |
24 | func ExampleNoPty() {
25 | ssh.ListenAndServe(":2222", nil, ssh.NoPty())
26 | }
27 |
28 | func ExamplePublicKeyAuth() {
29 | ssh.ListenAndServe(":2222", nil,
30 | ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
31 | data, _ := os.ReadFile("/path/to/allowed/key.pub")
32 | allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data)
33 | return ssh.KeysEqual(key, allowed)
34 | }),
35 | )
36 | }
37 |
38 | func ExampleHostKeyFile() {
39 | ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key"))
40 | }
41 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/gliderlabs/ssh
2 |
3 | go 1.20
4 |
5 | require (
6 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
7 | golang.org/x/crypto v0.31.0
8 | )
9 |
10 | require golang.org/x/sys v0.28.0 // indirect
11 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
2 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
3 | golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
4 | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
5 | golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
6 | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
7 | golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
8 |
--------------------------------------------------------------------------------
/options.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "os"
5 |
6 | gossh "golang.org/x/crypto/ssh"
7 | )
8 |
9 | // PasswordAuth returns a functional option that sets PasswordHandler on the server.
10 | func PasswordAuth(fn PasswordHandler) Option {
11 | return func(srv *Server) error {
12 | srv.PasswordHandler = fn
13 | return nil
14 | }
15 | }
16 |
17 | // PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server.
18 | func PublicKeyAuth(fn PublicKeyHandler) Option {
19 | return func(srv *Server) error {
20 | srv.PublicKeyHandler = fn
21 | return nil
22 | }
23 | }
24 |
25 | // HostKeyFile returns a functional option that adds HostSigners to the server
26 | // from a PEM file at filepath.
27 | func HostKeyFile(filepath string) Option {
28 | return func(srv *Server) error {
29 | pemBytes, err := os.ReadFile(filepath)
30 | if err != nil {
31 | return err
32 | }
33 |
34 | signer, err := gossh.ParsePrivateKey(pemBytes)
35 | if err != nil {
36 | return err
37 | }
38 |
39 | srv.AddHostKey(signer)
40 |
41 | return nil
42 | }
43 | }
44 |
45 | func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option {
46 | return func(srv *Server) error {
47 | srv.KeyboardInteractiveHandler = fn
48 | return nil
49 | }
50 | }
51 |
52 | // HostKeyPEM returns a functional option that adds HostSigners to the server
53 | // from a PEM file as bytes.
54 | func HostKeyPEM(bytes []byte) Option {
55 | return func(srv *Server) error {
56 | signer, err := gossh.ParsePrivateKey(bytes)
57 | if err != nil {
58 | return err
59 | }
60 |
61 | srv.AddHostKey(signer)
62 |
63 | return nil
64 | }
65 | }
66 |
67 | // NoPty returns a functional option that sets PtyCallback to return false,
68 | // denying PTY requests.
69 | func NoPty() Option {
70 | return func(srv *Server) error {
71 | srv.PtyCallback = func(ctx Context, pty Pty) bool {
72 | return false
73 | }
74 | return nil
75 | }
76 | }
77 |
78 | // WrapConn returns a functional option that sets ConnCallback on the server.
79 | func WrapConn(fn ConnCallback) Option {
80 | return func(srv *Server) error {
81 | srv.ConnCallback = fn
82 | return nil
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/options_test.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "net"
5 | "strings"
6 | "sync/atomic"
7 | "testing"
8 |
9 | gossh "golang.org/x/crypto/ssh"
10 | )
11 |
12 | func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) {
13 | for _, option := range options {
14 | if err := srv.SetOption(option); err != nil {
15 | t.Fatal(err)
16 | }
17 | }
18 | return newTestSession(t, srv, cfg)
19 | }
20 |
21 | func TestPasswordAuth(t *testing.T) {
22 | t.Parallel()
23 | testUser := "testuser"
24 | testPass := "testpass"
25 | session, _, cleanup := newTestSessionWithOptions(t, &Server{
26 | Handler: func(s Session) {
27 | // noop
28 | },
29 | }, &gossh.ClientConfig{
30 | User: testUser,
31 | Auth: []gossh.AuthMethod{
32 | gossh.Password(testPass),
33 | },
34 | HostKeyCallback: gossh.InsecureIgnoreHostKey(),
35 | }, PasswordAuth(func(ctx Context, password string) bool {
36 | if ctx.User() != testUser {
37 | t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
38 | }
39 | if password != testPass {
40 | t.Fatalf("user = %#v; want %#v", password, testPass)
41 | }
42 | return true
43 | }))
44 | defer cleanup()
45 | if err := session.Run(""); err != nil {
46 | t.Fatal(err)
47 | }
48 | }
49 |
50 | func TestPasswordAuthBadPass(t *testing.T) {
51 | t.Parallel()
52 | l := newLocalListener()
53 | srv := &Server{Handler: func(s Session) {}}
54 | srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
55 | return false
56 | }))
57 | go srv.serveOnce(l)
58 | _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{
59 | User: "testuser",
60 | Auth: []gossh.AuthMethod{
61 | gossh.Password("testpass"),
62 | },
63 | HostKeyCallback: gossh.InsecureIgnoreHostKey(),
64 | })
65 | if err != nil {
66 | if !strings.Contains(err.Error(), "unable to authenticate") {
67 | t.Fatal(err)
68 | }
69 | }
70 | }
71 |
72 | type wrappedConn struct {
73 | net.Conn
74 | written int32
75 | }
76 |
77 | func (c *wrappedConn) Write(p []byte) (n int, err error) {
78 | n, err = c.Conn.Write(p)
79 | atomic.AddInt32(&(c.written), int32(n))
80 | return
81 | }
82 |
83 | func TestConnWrapping(t *testing.T) {
84 | t.Parallel()
85 | var wrapped *wrappedConn
86 | session, _, cleanup := newTestSessionWithOptions(t, &Server{
87 | Handler: func(s Session) {
88 | // nothing
89 | },
90 | }, &gossh.ClientConfig{
91 | User: "testuser",
92 | Auth: []gossh.AuthMethod{
93 | gossh.Password("testpass"),
94 | },
95 | HostKeyCallback: gossh.InsecureIgnoreHostKey(),
96 | }, PasswordAuth(func(ctx Context, password string) bool {
97 | return true
98 | }), WrapConn(func(ctx Context, conn net.Conn) net.Conn {
99 | wrapped = &wrappedConn{conn, 0}
100 | return wrapped
101 | }))
102 | defer cleanup()
103 | if err := session.Shell(); err != nil {
104 | t.Fatal(err)
105 | }
106 | if atomic.LoadInt32(&(wrapped.written)) == 0 {
107 | t.Fatal("wrapped conn not written to")
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/server.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "sync"
9 | "time"
10 |
11 | gossh "golang.org/x/crypto/ssh"
12 | )
13 |
14 | // ErrServerClosed is returned by the Server's Serve, ListenAndServe,
15 | // and ListenAndServeTLS methods after a call to Shutdown or Close.
16 | var ErrServerClosed = errors.New("ssh: Server closed")
17 |
18 | type SubsystemHandler func(s Session)
19 |
20 | var DefaultSubsystemHandlers = map[string]SubsystemHandler{}
21 |
22 | type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte)
23 |
24 | var DefaultRequestHandlers = map[string]RequestHandler{}
25 |
26 | type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context)
27 |
28 | var DefaultChannelHandlers = map[string]ChannelHandler{
29 | "session": DefaultSessionHandler,
30 | }
31 |
32 | // Server defines parameters for running an SSH server. The zero value for
33 | // Server is a valid configuration. When both PasswordHandler and
34 | // PublicKeyHandler are nil, no client authentication is performed.
35 | type Server struct {
36 | Addr string // TCP address to listen on, ":22" if empty
37 | Handler Handler // handler to invoke, ssh.DefaultHandler if nil
38 | HostSigners []Signer // private keys for the host key, must have at least one
39 | Version string // server version to be sent before the initial handshake
40 | Banner string // server banner
41 |
42 | BannerHandler BannerHandler // server banner handler, overrides Banner
43 | KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler
44 | PasswordHandler PasswordHandler // password authentication handler
45 | PublicKeyHandler PublicKeyHandler // public key authentication handler
46 | PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
47 | ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
48 | LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
49 | ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
50 | ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options
51 | SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions
52 |
53 | ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures
54 |
55 | HandshakeTimeout time.Duration // connection timeout until successful handshake, none if empty
56 | IdleTimeout time.Duration // connection timeout when no activity, none if empty
57 | MaxTimeout time.Duration // absolute connection timeout, none if empty
58 |
59 | // ChannelHandlers allow overriding the built-in session handlers or provide
60 | // extensions to the protocol, such as tcpip forwarding. By default only the
61 | // "session" handler is enabled.
62 | ChannelHandlers map[string]ChannelHandler
63 |
64 | // RequestHandlers allow overriding the server-level request handlers or
65 | // provide extensions to the protocol, such as tcpip forwarding. By default
66 | // no handlers are enabled.
67 | RequestHandlers map[string]RequestHandler
68 |
69 | // SubsystemHandlers are handlers which are similar to the usual SSH command
70 | // handlers, but handle named subsystems.
71 | SubsystemHandlers map[string]SubsystemHandler
72 |
73 | listenerWg sync.WaitGroup
74 | mu sync.RWMutex
75 | listeners map[net.Listener]struct{}
76 | conns map[*gossh.ServerConn]struct{}
77 | connWg sync.WaitGroup
78 | doneChan chan struct{}
79 | }
80 |
81 | func (srv *Server) ensureHostSigner() error {
82 | srv.mu.Lock()
83 | defer srv.mu.Unlock()
84 |
85 | if len(srv.HostSigners) == 0 {
86 | signer, err := generateSigner()
87 | if err != nil {
88 | return err
89 | }
90 | srv.HostSigners = append(srv.HostSigners, signer)
91 | }
92 | return nil
93 | }
94 |
95 | func (srv *Server) ensureHandlers() {
96 | srv.mu.Lock()
97 | defer srv.mu.Unlock()
98 |
99 | if srv.RequestHandlers == nil {
100 | srv.RequestHandlers = map[string]RequestHandler{}
101 | for k, v := range DefaultRequestHandlers {
102 | srv.RequestHandlers[k] = v
103 | }
104 | }
105 | if srv.ChannelHandlers == nil {
106 | srv.ChannelHandlers = map[string]ChannelHandler{}
107 | for k, v := range DefaultChannelHandlers {
108 | srv.ChannelHandlers[k] = v
109 | }
110 | }
111 | if srv.SubsystemHandlers == nil {
112 | srv.SubsystemHandlers = map[string]SubsystemHandler{}
113 | for k, v := range DefaultSubsystemHandlers {
114 | srv.SubsystemHandlers[k] = v
115 | }
116 | }
117 | }
118 |
119 | func (srv *Server) config(ctx Context) *gossh.ServerConfig {
120 | srv.mu.RLock()
121 | defer srv.mu.RUnlock()
122 |
123 | var config *gossh.ServerConfig
124 | if srv.ServerConfigCallback == nil {
125 | config = &gossh.ServerConfig{}
126 | } else {
127 | config = srv.ServerConfigCallback(ctx)
128 | }
129 | for _, signer := range srv.HostSigners {
130 | config.AddHostKey(signer)
131 | }
132 | if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil {
133 | config.NoClientAuth = true
134 | }
135 | if srv.Version != "" {
136 | config.ServerVersion = "SSH-2.0-" + srv.Version
137 | }
138 | if srv.Banner != "" {
139 | config.BannerCallback = func(_ gossh.ConnMetadata) string {
140 | return srv.Banner
141 | }
142 | }
143 | if srv.BannerHandler != nil {
144 | config.BannerCallback = func(conn gossh.ConnMetadata) string {
145 | applyConnMetadata(ctx, conn)
146 | return srv.BannerHandler(ctx)
147 | }
148 | }
149 | if srv.PasswordHandler != nil {
150 | config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
151 | applyConnMetadata(ctx, conn)
152 | if ok := srv.PasswordHandler(ctx, string(password)); !ok {
153 | return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
154 | }
155 | return ctx.Permissions().Permissions, nil
156 | }
157 | }
158 | if srv.PublicKeyHandler != nil {
159 | config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
160 | applyConnMetadata(ctx, conn)
161 | if ok := srv.PublicKeyHandler(ctx, key); !ok {
162 | return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
163 | }
164 | ctx.SetValue(ContextKeyPublicKey, key)
165 | return ctx.Permissions().Permissions, nil
166 | }
167 | }
168 | if srv.KeyboardInteractiveHandler != nil {
169 | config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) {
170 | applyConnMetadata(ctx, conn)
171 | if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok {
172 | return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
173 | }
174 | return ctx.Permissions().Permissions, nil
175 | }
176 | }
177 | return config
178 | }
179 |
180 | // Handle sets the Handler for the server.
181 | func (srv *Server) Handle(fn Handler) {
182 | srv.mu.Lock()
183 | defer srv.mu.Unlock()
184 |
185 | srv.Handler = fn
186 | }
187 |
188 | // Close immediately closes all active listeners and all active
189 | // connections.
190 | //
191 | // Close returns any error returned from closing the Server's
192 | // underlying Listener(s).
193 | func (srv *Server) Close() error {
194 | srv.mu.Lock()
195 | defer srv.mu.Unlock()
196 |
197 | srv.closeDoneChanLocked()
198 | err := srv.closeListenersLocked()
199 | for c := range srv.conns {
200 | c.Close()
201 | delete(srv.conns, c)
202 | }
203 | return err
204 | }
205 |
206 | // Shutdown gracefully shuts down the server without interrupting any
207 | // active connections. Shutdown works by first closing all open
208 | // listeners, and then waiting indefinitely for connections to close.
209 | // If the provided context expires before the shutdown is complete,
210 | // then the context's error is returned.
211 | func (srv *Server) Shutdown(ctx context.Context) error {
212 | srv.mu.Lock()
213 | lnerr := srv.closeListenersLocked()
214 | srv.closeDoneChanLocked()
215 | srv.mu.Unlock()
216 |
217 | finished := make(chan struct{}, 1)
218 | go func() {
219 | srv.listenerWg.Wait()
220 | srv.connWg.Wait()
221 | finished <- struct{}{}
222 | }()
223 |
224 | select {
225 | case <-ctx.Done():
226 | return ctx.Err()
227 | case <-finished:
228 | return lnerr
229 | }
230 | }
231 |
232 | // Serve accepts incoming connections on the Listener l, creating a new
233 | // connection goroutine for each. The connection goroutines read requests and then
234 | // calls srv.Handler to handle sessions.
235 | //
236 | // Serve always returns a non-nil error.
237 | func (srv *Server) Serve(l net.Listener) error {
238 | srv.ensureHandlers()
239 | defer l.Close()
240 | if err := srv.ensureHostSigner(); err != nil {
241 | return err
242 | }
243 | if srv.Handler == nil {
244 | srv.Handler = DefaultHandler
245 | }
246 | var tempDelay time.Duration
247 |
248 | srv.trackListener(l, true)
249 | defer srv.trackListener(l, false)
250 | for {
251 | conn, e := l.Accept()
252 | if e != nil {
253 | select {
254 | case <-srv.getDoneChan():
255 | return ErrServerClosed
256 | default:
257 | }
258 | if ne, ok := e.(net.Error); ok && ne.Temporary() {
259 | if tempDelay == 0 {
260 | tempDelay = 5 * time.Millisecond
261 | } else {
262 | tempDelay *= 2
263 | }
264 | if max := 1 * time.Second; tempDelay > max {
265 | tempDelay = max
266 | }
267 | time.Sleep(tempDelay)
268 | continue
269 | }
270 | return e
271 | }
272 | go srv.HandleConn(conn)
273 | }
274 | }
275 |
276 | func (srv *Server) HandleConn(newConn net.Conn) {
277 | ctx, cancel := newContext(srv)
278 | if srv.ConnCallback != nil {
279 | cbConn := srv.ConnCallback(ctx, newConn)
280 | if cbConn == nil {
281 | newConn.Close()
282 | return
283 | }
284 | newConn = cbConn
285 | }
286 | conn := &serverConn{
287 | Conn: newConn,
288 | idleTimeout: srv.IdleTimeout,
289 | closeCanceler: cancel,
290 | }
291 | if srv.MaxTimeout > 0 {
292 | conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
293 | }
294 | if srv.HandshakeTimeout > 0 {
295 | conn.handshakeDeadline = time.Now().Add(srv.HandshakeTimeout)
296 | }
297 | conn.updateDeadline()
298 | defer conn.Close()
299 | sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
300 | if err != nil {
301 | if srv.ConnectionFailedCallback != nil {
302 | srv.ConnectionFailedCallback(conn, err)
303 | }
304 | return
305 | }
306 | conn.handshakeDeadline = time.Time{}
307 | conn.updateDeadline()
308 | srv.trackConn(sshConn, true)
309 | defer srv.trackConn(sshConn, false)
310 |
311 | ctx.SetValue(ContextKeyConn, sshConn)
312 | applyConnMetadata(ctx, sshConn)
313 | //go gossh.DiscardRequests(reqs)
314 | go srv.handleRequests(ctx, reqs)
315 | for ch := range chans {
316 | handler := srv.ChannelHandlers[ch.ChannelType()]
317 | if handler == nil {
318 | handler = srv.ChannelHandlers["default"]
319 | }
320 | if handler == nil {
321 | ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
322 | continue
323 | }
324 | go handler(srv, sshConn, ch, ctx)
325 | }
326 | }
327 |
328 | func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) {
329 | for req := range in {
330 | handler := srv.RequestHandlers[req.Type]
331 | if handler == nil {
332 | handler = srv.RequestHandlers["default"]
333 | }
334 | if handler == nil {
335 | req.Reply(false, nil)
336 | continue
337 | }
338 | /*reqCtx, cancel := context.WithCancel(ctx)
339 | defer cancel() */
340 | ret, payload := handler(ctx, srv, req)
341 | req.Reply(ret, payload)
342 | }
343 | }
344 |
345 | // ListenAndServe listens on the TCP network address srv.Addr and then calls
346 | // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used.
347 | // ListenAndServe always returns a non-nil error.
348 | func (srv *Server) ListenAndServe() error {
349 | addr := srv.Addr
350 | if addr == "" {
351 | addr = ":22"
352 | }
353 | ln, err := net.Listen("tcp", addr)
354 | if err != nil {
355 | return err
356 | }
357 | return srv.Serve(ln)
358 | }
359 |
360 | // AddHostKey adds a private key as a host key. If an existing host key exists
361 | // with the same algorithm, it is overwritten. Each server config must have at
362 | // least one host key.
363 | func (srv *Server) AddHostKey(key Signer) {
364 | srv.mu.Lock()
365 | defer srv.mu.Unlock()
366 |
367 | // these are later added via AddHostKey on ServerConfig, which performs the
368 | // check for one of every algorithm.
369 |
370 | // This check is based on the AddHostKey method from the x/crypto/ssh
371 | // library. This allows us to only keep one active key for each type on a
372 | // server at once. So, if you're dynamically updating keys at runtime, this
373 | // list will not keep growing.
374 | for i, k := range srv.HostSigners {
375 | if k.PublicKey().Type() == key.PublicKey().Type() {
376 | srv.HostSigners[i] = key
377 | return
378 | }
379 | }
380 |
381 | srv.HostSigners = append(srv.HostSigners, key)
382 | }
383 |
384 | // SetOption runs a functional option against the server.
385 | func (srv *Server) SetOption(option Option) error {
386 | // NOTE: there is a potential race here for any option that doesn't call an
387 | // internal method. We can't actually lock here because if something calls
388 | // (as an example) AddHostKey, it will deadlock.
389 |
390 | //srv.mu.Lock()
391 | //defer srv.mu.Unlock()
392 |
393 | return option(srv)
394 | }
395 |
396 | func (srv *Server) getDoneChan() <-chan struct{} {
397 | srv.mu.Lock()
398 | defer srv.mu.Unlock()
399 |
400 | return srv.getDoneChanLocked()
401 | }
402 |
403 | func (srv *Server) getDoneChanLocked() chan struct{} {
404 | if srv.doneChan == nil {
405 | srv.doneChan = make(chan struct{})
406 | }
407 | return srv.doneChan
408 | }
409 |
410 | func (srv *Server) closeDoneChanLocked() {
411 | ch := srv.getDoneChanLocked()
412 | select {
413 | case <-ch:
414 | // Already closed. Don't close again.
415 | default:
416 | // Safe to close here. We're the only closer, guarded
417 | // by srv.mu.
418 | close(ch)
419 | }
420 | }
421 |
422 | func (srv *Server) closeListenersLocked() error {
423 | var err error
424 | for ln := range srv.listeners {
425 | if cerr := ln.Close(); cerr != nil && err == nil {
426 | err = cerr
427 | }
428 | delete(srv.listeners, ln)
429 | }
430 | return err
431 | }
432 |
433 | func (srv *Server) trackListener(ln net.Listener, add bool) {
434 | srv.mu.Lock()
435 | defer srv.mu.Unlock()
436 |
437 | if srv.listeners == nil {
438 | srv.listeners = make(map[net.Listener]struct{})
439 | }
440 | if add {
441 | // If the *Server is being reused after a previous
442 | // Close or Shutdown, reset its doneChan:
443 | if len(srv.listeners) == 0 && len(srv.conns) == 0 {
444 | srv.doneChan = nil
445 | }
446 | srv.listeners[ln] = struct{}{}
447 | srv.listenerWg.Add(1)
448 | } else {
449 | delete(srv.listeners, ln)
450 | srv.listenerWg.Done()
451 | }
452 | }
453 |
454 | func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
455 | srv.mu.Lock()
456 | defer srv.mu.Unlock()
457 |
458 | if srv.conns == nil {
459 | srv.conns = make(map[*gossh.ServerConn]struct{})
460 | }
461 | if add {
462 | srv.conns[c] = struct{}{}
463 | srv.connWg.Add(1)
464 | } else {
465 | delete(srv.conns, c)
466 | srv.connWg.Done()
467 | }
468 | }
469 |
--------------------------------------------------------------------------------
/server_test.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "io"
7 | "net"
8 | "testing"
9 | "time"
10 | )
11 |
12 | func TestAddHostKey(t *testing.T) {
13 | s := Server{}
14 | signer, err := generateSigner()
15 | if err != nil {
16 | t.Fatal(err)
17 | }
18 | s.AddHostKey(signer)
19 | if len(s.HostSigners) != 1 {
20 | t.Fatal("Key was not properly added")
21 | }
22 | signer, err = generateSigner()
23 | if err != nil {
24 | t.Fatal(err)
25 | }
26 | s.AddHostKey(signer)
27 | if len(s.HostSigners) != 1 {
28 | t.Fatal("Key was not properly replaced")
29 | }
30 | }
31 |
32 | func TestServerShutdown(t *testing.T) {
33 | l := newLocalListener()
34 | testBytes := []byte("Hello world\n")
35 | s := &Server{
36 | Handler: func(s Session) {
37 | s.Write(testBytes)
38 | time.Sleep(50 * time.Millisecond)
39 | },
40 | }
41 | go func() {
42 | err := s.Serve(l)
43 | if err != nil && err != ErrServerClosed {
44 | t.Fatal(err)
45 | }
46 | }()
47 | sessDone := make(chan struct{})
48 | sess, _, cleanup := newClientSession(t, l.Addr().String(), nil)
49 | go func() {
50 | defer cleanup()
51 | defer close(sessDone)
52 | var stdout bytes.Buffer
53 | sess.Stdout = &stdout
54 | if err := sess.Run(""); err != nil {
55 | t.Fatal(err)
56 | }
57 | if !bytes.Equal(stdout.Bytes(), testBytes) {
58 | t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes())
59 | }
60 | }()
61 |
62 | srvDone := make(chan struct{})
63 | go func() {
64 | defer close(srvDone)
65 | err := s.Shutdown(context.Background())
66 | if err != nil {
67 | t.Fatal(err)
68 | }
69 | }()
70 |
71 | timeout := time.After(2 * time.Second)
72 | select {
73 | case <-timeout:
74 | t.Fatal("timeout")
75 | return
76 | case <-srvDone:
77 | // TODO: add timeout for sessDone
78 | <-sessDone
79 | return
80 | }
81 | }
82 |
83 | func TestServerClose(t *testing.T) {
84 | l := newLocalListener()
85 | s := &Server{
86 | Handler: func(s Session) {
87 | time.Sleep(5 * time.Second)
88 | },
89 | }
90 | go func() {
91 | err := s.Serve(l)
92 | if err != nil && err != ErrServerClosed {
93 | t.Fatal(err)
94 | }
95 | }()
96 |
97 | clientDoneChan := make(chan struct{})
98 | closeDoneChan := make(chan struct{})
99 |
100 | sess, _, cleanup := newClientSession(t, l.Addr().String(), nil)
101 | go func() {
102 | defer cleanup()
103 | defer close(clientDoneChan)
104 | <-closeDoneChan
105 | if err := sess.Run(""); err != nil && err != io.EOF {
106 | t.Fatal(err)
107 | }
108 | }()
109 |
110 | go func() {
111 | err := s.Close()
112 | if err != nil {
113 | t.Fatal(err)
114 | }
115 | close(closeDoneChan)
116 | }()
117 |
118 | timeout := time.After(100 * time.Millisecond)
119 | select {
120 | case <-timeout:
121 | t.Error("timeout")
122 | return
123 | case <-s.getDoneChan():
124 | <-clientDoneChan
125 | return
126 | }
127 | }
128 |
129 | func TestServerHandshakeTimeout(t *testing.T) {
130 | l := newLocalListener()
131 |
132 | s := &Server{
133 | HandshakeTimeout: time.Millisecond,
134 | }
135 | go func() {
136 | if err := s.Serve(l); err != nil {
137 | t.Error(err)
138 | }
139 | }()
140 |
141 | conn, err := net.Dial("tcp", l.Addr().String())
142 | if err != nil {
143 | t.Fatal(err)
144 | }
145 | defer conn.Close()
146 |
147 | ch := make(chan struct{})
148 | go func() {
149 | defer close(ch)
150 | io.Copy(io.Discard, conn)
151 | }()
152 |
153 | select {
154 | case <-ch:
155 | return
156 | case <-time.After(time.Second):
157 | t.Fatal("client connection was not force-closed")
158 | return
159 | }
160 | }
161 |
--------------------------------------------------------------------------------
/session.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "sync"
9 |
10 | "github.com/anmitsu/go-shlex"
11 | gossh "golang.org/x/crypto/ssh"
12 | )
13 |
14 | // Session provides access to information about an SSH session and methods
15 | // to read and write to the SSH channel with an embedded Channel interface from
16 | // crypto/ssh.
17 | //
18 | // When Command() returns an empty slice, the user requested a shell. Otherwise
19 | // the user is performing an exec with those command arguments.
20 | //
21 | // TODO: Signals
22 | type Session interface {
23 | gossh.Channel
24 |
25 | // User returns the username used when establishing the SSH connection.
26 | User() string
27 |
28 | // RemoteAddr returns the net.Addr of the client side of the connection.
29 | RemoteAddr() net.Addr
30 |
31 | // LocalAddr returns the net.Addr of the server side of the connection.
32 | LocalAddr() net.Addr
33 |
34 | // Environ returns a copy of strings representing the environment set by the
35 | // user for this session, in the form "key=value".
36 | Environ() []string
37 |
38 | // Exit sends an exit status and then closes the session.
39 | Exit(code int) error
40 |
41 | // Command returns a shell parsed slice of arguments that were provided by the
42 | // user. Shell parsing splits the command string according to POSIX shell rules,
43 | // which considers quoting not just whitespace.
44 | Command() []string
45 |
46 | // RawCommand returns the exact command that was provided by the user.
47 | RawCommand() string
48 |
49 | // Subsystem returns the subsystem requested by the user.
50 | Subsystem() string
51 |
52 | // PublicKey returns the PublicKey used to authenticate. If a public key was not
53 | // used it will return nil.
54 | PublicKey() PublicKey
55 |
56 | // Context returns the connection's context. The returned context is always
57 | // non-nil and holds the same data as the Context passed into auth
58 | // handlers and callbacks.
59 | //
60 | // The context is canceled when the client's connection closes or I/O
61 | // operation fails.
62 | Context() Context
63 |
64 | // Permissions returns a copy of the Permissions object that was available for
65 | // setup in the auth handlers via the Context.
66 | Permissions() Permissions
67 |
68 | // Pty returns PTY information, a channel of window size changes, and a boolean
69 | // of whether or not a PTY was accepted for this session.
70 | Pty() (Pty, <-chan Window, bool)
71 |
72 | // Signals registers a channel to receive signals sent from the client. The
73 | // channel must handle signal sends or it will block the SSH request loop.
74 | // Registering nil will unregister the channel from signal sends. During the
75 | // time no channel is registered signals are buffered up to a reasonable amount.
76 | // If there are buffered signals when a channel is registered, they will be
77 | // sent in order on the channel immediately after registering.
78 | Signals(c chan<- Signal)
79 |
80 | // Break regisers a channel to receive notifications of break requests sent
81 | // from the client. The channel must handle break requests, or it will block
82 | // the request handling loop. Registering nil will unregister the channel.
83 | // During the time that no channel is registered, breaks are ignored.
84 | Break(c chan<- bool)
85 | }
86 |
87 | // maxSigBufSize is how many signals will be buffered
88 | // when there is no signal channel specified
89 | const maxSigBufSize = 128
90 |
91 | func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
92 | ch, reqs, err := newChan.Accept()
93 | if err != nil {
94 | // TODO: trigger event callback
95 | return
96 | }
97 | sess := &session{
98 | Channel: ch,
99 | conn: conn,
100 | handler: srv.Handler,
101 | ptyCb: srv.PtyCallback,
102 | sessReqCb: srv.SessionRequestCallback,
103 | subsystemHandlers: srv.SubsystemHandlers,
104 | ctx: ctx,
105 | }
106 | sess.handleRequests(reqs)
107 | }
108 |
109 | type session struct {
110 | sync.Mutex
111 | gossh.Channel
112 | conn *gossh.ServerConn
113 | handler Handler
114 | subsystemHandlers map[string]SubsystemHandler
115 | handled bool
116 | exited bool
117 | pty *Pty
118 | winch chan Window
119 | env []string
120 | ptyCb PtyCallback
121 | sessReqCb SessionRequestCallback
122 | rawCmd string
123 | subsystem string
124 | ctx Context
125 | sigCh chan<- Signal
126 | sigBuf []Signal
127 | breakCh chan<- bool
128 | }
129 |
130 | func (sess *session) Write(p []byte) (n int, err error) {
131 | if sess.pty != nil {
132 | m := len(p)
133 | // normalize \n to \r\n when pty is accepted.
134 | // this is a hardcoded shortcut since we don't support terminal modes.
135 | p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1)
136 | p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1)
137 | n, err = sess.Channel.Write(p)
138 | if n > m {
139 | n = m
140 | }
141 | return
142 | }
143 | return sess.Channel.Write(p)
144 | }
145 |
146 | func (sess *session) PublicKey() PublicKey {
147 | sessionkey := sess.ctx.Value(ContextKeyPublicKey)
148 | if sessionkey == nil {
149 | return nil
150 | }
151 | return sessionkey.(PublicKey)
152 | }
153 |
154 | func (sess *session) Permissions() Permissions {
155 | // use context permissions because its properly
156 | // wrapped and easier to dereference
157 | perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions)
158 | return *perms
159 | }
160 |
161 | func (sess *session) Context() Context {
162 | return sess.ctx
163 | }
164 |
165 | func (sess *session) Exit(code int) error {
166 | sess.Lock()
167 | defer sess.Unlock()
168 | if sess.exited {
169 | return errors.New("Session.Exit called multiple times")
170 | }
171 | sess.exited = true
172 |
173 | status := struct{ Status uint32 }{uint32(code)}
174 | _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status))
175 | if err != nil {
176 | return err
177 | }
178 | return sess.Close()
179 | }
180 |
181 | func (sess *session) User() string {
182 | return sess.conn.User()
183 | }
184 |
185 | func (sess *session) RemoteAddr() net.Addr {
186 | return sess.conn.RemoteAddr()
187 | }
188 |
189 | func (sess *session) LocalAddr() net.Addr {
190 | return sess.conn.LocalAddr()
191 | }
192 |
193 | func (sess *session) Environ() []string {
194 | return append([]string(nil), sess.env...)
195 | }
196 |
197 | func (sess *session) RawCommand() string {
198 | return sess.rawCmd
199 | }
200 |
201 | func (sess *session) Command() []string {
202 | cmd, _ := shlex.Split(sess.rawCmd, true)
203 | return append([]string(nil), cmd...)
204 | }
205 |
206 | func (sess *session) Subsystem() string {
207 | return sess.subsystem
208 | }
209 |
210 | func (sess *session) Pty() (Pty, <-chan Window, bool) {
211 | if sess.pty != nil {
212 | return *sess.pty, sess.winch, true
213 | }
214 | return Pty{}, sess.winch, false
215 | }
216 |
217 | func (sess *session) Signals(c chan<- Signal) {
218 | sess.Lock()
219 | defer sess.Unlock()
220 | sess.sigCh = c
221 | if len(sess.sigBuf) > 0 {
222 | go func() {
223 | for _, sig := range sess.sigBuf {
224 | sess.sigCh <- sig
225 | }
226 | }()
227 | }
228 | }
229 |
230 | func (sess *session) Break(c chan<- bool) {
231 | sess.Lock()
232 | defer sess.Unlock()
233 | sess.breakCh = c
234 | }
235 |
236 | func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
237 | for req := range reqs {
238 | switch req.Type {
239 | case "shell", "exec":
240 | if sess.handled {
241 | req.Reply(false, nil)
242 | continue
243 | }
244 |
245 | var payload = struct{ Value string }{}
246 | gossh.Unmarshal(req.Payload, &payload)
247 | sess.rawCmd = payload.Value
248 |
249 | // If there's a session policy callback, we need to confirm before
250 | // accepting the session.
251 | if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) {
252 | sess.rawCmd = ""
253 | req.Reply(false, nil)
254 | continue
255 | }
256 |
257 | sess.handled = true
258 | req.Reply(true, nil)
259 |
260 | go func() {
261 | sess.handler(sess)
262 | sess.Exit(0)
263 | }()
264 | case "subsystem":
265 | if sess.handled {
266 | req.Reply(false, nil)
267 | continue
268 | }
269 |
270 | var payload = struct{ Value string }{}
271 | gossh.Unmarshal(req.Payload, &payload)
272 | sess.subsystem = payload.Value
273 |
274 | // If there's a session policy callback, we need to confirm before
275 | // accepting the session.
276 | if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) {
277 | sess.rawCmd = ""
278 | req.Reply(false, nil)
279 | continue
280 | }
281 |
282 | handler := sess.subsystemHandlers[payload.Value]
283 | if handler == nil {
284 | handler = sess.subsystemHandlers["default"]
285 | }
286 | if handler == nil {
287 | req.Reply(false, nil)
288 | continue
289 | }
290 |
291 | sess.handled = true
292 | req.Reply(true, nil)
293 |
294 | go func() {
295 | handler(sess)
296 | sess.Exit(0)
297 | }()
298 | case "env":
299 | if sess.handled {
300 | req.Reply(false, nil)
301 | continue
302 | }
303 | var kv struct{ Key, Value string }
304 | gossh.Unmarshal(req.Payload, &kv)
305 | sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
306 | req.Reply(true, nil)
307 | case "signal":
308 | var payload struct{ Signal string }
309 | gossh.Unmarshal(req.Payload, &payload)
310 | sess.Lock()
311 | if sess.sigCh != nil {
312 | sess.sigCh <- Signal(payload.Signal)
313 | } else {
314 | if len(sess.sigBuf) < maxSigBufSize {
315 | sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
316 | }
317 | }
318 | sess.Unlock()
319 | case "pty-req":
320 | if sess.handled || sess.pty != nil {
321 | req.Reply(false, nil)
322 | continue
323 | }
324 | ptyReq, ok := parsePtyRequest(req.Payload)
325 | if !ok {
326 | req.Reply(false, nil)
327 | continue
328 | }
329 | if sess.ptyCb != nil {
330 | ok := sess.ptyCb(sess.ctx, ptyReq)
331 | if !ok {
332 | req.Reply(false, nil)
333 | continue
334 | }
335 | }
336 | sess.pty = &ptyReq
337 | sess.winch = make(chan Window, 1)
338 | sess.winch <- ptyReq.Window
339 | defer func() {
340 | // when reqs is closed
341 | close(sess.winch)
342 | }()
343 | req.Reply(ok, nil)
344 | case "window-change":
345 | if sess.pty == nil {
346 | req.Reply(false, nil)
347 | continue
348 | }
349 | win, ok := parseWinchRequest(req.Payload)
350 | if ok {
351 | sess.pty.Window = win
352 | sess.winch <- win
353 | }
354 | req.Reply(ok, nil)
355 | case agentRequestType:
356 | // TODO: option/callback to allow agent forwarding
357 | SetAgentRequested(sess.ctx)
358 | req.Reply(true, nil)
359 | case "break":
360 | ok := false
361 | sess.Lock()
362 | if sess.breakCh != nil {
363 | sess.breakCh <- true
364 | ok = true
365 | }
366 | req.Reply(ok, nil)
367 | sess.Unlock()
368 | default:
369 | // TODO: debug log
370 | req.Reply(false, nil)
371 | }
372 | }
373 | }
374 |
--------------------------------------------------------------------------------
/session_test.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "net"
8 | "testing"
9 |
10 | gossh "golang.org/x/crypto/ssh"
11 | )
12 |
13 | func (srv *Server) serveOnce(l net.Listener) error {
14 | srv.ensureHandlers()
15 | if err := srv.ensureHostSigner(); err != nil {
16 | return err
17 | }
18 | conn, e := l.Accept()
19 | if e != nil {
20 | return e
21 | }
22 | srv.ChannelHandlers = map[string]ChannelHandler{
23 | "session": DefaultSessionHandler,
24 | "direct-tcpip": DirectTCPIPHandler,
25 | }
26 | srv.HandleConn(conn)
27 | return nil
28 | }
29 |
30 | func newLocalListener() net.Listener {
31 | l, err := net.Listen("tcp", "127.0.0.1:0")
32 | if err != nil {
33 | if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
34 | panic(fmt.Sprintf("failed to listen on a port: %v", err))
35 | }
36 | }
37 | return l
38 | }
39 |
40 | func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
41 | if config == nil {
42 | config = &gossh.ClientConfig{
43 | User: "testuser",
44 | Auth: []gossh.AuthMethod{
45 | gossh.Password("testpass"),
46 | },
47 | }
48 | }
49 | if config.HostKeyCallback == nil {
50 | config.HostKeyCallback = gossh.InsecureIgnoreHostKey()
51 | }
52 | client, err := gossh.Dial("tcp", addr, config)
53 | if err != nil {
54 | t.Fatal(err)
55 | }
56 | session, err := client.NewSession()
57 | if err != nil {
58 | t.Fatal(err)
59 | }
60 | return session, client, func() {
61 | session.Close()
62 | client.Close()
63 | }
64 | }
65 |
66 | func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) {
67 | l := newLocalListener()
68 | go srv.serveOnce(l)
69 | return newClientSession(t, l.Addr().String(), cfg)
70 | }
71 |
72 | func TestStdout(t *testing.T) {
73 | t.Parallel()
74 | testBytes := []byte("Hello world\n")
75 | session, _, cleanup := newTestSession(t, &Server{
76 | Handler: func(s Session) {
77 | s.Write(testBytes)
78 | },
79 | }, nil)
80 | defer cleanup()
81 | var stdout bytes.Buffer
82 | session.Stdout = &stdout
83 | if err := session.Run(""); err != nil {
84 | t.Fatal(err)
85 | }
86 | if !bytes.Equal(stdout.Bytes(), testBytes) {
87 | t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes)
88 | }
89 | }
90 |
91 | func TestStderr(t *testing.T) {
92 | t.Parallel()
93 | testBytes := []byte("Hello world\n")
94 | session, _, cleanup := newTestSession(t, &Server{
95 | Handler: func(s Session) {
96 | s.Stderr().Write(testBytes)
97 | },
98 | }, nil)
99 | defer cleanup()
100 | var stderr bytes.Buffer
101 | session.Stderr = &stderr
102 | if err := session.Run(""); err != nil {
103 | t.Fatal(err)
104 | }
105 | if !bytes.Equal(stderr.Bytes(), testBytes) {
106 | t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes)
107 | }
108 | }
109 |
110 | func TestStdin(t *testing.T) {
111 | t.Parallel()
112 | testBytes := []byte("Hello world\n")
113 | session, _, cleanup := newTestSession(t, &Server{
114 | Handler: func(s Session) {
115 | io.Copy(s, s) // stdin back into stdout
116 | },
117 | }, nil)
118 | defer cleanup()
119 | var stdout bytes.Buffer
120 | session.Stdout = &stdout
121 | session.Stdin = bytes.NewBuffer(testBytes)
122 | if err := session.Run(""); err != nil {
123 | t.Fatal(err)
124 | }
125 | if !bytes.Equal(stdout.Bytes(), testBytes) {
126 | t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes)
127 | }
128 | }
129 |
130 | func TestUser(t *testing.T) {
131 | t.Parallel()
132 | testUser := []byte("progrium")
133 | session, _, cleanup := newTestSession(t, &Server{
134 | Handler: func(s Session) {
135 | io.WriteString(s, s.User())
136 | },
137 | }, &gossh.ClientConfig{
138 | User: string(testUser),
139 | })
140 | defer cleanup()
141 | var stdout bytes.Buffer
142 | session.Stdout = &stdout
143 | if err := session.Run(""); err != nil {
144 | t.Fatal(err)
145 | }
146 | if !bytes.Equal(stdout.Bytes(), testUser) {
147 | t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser))
148 | }
149 | }
150 |
151 | func TestDefaultExitStatusZero(t *testing.T) {
152 | t.Parallel()
153 | session, _, cleanup := newTestSession(t, &Server{
154 | Handler: func(s Session) {
155 | // noop
156 | },
157 | }, nil)
158 | defer cleanup()
159 | err := session.Run("")
160 | if err != nil {
161 | t.Fatalf("expected nil but got %v", err)
162 | }
163 | }
164 |
165 | func TestExplicitExitStatusZero(t *testing.T) {
166 | t.Parallel()
167 | session, _, cleanup := newTestSession(t, &Server{
168 | Handler: func(s Session) {
169 | s.Exit(0)
170 | },
171 | }, nil)
172 | defer cleanup()
173 | err := session.Run("")
174 | if err != nil {
175 | t.Fatalf("expected nil but got %v", err)
176 | }
177 | }
178 |
179 | func TestExitStatusNonZero(t *testing.T) {
180 | t.Parallel()
181 | session, _, cleanup := newTestSession(t, &Server{
182 | Handler: func(s Session) {
183 | s.Exit(1)
184 | },
185 | }, nil)
186 | defer cleanup()
187 | err := session.Run("")
188 | e, ok := err.(*gossh.ExitError)
189 | if !ok {
190 | t.Fatalf("expected ExitError but got %T", err)
191 | }
192 | if e.ExitStatus() != 1 {
193 | t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1)
194 | }
195 | }
196 |
197 | func TestPty(t *testing.T) {
198 | t.Parallel()
199 | term := "xterm"
200 | winWidth := 40
201 | winHeight := 80
202 | done := make(chan bool)
203 | session, _, cleanup := newTestSession(t, &Server{
204 | Handler: func(s Session) {
205 | ptyReq, _, isPty := s.Pty()
206 | if !isPty {
207 | t.Fatalf("expected pty but none requested")
208 | }
209 | if ptyReq.Term != term {
210 | t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term)
211 | }
212 | if ptyReq.Window.Width != winWidth {
213 | t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width)
214 | }
215 | if ptyReq.Window.Height != winHeight {
216 | t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height)
217 | }
218 | close(done)
219 | },
220 | }, nil)
221 | defer cleanup()
222 | if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil {
223 | t.Fatalf("expected nil but got %v", err)
224 | }
225 | if err := session.Shell(); err != nil {
226 | t.Fatalf("expected nil but got %v", err)
227 | }
228 | <-done
229 | }
230 |
231 | func TestPtyResize(t *testing.T) {
232 | t.Parallel()
233 | winch0 := Window{40, 80}
234 | winch1 := Window{80, 160}
235 | winch2 := Window{20, 40}
236 | winches := make(chan Window)
237 | done := make(chan bool)
238 | session, _, cleanup := newTestSession(t, &Server{
239 | Handler: func(s Session) {
240 | ptyReq, winCh, isPty := s.Pty()
241 | if !isPty {
242 | t.Fatalf("expected pty but none requested")
243 | }
244 | if ptyReq.Window != winch0 {
245 | t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window)
246 | }
247 | for win := range winCh {
248 | winches <- win
249 | }
250 | close(done)
251 | },
252 | }, nil)
253 | defer cleanup()
254 | // winch0
255 | if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil {
256 | t.Fatalf("expected nil but got %v", err)
257 | }
258 | if err := session.Shell(); err != nil {
259 | t.Fatalf("expected nil but got %v", err)
260 | }
261 | gotWinch := <-winches
262 | if gotWinch != winch0 {
263 | t.Fatalf("expected window %#v but got %#v", winch0, gotWinch)
264 | }
265 | // winch1
266 | winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)}
267 | ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg))
268 | if err == nil && !ok {
269 | t.Fatalf("unexpected error or bad reply on send request")
270 | }
271 | gotWinch = <-winches
272 | if gotWinch != winch1 {
273 | t.Fatalf("expected window %#v but got %#v", winch1, gotWinch)
274 | }
275 | // winch2
276 | winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)}
277 | ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg))
278 | if err == nil && !ok {
279 | t.Fatalf("unexpected error or bad reply on send request")
280 | }
281 | gotWinch = <-winches
282 | if gotWinch != winch2 {
283 | t.Fatalf("expected window %#v but got %#v", winch2, gotWinch)
284 | }
285 | session.Close()
286 | <-done
287 | }
288 |
289 | func TestSignals(t *testing.T) {
290 | t.Parallel()
291 |
292 | // errChan lets us get errors back from the session
293 | errChan := make(chan error, 5)
294 |
295 | // doneChan lets us specify that we should exit.
296 | doneChan := make(chan interface{})
297 |
298 | session, _, cleanup := newTestSession(t, &Server{
299 | Handler: func(s Session) {
300 | // We need to use a buffered channel here, otherwise it's possible for the
301 | // second call to Signal to get discarded.
302 | signals := make(chan Signal, 2)
303 | s.Signals(signals)
304 |
305 | select {
306 | case sig := <-signals:
307 | if sig != SIGINT {
308 | errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig)
309 | return
310 | }
311 | case <-doneChan:
312 | errChan <- fmt.Errorf("Unexpected done")
313 | return
314 | }
315 |
316 | select {
317 | case sig := <-signals:
318 | if sig != SIGKILL {
319 | errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig)
320 | return
321 | }
322 | case <-doneChan:
323 | errChan <- fmt.Errorf("Unexpected done")
324 | return
325 | }
326 | },
327 | }, nil)
328 | defer cleanup()
329 |
330 | go func() {
331 | session.Signal(gossh.SIGINT)
332 | session.Signal(gossh.SIGKILL)
333 | }()
334 |
335 | go func() {
336 | errChan <- session.Run("")
337 | }()
338 |
339 | err := <-errChan
340 | close(doneChan)
341 |
342 | if err != nil {
343 | t.Fatalf("expected nil but got %v", err)
344 | }
345 | }
346 |
347 | func TestBreakWithChanRegistered(t *testing.T) {
348 | t.Parallel()
349 |
350 | // errChan lets us get errors back from the session
351 | errChan := make(chan error, 5)
352 |
353 | // doneChan lets us specify that we should exit.
354 | doneChan := make(chan interface{})
355 |
356 | breakChan := make(chan bool)
357 |
358 | readyToReceiveBreak := make(chan bool)
359 |
360 | session, _, cleanup := newTestSession(t, &Server{
361 | Handler: func(s Session) {
362 | s.Break(breakChan) // register a break channel with the session
363 | readyToReceiveBreak <- true
364 |
365 | select {
366 | case <-breakChan:
367 | io.WriteString(s, "break")
368 | case <-doneChan:
369 | errChan <- fmt.Errorf("Unexpected done")
370 | return
371 | }
372 | },
373 | }, nil)
374 | defer cleanup()
375 | var stdout bytes.Buffer
376 | session.Stdout = &stdout
377 | go func() {
378 | errChan <- session.Run("")
379 | }()
380 |
381 | <-readyToReceiveBreak
382 | ok, err := session.SendRequest("break", true, nil)
383 | if err != nil {
384 | t.Fatalf("expected nil but got %v", err)
385 | }
386 | if ok != true {
387 | t.Fatalf("expected true but got %v", ok)
388 | }
389 |
390 | err = <-errChan
391 | close(doneChan)
392 |
393 | if err != nil {
394 | t.Fatalf("expected nil but got %v", err)
395 | }
396 | if !bytes.Equal(stdout.Bytes(), []byte("break")) {
397 | t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes())
398 | }
399 | }
400 |
401 | func TestBreakWithoutChanRegistered(t *testing.T) {
402 | t.Parallel()
403 |
404 | // errChan lets us get errors back from the session
405 | errChan := make(chan error, 5)
406 |
407 | // doneChan lets us specify that we should exit.
408 | doneChan := make(chan interface{})
409 |
410 | waitUntilAfterBreakSent := make(chan bool)
411 |
412 | session, _, cleanup := newTestSession(t, &Server{
413 | Handler: func(s Session) {
414 | <-waitUntilAfterBreakSent
415 | },
416 | }, nil)
417 | defer cleanup()
418 | var stdout bytes.Buffer
419 | session.Stdout = &stdout
420 | go func() {
421 | errChan <- session.Run("")
422 | }()
423 |
424 | ok, err := session.SendRequest("break", true, nil)
425 | if err != nil {
426 | t.Fatalf("expected nil but got %v", err)
427 | }
428 | if ok != false {
429 | t.Fatalf("expected false but got %v", ok)
430 | }
431 | waitUntilAfterBreakSent <- true
432 |
433 | err = <-errChan
434 | close(doneChan)
435 | if err != nil {
436 | t.Fatalf("expected nil but got %v", err)
437 | }
438 | }
439 |
--------------------------------------------------------------------------------
/ssh.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "crypto/subtle"
5 | "net"
6 |
7 | gossh "golang.org/x/crypto/ssh"
8 | )
9 |
10 | type Signal string
11 |
12 | // POSIX signals as listed in RFC 4254 Section 6.10.
13 | const (
14 | SIGABRT Signal = "ABRT"
15 | SIGALRM Signal = "ALRM"
16 | SIGFPE Signal = "FPE"
17 | SIGHUP Signal = "HUP"
18 | SIGILL Signal = "ILL"
19 | SIGINT Signal = "INT"
20 | SIGKILL Signal = "KILL"
21 | SIGPIPE Signal = "PIPE"
22 | SIGQUIT Signal = "QUIT"
23 | SIGSEGV Signal = "SEGV"
24 | SIGTERM Signal = "TERM"
25 | SIGUSR1 Signal = "USR1"
26 | SIGUSR2 Signal = "USR2"
27 | )
28 |
29 | // DefaultHandler is the default Handler used by Serve.
30 | var DefaultHandler Handler
31 |
32 | // Option is a functional option handler for Server.
33 | type Option func(*Server) error
34 |
35 | // Handler is a callback for handling established SSH sessions.
36 | type Handler func(Session)
37 |
38 | // BannerHandler is a callback for displaying the server banner.
39 | type BannerHandler func(ctx Context) string
40 |
41 | // PublicKeyHandler is a callback for performing public key authentication.
42 | type PublicKeyHandler func(ctx Context, key PublicKey) bool
43 |
44 | // PasswordHandler is a callback for performing password authentication.
45 | type PasswordHandler func(ctx Context, password string) bool
46 |
47 | // KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication.
48 | type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool
49 |
50 | // PtyCallback is a hook for allowing PTY sessions.
51 | type PtyCallback func(ctx Context, pty Pty) bool
52 |
53 | // SessionRequestCallback is a callback for allowing or denying SSH sessions.
54 | type SessionRequestCallback func(sess Session, requestType string) bool
55 |
56 | // ConnCallback is a hook for new connections before handling.
57 | // It allows wrapping for timeouts and limiting by returning
58 | // the net.Conn that will be used as the underlying connection.
59 | type ConnCallback func(ctx Context, conn net.Conn) net.Conn
60 |
61 | // LocalPortForwardingCallback is a hook for allowing port forwarding
62 | type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool
63 |
64 | // ReversePortForwardingCallback is a hook for allowing reverse port forwarding
65 | type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool
66 |
67 | // ServerConfigCallback is a hook for creating custom default server configs
68 | type ServerConfigCallback func(ctx Context) *gossh.ServerConfig
69 |
70 | // ConnectionFailedCallback is a hook for reporting failed connections
71 | // Please note: the net.Conn is likely to be closed at this point
72 | type ConnectionFailedCallback func(conn net.Conn, err error)
73 |
74 | // Window represents the size of a PTY window.
75 | type Window struct {
76 | Width int
77 | Height int
78 | }
79 |
80 | // Pty represents a PTY request and configuration.
81 | type Pty struct {
82 | Term string
83 | Window Window
84 | // HELP WANTED: terminal modes!
85 | }
86 |
87 | // Serve accepts incoming SSH connections on the listener l, creating a new
88 | // connection goroutine for each. The connection goroutines read requests and
89 | // then calls handler to handle sessions. Handler is typically nil, in which
90 | // case the DefaultHandler is used.
91 | func Serve(l net.Listener, handler Handler, options ...Option) error {
92 | srv := &Server{Handler: handler}
93 | for _, option := range options {
94 | if err := srv.SetOption(option); err != nil {
95 | return err
96 | }
97 | }
98 | return srv.Serve(l)
99 | }
100 |
101 | // ListenAndServe listens on the TCP network address addr and then calls Serve
102 | // with handler to handle sessions on incoming connections. Handler is typically
103 | // nil, in which case the DefaultHandler is used.
104 | func ListenAndServe(addr string, handler Handler, options ...Option) error {
105 | srv := &Server{Addr: addr, Handler: handler}
106 | for _, option := range options {
107 | if err := srv.SetOption(option); err != nil {
108 | return err
109 | }
110 | }
111 | return srv.ListenAndServe()
112 | }
113 |
114 | // Handle registers the handler as the DefaultHandler.
115 | func Handle(handler Handler) {
116 | DefaultHandler = handler
117 | }
118 |
119 | // KeysEqual is constant time compare of the keys to avoid timing attacks.
120 | func KeysEqual(ak, bk PublicKey) bool {
121 | // avoid panic if one of the keys is nil, return false instead
122 | if ak == nil || bk == nil {
123 | return false
124 | }
125 |
126 | a := ak.Marshal()
127 | b := bk.Marshal()
128 | return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1)
129 | }
130 |
--------------------------------------------------------------------------------
/ssh_test.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestKeysEqual(t *testing.T) {
8 | defer func() {
9 | if r := recover(); r != nil {
10 | t.Errorf("The code did panic")
11 | }
12 | }()
13 |
14 | if KeysEqual(nil, nil) {
15 | t.Error("two nil keys should not return true")
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/tcpip.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "io"
5 | "log"
6 | "net"
7 | "strconv"
8 | "sync"
9 |
10 | gossh "golang.org/x/crypto/ssh"
11 | )
12 |
13 | const (
14 | forwardedTCPChannelType = "forwarded-tcpip"
15 | )
16 |
17 | // direct-tcpip data struct as specified in RFC4254, Section 7.2
18 | type localForwardChannelData struct {
19 | DestAddr string
20 | DestPort uint32
21 |
22 | OriginAddr string
23 | OriginPort uint32
24 | }
25 |
26 | // DirectTCPIPHandler can be enabled by adding it to the server's
27 | // ChannelHandlers under direct-tcpip.
28 | func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) {
29 | d := localForwardChannelData{}
30 | if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
31 | newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
32 | return
33 | }
34 |
35 | if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) {
36 | newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
37 | return
38 | }
39 |
40 | dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10))
41 |
42 | var dialer net.Dialer
43 | dconn, err := dialer.DialContext(ctx, "tcp", dest)
44 | if err != nil {
45 | newChan.Reject(gossh.ConnectionFailed, err.Error())
46 | return
47 | }
48 |
49 | ch, reqs, err := newChan.Accept()
50 | if err != nil {
51 | dconn.Close()
52 | return
53 | }
54 | go gossh.DiscardRequests(reqs)
55 |
56 | go func() {
57 | defer ch.Close()
58 | defer dconn.Close()
59 | io.Copy(ch, dconn)
60 | }()
61 | go func() {
62 | defer ch.Close()
63 | defer dconn.Close()
64 | io.Copy(dconn, ch)
65 | }()
66 | }
67 |
68 | type remoteForwardRequest struct {
69 | BindAddr string
70 | BindPort uint32
71 | }
72 |
73 | type remoteForwardSuccess struct {
74 | BindPort uint32
75 | }
76 |
77 | type remoteForwardCancelRequest struct {
78 | BindAddr string
79 | BindPort uint32
80 | }
81 |
82 | type remoteForwardChannelData struct {
83 | DestAddr string
84 | DestPort uint32
85 | OriginAddr string
86 | OriginPort uint32
87 | }
88 |
89 | // ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and
90 | // adding the HandleSSHRequest callback to the server's RequestHandlers under
91 | // tcpip-forward and cancel-tcpip-forward.
92 | type ForwardedTCPHandler struct {
93 | forwards map[string]net.Listener
94 | sync.Mutex
95 | }
96 |
97 | func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) {
98 | h.Lock()
99 | if h.forwards == nil {
100 | h.forwards = make(map[string]net.Listener)
101 | }
102 | h.Unlock()
103 | conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn)
104 | switch req.Type {
105 | case "tcpip-forward":
106 | var reqPayload remoteForwardRequest
107 | if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
108 | // TODO: log parse failure
109 | return false, []byte{}
110 | }
111 | if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) {
112 | return false, []byte("port forwarding is disabled")
113 | }
114 | addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
115 | ln, err := net.Listen("tcp", addr)
116 | if err != nil {
117 | // TODO: log listen failure
118 | return false, []byte{}
119 | }
120 | _, destPortStr, _ := net.SplitHostPort(ln.Addr().String())
121 | destPort, _ := strconv.Atoi(destPortStr)
122 | h.Lock()
123 | h.forwards[addr] = ln
124 | h.Unlock()
125 | go func() {
126 | <-ctx.Done()
127 | h.Lock()
128 | ln, ok := h.forwards[addr]
129 | h.Unlock()
130 | if ok {
131 | ln.Close()
132 | }
133 | }()
134 | go func() {
135 | for {
136 | c, err := ln.Accept()
137 | if err != nil {
138 | // TODO: log accept failure
139 | break
140 | }
141 | originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
142 | originPort, _ := strconv.Atoi(orignPortStr)
143 | payload := gossh.Marshal(&remoteForwardChannelData{
144 | DestAddr: reqPayload.BindAddr,
145 | DestPort: uint32(destPort),
146 | OriginAddr: originAddr,
147 | OriginPort: uint32(originPort),
148 | })
149 | go func() {
150 | ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
151 | if err != nil {
152 | // TODO: log failure to open channel
153 | log.Println(err)
154 | c.Close()
155 | return
156 | }
157 | go gossh.DiscardRequests(reqs)
158 | go func() {
159 | defer ch.Close()
160 | defer c.Close()
161 | io.Copy(ch, c)
162 | }()
163 | go func() {
164 | defer ch.Close()
165 | defer c.Close()
166 | io.Copy(c, ch)
167 | }()
168 | }()
169 | }
170 | h.Lock()
171 | delete(h.forwards, addr)
172 | h.Unlock()
173 | }()
174 | return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
175 |
176 | case "cancel-tcpip-forward":
177 | var reqPayload remoteForwardCancelRequest
178 | if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
179 | // TODO: log parse failure
180 | return false, []byte{}
181 | }
182 | addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
183 | h.Lock()
184 | ln, ok := h.forwards[addr]
185 | h.Unlock()
186 | if ok {
187 | ln.Close()
188 | }
189 | return true, nil
190 | default:
191 | return false, nil
192 | }
193 | }
194 |
--------------------------------------------------------------------------------
/tcpip_test.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "net"
7 | "strconv"
8 | "strings"
9 | "testing"
10 |
11 | gossh "golang.org/x/crypto/ssh"
12 | )
13 |
14 | var sampleServerResponse = []byte("Hello world")
15 |
16 | func sampleSocketServer() net.Listener {
17 | l := newLocalListener()
18 |
19 | go func() {
20 | conn, err := l.Accept()
21 | if err != nil {
22 | return
23 | }
24 | conn.Write(sampleServerResponse)
25 | conn.Close()
26 | }()
27 |
28 | return l
29 | }
30 |
31 | func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) {
32 | l := sampleSocketServer()
33 |
34 | _, client, cleanup := newTestSession(t, &Server{
35 | Handler: func(s Session) {},
36 | LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool {
37 | addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10))
38 | if addr != l.Addr().String() {
39 | panic("unexpected destinationHost: " + addr)
40 | }
41 | return forwardingEnabled
42 | },
43 | }, nil)
44 |
45 | return l, client, func() {
46 | cleanup()
47 | l.Close()
48 | }
49 | }
50 |
51 | func TestLocalPortForwardingWorks(t *testing.T) {
52 | t.Parallel()
53 |
54 | l, client, cleanup := newTestSessionWithForwarding(t, true)
55 | defer cleanup()
56 |
57 | conn, err := client.Dial("tcp", l.Addr().String())
58 | if err != nil {
59 | t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err)
60 | }
61 | result, err := io.ReadAll(conn)
62 | if err != nil {
63 | t.Fatal(err)
64 | }
65 | if !bytes.Equal(result, sampleServerResponse) {
66 | t.Fatalf("result = %#v; want %#v", result, sampleServerResponse)
67 | }
68 | }
69 |
70 | func TestLocalPortForwardingRespectsCallback(t *testing.T) {
71 | t.Parallel()
72 |
73 | l, client, cleanup := newTestSessionWithForwarding(t, false)
74 | defer cleanup()
75 |
76 | _, err := client.Dial("tcp", l.Addr().String())
77 | if err == nil {
78 | t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String())
79 | }
80 | if !strings.Contains(err.Error(), "port forwarding is disabled") {
81 | t.Fatalf("Expected permission error but got %#v", err)
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/util.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import (
4 | "crypto/rand"
5 | "crypto/rsa"
6 | "encoding/binary"
7 |
8 | "golang.org/x/crypto/ssh"
9 | )
10 |
11 | func generateSigner() (ssh.Signer, error) {
12 | key, err := rsa.GenerateKey(rand.Reader, 2048)
13 | if err != nil {
14 | return nil, err
15 | }
16 | return ssh.NewSignerFromKey(key)
17 | }
18 |
19 | func parsePtyRequest(s []byte) (pty Pty, ok bool) {
20 | term, s, ok := parseString(s)
21 | if !ok {
22 | return
23 | }
24 | width32, s, ok := parseUint32(s)
25 | if !ok {
26 | return
27 | }
28 | height32, _, ok := parseUint32(s)
29 | if !ok {
30 | return
31 | }
32 | pty = Pty{
33 | Term: term,
34 | Window: Window{
35 | Width: int(width32),
36 | Height: int(height32),
37 | },
38 | }
39 | return
40 | }
41 |
42 | func parseWinchRequest(s []byte) (win Window, ok bool) {
43 | width32, s, ok := parseUint32(s)
44 | if width32 < 1 {
45 | ok = false
46 | }
47 | if !ok {
48 | return
49 | }
50 | height32, _, ok := parseUint32(s)
51 | if height32 < 1 {
52 | ok = false
53 | }
54 | if !ok {
55 | return
56 | }
57 | win = Window{
58 | Width: int(width32),
59 | Height: int(height32),
60 | }
61 | return
62 | }
63 |
64 | func parseString(in []byte) (out string, rest []byte, ok bool) {
65 | if len(in) < 4 {
66 | return
67 | }
68 | length := binary.BigEndian.Uint32(in)
69 | if uint32(len(in)) < 4+length {
70 | return
71 | }
72 | out = string(in[4 : 4+length])
73 | rest = in[4+length:]
74 | ok = true
75 | return
76 | }
77 |
78 | func parseUint32(in []byte) (uint32, []byte, bool) {
79 | if len(in) < 4 {
80 | return 0, nil, false
81 | }
82 | return binary.BigEndian.Uint32(in), in[4:], true
83 | }
84 |
--------------------------------------------------------------------------------
/wrap.go:
--------------------------------------------------------------------------------
1 | package ssh
2 |
3 | import gossh "golang.org/x/crypto/ssh"
4 |
5 | // PublicKey is an abstraction of different types of public keys.
6 | type PublicKey interface {
7 | gossh.PublicKey
8 | }
9 |
10 | // The Permissions type holds fine-grained permissions that are specific to a
11 | // user or a specific authentication method for a user. Permissions, except for
12 | // "source-address", must be enforced in the server application layer, after
13 | // successful authentication.
14 | type Permissions struct {
15 | *gossh.Permissions
16 | }
17 |
18 | // A Signer can create signatures that verify against a public key.
19 | type Signer interface {
20 | gossh.Signer
21 | }
22 |
23 | // ParseAuthorizedKey parses a public key from an authorized_keys file used in
24 | // OpenSSH according to the sshd(8) manual page.
25 | func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
26 | return gossh.ParseAuthorizedKey(in)
27 | }
28 |
29 | // ParsePublicKey parses an SSH public key formatted for use in
30 | // the SSH wire protocol according to RFC 4253, section 6.6.
31 | func ParsePublicKey(in []byte) (out PublicKey, err error) {
32 | return gossh.ParsePublicKey(in)
33 | }
34 |
--------------------------------------------------------------------------------