├── .github
├── CODEOWNERS
├── dependabot.yml
└── workflows
│ ├── build.yml
│ ├── dependabot-sync.yml
│ ├── examples.yml
│ ├── goreleaser.yml
│ └── lint.yml
├── .gitignore
├── .golangci.yml
├── .goreleaser.yml
├── LICENSE
├── README.md
├── accesscontrol
├── accesscontrol.go
└── accesscontrol_test.go
├── activeterm
├── activeterm.go
└── activeterm_test.go
├── bubbletea
├── query.go
├── tea.go
├── tea_other.go
└── tea_unix.go
├── cmd.go
├── cmd_test.go
├── cmd_unix.go
├── cmd_windows.go
├── comment
├── comment.go
└── comment_test.go
├── elapsed
├── elapsed.go
└── elapsed_test.go
├── examples
├── .gitignore
├── README.md
├── banner
│ ├── banner.txt
│ └── main.go
├── bubbletea-exec
│ └── main.go
├── bubbletea
│ └── main.go
├── bubbleteaprogram
│ └── main.go
├── cobra
│ └── main.go
├── exec
│ ├── example.sh
│ └── main.go
├── forward
│ └── main.go
├── git
│ └── main.go
├── go.mod
├── go.sum
├── graceful-shutdown
│ └── main.go
├── identity
│ └── main.go
├── multi-auth
│ └── main.go
├── multichat
│ └── main.go
├── pty
│ └── main.go
├── scp
│ ├── main.go
│ └── testdata
│ │ └── .gitkeep
└── simple
│ └── main.go
├── git
├── git.go
└── git_test.go
├── go.mod
├── go.sum
├── logging
├── logging.go
└── logging_test.go
├── options.go
├── options_test.go
├── ratelimiter
├── ratelimiter.go
└── ratelimiter_test.go
├── recover
├── recover.go
└── recover_test.go
├── scp
├── copy_from_client.go
├── copy_to_client.go
├── filesystem.go
├── filesystem_test.go
├── fs.go
├── fs_test.go
├── limit_reader.go
├── limit_reader_test.go
├── scp.go
├── scp_test.go
└── testdata
│ ├── TestFS
│ ├── file.test
│ ├── glob.test
│ ├── recursive.test
│ ├── recursive_folder.test
│ └── recursive_glob.test
│ ├── TestFilesystem
│ └── scp_-f
│ │ ├── file.test
│ │ ├── glob.test
│ │ ├── recursive.test
│ │ ├── recursive_folder.test
│ │ └── recursive_glob.test
│ └── TestNoDirRootEntry.test
├── testdata
├── another-ca-cert.pub
├── authorized_keys
├── ca
├── ca.pub
├── expired-cert.pub
├── foo
├── foo.pub
├── invalid_authorized_keys
└── valid-cert.pub
├── testsession
├── testsession.go
└── testsession_test.go
├── wish.go
└── wish_test.go
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @charmbracelet/everyone
2 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | updates:
4 | - package-ecosystem: "gomod"
5 | directory: "/"
6 | schedule:
7 | interval: "weekly"
8 | day: "monday"
9 | time: "05:00"
10 | timezone: "America/New_York"
11 | labels:
12 | - "dependencies"
13 | commit-message:
14 | prefix: "chore"
15 | include: "scope"
16 |
17 | - package-ecosystem: "github-actions"
18 | directory: "/"
19 | schedule:
20 | interval: "weekly"
21 | day: "monday"
22 | time: "05:00"
23 | timezone: "America/New_York"
24 | labels:
25 | - "dependencies"
26 | commit-message:
27 | prefix: "chore"
28 | include: "scope"
29 |
30 | - package-ecosystem: "docker"
31 | directory: "/"
32 | schedule:
33 | interval: "weekly"
34 | day: "monday"
35 | time: "05:00"
36 | timezone: "America/New_York"
37 | labels:
38 | - "dependencies"
39 | commit-message:
40 | prefix: "chore"
41 | include: "scope"
42 |
43 | - package-ecosystem: "gomod"
44 | directory: "/examples"
45 | schedule:
46 | interval: "weekly"
47 | day: "monday"
48 | time: "05:00"
49 | timezone: "America/New_York"
50 | labels:
51 | - "dependencies"
52 | commit-message:
53 | prefix: "chore"
54 | include: "scope"
55 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | push:
5 | pull_request:
6 |
7 | jobs:
8 | build:
9 | uses: charmbracelet/meta/.github/workflows/build.yml@main
10 |
11 | codecov:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - uses: actions/setup-go@v5
16 | with:
17 | go-version: "stable"
18 | cache: true
19 | - run: go test -failfast -race -coverpkg=./... -covermode=atomic -coverprofile=coverage.txt ./... -timeout 5m
20 | - uses: codecov/codecov-action@v5
21 | with:
22 | file: ./coverage.txt
23 |
--------------------------------------------------------------------------------
/.github/workflows/dependabot-sync.yml:
--------------------------------------------------------------------------------
1 | name: dependabot-sync
2 | on:
3 | schedule:
4 | - cron: "0 0 * * 0" # every Sunday at midnight
5 | workflow_dispatch: # allows manual triggering
6 |
7 | permissions:
8 | contents: write
9 | pull-requests: write
10 |
11 | jobs:
12 | dependabot-sync:
13 | uses: charmbracelet/meta/.github/workflows/dependabot-sync.yml@main
14 | with:
15 | repo_name: ${{ github.event.repository.name }}
16 | secrets:
17 | gh_token: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
18 |
--------------------------------------------------------------------------------
/.github/workflows/examples.yml:
--------------------------------------------------------------------------------
1 | name: examples
2 |
3 | on:
4 | push:
5 | branches:
6 | - 'main'
7 | paths:
8 | - '.github/workflows/examples.yml'
9 | - './examples/go.mod'
10 | - './examples/go.sum'
11 | - './go.mod'
12 | - './go.sum'
13 | workflow_dispatch: {}
14 |
15 | jobs:
16 | tidy:
17 | permissions:
18 | contents: write
19 | runs-on: ubuntu-latest
20 | steps:
21 | - uses: actions/checkout@v4
22 | - uses: actions/setup-go@v5
23 | with:
24 | go-version: '^1'
25 | cache: true
26 | - shell: bash
27 | run: |
28 | (cd ./examples && go mod tidy)
29 | - uses: stefanzweifel/git-auto-commit-action@v5
30 | with:
31 | commit_message: "chore: go mod tidy examples"
32 | branch: main
33 | commit_user_name: actions-user
34 | commit_user_email: actions@github.com
35 |
36 |
37 |
--------------------------------------------------------------------------------
/.github/workflows/goreleaser.yml:
--------------------------------------------------------------------------------
1 | name: goreleaser
2 |
3 | on:
4 | push:
5 | tags:
6 | - v*.*.*
7 |
8 | concurrency:
9 | group: goreleaser
10 | cancel-in-progress: true
11 |
12 | jobs:
13 | goreleaser:
14 | uses: charmbracelet/meta/.github/workflows/goreleaser.yml@main
15 | secrets:
16 | docker_username: ${{ secrets.DOCKERHUB_USERNAME }}
17 | docker_token: ${{ secrets.DOCKERHUB_TOKEN }}
18 | gh_pat: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
19 | goreleaser_key: ${{ secrets.GORELEASER_KEY }}
20 | twitter_consumer_key: ${{ secrets.TWITTER_CONSUMER_KEY }}
21 | twitter_consumer_secret: ${{ secrets.TWITTER_CONSUMER_SECRET }}
22 | twitter_access_token: ${{ secrets.TWITTER_ACCESS_TOKEN }}
23 | twitter_access_token_secret: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }}
24 | mastodon_client_id: ${{ secrets.MASTODON_CLIENT_ID }}
25 | mastodon_client_secret: ${{ secrets.MASTODON_CLIENT_SECRET }}
26 | mastodon_access_token: ${{ secrets.MASTODON_ACCESS_TOKEN }}
27 | discord_webhook_id: ${{ secrets.DISCORD_WEBHOOK_ID }}
28 | discord_webhook_token: ${{ secrets.DISCORD_WEBHOOK_TOKEN }}
29 |
30 | # yaml-language-server: $schema=https://json.schemastore.org/github-workflow.json
31 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: lint
2 | on:
3 | push:
4 | pull_request:
5 |
6 | jobs:
7 | golangci:
8 | name: lint
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - uses: actions/setup-go@v5
13 | with:
14 | go-version: ^1
15 | cache: true
16 | - uses: golangci/golangci-lint-action@v8
17 | with:
18 | # Optional: golangci-lint command line arguments.
19 | args: --issues-exit-code=0
20 | # Optional: show only new issues if it's a pull request. The default value is `false`.
21 | only-new-issues: true
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | examples/bubbletea/bubbletea
2 | examples/bubbletea/.ssh
3 | examples/git/git
4 | examples/git/.ssh
5 | examples/git/.repos
6 | .repos
7 | .ssh
8 | coverage.txt
9 | id_ed25519
10 | id_ed25519.pub
11 |
12 | # MacOS specific
13 | .DS_Store
14 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | run:
2 | tests: false
3 |
4 | issues:
5 | include:
6 | - EXC0001
7 | - EXC0005
8 | - EXC0011
9 | - EXC0012
10 | - EXC0013
11 |
12 | max-issues-per-linter: 0
13 | max-same-issues: 0
14 |
15 | linters:
16 | enable:
17 | - bodyclose
18 | - depguard
19 | - dupl
20 | - goconst
21 | - godot
22 | - godox
23 | - gofumpt
24 | - goimports
25 | - goprintffuncname
26 | - gosec
27 | - misspell
28 | - prealloc
29 | - revive
30 | - rowserrcheck
31 | - sqlclosecheck
32 | - unconvert
33 | - unparam
34 | - whitespace
35 |
36 | linters-settings:
37 | depguard:
38 | rules:
39 | main:
40 | deny:
41 | - pkg: "github.com/gliderlabs/ssh"
42 | desc: "use github.com/charmbracelet/ssh instead"
43 |
--------------------------------------------------------------------------------
/.goreleaser.yml:
--------------------------------------------------------------------------------
1 | # yaml-language-server: $schema=https://goreleaser.com/static/schema-pro.json
2 | version: 2
3 | includes:
4 | - from_url:
5 | url: charmbracelet/meta/main/goreleaser-lib.yaml
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019-2023 Charmbracelet, Inc
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Wish
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | Make SSH apps, just like that! 💫
15 |
16 | SSH is an excellent platform for building remotely accessible applications. It
17 | offers:
18 | * secure communication without the hassle of HTTPS certificates
19 | * user identification with SSH keys
20 | * access from any terminal
21 |
22 | Powerful protocols like Git work over SSH and you can even render TUIs directly over an SSH connection.
23 |
24 | Wish is an SSH server with sensible defaults and a collection of middlewares that
25 | makes building SSH apps really easy. Wish is built on [gliderlabs/ssh][gliderlabs/ssh]
26 | and should be easy to integrate into any existing projects.
27 |
28 | ## What are SSH Apps?
29 |
30 | Usually, when we think about SSH, we think about remote shell access into servers,
31 | most commonly through `openssh-server`.
32 |
33 | That's a perfectly valid (and probably the most common) use of SSH, but it can do so much more than that.
34 | Just like HTTP, SMTP, FTP and others, SSH is a protocol!
35 | It is a cryptographic network protocol for operating network services securely over an unsecured network. [^1]
36 |
37 | [^1]: https://en.wikipedia.org/wiki/Secure_Shell
38 |
39 | That means, among other things, that we can write custom SSH servers without touching `openssh-server`,
40 | so we can securely do more things than just providing a shell.
41 |
42 | Wish is a library that helps writing these kind of apps using Go.
43 |
44 | ## Middleware
45 |
46 | Wish middlewares are analogous to those in several HTTP frameworks.
47 | They are essentially SSH handlers that you can use to do specific tasks,
48 | and then call the next middleware.
49 |
50 | Notice that middlewares are composed from first to last,
51 | which means the last one is executed first.
52 |
53 | ### Bubble Tea
54 |
55 | The [`bubbletea`](bubbletea) middleware makes it easy to serve any
56 | [Bubble Tea][bubbletea] application over SSH. Each SSH session will get their own
57 | `tea.Program` with the SSH pty input and output connected. Client window
58 | dimension and resize messages are also natively handled by the `tea.Program`.
59 |
60 | You can see a demo of the Wish middleware in action at: `ssh git.charm.sh`
61 |
62 | ### Git
63 |
64 | The [`git`](git) middleware adds `git` server functionality to any ssh server.
65 | It supports repo creation on initial push and custom public key based auth.
66 |
67 | This middleware requires that `git` is installed on the server.
68 |
69 | ### Logging
70 |
71 | The [`logging`](logging) middleware provides basic connection logging. Connects
72 | are logged with the remote address, invoked command, TERM setting, window
73 | dimensions and if the auth was public key based. Disconnect will log the remote
74 | address and connection duration.
75 |
76 | ### Access Control
77 |
78 | Not all applications will support general SSH connections. To restrict access
79 | to supported methods, you can use the [`activeterm`](activeterm) middleware to
80 | only allow connections with active terminals connected and the
81 | [`accesscontrol`](accesscontrol) middleware that lets you specify allowed
82 | commands.
83 |
84 | ## Default Server
85 |
86 | Wish includes the ability to easily create an always authenticating default SSH
87 | server with automatic server key generation.
88 |
89 | ## Examples
90 |
91 | There are examples for a standalone [Bubble Tea application](examples/bubbletea)
92 | and [Git server](examples/git) in the [examples](examples) folder.
93 |
94 | ## Apps Built With Wish
95 |
96 | * [Soft Serve](https://github.com/charmbracelet/soft-serve)
97 | * [Wishlist](https://github.com/charmbracelet/wishlist)
98 | * [SSHWordle](https://github.com/davidcroda/sshwordle)
99 | * [clidle](https://github.com/ajeetdsouza/clidle)
100 | * [ssh-warm-welcome](https://git.coopcloud.tech/decentral1se/ssh-warm-welcome)
101 |
102 | [bubbletea]: https://github.com/charmbracelet/bubbletea
103 | [gliderlabs/ssh]: https://github.com/gliderlabs/ssh
104 |
105 | ## Pro tip
106 |
107 | When building various Wish applications locally you can add the following to
108 | your `~/.ssh/config` to avoid having to clear out `localhost` entries in your
109 | `~/.ssh/known_hosts` file:
110 |
111 | ```
112 | Host localhost
113 | UserKnownHostsFile /dev/null
114 | ```
115 |
116 | ## How it works?
117 |
118 | Wish uses [gliderlabs/ssh][gliderlabs/ssh] to implement its SSH server, and
119 | OpenSSH is never used nor needed — you can even uninstall it if you want to.
120 |
121 | Incidentally, there's no risk of accidentally sharing a shell because there's no
122 | default behavior that does that on Wish.
123 |
124 | ## Running with SystemD
125 |
126 | If you want to run a Wish app with `systemd`, you can create an unit like so:
127 |
128 | `/etc/systemd/system/myapp.service`:
129 | ```service
130 | [Unit]
131 | Description=My App
132 | After=network.target
133 |
134 | [Service]
135 | Type=simple
136 | User=myapp
137 | Group=myapp
138 | WorkingDirectory=/home/myapp/
139 | ExecStart=/usr/bin/myapp
140 | Restart=on-failure
141 |
142 | [Install]
143 | WantedBy=multi-user.target
144 | ```
145 |
146 | You can tune the values below, and once you're happy with them, you can run:
147 |
148 | ```bash
149 | # need to run this every time you change the unit file
150 | sudo systemctl daemon-reload
151 |
152 | # start/restart/stop/etc:
153 | sudo systemctl start myapp
154 | ```
155 |
156 | If you use a new user for each app (which is good), you'll need to create them
157 | first:
158 |
159 | ```bash
160 | useradd --system --user-group --create-home myapp
161 | ```
162 |
163 | That should do it.
164 |
165 | ###
166 |
167 | ## Feedback
168 |
169 | We’d love to hear your thoughts on this project. Feel free to drop us a note!
170 |
171 | * [Twitter](https://twitter.com/charmcli)
172 | * [The Fediverse](https://mastodon.social/@charmcli)
173 | * [Discord](https://charm.sh/chat)
174 |
175 | ## License
176 |
177 | [MIT](https://github.com/charmbracelet/wish/raw/main/LICENSE)
178 |
179 | ***
180 |
181 | Part of [Charm](https://charm.sh).
182 |
183 |
184 |
185 | Charm热爱开源 • Charm loves open source
186 |
--------------------------------------------------------------------------------
/accesscontrol/accesscontrol.go:
--------------------------------------------------------------------------------
1 | // Package accesscontrol provides a middleware that allows you to restrict the commands the user can execute.
2 | package accesscontrol
3 |
4 | import (
5 | "fmt"
6 |
7 | "github.com/charmbracelet/ssh"
8 | "github.com/charmbracelet/wish"
9 | )
10 |
11 | // Middleware will exit 1 connections trying to execute commands that are not allowed.
12 | // If no allowed commands are provided, no commands will be allowed.
13 | func Middleware(cmds ...string) wish.Middleware {
14 | return func(sh ssh.Handler) ssh.Handler {
15 | return func(s ssh.Session) {
16 | if len(s.Command()) == 0 {
17 | sh(s)
18 | return
19 | }
20 | for _, cmd := range cmds {
21 | if s.Command()[0] == cmd {
22 | sh(s)
23 | return
24 | }
25 | }
26 | _, _ = fmt.Fprintln(s, "Command is not allowed: "+s.Command()[0])
27 | s.Exit(1) // nolint: errcheck
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/accesscontrol/accesscontrol_test.go:
--------------------------------------------------------------------------------
1 | package accesscontrol_test
2 |
3 | import (
4 | "fmt"
5 | "testing"
6 |
7 | "github.com/charmbracelet/ssh"
8 | "github.com/charmbracelet/wish/accesscontrol"
9 | "github.com/charmbracelet/wish/testsession"
10 | gossh "golang.org/x/crypto/ssh"
11 | )
12 |
13 | const out = "hello world"
14 |
15 | func TestMiddleware(t *testing.T) {
16 | requireDenied := func(tb testing.TB, s, cmd string) {
17 | tb.Helper()
18 | expected := fmt.Sprintf("Command is not allowed: %s\n", cmd)
19 | if s != expected {
20 | t.Errorf("expected %q, got %q", expected, s)
21 | }
22 | }
23 |
24 | requireOutput := func(tb testing.TB, s string) {
25 | tb.Helper()
26 | if out != s {
27 | t.Errorf("expected %q, got %q", out, s)
28 | }
29 | }
30 |
31 | t.Run("no allowed cmds no cmd", func(t *testing.T) {
32 | out, err := setup(t).Output("")
33 | if err != nil {
34 | t.Error(err)
35 | }
36 | requireOutput(t, string(out))
37 | })
38 |
39 | t.Run("no allowed cmds with cmd", func(t *testing.T) {
40 | out, err := setup(t).Output("echo")
41 | if err == nil {
42 | t.Errorf("should have errored")
43 | }
44 | requireDenied(t, string(out), "echo")
45 | })
46 |
47 | t.Run("allowed cmds no cmd", func(t *testing.T) {
48 | out, err := setup(t, "echo").Output("")
49 | if err != nil {
50 | t.Error(err)
51 | }
52 | requireOutput(t, string(out))
53 | })
54 |
55 | t.Run("allowed cmds with allowed cmd", func(t *testing.T) {
56 | out, err := setup(t, "echo").Output("echo")
57 | if err != nil {
58 | t.Error(err)
59 | }
60 | requireOutput(t, string(out))
61 | })
62 |
63 | t.Run("allowed cmds with disallowed cmd", func(t *testing.T) {
64 | out, err := setup(t, "echo").Output("cat")
65 | if err == nil {
66 | t.Error(err)
67 | }
68 | requireDenied(t, string(out), "cat")
69 | })
70 |
71 | t.Run("allowed cmds with allowed cmd followed disallowed cmd", func(t *testing.T) {
72 | out, err := setup(t, "echo").Output("cat echo")
73 | if err == nil {
74 | t.Error(err)
75 | }
76 | requireDenied(t, string(out), "cat")
77 | })
78 | }
79 |
80 | func setup(tb testing.TB, allowedCmds ...string) *gossh.Session {
81 | tb.Helper()
82 | return testsession.New(tb, &ssh.Server{
83 | Handler: accesscontrol.Middleware(allowedCmds...)(func(s ssh.Session) {
84 | s.Write([]byte(out))
85 | }),
86 | }, nil)
87 | }
88 |
--------------------------------------------------------------------------------
/activeterm/activeterm.go:
--------------------------------------------------------------------------------
1 | // Package activeterm provides a middleware to block inactive PTYs.
2 | package activeterm
3 |
4 | import (
5 | "github.com/charmbracelet/ssh"
6 | "github.com/charmbracelet/wish"
7 | )
8 |
9 | // Middleware will exit 1 connections trying with no active terminals.
10 | func Middleware() wish.Middleware {
11 | return func(next ssh.Handler) ssh.Handler {
12 | return func(sess ssh.Session) {
13 | _, _, active := sess.Pty()
14 | if active {
15 | next(sess)
16 | return
17 | }
18 | wish.Println(sess, "Requires an active PTY")
19 | _ = sess.Exit(1)
20 | }
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/activeterm/activeterm_test.go:
--------------------------------------------------------------------------------
1 | package activeterm_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/charmbracelet/ssh"
7 | "github.com/charmbracelet/wish/activeterm"
8 | "github.com/charmbracelet/wish/testsession"
9 | gossh "golang.org/x/crypto/ssh"
10 | )
11 |
12 | func TestMiddleware(t *testing.T) {
13 | t.Run("inactive term", func(t *testing.T) {
14 | out, err := setup(t).Output("")
15 | if err == nil {
16 | t.Errorf("tests should be an inactive pty")
17 | }
18 | if string(out) != "Requires an active PTY\n" {
19 | t.Errorf("invalid output: %q", string(out))
20 | }
21 | })
22 | }
23 |
24 | func setup(tb testing.TB) *gossh.Session {
25 | tb.Helper()
26 | return testsession.New(tb, &ssh.Server{
27 | Handler: activeterm.Middleware()(func(s ssh.Session) {
28 | s.Write([]byte("hello"))
29 | }),
30 | }, nil)
31 | }
32 |
--------------------------------------------------------------------------------
/bubbletea/query.go:
--------------------------------------------------------------------------------
1 | package bubbletea
2 |
3 | import (
4 | "image/color"
5 | "io"
6 | "time"
7 |
8 | "github.com/charmbracelet/x/ansi"
9 | "github.com/charmbracelet/x/input"
10 | )
11 |
12 | // queryBackgroundColor queries the terminal for the background color.
13 | // If the terminal does not support querying the background color, nil is
14 | // returned.
15 | //
16 | // Note: you will need to set the input to raw mode before calling this
17 | // function.
18 | //
19 | // state, _ := term.MakeRaw(in.Fd())
20 | // defer term.Restore(in.Fd(), state)
21 | //
22 | // copied from x/term@v0.1.3.
23 | func queryBackgroundColor(in io.Reader, out io.Writer) (c color.Color, err error) {
24 | // nolint: errcheck
25 | err = queryTerminal(in, out, defaultQueryTimeout,
26 | func(events []input.Event) bool {
27 | for _, e := range events {
28 | switch e := e.(type) {
29 | case input.BackgroundColorEvent:
30 | c = e.Color
31 | continue // we need to consume the next DA1 event
32 | case input.PrimaryDeviceAttributesEvent:
33 | return false
34 | }
35 | }
36 | return true
37 | }, ansi.RequestBackgroundColor+ansi.RequestPrimaryDeviceAttributes)
38 | return
39 | }
40 |
41 | const defaultQueryTimeout = time.Second * 2
42 |
43 | // QueryTerminalFilter is a function that filters input events using a type
44 | // switch. If false is returned, the QueryTerminal function will stop reading
45 | // input.
46 | type QueryTerminalFilter func(events []input.Event) bool
47 |
48 | // queryTerminal queries the terminal for support of various features and
49 | // returns a list of response events.
50 | // Most of the time, you will need to set stdin to raw mode before calling this
51 | // function.
52 | // Note: This function will block until the terminal responds or the timeout
53 | // is reached.
54 | // copied from x/term@v0.1.3.
55 | func queryTerminal(
56 | in io.Reader,
57 | out io.Writer,
58 | timeout time.Duration,
59 | filter QueryTerminalFilter,
60 | query string,
61 | ) error {
62 | rd, err := input.NewReader(in, "", 0)
63 | if err != nil {
64 | return err
65 | }
66 |
67 | defer rd.Close() // nolint: errcheck
68 |
69 | done := make(chan struct{}, 1)
70 | defer close(done)
71 | go func() {
72 | select {
73 | case <-done:
74 | case <-time.After(timeout):
75 | rd.Cancel()
76 | }
77 | }()
78 |
79 | if _, err := io.WriteString(out, query); err != nil {
80 | return err
81 | }
82 |
83 | for {
84 | events, err := rd.ReadEvents()
85 | if err != nil {
86 | return err
87 | }
88 |
89 | if !filter(events) {
90 | break
91 | }
92 | }
93 |
94 | return nil
95 | }
96 |
--------------------------------------------------------------------------------
/bubbletea/tea.go:
--------------------------------------------------------------------------------
1 | // Package bubbletea provides middleware for serving bubbletea apps over SSH.
2 | package bubbletea
3 |
4 | import (
5 | "context"
6 | "fmt"
7 | "strings"
8 |
9 | tea "github.com/charmbracelet/bubbletea"
10 | "github.com/charmbracelet/lipgloss"
11 | "github.com/charmbracelet/log"
12 | "github.com/charmbracelet/ssh"
13 | "github.com/charmbracelet/wish"
14 | "github.com/muesli/termenv"
15 | )
16 |
17 | // BubbleTeaHandler is the function Bubble Tea apps implement to hook into the
18 | // SSH Middleware. This will create a new tea.Program for every connection and
19 | // start it with the tea.ProgramOptions returned.
20 | //
21 | // Deprecated: use Handler instead.
22 | type BubbleTeaHandler = Handler // nolint: revive
23 |
24 | // Handler is the function Bubble Tea apps implement to hook into the
25 | // SSH Middleware. This will create a new tea.Program for every connection and
26 | // start it with the tea.ProgramOptions returned.
27 | type Handler func(sess ssh.Session) (tea.Model, []tea.ProgramOption)
28 |
29 | // ProgramHandler is the function Bubble Tea apps implement to hook into the SSH
30 | // Middleware. This should return a new tea.Program. This handler is different
31 | // from the default handler in that it returns a tea.Program instead of
32 | // (tea.Model, tea.ProgramOptions).
33 | //
34 | // Make sure to set the tea.WithInput and tea.WithOutput to the ssh.Session
35 | // otherwise the program will not function properly.
36 | type ProgramHandler func(sess ssh.Session) *tea.Program
37 |
38 | // Middleware takes a Handler and hooks the input and output for the
39 | // ssh.Session into the tea.Program.
40 | //
41 | // It also captures window resize events and sends them to the tea.Program
42 | // as tea.WindowSizeMsgs.
43 | func Middleware(handler Handler) wish.Middleware {
44 | return MiddlewareWithProgramHandler(newDefaultProgramHandler(handler), termenv.Ascii)
45 | }
46 |
47 | // MiddlewareWithColorProfile allows you to specify the minimum number of colors
48 | // this program needs to work properly.
49 | //
50 | // If the client's color profile has less colors than p, p will be forced.
51 | // Use with caution.
52 | func MiddlewareWithColorProfile(handler Handler, profile termenv.Profile) wish.Middleware {
53 | return MiddlewareWithProgramHandler(newDefaultProgramHandler(handler), profile)
54 | }
55 |
56 | // MiddlewareWithProgramHandler allows you to specify the ProgramHandler to be
57 | // able to access the underlying tea.Program, and the minimum supported color
58 | // profile.
59 | //
60 | // This is useful for creating custom middlewares that need access to
61 | // tea.Program for instance to use p.Send() to send messages to tea.Program.
62 | //
63 | // Make sure to set the tea.WithInput and tea.WithOutput to the ssh.Session
64 | // otherwise the program will not function properly. The recommended way
65 | // of doing so is by using MakeOptions.
66 | //
67 | // If the client's color profile has less colors than p, p will be forced.
68 | // Use with caution.
69 | func MiddlewareWithProgramHandler(handler ProgramHandler, profile termenv.Profile) wish.Middleware {
70 | return func(next ssh.Handler) ssh.Handler {
71 | return func(sess ssh.Session) {
72 | sess.Context().SetValue(minColorProfileKey, profile)
73 | program := handler(sess)
74 | if program == nil {
75 | next(sess)
76 | return
77 | }
78 | _, windowChanges, ok := sess.Pty()
79 | if !ok {
80 | wish.Fatalln(sess, "no active terminal, skipping")
81 | return
82 | }
83 | ctx, cancel := context.WithCancel(sess.Context())
84 | go func() {
85 | for {
86 | select {
87 | case <-ctx.Done():
88 | program.Quit()
89 | return
90 | case w := <-windowChanges:
91 | program.Send(tea.WindowSizeMsg{Width: w.Width, Height: w.Height})
92 | }
93 | }
94 | }()
95 | if _, err := program.Run(); err != nil {
96 | log.Error("app exit with error", "error", err)
97 | }
98 | // p.Kill() will force kill the program if it's still running,
99 | // and restore the terminal to its original state in case of a
100 | // tui crash
101 | program.Kill()
102 | cancel()
103 | next(sess)
104 | }
105 | }
106 | }
107 |
108 | var minColorProfileKey struct{}
109 |
110 | var profileNames = [4]string{"TrueColor", "ANSI256", "ANSI", "Ascii"}
111 |
112 | // MakeRenderer returns a lipgloss renderer for the current session.
113 | // This function handle PTYs as well, and should be used to style your application.
114 | func MakeRenderer(sess ssh.Session) *lipgloss.Renderer {
115 | cp, ok := sess.Context().Value(minColorProfileKey).(termenv.Profile)
116 | if !ok {
117 | cp = termenv.Ascii
118 | }
119 |
120 | r := newRenderer(sess)
121 |
122 | // We only force the color profile if the requested session is a PTY.
123 | _, _, ok = sess.Pty()
124 | if !ok {
125 | return r
126 | }
127 |
128 | if r.ColorProfile() > cp {
129 | _, _ = fmt.Fprintf(sess.Stderr(), "Warning: Client's terminal is %q, forcing %q\r\n",
130 | profileNames[r.ColorProfile()], profileNames[cp])
131 | r.SetColorProfile(cp)
132 | }
133 | return r
134 | }
135 |
136 | // MakeOptions returns the tea.WithInput and tea.WithOutput program options
137 | // taking into account possible Emulated or Allocated PTYs.
138 | func MakeOptions(sess ssh.Session) []tea.ProgramOption {
139 | return makeOpts(sess)
140 | }
141 |
142 | type sshEnviron []string
143 |
144 | var _ termenv.Environ = sshEnviron(nil)
145 |
146 | // Environ implements termenv.Environ.
147 | func (e sshEnviron) Environ() []string {
148 | return e
149 | }
150 |
151 | // Getenv implements termenv.Environ.
152 | func (e sshEnviron) Getenv(k string) string {
153 | for _, v := range e {
154 | if strings.HasPrefix(v, k+"=") {
155 | return v[len(k)+1:]
156 | }
157 | }
158 | return ""
159 | }
160 |
161 | func newDefaultProgramHandler(handler Handler) ProgramHandler {
162 | return func(s ssh.Session) *tea.Program {
163 | m, opts := handler(s)
164 | if m == nil {
165 | return nil
166 | }
167 | return tea.NewProgram(m, append(opts, makeOpts(s)...)...)
168 | }
169 | }
170 |
--------------------------------------------------------------------------------
/bubbletea/tea_other.go:
--------------------------------------------------------------------------------
1 | //go:build !linux && !darwin && !freebsd && !dragonfly && !netbsd && !openbsd && !solaris
2 | // +build !linux,!darwin,!freebsd,!dragonfly,!netbsd,!openbsd,!solaris
3 |
4 | package bubbletea
5 |
6 | import (
7 | tea "github.com/charmbracelet/bubbletea"
8 | "github.com/charmbracelet/lipgloss"
9 | "github.com/charmbracelet/ssh"
10 | "github.com/muesli/termenv"
11 | )
12 |
13 | func makeOpts(s ssh.Session) []tea.ProgramOption {
14 | return []tea.ProgramOption{
15 | tea.WithInput(s),
16 | tea.WithOutput(s),
17 | }
18 | }
19 |
20 | func newRenderer(s ssh.Session) *lipgloss.Renderer {
21 | pty, _, _ := s.Pty()
22 | env := sshEnviron(append(s.Environ(), "TERM="+pty.Term))
23 | return lipgloss.NewRenderer(s, termenv.WithEnvironment(env), termenv.WithUnsafe(), termenv.WithColorCache(true))
24 | }
25 |
--------------------------------------------------------------------------------
/bubbletea/tea_unix.go:
--------------------------------------------------------------------------------
1 | //go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
2 | // +build darwin dragonfly freebsd linux netbsd openbsd solaris
3 |
4 | package bubbletea
5 |
6 | import (
7 | "image/color"
8 | "time"
9 |
10 | tea "github.com/charmbracelet/bubbletea"
11 | "github.com/charmbracelet/lipgloss"
12 | "github.com/charmbracelet/ssh"
13 | "github.com/charmbracelet/x/ansi"
14 | "github.com/charmbracelet/x/input"
15 | "github.com/charmbracelet/x/term"
16 | "github.com/lucasb-eyer/go-colorful"
17 | "github.com/muesli/termenv"
18 | )
19 |
20 | func makeOpts(s ssh.Session) []tea.ProgramOption {
21 | pty, _, ok := s.Pty()
22 | if !ok || s.EmulatedPty() {
23 | return []tea.ProgramOption{
24 | tea.WithInput(s),
25 | tea.WithOutput(s),
26 | }
27 | }
28 |
29 | return []tea.ProgramOption{
30 | tea.WithInput(pty.Slave),
31 | tea.WithOutput(pty.Slave),
32 | }
33 | }
34 |
35 | func newRenderer(s ssh.Session) *lipgloss.Renderer {
36 | pty, _, ok := s.Pty()
37 | if !ok || pty.Term == "" || pty.Term == "dumb" {
38 | return lipgloss.NewRenderer(s, termenv.WithProfile(termenv.Ascii))
39 | }
40 | env := sshEnviron(append(s.Environ(), "TERM="+pty.Term))
41 | var r *lipgloss.Renderer
42 | var bg color.Color
43 | if ok && pty.Slave != nil {
44 | r = lipgloss.NewRenderer(
45 | pty.Slave,
46 | termenv.WithEnvironment(env),
47 | termenv.WithColorCache(true),
48 | )
49 | state, err := term.MakeRaw(pty.Slave.Fd())
50 | if err == nil {
51 | bg, _ = queryBackgroundColor(pty.Slave, pty.Slave)
52 | _ = term.Restore(pty.Slave.Fd(), state)
53 | }
54 | } else {
55 | r = lipgloss.NewRenderer(
56 | s,
57 | termenv.WithEnvironment(env),
58 | termenv.WithUnsafe(),
59 | termenv.WithColorCache(true),
60 | )
61 | bg = querySessionBackgroundColor(s)
62 | }
63 | if bg != nil {
64 | c, ok := colorful.MakeColor(bg)
65 | if ok {
66 | _, _, l := c.Hsl()
67 | r.SetHasDarkBackground(l < 0.5)
68 | }
69 | }
70 | return r
71 | }
72 |
73 | // copied from x/term@v0.1.3.
74 | func querySessionBackgroundColor(s ssh.Session) (bg color.Color) {
75 | _ = queryTerminal(s, s, time.Second, func(events []input.Event) bool {
76 | for _, e := range events {
77 | switch e := e.(type) {
78 | case input.BackgroundColorEvent:
79 | bg = e.Color
80 | continue // we need to consume the next DA1 event
81 | case input.PrimaryDeviceAttributesEvent:
82 | return false
83 | }
84 | }
85 | return true
86 | }, ansi.RequestBackgroundColor+ansi.RequestPrimaryDeviceAttributes)
87 | return
88 | }
89 |
--------------------------------------------------------------------------------
/cmd.go:
--------------------------------------------------------------------------------
1 | package wish
2 |
3 | import (
4 | "context"
5 | "io"
6 | "os/exec"
7 |
8 | tea "github.com/charmbracelet/bubbletea"
9 | "github.com/charmbracelet/ssh"
10 | )
11 |
12 | // CommandContext is like Command but includes a context.
13 | //
14 | // If the current session does not have a PTY, it sets them to the session
15 | // itself.
16 | func CommandContext(ctx context.Context, s ssh.Session, name string, args ...string) *Cmd {
17 | cmd := exec.CommandContext(ctx, name, args...)
18 | return &Cmd{s, cmd}
19 | }
20 |
21 | // Command sets stdin, stdout, and stderr to the current session's PTY.
22 | //
23 | // If the current session does not have a PTY, it sets them to the session
24 | // itself.
25 | //
26 | // This will use the session's context as the context for exec.Command.
27 | func Command(s ssh.Session, name string, args ...string) *Cmd {
28 | return CommandContext(s.Context(), s, name, args...)
29 | }
30 |
31 | // Cmd wraps a *exec.Cmd and a ssh.Pty so a command can be properly run.
32 | type Cmd struct {
33 | sess ssh.Session
34 | cmd *exec.Cmd
35 | }
36 |
37 | // SetDir set the underlying exec.Cmd env.
38 | func (c *Cmd) SetEnv(env []string) {
39 | c.cmd.Env = env
40 | }
41 |
42 | // Environ returns the underlying exec.Cmd environment.
43 | func (c *Cmd) Environ() []string {
44 | return c.cmd.Environ()
45 | }
46 |
47 | // SetDir set the underlying exec.Cmd dir.
48 | func (c *Cmd) SetDir(dir string) {
49 | c.cmd.Dir = dir
50 | }
51 |
52 | // Run runs the program and waits for it to finish.
53 | func (c *Cmd) Run() error {
54 | ppty, winCh, ok := c.sess.Pty()
55 | if !ok {
56 | c.cmd.Stdin, c.cmd.Stdout, c.cmd.Stderr = c.sess, c.sess, c.sess.Stderr()
57 | return c.cmd.Run()
58 | }
59 | return c.doRun(ppty, winCh)
60 | }
61 |
62 | var _ tea.ExecCommand = &Cmd{}
63 |
64 | // SetStderr conforms with tea.ExecCommand.
65 | func (*Cmd) SetStderr(io.Writer) {}
66 |
67 | // SetStdin conforms with tea.ExecCommand.
68 | func (*Cmd) SetStdin(io.Reader) {}
69 |
70 | // SetStdout conforms with tea.ExecCommand.
71 | func (*Cmd) SetStdout(io.Writer) {}
72 |
--------------------------------------------------------------------------------
/cmd_test.go:
--------------------------------------------------------------------------------
1 | package wish
2 |
3 | import (
4 | "bytes"
5 | "runtime"
6 | "strings"
7 | "testing"
8 | "time"
9 |
10 | "github.com/charmbracelet/ssh"
11 | "github.com/charmbracelet/wish/testsession"
12 | )
13 |
14 | func TestCommandNoPty(t *testing.T) {
15 | tmp := t.TempDir()
16 | sess := testsession.New(t, &ssh.Server{
17 | Handler: func(s ssh.Session) {
18 | runEcho(s, "hello")
19 | runEnv(s, []string{"HELLO=world"})
20 | runPwd(s, tmp)
21 | },
22 | }, nil)
23 | var stdout bytes.Buffer
24 | var stderr bytes.Buffer
25 | sess.Stdout = &stdout
26 | sess.Stderr = &stderr
27 | if err := sess.Run(""); err != nil {
28 | t.Errorf("expected no error, got %v: %s", err, stderr.String())
29 | }
30 | out := stdout.String()
31 | expectContains(t, out, "hello")
32 | expectContains(t, out, "HELLO=world")
33 | expectContains(t, out, tmp)
34 | }
35 |
36 | func TestCommandPty(t *testing.T) {
37 | tmp := t.TempDir()
38 | srv := &ssh.Server{
39 | Handler: func(s ssh.Session) {
40 | runEcho(s, "hello")
41 | runEnv(s, []string{"HELLO=world"})
42 | runPwd(s, tmp)
43 | // for some reason sometimes on macos github action runners,
44 | // it cuts parts of the output.
45 | time.Sleep(100 * time.Millisecond)
46 | },
47 | }
48 | if err := ssh.AllocatePty()(srv); err != nil {
49 | t.Fatalf("expected no error, got %v", err)
50 | }
51 |
52 | sess := testsession.New(t, srv, nil)
53 | if err := sess.RequestPty("xterm", 500, 200, nil); err != nil {
54 | t.Fatalf("expected no error, got %v", err)
55 | }
56 |
57 | var stdout bytes.Buffer
58 | var stderr bytes.Buffer
59 | sess.Stdout = &stdout
60 | sess.Stderr = &stderr
61 | if err := sess.Run(""); err != nil {
62 | t.Errorf("expected no error, got %v: %s", err, stderr.String())
63 | }
64 | out := stdout.String()
65 | expectContains(t, out, "hello")
66 | expectContains(t, out, "HELLO=world")
67 | expectContains(t, out, tmp)
68 | }
69 |
70 | func TestCommandPtyError(t *testing.T) {
71 | if runtime.GOOS == "windows" {
72 | t.Skip()
73 | }
74 | srv := &ssh.Server{
75 | Handler: func(s ssh.Session) {
76 | if err := Command(s, "nopenopenope").Run(); err != nil {
77 | Fatal(s, err)
78 | }
79 | },
80 | }
81 | if err := ssh.AllocatePty()(srv); err != nil {
82 | t.Fatalf("expected no error, got %v", err)
83 | }
84 |
85 | sess := testsession.New(t, srv, nil)
86 | if err := sess.RequestPty("xterm", 500, 200, nil); err != nil {
87 | t.Fatalf("expected no error, got %v", err)
88 | }
89 |
90 | var stderr bytes.Buffer
91 | sess.Stderr = &stderr
92 | if err := sess.Run(""); err == nil {
93 | t.Errorf("expected an error, got nil")
94 | }
95 | expect := `exec: "nopenopenope"`
96 | if s := stderr.String(); !strings.Contains(s, expect) {
97 | t.Errorf("expected output to contain %q, got %q", expect, s)
98 | }
99 | }
100 |
101 | func runEcho(s ssh.Session, str string) {
102 | cmd := Command(s, "echo", str)
103 | if runtime.GOOS == "windows" {
104 | cmd = Command(s, "cmd", "/C", "echo", str)
105 | }
106 | // these should do nothing...
107 | cmd.SetStderr(nil)
108 | cmd.SetStdin(nil)
109 | cmd.SetStdout(nil)
110 | if err := cmd.Run(); err != nil {
111 | Fatal(s, err)
112 | }
113 | }
114 |
115 | func runEnv(s ssh.Session, env []string) {
116 | cmd := Command(s, "env")
117 | if runtime.GOOS == "windows" {
118 | cmd = Command(s, "cmd", "/C", "set")
119 | }
120 | cmd.SetEnv(env)
121 | if err := cmd.Run(); err != nil {
122 | Fatal(s, err)
123 | }
124 | if len(cmd.Environ()) == 0 {
125 | Fatal(s, "cmd.Environ() should not be empty")
126 | }
127 | }
128 |
129 | func runPwd(s ssh.Session, dir string) {
130 | cmd := Command(s, "pwd")
131 | if runtime.GOOS == "windows" {
132 | cmd = Command(s, "cmd", "/C", "cd")
133 | }
134 | cmd.SetDir(dir)
135 | if err := cmd.Run(); err != nil {
136 | Fatal(s, err)
137 | }
138 | }
139 |
140 | func expectContains(tb testing.TB, s, substr string) {
141 | if !strings.Contains(s, substr) {
142 | tb.Errorf("expected output %q to contain %q", s, substr)
143 | }
144 | }
145 |
--------------------------------------------------------------------------------
/cmd_unix.go:
--------------------------------------------------------------------------------
1 | //go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
2 | // +build darwin dragonfly freebsd linux netbsd openbsd solaris
3 |
4 | package wish
5 |
6 | import "github.com/charmbracelet/ssh"
7 |
8 | func (c *Cmd) doRun(ppty ssh.Pty, _ <-chan ssh.Window) error {
9 | if err := ppty.Start(c.cmd); err != nil {
10 | return err
11 | }
12 | return c.cmd.Wait()
13 | }
14 |
--------------------------------------------------------------------------------
/cmd_windows.go:
--------------------------------------------------------------------------------
1 | //go:build windows
2 | // +build windows
3 |
4 | package wish
5 |
6 | import (
7 | "fmt"
8 | "time"
9 |
10 | "github.com/charmbracelet/ssh"
11 | )
12 |
13 | func (c *Cmd) doRun(ppty ssh.Pty, _ <-chan ssh.Window) error {
14 | if err := ppty.Start(c.cmd); err != nil {
15 | return err
16 | }
17 |
18 | start := time.Now()
19 | for c.cmd.ProcessState == nil {
20 | if time.Since(start) > time.Second*10 {
21 | return fmt.Errorf("could not start process")
22 | }
23 | time.Sleep(100 * time.Millisecond)
24 | }
25 | if !c.cmd.ProcessState.Success() {
26 | return fmt.Errorf("process failed: exit %d", c.cmd.ProcessState.ExitCode())
27 | }
28 | return nil
29 | }
30 |
--------------------------------------------------------------------------------
/comment/comment.go:
--------------------------------------------------------------------------------
1 | package comment
2 |
3 | import (
4 | "github.com/charmbracelet/ssh"
5 | "github.com/charmbracelet/wish"
6 | )
7 |
8 | // Middleware prints a comment at the end of the session.
9 | func Middleware(comment string) wish.Middleware {
10 | return func(sh ssh.Handler) ssh.Handler {
11 | return func(s ssh.Session) {
12 | sh(s)
13 | wish.Println(s, comment)
14 | }
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/comment/comment_test.go:
--------------------------------------------------------------------------------
1 | package comment
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/charmbracelet/ssh"
7 | "github.com/charmbracelet/wish/testsession"
8 | gossh "golang.org/x/crypto/ssh"
9 | )
10 |
11 | func TestMiddleware(t *testing.T) {
12 | t.Run("recover session", func(t *testing.T) {
13 | b, err := setup(t).Output("")
14 | requireNoError(t, err)
15 | if string(b) != "test\n" {
16 | t.Errorf("expected comment to be 'test', got %q", string(b))
17 | }
18 | })
19 | }
20 |
21 | func setup(tb testing.TB) *gossh.Session {
22 | tb.Helper()
23 | return testsession.New(tb, &ssh.Server{
24 | Handler: Middleware("test")(func(s ssh.Session) {}),
25 | }, nil)
26 | }
27 |
28 | func requireNoError(t *testing.T, err error) {
29 | t.Helper()
30 |
31 | if err != nil {
32 | t.Fatalf("expected no error, got %q", err.Error())
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/elapsed/elapsed.go:
--------------------------------------------------------------------------------
1 | package elapsed
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/charmbracelet/ssh"
7 | "github.com/charmbracelet/wish"
8 | )
9 |
10 | // MiddlewareWithFormat returns a middleware that logs the elapsed time of the
11 | // session. It accepts a format string to print the elapsed time.
12 | //
13 | // In order to provide an accurate elapsed time for the entire session,
14 | // this must be called as the last middleware in the chain.
15 | func MiddlewareWithFormat(format string) wish.Middleware {
16 | return func(sh ssh.Handler) ssh.Handler {
17 | return func(s ssh.Session) {
18 | now := time.Now()
19 | sh(s)
20 | wish.Printf(s, format, time.Since(now))
21 | }
22 | }
23 | }
24 |
25 | // Middleware returns a middleware that logs the elapsed time of the session.
26 | //
27 | // In order to provide an accurate elapsed time for the entire session,
28 | // this must be called as the last middleware in the chain.
29 | func Middleware() wish.Middleware {
30 | return MiddlewareWithFormat("elapsed time: %v\n")
31 | }
32 |
--------------------------------------------------------------------------------
/elapsed/elapsed_test.go:
--------------------------------------------------------------------------------
1 | package elapsed
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/charmbracelet/ssh"
8 | "github.com/charmbracelet/wish/testsession"
9 | gossh "golang.org/x/crypto/ssh"
10 | )
11 |
12 | var waitDuration = time.Second
13 |
14 | func TestMiddleware(t *testing.T) {
15 | t.Run("recover session", func(t *testing.T) {
16 | b, err := setup(t).Output("")
17 | requireNoError(t, err)
18 | dur, err := time.ParseDuration(string(b))
19 | requireNoError(t, err)
20 | if dur < waitDuration {
21 | t.Errorf("expected elapsed time to be at least 1s, got %v", dur)
22 | }
23 | })
24 | }
25 |
26 | func setup(tb testing.TB) *gossh.Session {
27 | tb.Helper()
28 | return testsession.New(tb, &ssh.Server{
29 | Handler: MiddlewareWithFormat("%v")(func(s ssh.Session) {
30 | time.Sleep(waitDuration)
31 | }),
32 | }, nil)
33 | }
34 |
35 | func requireNoError(t *testing.T, err error) {
36 | t.Helper()
37 |
38 | if err != nil {
39 | t.Fatalf("expected no error, got %q", err.Error())
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/examples/.gitignore:
--------------------------------------------------------------------------------
1 | id_ed25519*
2 | file.txt
3 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Wish Examples
2 |
3 | We recommend you follow the examples in the following order:
4 |
5 | ## Basics
6 |
7 | 1. [Simple](./simple)
8 | 1. [Graceful Shutdown](./graceful-shutdown)
9 | 1. [Server banner and middleware](./banner)
10 | 1. [Identifying Users](./identity)
11 | 1. [Multiple authentication types](./multi-auth)
12 |
13 | ## Making SSH apps
14 |
15 | 1. [Using spf13/cobra](./cobra)
16 | 1. [Serving Bubble Tea apps](./bubbletea)
17 | 1. [Serving Bubble Tea programs](./bubbleteaprogram)
18 | 1. [Reverse Port Forwarding](./forward)
19 | 1. [Multichat](./multichat)
20 |
21 | ## SCP, SFTP, and Git
22 |
23 | 1. [Serving a Git repository](./git)
24 | 1. [SCP and SFTP](./scp)
25 |
26 | ## Pseudo Terminals
27 |
28 | 1. [Allocate a PTY](./pty)
29 | 1. [Running Bubble Tea, and executing another program on an allocated PTY](./bubbletea-exec)
30 |
--------------------------------------------------------------------------------
/examples/banner/banner.txt:
--------------------------------------------------------------------------------
1 |
2 | Hello %s, welcome to a SSH server powered by
3 | __ ___ _ _ _ _
4 | \ \ / (_)___| |__ | | | |
5 | \ \ /\ / /| / __| '_ \| | | |
6 | \ V V / | \__ \ | | |_|_|_|
7 | \_/\_/ |_|___/_| |_(_|_|_)
8 |
9 |
10 | PS: The password is "asd123".
11 | PPS: Visit https://charm.sh to learn more!
12 |
13 | ---
14 |
15 |
--------------------------------------------------------------------------------
/examples/banner/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "os"
9 | "os/signal"
10 | "syscall"
11 | "time"
12 |
13 | _ "embed"
14 |
15 | "github.com/charmbracelet/log"
16 | "github.com/charmbracelet/ssh"
17 | "github.com/charmbracelet/wish"
18 | "github.com/charmbracelet/wish/elapsed"
19 | "github.com/charmbracelet/wish/logging"
20 | )
21 |
22 | const (
23 | host = "localhost"
24 | port = "23234"
25 | )
26 |
27 | //go:embed banner.txt
28 | var banner string
29 |
30 | func main() {
31 | s, err := wish.NewServer(
32 | wish.WithAddress(net.JoinHostPort(host, port)),
33 | wish.WithHostKeyPath(".ssh/id_ed25519"),
34 | // A banner is always shown, even before authentication.
35 | wish.WithBannerHandler(func(ctx ssh.Context) string {
36 | return fmt.Sprintf(banner, ctx.User())
37 | }),
38 | wish.WithPasswordAuth(func(ctx ssh.Context, password string) bool {
39 | return password == "asd123"
40 | }),
41 | wish.WithMiddleware(
42 | func(next ssh.Handler) ssh.Handler {
43 | return func(sess ssh.Session) {
44 | wish.Println(sess, fmt.Sprintf("Hello, %s!", sess.User()))
45 | next(sess)
46 | }
47 | },
48 | logging.Middleware(),
49 | // This middleware prints the session duration before disconnecting.
50 | elapsed.Middleware(),
51 | ),
52 | )
53 | if err != nil {
54 | log.Error("Could not start server", "error", err)
55 | }
56 |
57 | done := make(chan os.Signal, 1)
58 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
59 | log.Info("Starting SSH server", "host", host, "port", port)
60 | go func() {
61 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
62 | log.Error("Could not start server", "error", err)
63 | done <- nil
64 | }
65 | }()
66 |
67 | <-done
68 | log.Info("Stopping SSH server")
69 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
70 | defer func() { cancel() }()
71 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
72 | log.Error("Could not stop server", "error", err)
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/examples/bubbletea-exec/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "os"
8 | "os/signal"
9 | "runtime"
10 | "syscall"
11 | "time"
12 |
13 | tea "github.com/charmbracelet/bubbletea"
14 | "github.com/charmbracelet/lipgloss"
15 | "github.com/charmbracelet/log"
16 | "github.com/charmbracelet/ssh"
17 | "github.com/charmbracelet/wish"
18 | "github.com/charmbracelet/wish/activeterm"
19 | "github.com/charmbracelet/wish/bubbletea"
20 | "github.com/charmbracelet/wish/logging"
21 | "github.com/charmbracelet/x/editor"
22 | )
23 |
24 | const (
25 | host = "localhost"
26 | port = "23234"
27 | )
28 |
29 | func main() {
30 | s, err := wish.NewServer(
31 | wish.WithAddress(net.JoinHostPort(host, port)),
32 |
33 | // Allocate a pty.
34 | // This creates a pseudoconsole on windows, compatibility is limited in
35 | // that case, see the open issues for more details.
36 | ssh.AllocatePty(),
37 | wish.WithMiddleware(
38 | // run our Bubble Tea handler
39 | bubbletea.Middleware(teaHandler),
40 |
41 | // ensure the user has requested a tty
42 | activeterm.Middleware(),
43 | logging.Middleware(),
44 | ),
45 | )
46 | if err != nil {
47 | log.Error("Could not start server", "error", err)
48 | }
49 |
50 | done := make(chan os.Signal, 1)
51 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
52 | log.Info("Starting SSH server", "host", host, "port", port)
53 | go func() {
54 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
55 | log.Error("Could not start server", "error", err)
56 | done <- nil
57 | }
58 | }()
59 |
60 | <-done
61 | log.Info("Stopping SSH server")
62 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
63 | defer func() { cancel() }()
64 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
65 | log.Error("Could not stop server", "error", err)
66 | }
67 | }
68 |
69 | func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
70 | // Create a lipgloss.Renderer for the session
71 | renderer := bubbletea.MakeRenderer(s)
72 | // Set up the model with the current session and styles.
73 | // We'll use the session to call wish.Command, which makes it compatible
74 | // with tea.Command.
75 | m := model{
76 | sess: s,
77 | style: renderer.NewStyle().Foreground(lipgloss.Color("8")),
78 | errStyle: renderer.NewStyle().Foreground(lipgloss.Color("3")),
79 | }
80 | return m, []tea.ProgramOption{tea.WithAltScreen()}
81 | }
82 |
83 | type model struct {
84 | err error
85 | sess ssh.Session
86 | style lipgloss.Style
87 | errStyle lipgloss.Style
88 | }
89 |
90 | func (m model) Init() tea.Cmd {
91 | return nil
92 | }
93 |
94 | type cmdFinishedMsg struct{ err error }
95 |
96 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
97 | switch msg := msg.(type) {
98 | case tea.KeyMsg:
99 | switch msg.String() {
100 | case "e":
101 | // Open file.txt in the default editor.
102 | edit, err := editor.Cmd("wish", "file.txt")
103 | if err != nil {
104 | m.err = err
105 | return m, nil
106 | }
107 | // Creates a wish.Cmd from the exec.Cmd
108 | wishCmd := wish.Command(m.sess, edit.Path, edit.Args...)
109 | // Runs the cmd through Bubble Tea.
110 | // Bubble Tea should handle the IO to the program, and get it back
111 | // once the program quits.
112 | cmd := tea.Exec(wishCmd, func(err error) tea.Msg {
113 | if err != nil {
114 | log.Error("editor finished", "error", err)
115 | }
116 | return cmdFinishedMsg{err: err}
117 | })
118 | return m, cmd
119 | case "s":
120 | // We can also execute a shell and give it over to the user.
121 | // Note that this session won't have control, so it can't run tasks
122 | // in background, suspend, etc.
123 | c := wish.Command(m.sess, "bash", "-im")
124 | if runtime.GOOS == "windows" {
125 | c = wish.Command(m.sess, "powershell")
126 | }
127 | cmd := tea.Exec(c, func(err error) tea.Msg {
128 | if err != nil {
129 | log.Error("shell finished", "error", err)
130 | }
131 | return cmdFinishedMsg{err: err}
132 | })
133 | return m, cmd
134 | case "q", "ctrl+c":
135 | return m, tea.Quit
136 | }
137 | case cmdFinishedMsg:
138 | m.err = msg.err
139 | return m, nil
140 | }
141 |
142 | return m, nil
143 | }
144 |
145 | func (m model) View() string {
146 | if m.err != nil {
147 | return m.errStyle.Render(m.err.Error() + "\n")
148 | }
149 | return m.style.Render("Press 'e' to edit, 's' to hop into a shell, or 'q' to quit...\n")
150 | }
151 |
--------------------------------------------------------------------------------
/examples/bubbletea/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | // An example Bubble Tea server. This will put an ssh session into alt screen
4 | // and continually print up to date terminal information.
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "fmt"
10 | "net"
11 | "os"
12 | "os/signal"
13 | "syscall"
14 | "time"
15 |
16 | tea "github.com/charmbracelet/bubbletea"
17 | "github.com/charmbracelet/lipgloss"
18 | "github.com/charmbracelet/log"
19 | "github.com/charmbracelet/ssh"
20 | "github.com/charmbracelet/wish"
21 | "github.com/charmbracelet/wish/activeterm"
22 | "github.com/charmbracelet/wish/bubbletea"
23 | "github.com/charmbracelet/wish/logging"
24 | )
25 |
26 | const (
27 | host = "localhost"
28 | port = "23234"
29 | )
30 |
31 | func main() {
32 | s, err := wish.NewServer(
33 | wish.WithAddress(net.JoinHostPort(host, port)),
34 | wish.WithHostKeyPath(".ssh/id_ed25519"),
35 | wish.WithMiddleware(
36 | bubbletea.Middleware(teaHandler),
37 | activeterm.Middleware(), // Bubble Tea apps usually require a PTY.
38 | logging.Middleware(),
39 | ),
40 | )
41 | if err != nil {
42 | log.Error("Could not start server", "error", err)
43 | }
44 |
45 | done := make(chan os.Signal, 1)
46 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
47 | log.Info("Starting SSH server", "host", host, "port", port)
48 | go func() {
49 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
50 | log.Error("Could not start server", "error", err)
51 | done <- nil
52 | }
53 | }()
54 |
55 | <-done
56 | log.Info("Stopping SSH server")
57 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
58 | defer func() { cancel() }()
59 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
60 | log.Error("Could not stop server", "error", err)
61 | }
62 | }
63 |
64 | // You can wire any Bubble Tea model up to the middleware with a function that
65 | // handles the incoming ssh.Session. Here we just grab the terminal info and
66 | // pass it to the new model. You can also return tea.ProgramOptions (such as
67 | // tea.WithAltScreen) on a session by session basis.
68 | func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
69 | // This should never fail, as we are using the activeterm middleware.
70 | pty, _, _ := s.Pty()
71 |
72 | // When running a Bubble Tea app over SSH, you shouldn't use the default
73 | // lipgloss.NewStyle function.
74 | // That function will use the color profile from the os.Stdin, which is the
75 | // server, not the client.
76 | // We provide a MakeRenderer function in the bubbletea middleware package,
77 | // so you can easily get the correct renderer for the current session, and
78 | // use it to create the styles.
79 | // The recommended way to use these styles is to then pass them down to
80 | // your Bubble Tea model.
81 | renderer := bubbletea.MakeRenderer(s)
82 | txtStyle := renderer.NewStyle().Foreground(lipgloss.Color("10"))
83 | quitStyle := renderer.NewStyle().Foreground(lipgloss.Color("8"))
84 |
85 | bg := "light"
86 | if renderer.HasDarkBackground() {
87 | bg = "dark"
88 | }
89 |
90 | m := model{
91 | term: pty.Term,
92 | profile: renderer.ColorProfile().Name(),
93 | width: pty.Window.Width,
94 | height: pty.Window.Height,
95 | bg: bg,
96 | txtStyle: txtStyle,
97 | quitStyle: quitStyle,
98 | }
99 | return m, []tea.ProgramOption{tea.WithAltScreen()}
100 | }
101 |
102 | // Just a generic tea.Model to demo terminal information of ssh.
103 | type model struct {
104 | term string
105 | profile string
106 | width int
107 | height int
108 | bg string
109 | txtStyle lipgloss.Style
110 | quitStyle lipgloss.Style
111 | }
112 |
113 | func (m model) Init() tea.Cmd {
114 | return nil
115 | }
116 |
117 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
118 | switch msg := msg.(type) {
119 | case tea.WindowSizeMsg:
120 | m.height = msg.Height
121 | m.width = msg.Width
122 | case tea.KeyMsg:
123 | switch msg.String() {
124 | case "q", "ctrl+c":
125 | return m, tea.Quit
126 | }
127 | }
128 | return m, nil
129 | }
130 |
131 | func (m model) View() string {
132 | s := fmt.Sprintf("Your term is %s\nYour window size is %dx%d\nBackground: %s\nColor Profile: %s", m.term, m.width, m.height, m.bg, m.profile)
133 | return m.txtStyle.Render(s) + "\n\n" + m.quitStyle.Render("Press 'q' to quit\n")
134 | }
135 |
--------------------------------------------------------------------------------
/examples/bubbleteaprogram/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | // An example Bubble Tea server. This will put an ssh session into alt screen
4 | // and continually print up to date terminal information.
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "fmt"
10 | "net"
11 | "os"
12 | "os/signal"
13 | "syscall"
14 | "time"
15 |
16 | tea "github.com/charmbracelet/bubbletea"
17 | "github.com/charmbracelet/log"
18 | "github.com/charmbracelet/ssh"
19 | "github.com/charmbracelet/wish"
20 | "github.com/charmbracelet/wish/bubbletea"
21 | "github.com/charmbracelet/wish/logging"
22 | "github.com/muesli/termenv"
23 | )
24 |
25 | const (
26 | host = "localhost"
27 | port = "23234"
28 | )
29 |
30 | func main() {
31 | s, err := wish.NewServer(
32 | wish.WithAddress(net.JoinHostPort(host, port)),
33 | wish.WithHostKeyPath(".ssh/id_ed25519"),
34 | wish.WithMiddleware(
35 | myCustomBubbleteaMiddleware(),
36 | logging.Middleware(),
37 | ),
38 | )
39 | if err != nil {
40 | log.Error("Could not start server", "error", err)
41 | }
42 |
43 | done := make(chan os.Signal, 1)
44 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
45 | log.Info("Starting SSH server", "host", host, "port", port)
46 | go func() {
47 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
48 | log.Error("Could not start server", "error", err)
49 | done <- nil
50 | }
51 | }()
52 |
53 | <-done
54 | log.Info("Stopping SSH server")
55 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
56 | defer func() { cancel() }()
57 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
58 | log.Error("Could not stop server", "error", err)
59 | }
60 | }
61 |
62 | // You can write your own custom bubbletea middleware that wraps tea.Program.
63 | // Make sure you set the program input and output to ssh.Session.
64 | func myCustomBubbleteaMiddleware() wish.Middleware {
65 | newProg := func(m tea.Model, opts ...tea.ProgramOption) *tea.Program {
66 | p := tea.NewProgram(m, opts...)
67 | go func() {
68 | for {
69 | <-time.After(1 * time.Second)
70 | p.Send(timeMsg(time.Now()))
71 | }
72 | }()
73 | return p
74 | }
75 | teaHandler := func(s ssh.Session) *tea.Program {
76 | pty, _, active := s.Pty()
77 | if !active {
78 | wish.Fatalln(s, "no active terminal, skipping")
79 | return nil
80 | }
81 | m := model{
82 | term: pty.Term,
83 | width: pty.Window.Width,
84 | height: pty.Window.Height,
85 | time: time.Now(),
86 | }
87 | return newProg(m, append(bubbletea.MakeOptions(s), tea.WithAltScreen())...)
88 | }
89 | return bubbletea.MiddlewareWithProgramHandler(teaHandler, termenv.ANSI256)
90 | }
91 |
92 | // Just a generic tea.Model to demo terminal information of ssh.
93 | type model struct {
94 | term string
95 | width int
96 | height int
97 | time time.Time
98 | }
99 |
100 | type timeMsg time.Time
101 |
102 | func (m model) Init() tea.Cmd {
103 | return nil
104 | }
105 |
106 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
107 | switch msg := msg.(type) {
108 | case timeMsg:
109 | m.time = time.Time(msg)
110 | case tea.WindowSizeMsg:
111 | m.height = msg.Height
112 | m.width = msg.Width
113 | case tea.KeyMsg:
114 | switch msg.String() {
115 | case "q", "ctrl+c":
116 | return m, tea.Quit
117 | }
118 | }
119 | return m, nil
120 | }
121 |
122 | func (m model) View() string {
123 | s := "Your term is %s\n"
124 | s += "Your window size is x: %d y: %d\n"
125 | s += "Time: " + m.time.Format(time.RFC1123) + "\n\n"
126 | s += "Press 'q' to quit\n"
127 | return fmt.Sprintf(s, m.term, m.width, m.height)
128 | }
129 |
--------------------------------------------------------------------------------
/examples/cobra/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "os"
8 | "os/signal"
9 | "syscall"
10 | "time"
11 |
12 | "github.com/charmbracelet/log"
13 | "github.com/charmbracelet/ssh"
14 | "github.com/charmbracelet/wish"
15 | "github.com/charmbracelet/wish/logging"
16 | "github.com/spf13/cobra"
17 | )
18 |
19 | const (
20 | host = "localhost"
21 | port = "23235"
22 | )
23 |
24 | func cmd() *cobra.Command {
25 | var reverse bool
26 | cmd := &cobra.Command{
27 | Use: "echo [string]",
28 | Args: cobra.ExactArgs(1),
29 | RunE: func(cmd *cobra.Command, args []string) error {
30 | s := args[0]
31 | if reverse {
32 | ss := make([]byte, 0, len(s))
33 | for i := len(s) - 1; i >= 0; i-- {
34 | ss = append(ss, s[i])
35 | }
36 | s = string(ss)
37 | }
38 | cmd.Println(s)
39 | return nil
40 | },
41 | }
42 |
43 | cmd.PersistentFlags().BoolVarP(&reverse, "reverse", "r", false, "Reverse string on echo")
44 | return cmd
45 | }
46 |
47 | func main() {
48 | s, err := wish.NewServer(
49 | wish.WithAddress(net.JoinHostPort(host, port)),
50 | wish.WithHostKeyPath(".ssh/id_ed25519"),
51 | wish.WithMiddleware(
52 | func(next ssh.Handler) ssh.Handler {
53 | return func(sess ssh.Session) {
54 | // Here we wire our command's args and IO to the user
55 | // session's
56 | rootCmd := cmd()
57 | rootCmd.SetArgs(sess.Command())
58 | rootCmd.SetIn(sess)
59 | rootCmd.SetOut(sess)
60 | rootCmd.SetErr(sess.Stderr())
61 | rootCmd.CompletionOptions.DisableDefaultCmd = true
62 | if err := rootCmd.Execute(); err != nil {
63 | _ = sess.Exit(1)
64 | return
65 | }
66 |
67 | next(sess)
68 | }
69 | },
70 | logging.Middleware(),
71 | ),
72 | )
73 | if err != nil {
74 | log.Error("Could not start server", "error", err)
75 | }
76 |
77 | done := make(chan os.Signal, 1)
78 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
79 | log.Info("Starting SSH server", "host", host, "port", port)
80 | go func() {
81 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
82 | log.Error("Could not start server", "error", err)
83 | done <- nil
84 | }
85 | }()
86 |
87 | <-done
88 | log.Info("Stopping SSH server")
89 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
90 | defer func() { cancel() }()
91 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
92 | log.Error("Could not stop server", "error", err)
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/examples/exec/example.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gum choose a b c d
4 |
--------------------------------------------------------------------------------
/examples/exec/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "os"
8 | "os/signal"
9 | "syscall"
10 | "time"
11 |
12 | "github.com/charmbracelet/log"
13 | "github.com/charmbracelet/ssh"
14 | "github.com/charmbracelet/wish"
15 | "github.com/charmbracelet/wish/activeterm"
16 | "github.com/charmbracelet/wish/logging"
17 | )
18 |
19 | const (
20 | host = "localhost"
21 | port = "23234"
22 | )
23 |
24 | func main() {
25 | s, err := wish.NewServer(
26 | wish.WithAddress(net.JoinHostPort(host, port)),
27 |
28 | // Allocate a pty.
29 | // This creates a pseudoconsole on windows, compatibility is limited in
30 | // that case, see the open issues for more details.
31 | ssh.AllocatePty(),
32 | wish.WithMiddleware(
33 | func(next ssh.Handler) ssh.Handler {
34 | return func(s ssh.Session) {
35 | cmd := wish.Command(s, "bash", "example.sh")
36 | if err := cmd.Run(); err != nil {
37 | wish.Fatalln(s, err)
38 | }
39 | next(s)
40 | }
41 | },
42 | // ensure the user has requested a tty
43 | activeterm.Middleware(),
44 | logging.Middleware(),
45 | ),
46 | )
47 | if err != nil {
48 | log.Error("Could not start server", "error", err)
49 | }
50 |
51 | done := make(chan os.Signal, 1)
52 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
53 | log.Info("Starting SSH server", "host", host, "port", port)
54 | go func() {
55 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
56 | log.Error("Could not start server", "error", err)
57 | done <- nil
58 | }
59 | }()
60 |
61 | <-done
62 | log.Info("Stopping SSH server")
63 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
64 | defer func() { cancel() }()
65 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
66 | log.Error("Could not stop server", "error", err)
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/examples/forward/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "os"
8 | "os/signal"
9 | "syscall"
10 | "time"
11 |
12 | "github.com/charmbracelet/log"
13 | "github.com/charmbracelet/ssh"
14 | "github.com/charmbracelet/wish"
15 | "github.com/charmbracelet/wish/logging"
16 | )
17 |
18 | const (
19 | host = "localhost"
20 | port = "23234"
21 | )
22 |
23 | // example usage: ssh -N -R 23236:localhost:23235 -p 23234 localhost
24 |
25 | func main() {
26 | // Create a new SSH ForwardedTCPHandler.
27 | forwardHandler := &ssh.ForwardedTCPHandler{}
28 | s, err := wish.NewServer(
29 | wish.WithAddress(net.JoinHostPort(host, port)),
30 | wish.WithHostKeyPath(".ssh/id_ed25519"),
31 | func(s *ssh.Server) error {
32 | // Set the Reverse TCP Handler up:
33 | s.ReversePortForwardingCallback = func(_ ssh.Context, bindHost string, bindPort uint32) bool {
34 | log.Info("reverse port forwarding allowed", "host", bindHost, "port", bindPort)
35 | return true
36 | }
37 | s.RequestHandlers = map[string]ssh.RequestHandler{
38 | "tcpip-forward": forwardHandler.HandleSSHRequest,
39 | "cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
40 | }
41 | return nil
42 | },
43 | wish.WithMiddleware(
44 | func(h ssh.Handler) ssh.Handler {
45 | return func(s ssh.Session) {
46 | wish.Println(s, "Remote port forwarding available!")
47 | wish.Println(s, "Try it with:")
48 | wish.Println(s, " ssh -N -R 23236:localhost:23235 -p 23234 localhost")
49 | h(s)
50 | }
51 | },
52 | logging.Middleware(),
53 | ),
54 | )
55 | if err != nil {
56 | log.Error("Could not start server", "error", err)
57 | }
58 |
59 | done := make(chan os.Signal, 1)
60 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
61 | log.Info("Starting SSH server", "host", host, "port", port)
62 | go func() {
63 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
64 | log.Error("Could not start server", "error", err)
65 | done <- nil
66 | }
67 | }()
68 |
69 | <-done
70 | log.Info("Stopping SSH server")
71 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
72 | defer func() { cancel() }()
73 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
74 | log.Error("Could not stop server", "error", err)
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/examples/git/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | // An example git server. This will list all available repos if you ssh
4 | // directly to the server. To test `ssh -p 23233 localhost` once it's running.
5 |
6 | import (
7 | "context"
8 | "errors"
9 | "fmt"
10 | "io/fs"
11 | "net"
12 | "os"
13 | "os/signal"
14 | "syscall"
15 | "time"
16 |
17 | "github.com/charmbracelet/log"
18 | "github.com/charmbracelet/ssh"
19 | "github.com/charmbracelet/wish"
20 | "github.com/charmbracelet/wish/git"
21 | "github.com/charmbracelet/wish/logging"
22 | )
23 |
24 | const (
25 | port = "23233"
26 | host = "localhost"
27 | repoDir = ".repos"
28 | )
29 |
30 | type app struct {
31 | access git.AccessLevel
32 | }
33 |
34 | func (a app) AuthRepo(string, ssh.PublicKey) git.AccessLevel {
35 | return a.access
36 | }
37 |
38 | func (a app) Push(repo string, _ ssh.PublicKey) {
39 | log.Info("push", "repo", repo)
40 | }
41 |
42 | func (a app) Fetch(repo string, _ ssh.PublicKey) {
43 | log.Info("fetch", "repo", repo)
44 | }
45 |
46 | func main() {
47 | // A simple GitHooks implementation to allow global read write access.
48 | a := app{git.ReadWriteAccess}
49 |
50 | s, err := wish.NewServer(
51 | wish.WithAddress(net.JoinHostPort(host, port)),
52 | wish.WithHostKeyPath(".ssh/id_ed25519"),
53 | // Accept any public key.
54 | ssh.PublicKeyAuth(func(ssh.Context, ssh.PublicKey) bool { return true }),
55 | // Do not accept password auth.
56 | ssh.PasswordAuth(func(ssh.Context, string) bool { return false }),
57 | wish.WithMiddleware(
58 | // Setup the git middleware.
59 | git.Middleware(repoDir, a),
60 | // Adds a middleware to list all available repositories to the user.
61 | gitListMiddleware,
62 | logging.Middleware(),
63 | ),
64 | )
65 | if err != nil {
66 | log.Error("Could not start server", "error", err)
67 | }
68 |
69 | done := make(chan os.Signal, 1)
70 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
71 | log.Info("Starting SSH server", "host", host, "port", port)
72 | go func() {
73 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
74 | log.Error("Could not start server", "error", err)
75 | done <- nil
76 | }
77 | }()
78 |
79 | <-done
80 | log.Info("Stopping SSH server")
81 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
82 | defer func() { cancel() }()
83 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
84 | log.Error("Could not stop server", "error", err)
85 | }
86 | }
87 |
88 | // Normally we would use a Bubble Tea program for the TUI but for simplicity,
89 | // we'll just write a list of the pushed repos to the terminal and exit the ssh
90 | // session.
91 | func gitListMiddleware(next ssh.Handler) ssh.Handler {
92 | return func(sess ssh.Session) {
93 | // Git will have a command included so only run this if there are no
94 | // commands passed to ssh.
95 | if len(sess.Command()) != 0 {
96 | next(sess)
97 | return
98 | }
99 |
100 | dest, err := os.ReadDir(repoDir)
101 | if err != nil && err != fs.ErrNotExist {
102 | log.Error("Invalid repository", "error", err)
103 | }
104 | if len(dest) > 0 {
105 | fmt.Fprintf(sess, "\n### Repo Menu ###\n\n")
106 | }
107 | for _, dir := range dest {
108 | wish.Println(sess, fmt.Sprintf("• %s - ", dir.Name()))
109 | wish.Println(sess, fmt.Sprintf("git clone ssh://%s/%s", net.JoinHostPort(host, port), dir.Name()))
110 | }
111 | wish.Printf(sess, "\n\n### Add some repos! ###\n\n")
112 | wish.Printf(sess, "> cd some_repo\n")
113 | wish.Printf(sess, "> git remote add wish_test ssh://%s/some_repo\n", net.JoinHostPort(host, port))
114 | wish.Printf(sess, "> git push wish_test\n\n\n")
115 | next(sess)
116 | }
117 | }
118 |
--------------------------------------------------------------------------------
/examples/go.mod:
--------------------------------------------------------------------------------
1 | module examples
2 |
3 | go 1.23.0
4 |
5 | toolchain go1.24.1
6 |
7 | require (
8 | github.com/charmbracelet/bubbles v0.21.0
9 | github.com/charmbracelet/bubbletea v1.3.5
10 | github.com/charmbracelet/lipgloss v1.1.0
11 | github.com/charmbracelet/log v0.4.2
12 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894
13 | github.com/charmbracelet/wish v0.5.0
14 | github.com/charmbracelet/x/editor v0.1.0
15 | github.com/muesli/termenv v0.16.0
16 | github.com/pkg/sftp v1.13.9
17 | github.com/spf13/cobra v1.9.1
18 | golang.org/x/crypto v0.38.0
19 | )
20 |
21 | require (
22 | dario.cat/mergo v1.0.0 // indirect
23 | github.com/Microsoft/go-winio v0.6.2 // indirect
24 | github.com/ProtonMail/go-crypto v1.1.6 // indirect
25 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
26 | github.com/atotto/clipboard v0.1.4 // indirect
27 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
28 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
29 | github.com/charmbracelet/keygen v0.5.3 // indirect
30 | github.com/charmbracelet/x/ansi v0.9.2 // indirect
31 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
32 | github.com/charmbracelet/x/conpty v0.1.0 // indirect
33 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 // indirect
34 | github.com/charmbracelet/x/input v0.3.4 // indirect
35 | github.com/charmbracelet/x/term v0.2.1 // indirect
36 | github.com/charmbracelet/x/termios v0.1.0 // indirect
37 | github.com/charmbracelet/x/windows v0.2.0 // indirect
38 | github.com/cloudflare/circl v1.6.1 // indirect
39 | github.com/creack/pty v1.1.21 // indirect
40 | github.com/cyphar/filepath-securejoin v0.4.1 // indirect
41 | github.com/emirpasic/gods v1.18.1 // indirect
42 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
43 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
44 | github.com/go-git/go-billy/v5 v5.6.2 // indirect
45 | github.com/go-git/go-git/v5 v5.16.0 // indirect
46 | github.com/go-logfmt/logfmt v0.6.0 // indirect
47 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
48 | github.com/inconshreveable/mousetrap v1.1.0 // indirect
49 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
50 | github.com/kevinburke/ssh_config v1.2.0 // indirect
51 | github.com/kr/fs v0.1.0 // indirect
52 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
53 | github.com/mattn/go-isatty v0.0.20 // indirect
54 | github.com/mattn/go-localereader v0.0.1 // indirect
55 | github.com/mattn/go-runewidth v0.0.16 // indirect
56 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
57 | github.com/muesli/cancelreader v0.2.2 // indirect
58 | github.com/pjbgf/sha1cd v0.3.2 // indirect
59 | github.com/rivo/uniseg v0.4.7 // indirect
60 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
61 | github.com/skeema/knownhosts v1.3.1 // indirect
62 | github.com/spf13/pflag v1.0.6 // indirect
63 | github.com/xanzy/ssh-agent v0.3.3 // indirect
64 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
65 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
66 | golang.org/x/net v0.39.0 // indirect
67 | golang.org/x/sync v0.14.0 // indirect
68 | golang.org/x/sys v0.33.0 // indirect
69 | golang.org/x/text v0.25.0 // indirect
70 | gopkg.in/warnings.v0 v0.1.2 // indirect
71 | )
72 |
73 | replace github.com/charmbracelet/wish => ../
74 |
--------------------------------------------------------------------------------
/examples/graceful-shutdown/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "os"
8 | "os/signal"
9 | "syscall"
10 | "time"
11 |
12 | "github.com/charmbracelet/log"
13 | "github.com/charmbracelet/ssh"
14 | "github.com/charmbracelet/wish"
15 | "github.com/charmbracelet/wish/logging"
16 | )
17 |
18 | const (
19 | host = "localhost"
20 | port = "23234"
21 | )
22 |
23 | func main() {
24 | srv, err := wish.NewServer(
25 | wish.WithAddress(net.JoinHostPort(host, port)),
26 | wish.WithHostKeyPath(".ssh/id_ed25519"),
27 | wish.WithMiddleware(
28 | func(next ssh.Handler) ssh.Handler {
29 | return func(sess ssh.Session) {
30 | wish.Println(sess, "Hello, world!")
31 | next(sess)
32 | }
33 | },
34 | logging.Middleware(),
35 | ),
36 | )
37 | if err != nil {
38 | log.Error("Could not start server", "error", err)
39 | }
40 |
41 | // Before starting our server, we create a channel and listen for some
42 | // common interrupt signals.
43 | done := make(chan os.Signal, 1)
44 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
45 |
46 | // We then start the server in a goroutine, as we'll listen for the done
47 | // signal later.
48 | go func() {
49 | log.Info("Starting SSH server", "host", host, "port", port)
50 | if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
51 | // We ignore ErrServerClosed because it is expected.
52 | log.Error("Could not start server", "error", err)
53 | done <- nil
54 | }
55 | }()
56 |
57 | // Here we wait for the done signal: this can be either an interrupt, or
58 | // the server shutting down for any other reason.
59 | <-done
60 |
61 | // When it arrives, we create a context with a timeout.
62 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
63 | defer func() { cancel() }()
64 |
65 | // When we start the shutdown, the server will no longer accept new
66 | // connections, but will wait as much as the given context allows for the
67 | // active connections to finish.
68 | // After the timeout, it shuts down anyway.
69 | log.Info("Stopping SSH server")
70 | if err := srv.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
71 | log.Error("Could not stop server", "error", err)
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/examples/identity/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net"
8 | "os"
9 | "os/signal"
10 | "syscall"
11 | "time"
12 |
13 | "github.com/charmbracelet/log"
14 | "github.com/charmbracelet/ssh"
15 | "github.com/charmbracelet/wish"
16 | "github.com/charmbracelet/wish/logging"
17 | )
18 |
19 | const (
20 | host = "localhost"
21 | port = "23234"
22 | )
23 |
24 | var users = map[string]string{
25 | "Carlos": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILxWe2rXKoiO6W14LYPVfJKzRfJ1f3Jhzxrgjc/D4tU7",
26 | // You can add add your name and public key here :)
27 | }
28 |
29 | func main() {
30 | s, err := wish.NewServer(
31 | wish.WithAddress(net.JoinHostPort(host, port)),
32 | wish.WithHostKeyPath(".ssh/id_ed25519"),
33 | // This will allow anyone to log in, as long as they have given an
34 | // ed25519 public key.
35 | // You can test this by doing something like:
36 | // ssh -i ~/.ssh/id_ed25519 -p 23234 localhost
37 | // ssh -i ~/.ssh/id_rsa -p 23234 localhost
38 | // ssh -o PreferredAuthentications=password -p 23234 localhost
39 | wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
40 | return key.Type() == "ssh-ed25519"
41 | }),
42 | wish.WithMiddleware(
43 | func(next ssh.Handler) ssh.Handler {
44 | return func(sess ssh.Session) {
45 | // if the current session's user public key is one of the
46 | // known users, we greet them and return.
47 | for name, pubkey := range users {
48 | parsed, _, _, _, _ := ssh.ParseAuthorizedKey(
49 | []byte(pubkey),
50 | )
51 | if ssh.KeysEqual(sess.PublicKey(), parsed) {
52 | wish.Println(sess, fmt.Sprintf("Hey %s!", name))
53 | next(sess)
54 | return
55 | }
56 | }
57 | wish.Println(sess, "Hey, I don't know who you are!")
58 | next(sess)
59 | }
60 | },
61 | logging.Middleware(),
62 | ),
63 | )
64 | if err != nil {
65 | log.Error("Could not start server", "error", err)
66 | }
67 |
68 | done := make(chan os.Signal, 1)
69 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
70 | log.Info("Starting SSH server", "host", host, "port", port)
71 | go func() {
72 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
73 | log.Error("Could not start server", "error", err)
74 | done <- nil
75 | }
76 | }()
77 |
78 | <-done
79 | log.Info("Stopping SSH server")
80 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
81 | defer func() { cancel() }()
82 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
83 | log.Error("Could not stop server", "error", err)
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/examples/multi-auth/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "os"
8 | "os/signal"
9 | "syscall"
10 | "time"
11 |
12 | "github.com/charmbracelet/log"
13 | "github.com/charmbracelet/ssh"
14 | "github.com/charmbracelet/wish"
15 | "github.com/charmbracelet/wish/logging"
16 | gossh "golang.org/x/crypto/ssh"
17 | )
18 |
19 | const (
20 | host = "localhost"
21 | port = "23234"
22 | validPassword = "asd123"
23 | )
24 |
25 | var users = map[string]string{
26 | "Carlos": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILxWe2rXKoiO6W14LYPVfJKzRfJ1f3Jhzxrgjc/D4tU7",
27 | // You can add add your name and public key here :)
28 | }
29 |
30 | func main() {
31 | s, err := wish.NewServer(
32 | wish.WithAddress(net.JoinHostPort(host, port)),
33 | wish.WithHostKeyPath(".ssh/id_ed25519"),
34 |
35 | // In this example, we'll have multiple possible authentication methods.
36 | // The order of preference is defined by the user (via
37 | // PreferredAuthentications), and if all of them fails, they aren't
38 | // allowed in.
39 | //
40 | // You can SSH into the server like so:
41 | // ssh -o PreferredAuthentications=none -p 23234 localhost
42 | // ssh -o PreferredAuthentications=password -p 23234 localhost
43 | // ssh -o PreferredAuthentications=publickey -p 23234 localhost
44 | // ssh -o PreferredAuthentications=keyboard-interactive -p 23234 localhost
45 |
46 | // First, public-key authentication:
47 | wish.WithPublicKeyAuth(func(_ ssh.Context, key ssh.PublicKey) bool {
48 | log.Info("publickey")
49 | for _, pubkey := range users {
50 | parsed, _, _, _, _ := ssh.ParseAuthorizedKey(
51 | []byte(pubkey),
52 | )
53 | if ssh.KeysEqual(key, parsed) {
54 | return true
55 | }
56 | }
57 | return false
58 | }),
59 |
60 | // Then, password.
61 | wish.WithPasswordAuth(func(_ ssh.Context, password string) bool {
62 | log.Info("password")
63 | return password == validPassword
64 | }),
65 |
66 | // Finally, keyboard-interactive, which you can use to ask the user to
67 | // answer a challenge:
68 | wish.WithKeyboardInteractiveAuth(func(_ ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
69 | log.Info("keyboard-interactive")
70 | answers, err := challenger(
71 | "", "",
72 | []string{
73 | "♦ How much is 2+3: ",
74 | "♦ Which editor is best, vim or emacs? ",
75 | },
76 | []bool{true, true},
77 | )
78 | if err != nil {
79 | return false
80 | }
81 | // here we check for the correct answers:
82 | return len(answers) == 2 && answers[0] == "5" && answers[1] == "vim"
83 | }),
84 |
85 | wish.WithMiddleware(
86 | logging.Middleware(),
87 | func(next ssh.Handler) ssh.Handler {
88 | return func(sess ssh.Session) {
89 | wish.Println(sess, "Authorized!")
90 | wish.Println(sess, sess.PublicKey())
91 | }
92 | },
93 | ),
94 | )
95 | if err != nil {
96 | log.Error("Could not start server", "error", err)
97 | }
98 |
99 | done := make(chan os.Signal, 1)
100 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
101 | log.Info("Starting SSH server", "host", host, "port", port)
102 | go func() {
103 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
104 | log.Error("Could not start server", "error", err)
105 | done <- nil
106 | }
107 | }()
108 |
109 | <-done
110 | log.Info("Stopping SSH server")
111 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
112 | defer func() { cancel() }()
113 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
114 | log.Error("Could not stop server", "error", err)
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/examples/multichat/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net"
7 | "os"
8 | "os/signal"
9 | "strings"
10 | "syscall"
11 | "time"
12 |
13 | "github.com/charmbracelet/bubbles/textarea"
14 | "github.com/charmbracelet/bubbles/viewport"
15 | tea "github.com/charmbracelet/bubbletea"
16 | "github.com/charmbracelet/lipgloss"
17 | "github.com/charmbracelet/log"
18 | "github.com/charmbracelet/ssh"
19 | "github.com/charmbracelet/wish"
20 | "github.com/charmbracelet/wish/activeterm"
21 | "github.com/charmbracelet/wish/bubbletea"
22 | "github.com/charmbracelet/wish/logging"
23 | "github.com/muesli/termenv"
24 | )
25 |
26 | const (
27 | host = "localhost"
28 | port = "23234"
29 | )
30 |
31 | // app contains a wish server and the list of running programs.
32 | type app struct {
33 | *ssh.Server
34 | progs []*tea.Program
35 | }
36 |
37 | // send dispatches a message to all running programs.
38 | func (a *app) send(msg tea.Msg) {
39 | for _, p := range a.progs {
40 | go p.Send(msg)
41 | }
42 | }
43 |
44 | func newApp() *app {
45 | a := new(app)
46 | s, err := wish.NewServer(
47 | wish.WithAddress(net.JoinHostPort(host, port)),
48 | wish.WithHostKeyPath(".ssh/id_ed25519"),
49 | wish.WithMiddleware(
50 | bubbletea.MiddlewareWithProgramHandler(a.ProgramHandler, termenv.ANSI256),
51 | activeterm.Middleware(),
52 | logging.Middleware(),
53 | ),
54 | )
55 | if err != nil {
56 | log.Error("Could not start server", "error", err)
57 | }
58 |
59 | a.Server = s
60 | return a
61 | }
62 |
63 | func (a *app) Start() {
64 | var err error
65 | done := make(chan os.Signal, 1)
66 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
67 | log.Info("Starting SSH server", "host", host, "port", port)
68 | go func() {
69 | if err = a.ListenAndServe(); err != nil {
70 | log.Error("Could not start server", "error", err)
71 | done <- nil
72 | }
73 | }()
74 |
75 | <-done
76 | log.Info("Stopping SSH server")
77 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
78 | defer func() { cancel() }()
79 | if err := a.Shutdown(ctx); err != nil {
80 | log.Error("Could not stop server", "error", err)
81 | }
82 | }
83 |
84 | func (a *app) ProgramHandler(s ssh.Session) *tea.Program {
85 | model := initialModel()
86 | model.app = a
87 | model.id = s.User()
88 |
89 | p := tea.NewProgram(model, bubbletea.MakeOptions(s)...)
90 | a.progs = append(a.progs, p)
91 |
92 | return p
93 | }
94 |
95 | func main() {
96 | app := newApp()
97 | app.Start()
98 | }
99 |
100 | type (
101 | errMsg error
102 | chatMsg struct {
103 | id string
104 | text string
105 | }
106 | )
107 |
108 | type model struct {
109 | *app
110 | viewport viewport.Model
111 | messages []string
112 | id string
113 | textarea textarea.Model
114 | senderStyle lipgloss.Style
115 | err error
116 | }
117 |
118 | func initialModel() model {
119 | ta := textarea.New()
120 | ta.Placeholder = "Send a message..."
121 | ta.Focus()
122 |
123 | ta.Prompt = "┃ "
124 | ta.CharLimit = 280
125 |
126 | ta.SetWidth(30)
127 | ta.SetHeight(3)
128 |
129 | // Remove cursor line styling
130 | ta.FocusedStyle.CursorLine = lipgloss.NewStyle()
131 |
132 | ta.ShowLineNumbers = false
133 |
134 | vp := viewport.New(30, 5)
135 | vp.SetContent(`Welcome to the chat room!
136 | Type a message and press Enter to send.`)
137 |
138 | ta.KeyMap.InsertNewline.SetEnabled(false)
139 |
140 | return model{
141 | textarea: ta,
142 | messages: []string{},
143 | viewport: vp,
144 | senderStyle: lipgloss.NewStyle().Foreground(lipgloss.Color("5")),
145 | err: nil,
146 | }
147 | }
148 |
149 | func (m model) Init() tea.Cmd {
150 | return textarea.Blink
151 | }
152 |
153 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
154 | var (
155 | tiCmd tea.Cmd
156 | vpCmd tea.Cmd
157 | )
158 |
159 | m.textarea, tiCmd = m.textarea.Update(msg)
160 | m.viewport, vpCmd = m.viewport.Update(msg)
161 |
162 | switch msg := msg.(type) {
163 | case tea.KeyMsg:
164 | switch msg.Type {
165 | case tea.KeyCtrlC, tea.KeyEsc:
166 | return m, tea.Quit
167 | case tea.KeyEnter:
168 | m.app.send(chatMsg{
169 | id: m.id,
170 | text: m.textarea.Value(),
171 | })
172 | m.textarea.Reset()
173 | }
174 |
175 | case chatMsg:
176 | m.messages = append(m.messages, m.senderStyle.Render(msg.id)+": "+msg.text)
177 | m.viewport.SetContent(strings.Join(m.messages, "\n"))
178 | m.viewport.GotoBottom()
179 |
180 | // We handle errors just like any other message
181 | case errMsg:
182 | m.err = msg
183 | return m, nil
184 | }
185 |
186 | return m, tea.Batch(tiCmd, vpCmd)
187 | }
188 |
189 | func (m model) View() string {
190 | return fmt.Sprintf(
191 | "%s\n\n%s",
192 | m.viewport.View(),
193 | m.textarea.View(),
194 | ) + "\n\n"
195 | }
196 |
--------------------------------------------------------------------------------
/examples/pty/main.go:
--------------------------------------------------------------------------------
1 | //go:build !windows
2 |
3 | package main
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "net"
9 | "os"
10 | "os/signal"
11 | "syscall"
12 | "time"
13 |
14 | "github.com/charmbracelet/log"
15 | "github.com/charmbracelet/ssh"
16 | "github.com/charmbracelet/wish"
17 | "github.com/charmbracelet/wish/activeterm"
18 | "github.com/charmbracelet/wish/bubbletea"
19 | "github.com/charmbracelet/wish/logging"
20 | )
21 |
22 | const (
23 | host = "localhost"
24 | port = "23234"
25 | )
26 |
27 | func main() {
28 | srv, err := wish.NewServer(
29 | wish.WithAddress(net.JoinHostPort(host, port)),
30 | wish.WithHostKeyPath(".ssh/id_ed25519"),
31 |
32 | // Wish can allocate a PTY per user session.
33 | ssh.AllocatePty(),
34 |
35 | wish.WithMiddleware(
36 | func(next ssh.Handler) ssh.Handler {
37 | return func(sess ssh.Session) {
38 | pty, _, _ := sess.Pty()
39 | renderer := bubbletea.MakeRenderer(sess)
40 |
41 | bg := "light"
42 | if renderer.HasDarkBackground() {
43 | bg = "dark"
44 | }
45 |
46 | wish.Printf(sess, "Hello, world!\r\n")
47 | wish.Printf(sess, "Term: %s\r\n", pty.Term)
48 | wish.Printf(sess, "PTY: %s\r\n", pty.Slave.Name())
49 | wish.Printf(sess, "FD: %d\r\n", pty.Slave.Fd())
50 | wish.Printf(sess, "Background: %v\r\n", bg)
51 | next(sess)
52 | }
53 | },
54 |
55 | activeterm.Middleware(),
56 | logging.Middleware(),
57 | ),
58 | )
59 | if err != nil {
60 | log.Error("Could not start server", "error", err)
61 | }
62 |
63 | done := make(chan os.Signal, 1)
64 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
65 |
66 | go func() {
67 | log.Info("Starting SSH server", "host", host, "port", port)
68 | if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
69 | log.Error("Could not start server", "error", err)
70 | done <- nil
71 | }
72 | }()
73 |
74 | <-done
75 |
76 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
77 | defer func() { cancel() }()
78 | log.Info("Stopping SSH server")
79 | if err := srv.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
80 | log.Error("Could not stop server", "error", err)
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/examples/scp/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | // An example SCP server. This will serve files from and to ./examples/scp/testdata.
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "io/fs"
11 | "net"
12 | "os"
13 | "os/signal"
14 | "path/filepath"
15 | "syscall"
16 | "time"
17 |
18 | "github.com/charmbracelet/log"
19 | "github.com/charmbracelet/ssh"
20 | "github.com/charmbracelet/wish"
21 | "github.com/charmbracelet/wish/scp"
22 | "github.com/pkg/sftp"
23 | )
24 |
25 | const (
26 | host = "localhost"
27 | port = "23235"
28 | )
29 |
30 | func main() {
31 | root, _ := filepath.Abs("./examples/scp/testdata")
32 | handler := scp.NewFileSystemHandler(root)
33 | s, err := wish.NewServer(
34 | wish.WithAddress(net.JoinHostPort(host, port)),
35 | wish.WithHostKeyPath(".ssh/id_ed25519"),
36 |
37 | // setup the sftp subsystem
38 | wish.WithSubsystem("sftp", sftpSubsystem(root)),
39 | wish.WithMiddleware(
40 | // setup the scp middleware
41 | scp.Middleware(handler, handler),
42 | ),
43 | )
44 | if err != nil {
45 | log.Error("Could not start server", "error", err)
46 | }
47 |
48 | done := make(chan os.Signal, 1)
49 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
50 | log.Info("Starting SSH server", "host", host, "port", port)
51 | go func() {
52 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
53 | log.Error("Could not start server", "error", err)
54 | done <- nil
55 | }
56 | }()
57 |
58 | <-done
59 | log.Info("Stopping SSH server")
60 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
61 | defer func() { cancel() }()
62 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
63 | log.Error("Could not stop server", "error", err)
64 | }
65 | }
66 |
67 | func sftpSubsystem(root string) ssh.SubsystemHandler {
68 | return func(s ssh.Session) {
69 | log.Info("sftp", "root", root)
70 | fs := &sftpHandler{root}
71 | srv := sftp.NewRequestServer(s, sftp.Handlers{
72 | FileList: fs,
73 | FileGet: fs,
74 | })
75 | if err := srv.Serve(); err == io.EOF {
76 | if err := srv.Close(); err != nil {
77 | wish.Fatalln(s, "sftp:", err)
78 | }
79 | } else if err != nil {
80 | wish.Fatalln(s, "sftp:", err)
81 | }
82 | }
83 | }
84 |
85 | // Example readonly handler implementation for sftp.
86 | //
87 | // Other example implementations:
88 | // - https://github.com/gravitational/teleport/blob/f57dc2fe2a9900ec198779aae747ac4f833b278d/tool/teleport/common/sftp.go
89 | // - https://github.com/minio/minio/blob/c66c5828eacb4a7fa9a49b4c890c77dd8684b171/cmd/sftp-server.go
90 | type sftpHandler struct {
91 | root string
92 | }
93 |
94 | var (
95 | _ sftp.FileLister = &sftpHandler{}
96 | _ sftp.FileReader = &sftpHandler{}
97 | )
98 |
99 | type listerAt []fs.FileInfo
100 |
101 | func (l listerAt) ListAt(ls []fs.FileInfo, offset int64) (int, error) {
102 | if offset >= int64(len(l)) {
103 | return 0, io.EOF
104 | }
105 | n := copy(ls, l[offset:])
106 | if n < len(ls) {
107 | return n, io.EOF
108 | }
109 | return n, nil
110 | }
111 |
112 | // Fileread implements sftp.FileReader.
113 | func (s *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
114 | var flags int
115 | pflags := r.Pflags()
116 | if pflags.Append {
117 | flags |= os.O_APPEND
118 | }
119 | if pflags.Creat {
120 | flags |= os.O_CREATE
121 | }
122 | if pflags.Excl {
123 | flags |= os.O_EXCL
124 | }
125 | if pflags.Trunc {
126 | flags |= os.O_TRUNC
127 | }
128 |
129 | if pflags.Read && pflags.Write {
130 | flags |= os.O_RDWR
131 | } else if pflags.Read {
132 | flags |= os.O_RDONLY
133 | } else if pflags.Write {
134 | flags |= os.O_WRONLY
135 | }
136 |
137 | f, err := os.OpenFile(filepath.Join(s.root, r.Filepath), flags, 0600)
138 | if err != nil {
139 | return nil, err
140 | }
141 |
142 | return f, nil
143 | }
144 |
145 | // Filelist implements sftp.FileLister.
146 | func (s *sftpHandler) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
147 | switch r.Method {
148 | case "List":
149 | entries, err := os.ReadDir(filepath.Join(s.root, r.Filepath))
150 | if err != nil {
151 | return nil, fmt.Errorf("sftp: %w", err)
152 | }
153 | infos := make([]fs.FileInfo, len(entries))
154 | for i, entry := range entries {
155 | info, err := entry.Info()
156 | if err != nil {
157 | return nil, err
158 | }
159 | infos[i] = info
160 | }
161 | return listerAt(infos), nil
162 | case "Stat":
163 | fi, err := os.Stat(filepath.Join(s.root, r.Filepath))
164 | if err != nil {
165 | return nil, err
166 | }
167 | return listerAt{fi}, nil
168 | default:
169 | return nil, sftp.ErrSSHFxOpUnsupported
170 | }
171 | }
172 |
--------------------------------------------------------------------------------
/examples/scp/testdata/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/charmbracelet/wish/5097be3e4161e3035a74a0d53d89d3adfe662320/examples/scp/testdata/.gitkeep
--------------------------------------------------------------------------------
/examples/simple/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "errors"
5 | "net"
6 |
7 | "github.com/charmbracelet/log"
8 | "github.com/charmbracelet/ssh"
9 | "github.com/charmbracelet/wish"
10 | "github.com/charmbracelet/wish/logging"
11 | )
12 |
13 | const (
14 | host = "localhost"
15 | port = "23234"
16 | )
17 |
18 | func main() {
19 | srv, err := wish.NewServer(
20 | // The address the server will listen to.
21 | wish.WithAddress(net.JoinHostPort(host, port)),
22 |
23 | // The SSH server need its own keys, this will create a keypair in the
24 | // given path if it doesn't exist yet.
25 | // By default, it will create an ED25519 key.
26 | wish.WithHostKeyPath(".ssh/id_ed25519"),
27 |
28 | // Middlewares do something on a ssh.Session, and then call the next
29 | // middleware in the stack.
30 | wish.WithMiddleware(
31 | func(next ssh.Handler) ssh.Handler {
32 | return func(sess ssh.Session) {
33 | wish.Println(sess, "Hello, world!")
34 | next(sess)
35 | }
36 | },
37 |
38 | // The last item in the chain is the first to be called.
39 | logging.Middleware(),
40 | ),
41 | )
42 | if err != nil {
43 | log.Error("Could not start server", "error", err)
44 | }
45 |
46 | log.Info("Starting SSH server", "host", host, "port", port)
47 | if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
48 | // We ignore ErrServerClosed because it is expected.
49 | log.Error("Could not start server", "error", err)
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/git/git.go:
--------------------------------------------------------------------------------
1 | package git
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "os"
7 | "os/exec"
8 | "path/filepath"
9 | "strings"
10 |
11 | "github.com/charmbracelet/log"
12 | "github.com/charmbracelet/ssh"
13 | "github.com/charmbracelet/wish"
14 | "github.com/go-git/go-git/v5"
15 | "github.com/go-git/go-git/v5/plumbing"
16 | )
17 |
18 | // ErrNotAuthed represents unauthorized access.
19 | var ErrNotAuthed = errors.New("you are not authorized to do this")
20 |
21 | // ErrSystemMalfunction represents a general system error returned to clients.
22 | var ErrSystemMalfunction = errors.New("something went wrong")
23 |
24 | // ErrInvalidRepo represents an attempt to access a non-existent repo.
25 | var ErrInvalidRepo = errors.New("invalid repo")
26 |
27 | // AccessLevel is the level of access allowed to a repo.
28 | type AccessLevel int
29 |
30 | const (
31 | // NoAccess does not allow access to the repo.
32 | NoAccess AccessLevel = iota
33 |
34 | // ReadOnlyAccess allows read-only access to the repo.
35 | ReadOnlyAccess
36 |
37 | // ReadWriteAccess allows read and write access to the repo.
38 | ReadWriteAccess
39 |
40 | // AdminAccess allows read, write, and admin access to the repo.
41 | AdminAccess
42 | )
43 |
44 | // GitHooks is an interface that allows for custom authorization
45 | // implementations and post push/fetch notifications. Prior to git access,
46 | // AuthRepo will be called with the ssh.Session public key and the repo name.
47 | // Implementers return the appropriate AccessLevel.
48 | //
49 | // Deprecated: use Hooks instead.
50 | type GitHooks = Hooks // nolint: revive
51 |
52 | // Hooks is an interface that allows for custom authorization
53 | // implementations and post push/fetch notifications. Prior to git access,
54 | // AuthRepo will be called with the ssh.Session public key and the repo name.
55 | // Implementers return the appropriate AccessLevel.
56 | type Hooks interface {
57 | AuthRepo(string, ssh.PublicKey) AccessLevel
58 | Push(string, ssh.PublicKey)
59 | Fetch(string, ssh.PublicKey)
60 | }
61 |
62 | // Middleware adds Git server functionality to the ssh.Server. Repos are stored
63 | // in the specified repo directory. The provided Hooks implementation will be
64 | // checked for access on a per repo basis for a ssh.Session public key.
65 | // Hooks.Push and Hooks.Fetch will be called on successful completion of
66 | // their commands.
67 | func Middleware(repoDir string, gh Hooks) wish.Middleware {
68 | return func(sh ssh.Handler) ssh.Handler {
69 | return func(s ssh.Session) {
70 | cmd := s.Command()
71 | if len(cmd) == 2 {
72 | gc := cmd[0]
73 | // repo should be in the form of "repo.git" or "user/repo.git"
74 | repo := strings.TrimSuffix(strings.TrimPrefix(cmd[1], "/"), "/")
75 | repo = filepath.Clean(repo)
76 | if n := strings.Count(repo, "/"); n > 1 {
77 | Fatal(s, ErrInvalidRepo)
78 | return
79 | }
80 | pk := s.PublicKey()
81 | access := gh.AuthRepo(repo, pk)
82 | switch gc {
83 | case "git-receive-pack":
84 | switch access {
85 | case ReadWriteAccess, AdminAccess:
86 | err := gitPack(s, gc, repoDir, repo)
87 | if err != nil {
88 | Fatal(s, ErrSystemMalfunction)
89 | } else {
90 | gh.Push(repo, pk)
91 | }
92 | default:
93 | Fatal(s, ErrNotAuthed)
94 | }
95 | return
96 | case "git-upload-archive", "git-upload-pack":
97 | switch access {
98 | case ReadOnlyAccess, ReadWriteAccess, AdminAccess:
99 | err := gitPack(s, gc, repoDir, repo)
100 | switch err {
101 | case ErrInvalidRepo:
102 | Fatal(s, ErrInvalidRepo)
103 | case nil:
104 | gh.Fetch(repo, pk)
105 | default:
106 | log.Error("unknown git error", "error", err)
107 | Fatal(s, ErrSystemMalfunction)
108 | }
109 | default:
110 | Fatal(s, ErrNotAuthed)
111 | }
112 | return
113 | }
114 | }
115 | sh(s)
116 | }
117 | }
118 | }
119 |
120 | func gitPack(s ssh.Session, gitCmd string, repoDir string, repo string) error {
121 | cmd := strings.TrimPrefix(gitCmd, "git-")
122 | rp := filepath.Join(repoDir, repo)
123 | switch gitCmd {
124 | case "git-upload-archive", "git-upload-pack":
125 | exists, err := fileExists(rp)
126 | if !exists {
127 | return ErrInvalidRepo
128 | }
129 | if err != nil {
130 | return err
131 | }
132 | return runGit(s, "", cmd, rp)
133 | case "git-receive-pack":
134 | err := EnsureRepo(repoDir, repo)
135 | if err != nil {
136 | return err
137 | }
138 | err = runGit(s, "", cmd, rp)
139 | if err != nil {
140 | return err
141 | }
142 | err = ensureDefaultBranch(s, rp)
143 | if err != nil {
144 | return err
145 | }
146 | // Needed for git dumb http server
147 | return runGit(s, rp, "update-server-info")
148 | default:
149 | return fmt.Errorf("unknown git command: %s", gitCmd)
150 | }
151 | }
152 |
153 | func fileExists(path string) (bool, error) {
154 | _, err := os.Stat(path)
155 | if err == nil {
156 | return true, nil
157 | }
158 | if os.IsNotExist(err) {
159 | return false, nil
160 | }
161 | return true, err
162 | }
163 |
164 | // Fatal prints to the session's STDOUT as a git response and exit 1.
165 | func Fatal(s ssh.Session, v ...interface{}) {
166 | msg := fmt.Sprint(v...)
167 | // hex length includes 4 byte length prefix and ending newline
168 | pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
169 | _, _ = wish.WriteString(s, pktLine)
170 | s.Exit(1) // nolint: errcheck
171 | }
172 |
173 | // EnsureRepo makes sure the given repo exists within the given dir, and that
174 | // it is git repository.
175 | //
176 | // If path does not exist, it'll be created.
177 | // If the path is not a git repo, it will be git init-ed as a bare repository.
178 | func EnsureRepo(dir, repo string) error {
179 | exists, err := fileExists(dir)
180 | if err != nil {
181 | return err
182 | }
183 | if !exists {
184 | err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0o700))
185 | if err != nil {
186 | return err
187 | }
188 | }
189 | rp := filepath.Join(dir, repo)
190 | exists, err = fileExists(rp)
191 | if err != nil {
192 | return err
193 | }
194 | if !exists {
195 | _, err := git.PlainInit(rp, true)
196 | if err != nil {
197 | return err
198 | }
199 | }
200 | return nil
201 | }
202 |
203 | func runGit(s ssh.Session, dir string, args ...string) error {
204 | usi := exec.CommandContext(s.Context(), "git", args...)
205 | usi.Dir = dir
206 | usi.Stdout = s
207 | usi.Stdin = s
208 | if err := usi.Run(); err != nil {
209 | return err
210 | }
211 | return nil
212 | }
213 |
214 | func ensureDefaultBranch(s ssh.Session, repoPath string) error {
215 | r, err := git.PlainOpen(repoPath)
216 | if err != nil {
217 | return err
218 | }
219 | brs, err := r.Branches()
220 | if err != nil {
221 | return err
222 | }
223 | defer brs.Close()
224 | fb, err := brs.Next()
225 | if err != nil {
226 | return err
227 | }
228 | // Rename the default branch to the first branch available
229 | _, err = r.Head()
230 | if err == plumbing.ErrReferenceNotFound {
231 | err = runGit(s, repoPath, "branch", "-M", fb.Name().Short())
232 | if err != nil {
233 | return err
234 | }
235 | }
236 | if err != nil && err != plumbing.ErrReferenceNotFound {
237 | return err
238 | }
239 | return nil
240 | }
241 |
--------------------------------------------------------------------------------
/git/git_test.go:
--------------------------------------------------------------------------------
1 | package git
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | "os/exec"
7 | "path/filepath"
8 | "runtime"
9 | "sync"
10 | "testing"
11 |
12 | "github.com/charmbracelet/keygen"
13 | "github.com/charmbracelet/ssh"
14 | "github.com/charmbracelet/wish"
15 | )
16 |
17 | func TestGitMiddleware(t *testing.T) {
18 | pubkey, pkPath := createKeyPair(t)
19 | hkPath := filepath.Join(t.TempDir(), "id_ed25519")
20 |
21 | l, err := net.Listen("tcp", "127.0.0.1:0")
22 | requireNoError(t, err)
23 | remote := "ssh://" + l.Addr().String()
24 |
25 | repoDir := t.TempDir()
26 | hooks := &testHooks{
27 | pushes: []action{},
28 | fetches: []action{},
29 | access: []accessDetails{
30 | {pubkey, "repo1", AdminAccess},
31 | {pubkey, "repo2", AdminAccess},
32 | {pubkey, "repo3", AdminAccess},
33 | {pubkey, "repo4", AdminAccess},
34 | {pubkey, "repo5", NoAccess},
35 | {pubkey, "repo6", ReadOnlyAccess},
36 | {pubkey, "repo7", AdminAccess},
37 | {pubkey, "abc/repo1", AdminAccess},
38 | {pubkey, "abc/def/repo1", AdminAccess},
39 | },
40 | }
41 | srv, err := wish.NewServer(
42 | wish.WithHostKeyPath(hkPath),
43 | wish.WithMiddleware(Middleware(repoDir, hooks)),
44 | wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
45 | return true
46 | }),
47 | )
48 | requireNoError(t, err)
49 | go func() { srv.Serve(l) }()
50 | t.Cleanup(func() { _ = srv.Close() })
51 |
52 | t.Run("create repo on master", func(t *testing.T) {
53 | cwd := t.TempDir()
54 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "master"))
55 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo1"))
56 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
57 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "master"))
58 | requireHasAction(t, hooks.pushes, pubkey, "repo1")
59 | })
60 |
61 | t.Run("create repo on main", func(t *testing.T) {
62 | cwd := t.TempDir()
63 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
64 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo2"))
65 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
66 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
67 | requireHasAction(t, hooks.pushes, pubkey, "repo2")
68 | })
69 |
70 | t.Run("create repo in subdir", func(t *testing.T) {
71 | if runtime.GOOS == "windows" {
72 | t.Skip("permission issues")
73 | }
74 | cwd := t.TempDir()
75 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
76 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/abc/repo1"))
77 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
78 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
79 | requireHasAction(t, hooks.pushes, pubkey, "abc/repo1")
80 | })
81 |
82 | t.Run("create wrong repo", func(t *testing.T) {
83 | cwd := t.TempDir()
84 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
85 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"//../../repo1"))
86 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
87 | requireError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
88 | })
89 |
90 | t.Run("create wrong repo in subdir", func(t *testing.T) {
91 | cwd := t.TempDir()
92 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
93 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/abc/def/repo1"))
94 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
95 | requireError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
96 | })
97 |
98 | t.Run("create and clone repo", func(t *testing.T) {
99 | cwd := t.TempDir()
100 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
101 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo3"))
102 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
103 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
104 |
105 | cwd = t.TempDir()
106 | requireNoError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo3"))
107 |
108 | requireHasAction(t, hooks.pushes, pubkey, "repo3")
109 | requireHasAction(t, hooks.fetches, pubkey, "repo3")
110 | })
111 |
112 | t.Run("clone repo that doesn't exist", func(t *testing.T) {
113 | cwd := t.TempDir()
114 | requireError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo4"))
115 | })
116 |
117 | t.Run("clone repo with no access", func(t *testing.T) {
118 | cwd := t.TempDir()
119 | requireError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo5"))
120 | })
121 |
122 | t.Run("push repo with with readonly", func(t *testing.T) {
123 | cwd := t.TempDir()
124 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main"))
125 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo6"))
126 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
127 | requireError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main"))
128 | })
129 |
130 | t.Run("create and clone repo on weird branch", func(t *testing.T) {
131 | cwd := t.TempDir()
132 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "a-weird-branch-name"))
133 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo7"))
134 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit"))
135 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "a-weird-branch-name"))
136 |
137 | cwd = t.TempDir()
138 | requireNoError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo7"))
139 |
140 | requireHasAction(t, hooks.pushes, pubkey, "repo7")
141 | requireHasAction(t, hooks.fetches, pubkey, "repo7")
142 | })
143 | }
144 |
145 | func runGitHelper(t *testing.T, pk, cwd string, args ...string) error {
146 | t.Helper()
147 |
148 | allArgs := []string{
149 | "-c", "user.name='wish'",
150 | "-c", "user.email='test@wish'",
151 | "-c", "commit.gpgSign=false",
152 | "-c", "tag.gpgSign=false",
153 | "-c", "log.showSignature=false",
154 | "-c", "ssh.variant=ssh",
155 | }
156 | allArgs = append(allArgs, args...)
157 |
158 | cmd := exec.Command("git", allArgs...)
159 | cmd.Dir = cwd
160 | cmd.Env = []string{fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -i "%s" -F /dev/null`, pk)}
161 | out, err := cmd.CombinedOutput()
162 | if err != nil {
163 | t.Log("git out:", string(out))
164 | }
165 | return err
166 | }
167 |
168 | func requireNoError(t *testing.T, err error) {
169 | t.Helper()
170 |
171 | if err != nil {
172 | t.Fatalf("expected no error, got %q", err.Error())
173 | }
174 | }
175 |
176 | func requireError(t *testing.T, err error) {
177 | t.Helper()
178 |
179 | if err == nil {
180 | t.Fatalf("expected an error, got nil")
181 | }
182 | }
183 |
184 | func requireHasAction(t *testing.T, actions []action, key ssh.PublicKey, repo string) {
185 | t.Helper()
186 |
187 | for _, action := range actions {
188 | if repo == action.repo && ssh.KeysEqual(key, action.key) {
189 | return
190 | }
191 | }
192 | t.Fatalf("expected action for %q, got none", repo)
193 | }
194 |
195 | func createKeyPair(t *testing.T) (ssh.PublicKey, string) {
196 | t.Helper()
197 |
198 | pk := filepath.Join(t.TempDir(), "id_ed25519")
199 | kp, err := keygen.New(pk, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
200 | requireNoError(t, err)
201 | return kp.PublicKey(), pk
202 | }
203 |
204 | type accessDetails struct {
205 | key ssh.PublicKey
206 | repo string
207 | level AccessLevel
208 | }
209 |
210 | type action struct {
211 | key ssh.PublicKey
212 | repo string
213 | }
214 |
215 | type testHooks struct {
216 | sync.Mutex
217 | pushes []action
218 | fetches []action
219 | access []accessDetails
220 | }
221 |
222 | func (h *testHooks) AuthRepo(repo string, key ssh.PublicKey) AccessLevel {
223 | for _, dets := range h.access {
224 | if dets.repo == repo && ssh.KeysEqual(key, dets.key) {
225 | return dets.level
226 | }
227 | }
228 | return NoAccess
229 | }
230 |
231 | func (h *testHooks) Push(repo string, key ssh.PublicKey) {
232 | h.Lock()
233 | defer h.Unlock()
234 |
235 | h.pushes = append(h.pushes, action{key, repo})
236 | }
237 |
238 | func (h *testHooks) Fetch(repo string, key ssh.PublicKey) {
239 | h.Lock()
240 | defer h.Unlock()
241 |
242 | h.fetches = append(h.fetches, action{key, repo})
243 | }
244 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/charmbracelet/wish
2 |
3 | go 1.23.0
4 |
5 | toolchain go1.24.1
6 |
7 | require (
8 | github.com/charmbracelet/bubbletea v1.3.5
9 | github.com/charmbracelet/keygen v0.5.3
10 | github.com/charmbracelet/lipgloss v1.1.0
11 | github.com/charmbracelet/log v0.4.2
12 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894
13 | github.com/charmbracelet/x/ansi v0.9.2
14 | github.com/charmbracelet/x/input v0.3.4
15 | github.com/charmbracelet/x/term v0.2.1
16 | github.com/go-git/go-git/v5 v5.16.0
17 | github.com/google/go-cmp v0.7.0
18 | github.com/hashicorp/golang-lru/v2 v2.0.7
19 | github.com/lucasb-eyer/go-colorful v1.2.0
20 | github.com/matryer/is v1.4.1
21 | github.com/muesli/termenv v0.16.0
22 | golang.org/x/crypto v0.38.0
23 | golang.org/x/sync v0.14.0
24 | golang.org/x/time v0.11.0
25 | )
26 |
27 | require (
28 | dario.cat/mergo v1.0.0 // indirect
29 | github.com/Microsoft/go-winio v0.6.2 // indirect
30 | github.com/ProtonMail/go-crypto v1.1.6 // indirect
31 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
32 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
33 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
34 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
35 | github.com/charmbracelet/x/conpty v0.1.0 // indirect
36 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 // indirect
37 | github.com/charmbracelet/x/termios v0.1.0 // indirect
38 | github.com/charmbracelet/x/windows v0.2.0 // indirect
39 | github.com/cloudflare/circl v1.6.1 // indirect
40 | github.com/creack/pty v1.1.21 // indirect
41 | github.com/cyphar/filepath-securejoin v0.4.1 // indirect
42 | github.com/emirpasic/gods v1.18.1 // indirect
43 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
44 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
45 | github.com/go-git/go-billy/v5 v5.6.2 // indirect
46 | github.com/go-logfmt/logfmt v0.6.0 // indirect
47 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
48 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
49 | github.com/kevinburke/ssh_config v1.2.0 // indirect
50 | github.com/mattn/go-isatty v0.0.20 // indirect
51 | github.com/mattn/go-localereader v0.0.1 // indirect
52 | github.com/mattn/go-runewidth v0.0.16 // indirect
53 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
54 | github.com/muesli/cancelreader v0.2.2 // indirect
55 | github.com/pjbgf/sha1cd v0.3.2 // indirect
56 | github.com/rivo/uniseg v0.4.7 // indirect
57 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
58 | github.com/skeema/knownhosts v1.3.1 // indirect
59 | github.com/xanzy/ssh-agent v0.3.3 // indirect
60 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
61 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
62 | golang.org/x/net v0.39.0 // indirect
63 | golang.org/x/sys v0.33.0 // indirect
64 | golang.org/x/text v0.25.0 // indirect
65 | gopkg.in/warnings.v0 v0.1.2 // indirect
66 | )
67 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
2 | dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
3 | github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
4 | github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
5 | github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
6 | github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNxpLfdw=
7 | github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE=
8 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
9 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
10 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
11 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
12 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
13 | github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
14 | github.com/charmbracelet/bubbletea v1.3.5 h1:JAMNLTbqMOhSwoELIr0qyP4VidFq72/6E9j7HHmRKQc=
15 | github.com/charmbracelet/bubbletea v1.3.5/go.mod h1:TkCnmH+aBd4LrXhXcqrKiYwRs7qyQx5rBgH5fVY3v54=
16 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
17 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
18 | github.com/charmbracelet/keygen v0.5.3 h1:2MSDC62OUbDy6VmjIE2jM24LuXUvKywLCmaJDmr/Z/4=
19 | github.com/charmbracelet/keygen v0.5.3/go.mod h1:TcpNoMAO5GSmhx3SgcEMqCrtn8BahKhB8AlwnLjRUpk=
20 | github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
21 | github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
22 | github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
23 | github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
24 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894 h1:Ffon9TbltLGBsT6XE//YvNuu4OAaThXioqalhH11xEw=
25 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894/go.mod h1:hg+I6gvlMl16nS9ZzQNgBIrrCasGwEw0QiLsDcP01Ko=
26 | github.com/charmbracelet/x/ansi v0.9.2 h1:92AGsQmNTRMzuzHEYfCdjQeUzTrgE1vfO5/7fEVoXdY=
27 | github.com/charmbracelet/x/ansi v0.9.2/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
28 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
29 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
30 | github.com/charmbracelet/x/conpty v0.1.0 h1:4zc8KaIcbiL4mghEON8D72agYtSeIgq8FSThSPQIb+U=
31 | github.com/charmbracelet/x/conpty v0.1.0/go.mod h1:rMFsDJoDwVmiYM10aD4bH2XiRgwI7NYJtQgl5yskjEQ=
32 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
33 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
34 | github.com/charmbracelet/x/input v0.3.4 h1:Mujmnv/4DaitU0p+kIsrlfZl/UlmeLKw1wAP3e1fMN0=
35 | github.com/charmbracelet/x/input v0.3.4/go.mod h1:JI8RcvdZWQIhn09VzeK3hdp4lTz7+yhiEdpEQtZN+2c=
36 | github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
37 | github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
38 | github.com/charmbracelet/x/termios v0.1.0 h1:y4rjAHeFksBAfGbkRDmVinMg7x7DELIGAFbdNvxg97k=
39 | github.com/charmbracelet/x/termios v0.1.0/go.mod h1:H/EVv/KRnrYjz+fCYa9bsKdqF3S8ouDK0AZEbG7r+/U=
40 | github.com/charmbracelet/x/windows v0.2.0 h1:ilXA1GJjTNkgOm94CLPeSz7rar54jtFatdmoiONPuEw=
41 | github.com/charmbracelet/x/windows v0.2.0/go.mod h1:ZibNFR49ZFqCXgP76sYanisxRyC+EYrBE7TTknD8s1s=
42 | github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
43 | github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
44 | github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0=
45 | github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
46 | github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s=
47 | github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
48 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
49 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
50 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
51 | github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
52 | github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
53 | github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
54 | github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
55 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
56 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
57 | github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
58 | github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
59 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI=
60 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic=
61 | github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM=
62 | github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU=
63 | github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4=
64 | github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
65 | github.com/go-git/go-git/v5 v5.16.0 h1:k3kuOEpkc0DeY7xlL6NaaNg39xdgQbtH5mwCafHO9AQ=
66 | github.com/go-git/go-git/v5 v5.16.0/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8=
67 | github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
68 | github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
69 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
70 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
71 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
72 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
73 | github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
74 | github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
75 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
76 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
77 | github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4=
78 | github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
79 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
80 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
81 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
82 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
83 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
84 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
85 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
86 | github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
87 | github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
88 | github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ=
89 | github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
90 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
91 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
92 | github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
93 | github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
94 | github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
95 | github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
96 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
97 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
98 | github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
99 | github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
100 | github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
101 | github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
102 | github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
103 | github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
104 | github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4=
105 | github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A=
106 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
107 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
108 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
109 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
110 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
111 | github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
112 | github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
113 | github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
114 | github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
115 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8=
116 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
117 | github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
118 | github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8=
119 | github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY=
120 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
121 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
122 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
123 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
124 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
125 | github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
126 | github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
127 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
128 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
129 | golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
130 | golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
131 | golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
132 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
133 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
134 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
135 | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
136 | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
137 | golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
138 | golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
139 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
140 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
141 | golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
142 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
143 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
144 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
145 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
146 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
147 | golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
148 | golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
149 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
150 | golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
151 | golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
152 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
153 | golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
154 | golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
155 | golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
156 | golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
157 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
158 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
159 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
160 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
161 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
162 | gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
163 | gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
164 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
165 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
166 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
167 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
168 |
--------------------------------------------------------------------------------
/logging/logging.go:
--------------------------------------------------------------------------------
1 | package logging
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/charmbracelet/log"
7 | "github.com/charmbracelet/ssh"
8 | "github.com/charmbracelet/wish"
9 | )
10 |
11 | // Middleware provides basic connection logging.
12 | // Connects are logged with the remote address, invoked command, TERM setting,
13 | // window dimensions, client version, and if the auth was public key based.
14 | // Disconnect will log the remote address and connection duration.
15 | //
16 | // It will use charmbracelet/log.StandardLog() by default.
17 | func Middleware() wish.Middleware {
18 | return MiddlewareWithLogger(log.StandardLog())
19 | }
20 |
21 | // Logger is the interface that wraps the basic Log method.
22 | type Logger interface {
23 | Printf(format string, v ...interface{})
24 | }
25 |
26 | // MiddlewareWithLogger provides basic connection logging.
27 | // Connects are logged with the remote address, invoked command, TERM setting,
28 | // window dimensions, client version, and if the auth was public key based.
29 | // Disconnect will log the remote address and connection duration.
30 | func MiddlewareWithLogger(logger Logger) wish.Middleware {
31 | return func(next ssh.Handler) ssh.Handler {
32 | return func(sess ssh.Session) {
33 | ct := time.Now()
34 | hpk := sess.PublicKey() != nil
35 | pty, _, _ := sess.Pty()
36 | logger.Printf(
37 | "%s connect %s %v %v %s %v %v %v",
38 | sess.User(),
39 | sess.RemoteAddr().String(),
40 | hpk,
41 | sess.Command(),
42 | pty.Term,
43 | pty.Window.Width,
44 | pty.Window.Height,
45 | sess.Context().ClientVersion(),
46 | )
47 | next(sess)
48 | logger.Printf(
49 | "%s disconnect %s\n",
50 | sess.RemoteAddr().String(),
51 | time.Since(ct),
52 | )
53 | }
54 | }
55 | }
56 |
57 | // StructuredMiddleware provides basic connection logging in a structured form.
58 | // Connects are logged with the remote address, invoked command, TERM setting,
59 | // window dimensions, client version, and if the auth was public key based.
60 | // Disconnect will log the remote address and connection duration.
61 | //
62 | // It will use the charmbracelet/log.Default() and Info level by default.
63 | func StructuredMiddleware() wish.Middleware {
64 | return StructuredMiddlewareWithLogger(log.Default(), log.InfoLevel)
65 | }
66 |
67 | // StructuredMiddlewareWithLogger provides basic connection logging in a structured form.
68 | // Connects are logged with the remote address, invoked command, TERM setting,
69 | // window dimensions, client version, and if the auth was public key based.
70 | // Disconnect will log the remote address and connection duration.
71 | func StructuredMiddlewareWithLogger(logger *log.Logger, level log.Level) wish.Middleware {
72 | return func(next ssh.Handler) ssh.Handler {
73 | return func(sess ssh.Session) {
74 | ct := time.Now()
75 | hpk := sess.PublicKey() != nil
76 | pty, _, _ := sess.Pty()
77 | logger.Log(
78 | level,
79 | "connect",
80 | "user", sess.User(),
81 | "remote-addr", sess.RemoteAddr().String(),
82 | "public-key", hpk,
83 | "command", sess.Command(),
84 | "term", pty.Term,
85 | "width", pty.Window.Width,
86 | "height", pty.Window.Height,
87 | "client-version", sess.Context().ClientVersion(),
88 | )
89 | next(sess)
90 | logger.Log(
91 | level,
92 | "disconnect",
93 | "user", sess.User(),
94 | "remote-addr", sess.RemoteAddr().String(),
95 | "duration", time.Since(ct),
96 | )
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/logging/logging_test.go:
--------------------------------------------------------------------------------
1 | package logging_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/charmbracelet/ssh"
7 | "github.com/charmbracelet/wish"
8 | "github.com/charmbracelet/wish/logging"
9 | "github.com/charmbracelet/wish/testsession"
10 | gossh "golang.org/x/crypto/ssh"
11 | )
12 |
13 | func TestMiddleware(t *testing.T) {
14 | t.Run("inactive term", func(t *testing.T) {
15 | if err := setup(t, logging.Middleware()).Run(""); err != nil {
16 | t.Error(err)
17 | }
18 | })
19 | }
20 |
21 | func TestStructuredMiddleware(t *testing.T) {
22 | t.Run("inactive term", func(t *testing.T) {
23 | if err := setup(t, logging.StructuredMiddleware()).Run(""); err != nil {
24 | t.Error(err)
25 | }
26 | })
27 | }
28 |
29 | func setup(tb testing.TB, middleware wish.Middleware) *gossh.Session {
30 | tb.Helper()
31 | return testsession.New(tb, &ssh.Server{
32 | Handler: middleware(func(s ssh.Session) {
33 | s.Write([]byte("hello"))
34 | }),
35 | }, nil)
36 | }
37 |
--------------------------------------------------------------------------------
/options.go:
--------------------------------------------------------------------------------
1 | package wish
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "errors"
7 | "io"
8 | "os"
9 | "strings"
10 | "time"
11 |
12 | "github.com/charmbracelet/keygen"
13 | "github.com/charmbracelet/log"
14 | "github.com/charmbracelet/ssh"
15 | gossh "golang.org/x/crypto/ssh"
16 | )
17 |
18 | // WithAddress returns an ssh.Option that sets the address to listen on.
19 | func WithAddress(addr string) ssh.Option {
20 | return func(s *ssh.Server) error {
21 | s.Addr = addr
22 | return nil
23 | }
24 | }
25 |
26 | // WithVersion returns an ssh.Option that sets the server version.
27 | func WithVersion(version string) ssh.Option {
28 | return func(s *ssh.Server) error {
29 | s.Version = version
30 | return nil
31 | }
32 | }
33 |
34 | // WithBanner return an ssh.Option that sets the server banner.
35 | func WithBanner(banner string) ssh.Option {
36 | return func(s *ssh.Server) error {
37 | s.Banner = banner
38 | return nil
39 | }
40 | }
41 |
42 | // WithBannerHandler return an ssh.Option that sets the server banner handler,
43 | // overriding WithBanner.
44 | func WithBannerHandler(h ssh.BannerHandler) ssh.Option {
45 | return func(s *ssh.Server) error {
46 | s.BannerHandler = h
47 | return nil
48 | }
49 | }
50 |
51 | // WithMiddleware composes the provided Middleware and returns an ssh.Option.
52 | // This is useful if you manually create an ssh.Server and want to set the
53 | // Server.Handler.
54 | //
55 | // Notice that middlewares are composed from first to last, which means the last one is executed first.
56 | func WithMiddleware(mw ...Middleware) ssh.Option {
57 | return func(s *ssh.Server) error {
58 | h := func(ssh.Session) {}
59 | for _, m := range mw {
60 | h = m(h)
61 | }
62 | s.Handler = h
63 | return nil
64 | }
65 | }
66 |
67 | // WithHostKeyFile returns an ssh.Option that sets the path to the private key.
68 | func WithHostKeyPath(path string) ssh.Option {
69 | if _, err := os.Stat(path); os.IsNotExist(err) {
70 | _, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
71 | if err != nil {
72 | return func(*ssh.Server) error {
73 | return err
74 | }
75 | }
76 | }
77 | return ssh.HostKeyFile(path)
78 | }
79 |
80 | // WithHostKeyPEM returns an ssh.Option that sets the host key from a PEM block.
81 | func WithHostKeyPEM(pem []byte) ssh.Option {
82 | return ssh.HostKeyPEM(pem)
83 | }
84 |
85 | // WithAuthorizedKeys allows the use of an SSH authorized_keys file to allowlist users.
86 | func WithAuthorizedKeys(path string) ssh.Option {
87 | return func(s *ssh.Server) error {
88 | if _, err := os.Stat(path); err != nil {
89 | return err
90 | }
91 | return WithPublicKeyAuth(func(_ ssh.Context, key ssh.PublicKey) bool {
92 | return isAuthorized(path, func(k ssh.PublicKey) bool {
93 | return ssh.KeysEqual(key, k)
94 | })
95 | })(s)
96 | }
97 | }
98 |
99 | // WithTrustedUserCAKeys authorize certificates that are signed with the given
100 | // Certificate Authority public key, and are valid.
101 | // Analogous to the TrustedUserCAKeys OpenSSH option.
102 | func WithTrustedUserCAKeys(path string) ssh.Option {
103 | return func(s *ssh.Server) error {
104 | if _, err := os.Stat(path); err != nil {
105 | return err
106 | }
107 | return WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
108 | cert, ok := key.(*gossh.Certificate)
109 | if !ok {
110 | // not a certificate...
111 | return false
112 | }
113 |
114 | return isAuthorized(path, func(k ssh.PublicKey) bool {
115 | checker := &gossh.CertChecker{
116 | IsUserAuthority: func(auth gossh.PublicKey) bool {
117 | // its a cert signed by one of the CAs
118 | return bytes.Equal(auth.Marshal(), k.Marshal())
119 | },
120 | }
121 |
122 | if !checker.IsUserAuthority(cert.SignatureKey) {
123 | return false
124 | }
125 |
126 | if err := checker.CheckCert(ctx.User(), cert); err != nil {
127 | return false
128 | }
129 |
130 | return true
131 | })
132 | })(s)
133 | }
134 | }
135 |
136 | func isAuthorized(path string, checker func(k ssh.PublicKey) bool) bool {
137 | f, err := os.Open(path)
138 | if err != nil {
139 | log.Warn("failed to parse", "path", path, "error", err)
140 | return false
141 | }
142 | defer f.Close() // nolint: errcheck
143 |
144 | rd := bufio.NewReader(f)
145 | for {
146 | line, _, err := rd.ReadLine()
147 | if err != nil {
148 | if errors.Is(err, io.EOF) {
149 | break
150 | }
151 | log.Warn("failed to parse", "path", path, "error", err)
152 | return false
153 | }
154 | if strings.TrimSpace(string(line)) == "" {
155 | continue
156 | }
157 | if bytes.HasPrefix(line, []byte{'#'}) {
158 | continue
159 | }
160 | upk, _, _, _, err := ssh.ParseAuthorizedKey(line)
161 | if err != nil {
162 | log.Warn("failed to parse", "path", path, "error", err)
163 | return false
164 | }
165 | if checker(upk) {
166 | return true
167 | }
168 | }
169 | return false
170 | }
171 |
172 | // WithPublicKeyAuth returns an ssh.Option that sets the public key auth handler.
173 | func WithPublicKeyAuth(h ssh.PublicKeyHandler) ssh.Option {
174 | return ssh.PublicKeyAuth(h)
175 | }
176 |
177 | // WithPasswordAuth returns an ssh.Option that sets the password auth handler.
178 | func WithPasswordAuth(p ssh.PasswordHandler) ssh.Option {
179 | return ssh.PasswordAuth(p)
180 | }
181 |
182 | // WithKeyboardInteractiveAuth returns an ssh.Option that sets the keyboard interactive auth handler.
183 | func WithKeyboardInteractiveAuth(h ssh.KeyboardInteractiveHandler) ssh.Option {
184 | return ssh.KeyboardInteractiveAuth(h)
185 | }
186 |
187 | // WithIdleTimeout returns an ssh.Option that sets the connection's idle timeout.
188 | func WithIdleTimeout(d time.Duration) ssh.Option {
189 | return func(s *ssh.Server) error {
190 | s.IdleTimeout = d
191 | return nil
192 | }
193 | }
194 |
195 | // WithMaxTimeout returns an ssh.Option that sets the connection's absolute timeout.
196 | func WithMaxTimeout(d time.Duration) ssh.Option {
197 | return func(s *ssh.Server) error {
198 | s.MaxTimeout = d
199 | return nil
200 | }
201 | }
202 |
203 | // WithSubsystem returns an ssh.Option that sets the subsystem
204 | // handler for a given protocol.
205 | func WithSubsystem(key string, h ssh.SubsystemHandler) ssh.Option {
206 | return func(s *ssh.Server) error {
207 | if s.SubsystemHandlers == nil {
208 | s.SubsystemHandlers = map[string]ssh.SubsystemHandler{}
209 | }
210 | s.SubsystemHandlers[key] = h
211 | return nil
212 | }
213 | }
214 |
--------------------------------------------------------------------------------
/options_test.go:
--------------------------------------------------------------------------------
1 | package wish
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "os"
7 | "strings"
8 | "testing"
9 | "time"
10 |
11 | "github.com/charmbracelet/ssh"
12 | "github.com/charmbracelet/wish/testsession"
13 | gossh "golang.org/x/crypto/ssh"
14 | )
15 |
16 | func TestWithSubsystem(t *testing.T) {
17 | srv := &ssh.Server{
18 | Handler: func(s ssh.Session) {},
19 | }
20 | requireNoError(t, WithSubsystem("foo", func(s ssh.Session) {})(srv))
21 | if srv.SubsystemHandlers == nil {
22 | t.Fatalf("should not have been nil")
23 | }
24 | if _, ok := srv.SubsystemHandlers["foo"]; !ok {
25 | t.Fatalf("should have set the foo subsystem handler")
26 | }
27 | }
28 |
29 | func TestWithBanner(t *testing.T) {
30 | const banner = "a banner"
31 | var got string
32 |
33 | srv := &ssh.Server{
34 | Handler: func(s ssh.Session) {},
35 | }
36 | requireNoError(t, WithBanner(banner)(srv))
37 |
38 | requireNoError(t, testsession.New(t, srv, &gossh.ClientConfig{
39 | BannerCallback: func(message string) error {
40 | got = message
41 | return nil
42 | },
43 | }).Run(""))
44 | requireEqual(t, banner, got)
45 | }
46 |
47 | func TestWithBannerHandler(t *testing.T) {
48 | var got string
49 |
50 | srv := &ssh.Server{
51 | Handler: func(s ssh.Session) {},
52 | }
53 | requireNoError(t, WithBannerHandler(func(ctx ssh.Context) string {
54 | return fmt.Sprintf("banner for %s", ctx.User())
55 | })(srv))
56 |
57 | requireNoError(t, testsession.New(t, srv, &gossh.ClientConfig{
58 | User: "fulano",
59 | BannerCallback: func(message string) error {
60 | got = message
61 | return nil
62 | },
63 | }).Run(""))
64 | requireEqual(t, "banner for fulano", got)
65 | }
66 |
67 | func TestWithIdleTimeout(t *testing.T) {
68 | s := ssh.Server{}
69 | requireNoError(t, WithIdleTimeout(time.Second)(&s))
70 | requireEqual(t, time.Second, s.IdleTimeout)
71 | }
72 |
73 | func TestWithMaxTimeout(t *testing.T) {
74 | s := ssh.Server{}
75 | requireNoError(t, WithMaxTimeout(time.Second)(&s))
76 | requireEqual(t, time.Second, s.MaxTimeout)
77 | }
78 |
79 | func TestIsAuthorized(t *testing.T) {
80 | t.Run("valid", func(t *testing.T) {
81 | requireEqual(t, true, isAuthorized("testdata/authorized_keys", func(k ssh.PublicKey) bool { return true }))
82 | })
83 |
84 | t.Run("invalid", func(t *testing.T) {
85 | requireEqual(t, false, isAuthorized("testdata/invalid_authorized_keys", func(k ssh.PublicKey) bool { return true }))
86 | })
87 |
88 | t.Run("file not found", func(t *testing.T) {
89 | requireEqual(t, false, isAuthorized("testdata/nope_authorized_keys", func(k ssh.PublicKey) bool { return true }))
90 | })
91 | }
92 |
93 | func TestWithAuthorizedKeys(t *testing.T) {
94 | t.Run("valid", func(t *testing.T) {
95 | s := ssh.Server{}
96 | requireNoError(t, WithAuthorizedKeys("testdata/authorized_keys")(&s))
97 |
98 | for key, authorize := range map[string]bool{
99 | `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMJlb/qf2B2kMNdBxfpCQqI2ctPcsOkdZGVh5zTRhKtH k3@test`: true,
100 | `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOhsthN+zSFSJF7V2HFSO4+2OJYRghuAA43CIbVyvzF8 k7@test`: false,
101 | } {
102 | parts := strings.Fields(key)
103 | t.Run(parts[len(parts)-1], func(t *testing.T) {
104 | key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
105 | requireNoError(t, err)
106 | requireEqual(t, authorize, s.PublicKeyHandler(nil, key))
107 | })
108 | }
109 | })
110 |
111 | t.Run("invalid", func(t *testing.T) {
112 | s := ssh.Server{}
113 | requireNoError(
114 | t,
115 | WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s),
116 | )
117 | })
118 |
119 | t.Run("file not found", func(t *testing.T) {
120 | s := ssh.Server{}
121 | if err := WithAuthorizedKeys("testdata/nope_authorized_keys")(&s); err == nil {
122 | t.Fatal("expected an error, got nil")
123 | }
124 | })
125 | }
126 |
127 | func TestWithTrustedUserCAKeys(t *testing.T) {
128 | setup := func(tb testing.TB, certPath string) (*ssh.Server, *gossh.ClientConfig) {
129 | tb.Helper()
130 | s := &ssh.Server{
131 | Handler: func(s ssh.Session) {
132 | cert, ok := s.PublicKey().(*gossh.Certificate)
133 | fmt.Fprintf(s, "cert? %v - principals: %v - type: %v", ok, cert.ValidPrincipals, cert.CertType)
134 | },
135 | }
136 | requireNoError(tb, WithTrustedUserCAKeys("testdata/ca.pub")(s))
137 |
138 | signer, err := gossh.ParsePrivateKey(getBytes(tb, "testdata/foo"))
139 | requireNoError(tb, err)
140 |
141 | cert, _, _, _, err := gossh.ParseAuthorizedKey(getBytes(tb, certPath))
142 | requireNoError(tb, err)
143 |
144 | certSigner, err := gossh.NewCertSigner(cert.(*gossh.Certificate), signer)
145 | requireNoError(tb, err)
146 | return s, &gossh.ClientConfig{
147 | User: "foo",
148 | Auth: []gossh.AuthMethod{
149 | gossh.PublicKeys(certSigner),
150 | },
151 | }
152 | }
153 |
154 | t.Run("invalid ca key", func(t *testing.T) {
155 | s := &ssh.Server{}
156 | if err := WithTrustedUserCAKeys("testdata/invalid-path")(s); err == nil {
157 | t.Fatal("expected an error, got nil")
158 | }
159 | })
160 |
161 | t.Run("valid", func(t *testing.T) {
162 | s, cc := setup(t, "testdata/valid-cert.pub")
163 | sess := testsession.New(t, s, cc)
164 | var b bytes.Buffer
165 | sess.Stdout = &b
166 | requireNoError(t, sess.Run(""))
167 | requireEqual(t, "cert? true - principals: [foo] - type: 1", b.String())
168 | })
169 |
170 | t.Run("valid wrong principal", func(t *testing.T) {
171 | s, cc := setup(t, "testdata/valid-cert.pub")
172 | cc.User = "not-foo"
173 | _, err := testsession.NewClientSession(t, testsession.Listen(t, s), cc)
174 | requireAuthError(t, err)
175 | })
176 |
177 | t.Run("expired", func(t *testing.T) {
178 | s, cc := setup(t, "testdata/expired-cert.pub")
179 | _, err := testsession.NewClientSession(t, testsession.Listen(t, s), cc)
180 | requireAuthError(t, err)
181 | })
182 |
183 | t.Run("signed by another ca", func(t *testing.T) {
184 | s, cc := setup(t, "testdata/another-ca-cert.pub")
185 | _, err := testsession.NewClientSession(t, testsession.Listen(t, s), cc)
186 | requireAuthError(t, err)
187 | })
188 |
189 | t.Run("not a cert", func(t *testing.T) {
190 | s := &ssh.Server{
191 | Handler: func(s ssh.Session) {
192 | fmt.Fprintln(s, "hello")
193 | },
194 | }
195 | requireNoError(t, WithTrustedUserCAKeys("testdata/ca.pub")(s))
196 |
197 | signer, err := gossh.ParsePrivateKey(getBytes(t, "testdata/foo"))
198 | requireNoError(t, err)
199 |
200 | _, err = testsession.NewClientSession(t, testsession.Listen(t, s), &gossh.ClientConfig{
201 | User: "foo",
202 | Auth: []gossh.AuthMethod{
203 | gossh.PublicKeys(signer),
204 | },
205 | })
206 | requireAuthError(t, err)
207 | })
208 | }
209 |
210 | func getBytes(tb testing.TB, path string) []byte {
211 | tb.Helper()
212 | bts, err := os.ReadFile(path)
213 | requireNoError(tb, err)
214 | return bts
215 | }
216 |
217 | func requireEqual(tb testing.TB, a, b interface{}) {
218 | tb.Helper()
219 | if a != b {
220 | tb.Fatalf("expected %v, got %v", a, b)
221 | }
222 | }
223 |
224 | func requireNoError(tb testing.TB, err error) {
225 | tb.Helper()
226 | if err != nil {
227 | tb.Fatalf("expected no error, got %v", err)
228 | }
229 | }
230 |
231 | func requireAuthError(tb testing.TB, err error) {
232 | if err == nil {
233 | tb.Fatal("required an error, got nil")
234 | }
235 | requireEqual(tb, "ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain", err.Error())
236 | }
237 |
--------------------------------------------------------------------------------
/ratelimiter/ratelimiter.go:
--------------------------------------------------------------------------------
1 | // Package ratelimiter provides basic rate limiting functionality as a with middleware.
2 | //
3 | // It limits the amount of connections a source can make in a specified amount of time.
4 | package ratelimiter
5 |
6 | import (
7 | "errors"
8 | "net"
9 |
10 | "github.com/charmbracelet/log"
11 | "github.com/charmbracelet/ssh"
12 | "github.com/charmbracelet/wish"
13 | lru "github.com/hashicorp/golang-lru/v2"
14 | "golang.org/x/time/rate"
15 | )
16 |
17 | // ErrRateLimitExceeded happens when the connection was denied due to the rate limit being exceeded.
18 | var ErrRateLimitExceeded = errors.New("rate limit exceeded, please try again later")
19 |
20 | // RateLimiter implementations should check if a given session is allowed to
21 | // proceed or not, returning an error if they aren't.
22 | // Its up to the implementation to handle what identifies an session as well
23 | // as the implementation details of these limits.
24 | type RateLimiter interface {
25 | Allow(s ssh.Session) error
26 | }
27 |
28 | // Middleware provides a new rate limiting Middleware.
29 | func Middleware(limiter RateLimiter) wish.Middleware {
30 | return func(sh ssh.Handler) ssh.Handler {
31 | return func(s ssh.Session) {
32 | if err := limiter.Allow(s); err != nil {
33 | wish.Fatal(s, err)
34 | return
35 | }
36 |
37 | sh(s)
38 | }
39 | }
40 | }
41 |
42 | // NewRateLimiter returns a new RateLimiter that allows events up to rate rate,
43 | // permits bursts of at most burst tokens and keeps a cache of maxEntries
44 | // limiters.
45 | //
46 | // Internally, it creates a LRU Cache of *rate.Limiter, in which the key is
47 | // the remote IP address.
48 | func NewRateLimiter(r rate.Limit, burst int, maxEntries int) RateLimiter {
49 | if maxEntries <= 0 {
50 | maxEntries = 1
51 | }
52 | // only possible error is if maxEntries is <= 0, which is prevented above.
53 | cache, _ := lru.New[string, *rate.Limiter](maxEntries)
54 | return &limiters{
55 | rate: r,
56 | burst: burst,
57 | cache: cache,
58 | }
59 | }
60 |
61 | type limiters struct {
62 | cache *lru.Cache[string, *rate.Limiter]
63 | rate rate.Limit
64 | burst int
65 | }
66 |
67 | func (r *limiters) Allow(s ssh.Session) error {
68 | var key string
69 | switch addr := s.RemoteAddr().(type) {
70 | case *net.TCPAddr:
71 | key = addr.IP.String()
72 | default:
73 | key = addr.String()
74 | }
75 |
76 | var allowed bool
77 | limiter, ok := r.cache.Get(key)
78 | if ok {
79 | allowed = limiter.Allow()
80 | } else {
81 | limiter := rate.NewLimiter(r.rate, r.burst)
82 | allowed = limiter.Allow()
83 | r.cache.Add(key, limiter)
84 | }
85 |
86 | log.Debug("rate limiter key", "key", key, "allowed", allowed)
87 | if allowed {
88 | return nil
89 | }
90 | return ErrRateLimitExceeded
91 | }
92 |
--------------------------------------------------------------------------------
/ratelimiter/ratelimiter_test.go:
--------------------------------------------------------------------------------
1 | package ratelimiter
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/charmbracelet/ssh"
8 | "github.com/charmbracelet/wish/testsession"
9 | "golang.org/x/sync/errgroup"
10 | "golang.org/x/time/rate"
11 | )
12 |
13 | func TestRateLimiterNoLimit(t *testing.T) {
14 | s := &ssh.Server{
15 | Handler: Middleware(NewRateLimiter(rate.Limit(0), 0, 5))(func(s ssh.Session) {
16 | s.Write([]byte("hello"))
17 | }),
18 | }
19 |
20 | sess := testsession.New(t, s, nil)
21 | if err := sess.Run(""); err == nil {
22 | t.Fatal("expected an error, got nil")
23 | }
24 | }
25 |
26 | func TestRateLimiterZeroedMaxEntried(t *testing.T) {
27 | s := &ssh.Server{
28 | Handler: Middleware(NewRateLimiter(rate.Limit(1), 1, 0))(func(s ssh.Session) {
29 | s.Write([]byte("hello"))
30 | }),
31 | }
32 |
33 | sess := testsession.New(t, s, nil)
34 | if err := sess.Run(""); err != nil {
35 | t.Fatalf("expected no error, got %v", err)
36 | }
37 | }
38 |
39 | func TestRateLimiter(t *testing.T) {
40 | s := &ssh.Server{
41 | Handler: Middleware(NewRateLimiter(rate.Limit(10), 4, 1))(func(s ssh.Session) {
42 | // noop
43 | }),
44 | }
45 |
46 | addr := testsession.Listen(t, s)
47 |
48 | g := errgroup.Group{}
49 | for i := 0; i < 10; i++ {
50 | g.Go(func() error {
51 | sess, err := testsession.NewClientSession(t, addr, nil)
52 | if err != nil {
53 | t.Fatalf("expected no errors, got %v", err)
54 | }
55 | if err := sess.Run(""); err != nil {
56 | return err
57 | }
58 | return nil
59 | })
60 | }
61 |
62 | if err := g.Wait(); err == nil {
63 | t.Fatal("expected error, got nil")
64 | }
65 |
66 | // after some time, it should reset and pass again
67 | time.Sleep(100 * time.Millisecond)
68 | sess, err := testsession.NewClientSession(t, addr, nil)
69 | if err != nil {
70 | t.Fatalf("expected no errors, got %v", err)
71 | }
72 | if err := sess.Run(""); err != nil {
73 | t.Fatalf("expected no errors, got %v", err)
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/recover/recover.go:
--------------------------------------------------------------------------------
1 | package recover
2 |
3 | import (
4 | "runtime/debug"
5 |
6 | "github.com/charmbracelet/log"
7 | "github.com/charmbracelet/ssh"
8 | "github.com/charmbracelet/wish"
9 | )
10 |
11 | // Middleware is a wish middleware that recovers from panics and log to stderr.
12 | func Middleware(mw ...wish.Middleware) wish.Middleware {
13 | return MiddlewareWithLogger(nil, mw...)
14 | }
15 |
16 | // Logger is the interface that wraps the basic Log method.
17 | type Logger interface {
18 | Printf(format string, v ...interface{})
19 | }
20 |
21 | // MiddlewareWithLogger is a wish middleware that recovers from panics and log to
22 | // the provided logger.
23 | func MiddlewareWithLogger(logger Logger, mw ...wish.Middleware) wish.Middleware {
24 | if logger == nil {
25 | logger = log.StandardLog()
26 | }
27 | h := func(ssh.Session) {}
28 | for _, m := range mw {
29 | h = m(h)
30 | }
31 | return func(sh ssh.Handler) ssh.Handler {
32 | return func(s ssh.Session) {
33 | func() {
34 | defer func() {
35 | if r := recover(); r != nil {
36 | logger.Printf(
37 | "panic: %v\n%s",
38 | r,
39 | string(debug.Stack()),
40 | )
41 | }
42 | }()
43 | h(s)
44 | }()
45 | sh(s)
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/recover/recover_test.go:
--------------------------------------------------------------------------------
1 | package recover
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/charmbracelet/ssh"
7 | "github.com/charmbracelet/wish/testsession"
8 | gossh "golang.org/x/crypto/ssh"
9 | )
10 |
11 | func TestMiddleware(t *testing.T) {
12 | t.Run("recover session", func(t *testing.T) {
13 | _, err := setup(t).Output("")
14 | requireNoError(t, err)
15 | })
16 | }
17 |
18 | func setup(tb testing.TB) *gossh.Session {
19 | tb.Helper()
20 | return testsession.New(tb, &ssh.Server{
21 | Handler: Middleware(func(h ssh.Handler) ssh.Handler {
22 | return func(s ssh.Session) {
23 | panic("hello")
24 | }
25 | })(func(s ssh.Session) {}),
26 | }, nil)
27 | }
28 |
29 | func requireNoError(t *testing.T, err error) {
30 | t.Helper()
31 |
32 | if err != nil {
33 | t.Fatalf("expected no error, got %q", err.Error())
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/scp/copy_from_client.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "bufio"
5 | "errors"
6 | "fmt"
7 | "io"
8 | "io/fs"
9 | "path/filepath"
10 | "regexp"
11 | "strconv"
12 | "strings"
13 |
14 | "github.com/charmbracelet/ssh"
15 | )
16 |
17 | var (
18 | reTimestamp = regexp.MustCompile(`^T(\d{10}) 0 (\d{10}) 0$`)
19 | reNewFolder = regexp.MustCompile(`^D(\d{4}) 0 (.*)$`)
20 | reNewFile = regexp.MustCompile(`^C(\d{4}) (\d+) (.*)$`)
21 | )
22 |
23 | type parseError struct {
24 | subject string
25 | }
26 |
27 | func (e parseError) Error() string {
28 | return fmt.Sprintf("failed to parse: %q", e.subject)
29 | }
30 |
31 | func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) error {
32 | // accepts the request
33 | _, _ = s.Write(NULL)
34 |
35 | var (
36 | path = info.Path
37 | r = bufio.NewReader(s)
38 | mtime int64
39 | atime int64
40 | )
41 |
42 | for {
43 | line, err := r.ReadString('\n')
44 | if err != nil {
45 | if errors.Is(err, io.EOF) {
46 | break
47 | }
48 | return fmt.Errorf("failed to read line: %w", err)
49 | }
50 | line = strings.TrimSuffix(line, "\n")
51 |
52 | if matches := reTimestamp.FindAllStringSubmatch(line, 2); matches != nil {
53 | mtime, err = strconv.ParseInt(matches[0][1], 10, 64)
54 | if err != nil {
55 | return parseError{line}
56 | }
57 | atime, err = strconv.ParseInt(matches[0][2], 10, 64)
58 | if err != nil {
59 | return parseError{line}
60 | }
61 |
62 | // accepts the header
63 | _, _ = s.Write(NULL)
64 | continue
65 | }
66 |
67 | if matches := reNewFile.FindAllStringSubmatch(line, 3); matches != nil {
68 | if len(matches) != 1 || len(matches[0]) != 4 {
69 | return parseError{line}
70 | }
71 |
72 | mode, err := strconv.ParseUint(matches[0][1], 8, 32)
73 | if err != nil {
74 | return parseError{line}
75 | }
76 |
77 | size, err := strconv.ParseInt(matches[0][2], 10, 64)
78 | if err != nil {
79 | return parseError{line}
80 | }
81 | name := matches[0][3]
82 |
83 | // accepts the header
84 | _, _ = s.Write(NULL)
85 |
86 | written, err := handler.Write(s, &FileEntry{
87 | Name: name,
88 | Filepath: filepath.Join(path, name),
89 | Mode: fs.FileMode(mode), //nolint:gosec
90 | Size: size,
91 | Mtime: mtime,
92 | Atime: atime,
93 | Reader: newLimitReader(r, int(size)),
94 | })
95 | if err != nil {
96 | return fmt.Errorf("failed to write file: %q: %w", name, err)
97 | }
98 | if written != size {
99 | return fmt.Errorf("failed to write the file: %q: written %d out of %d bytes", name, written, size)
100 | }
101 |
102 | // read the trailing nil char
103 | _, _ = r.ReadByte()
104 |
105 | mtime = 0
106 | atime = 0
107 | // says 'hey im done'
108 | _, _ = s.Write(NULL)
109 | continue
110 | }
111 |
112 | if matches := reNewFolder.FindAllStringSubmatch(line, 2); matches != nil {
113 | if len(matches) != 1 || len(matches[0]) != 3 {
114 | return parseError{line}
115 | }
116 |
117 | mode, err := strconv.ParseUint(matches[0][1], 8, 32)
118 | if err != nil {
119 | return parseError{line}
120 | }
121 | name := matches[0][2]
122 |
123 | path = filepath.Join(path, name)
124 | if err := handler.Mkdir(s, &DirEntry{
125 | Name: name,
126 | Filepath: path,
127 | Mode: fs.FileMode(mode), //nolint:gosec
128 | Mtime: mtime,
129 | Atime: atime,
130 | }); err != nil {
131 | return fmt.Errorf("failed to create dir: %q: %w", name, err)
132 | }
133 |
134 | mtime = 0
135 | atime = 0
136 | // says 'hey im done'
137 | _, _ = s.Write(NULL)
138 | continue
139 | }
140 |
141 | if line == "E" {
142 | path = filepath.Dir(path)
143 |
144 | // says 'hey im done'
145 | _, _ = s.Write(NULL)
146 | continue
147 | }
148 |
149 | return fmt.Errorf("unhandled input: %q", line)
150 | }
151 |
152 | _, _ = s.Write(NULL)
153 | return nil
154 | }
155 |
--------------------------------------------------------------------------------
/scp/copy_to_client.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "fmt"
5 | "io/fs"
6 |
7 | "github.com/charmbracelet/ssh"
8 | )
9 |
10 | func copyToClient(s ssh.Session, info Info, handler CopyToClientHandler) error {
11 | matches, err := handler.Glob(s, info.Path)
12 | if err != nil {
13 | return err
14 | }
15 | if len(matches) == 0 {
16 | return fmt.Errorf("no files matching %q", info.Path)
17 | }
18 |
19 | rootEntry := &RootEntry{}
20 | var closers []func() error
21 | defer func() {
22 | closeAll(closers)
23 | }()
24 |
25 | for _, match := range matches {
26 | if !info.Recursive {
27 | entry, closer, err := handler.NewFileEntry(s, match)
28 | closers = append(closers, closer)
29 | if err != nil {
30 | return err
31 | }
32 | rootEntry.Append(entry)
33 | continue
34 | }
35 |
36 | if err := handler.WalkDir(s, match, func(path string, d fs.DirEntry, err error) error {
37 | if err != nil {
38 | return err
39 | }
40 |
41 | if d.IsDir() {
42 | entry, err := handler.NewDirEntry(s, path)
43 | if err != nil {
44 | return err
45 | }
46 | rootEntry.Append(entry)
47 | } else {
48 | entry, closer, err := handler.NewFileEntry(s, path)
49 | if err != nil {
50 | return err
51 | }
52 | closers = append(closers, closer)
53 | rootEntry.Append(entry)
54 | }
55 |
56 | return nil
57 | }); err != nil {
58 | return err
59 | }
60 | }
61 |
62 | return rootEntry.Write(s)
63 | }
64 |
65 | func closeAll(closers []func() error) {
66 | for _, closer := range closers {
67 | if closer != nil {
68 | _ = closer()
69 | }
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/scp/filesystem.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "io/fs"
7 | "os"
8 | "path/filepath"
9 | "strings"
10 | "time"
11 |
12 | "github.com/charmbracelet/ssh"
13 | )
14 |
15 | // fileSystemHandler is a Handler implementation for a given root path.
16 | type fileSystemHandler struct{ root string }
17 |
18 | var _ Handler = &fileSystemHandler{}
19 |
20 | // NewFileSystemHandler return a Handler based on the given dir.
21 | func NewFileSystemHandler(root string) Handler {
22 | return &fileSystemHandler{
23 | root: filepath.Clean(root),
24 | }
25 | }
26 |
27 | func (h *fileSystemHandler) chtimes(path string, mtime, atime int64) error {
28 | if mtime == 0 || atime == 0 {
29 | return nil
30 | }
31 | if err := os.Chtimes(
32 | h.prefixed(path),
33 | time.Unix(atime, 0),
34 | time.Unix(mtime, 0),
35 | ); err != nil {
36 | return fmt.Errorf("failed to chtimes: %q: %w", path, err)
37 | }
38 | return nil
39 | }
40 |
41 | func (h *fileSystemHandler) prefixed(path string) string {
42 | path = filepath.Clean(path)
43 | if strings.HasPrefix(path, h.root) {
44 | return path
45 | }
46 | return filepath.Join(h.root, path)
47 | }
48 |
49 | func (h *fileSystemHandler) Glob(_ ssh.Session, s string) ([]string, error) {
50 | matches, err := filepath.Glob(h.prefixed(s))
51 | if err != nil {
52 | return []string{}, err
53 | }
54 |
55 | for i, match := range matches {
56 | matches[i], err = filepath.Rel(h.root, match)
57 | if err != nil {
58 | return []string{}, err
59 | }
60 | }
61 | return matches, nil
62 | }
63 |
64 | func (h *fileSystemHandler) WalkDir(_ ssh.Session, path string, fn fs.WalkDirFunc) error {
65 | return filepath.WalkDir(h.prefixed(path), func(path string, d fs.DirEntry, err error) error {
66 | // if h.root is ./foo/bar, we don't want to server `bar` as the root,
67 | // but instead its contents.
68 | if path == h.root {
69 | return err
70 | }
71 | return fn(path, d, err)
72 | })
73 | }
74 |
75 | func (h *fileSystemHandler) NewDirEntry(_ ssh.Session, name string) (*DirEntry, error) {
76 | path := h.prefixed(name)
77 | info, err := os.Stat(path)
78 | if err != nil {
79 | return nil, fmt.Errorf("failed to open dir: %q: %w", path, err)
80 | }
81 | return &DirEntry{
82 | Children: []Entry{},
83 | Name: info.Name(),
84 | Filepath: path,
85 | Mode: info.Mode(),
86 | Mtime: info.ModTime().Unix(),
87 | Atime: info.ModTime().Unix(),
88 | }, nil
89 | }
90 |
91 | func (h *fileSystemHandler) NewFileEntry(_ ssh.Session, name string) (*FileEntry, func() error, error) {
92 | path := h.prefixed(name)
93 | info, err := os.Stat(path)
94 | if err != nil {
95 | return nil, nil, fmt.Errorf("failed to stat %q: %w", path, err)
96 | }
97 | f, err := os.Open(path)
98 | if err != nil {
99 | return nil, nil, fmt.Errorf("failed to open %q: %w", path, err)
100 | }
101 | return &FileEntry{
102 | Name: info.Name(),
103 | Filepath: path,
104 | Mode: info.Mode(),
105 | Size: info.Size(),
106 | Mtime: info.ModTime().Unix(),
107 | Atime: info.ModTime().Unix(),
108 | Reader: f,
109 | }, f.Close, nil
110 | }
111 |
112 | func (h *fileSystemHandler) Mkdir(_ ssh.Session, entry *DirEntry) error {
113 | if err := os.Mkdir(h.prefixed(entry.Filepath), entry.Mode); err != nil {
114 | return fmt.Errorf("failed to create dir: %q: %w", entry.Filepath, err)
115 | }
116 | return h.chtimes(entry.Filepath, entry.Mtime, entry.Atime)
117 | }
118 |
119 | func (h *fileSystemHandler) Write(_ ssh.Session, entry *FileEntry) (int64, error) {
120 | f, err := os.OpenFile(h.prefixed(entry.Filepath), os.O_TRUNC|os.O_RDWR|os.O_CREATE, entry.Mode)
121 | if err != nil {
122 | return 0, fmt.Errorf("failed to open file: %q: %w", entry.Filepath, err)
123 | }
124 | defer f.Close() //nolint:errcheck
125 | written, err := io.Copy(f, entry.Reader)
126 | if err != nil {
127 | return 0, fmt.Errorf("failed to write file: %q: %w", entry.Filepath, err)
128 | }
129 | if err := f.Close(); err != nil {
130 | return 0, fmt.Errorf("failed to close file: %q: %w", entry.Filepath, err)
131 | }
132 | return written, h.chtimes(entry.Filepath, entry.Mtime, entry.Atime)
133 | }
134 |
--------------------------------------------------------------------------------
/scp/filesystem_test.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io/fs"
7 | "os"
8 | "path/filepath"
9 | "testing"
10 | "testing/iotest"
11 | "time"
12 |
13 | "github.com/matryer/is"
14 | )
15 |
16 | func TestFilesystem(t *testing.T) {
17 | mtime := time.Unix(1323853868, 0)
18 | atime := time.Unix(1380425711, 0)
19 |
20 | t.Run("scp -f", func(t *testing.T) {
21 | t.Run("file", func(t *testing.T) {
22 | is := is.New(t)
23 |
24 | dir := t.TempDir()
25 | h := NewFileSystemHandler(dir)
26 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644))
27 | chtimesTree(t, dir, atime, mtime)
28 |
29 | session := setup(t, h, nil)
30 | bts, err := session.CombinedOutput("scp -f a.txt")
31 | is.NoErr(err)
32 | requireEqualGolden(t, bts)
33 | })
34 |
35 | t.Run("glob", func(t *testing.T) {
36 | is := is.New(t)
37 |
38 | dir := t.TempDir()
39 | h := NewFileSystemHandler(dir)
40 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644))
41 | is.NoErr(os.WriteFile(filepath.Join(dir, "b.txt"), []byte("another text file"), 0o644))
42 | chtimesTree(t, dir, atime, mtime)
43 |
44 | session := setup(t, h, nil)
45 | bts, err := session.CombinedOutput("scp -f *.txt")
46 | is.NoErr(err)
47 | requireEqualGolden(t, bts)
48 | })
49 |
50 | t.Run("invalid file", func(t *testing.T) {
51 | is := is.New(t)
52 |
53 | dir := t.TempDir()
54 | h := NewFileSystemHandler(dir)
55 |
56 | session := setup(t, h, nil)
57 | _, err := session.CombinedOutput("scp -f a.txt")
58 | is.True(err != nil)
59 | })
60 |
61 | t.Run("recursive", func(t *testing.T) {
62 | is := is.New(t)
63 |
64 | dir := t.TempDir()
65 | h := NewFileSystemHandler(dir)
66 |
67 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755))
68 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644))
69 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644))
70 | chtimesTree(t, dir, atime, mtime)
71 |
72 | session := setup(t, h, nil)
73 | bts, err := session.CombinedOutput("scp -r -f a")
74 | is.NoErr(err)
75 | requireEqualGolden(t, bts)
76 | })
77 |
78 | t.Run("recursive glob", func(t *testing.T) {
79 | is := is.New(t)
80 |
81 | dir := t.TempDir()
82 | h := NewFileSystemHandler(dir)
83 |
84 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755))
85 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644))
86 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644))
87 | chtimesTree(t, dir, atime, mtime)
88 |
89 | session := setup(t, h, nil)
90 | bts, err := session.CombinedOutput("scp -r -f a/*")
91 | is.NoErr(err)
92 | requireEqualGolden(t, bts)
93 | })
94 |
95 | t.Run("recursive invalid file", func(t *testing.T) {
96 | is := is.New(t)
97 |
98 | dir := t.TempDir()
99 | h := NewFileSystemHandler(dir)
100 |
101 | session := setup(t, h, nil)
102 | _, err := session.CombinedOutput("scp -r -f a")
103 | is.True(err != nil)
104 | })
105 |
106 | t.Run("recursive folder", func(t *testing.T) {
107 | is := is.New(t)
108 |
109 | dir := t.TempDir()
110 | h := NewFileSystemHandler(dir)
111 |
112 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755))
113 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644))
114 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644))
115 | chtimesTree(t, dir, atime, mtime)
116 |
117 | session := setup(t, h, nil)
118 | bts, err := session.CombinedOutput("scp -r -f /")
119 | is.NoErr(err)
120 | requireEqualGolden(t, bts)
121 | })
122 | })
123 |
124 | t.Run("scp -t", func(t *testing.T) {
125 | t.Run("file", func(t *testing.T) {
126 | is := is.New(t)
127 | dir := t.TempDir()
128 | h := NewFileSystemHandler(dir)
129 | session := setup(t, nil, h)
130 |
131 | var in bytes.Buffer
132 | in.WriteString("T1183832947 0 1183833773 0\n")
133 | in.WriteString("C0644 6 a.txt\n")
134 | in.WriteString("hello\n")
135 | in.Write(NULL)
136 | session.Stdin = &in
137 |
138 | _, err := session.CombinedOutput("scp -t .")
139 | is.NoErr(err)
140 |
141 | bts, err := os.ReadFile(filepath.Join(dir, "a.txt"))
142 | is.NoErr(err)
143 | is.Equal("hello\n", string(bts))
144 | })
145 |
146 | t.Run("recursive", func(t *testing.T) {
147 | is := is.New(t)
148 | dir := t.TempDir()
149 | h := NewFileSystemHandler(dir)
150 |
151 | var in bytes.Buffer
152 | in.WriteString("T1183832947 0 1183833773 0\n")
153 | in.WriteString("D0755 0 folder1\n")
154 | in.WriteString("C0644 6 file1\n")
155 | in.WriteString("hello\n")
156 | in.Write(NULL)
157 | in.WriteString("D0755 0 folder2\n")
158 | in.WriteString("T1183832947 0 1183833773 0\n")
159 | in.WriteString("C0644 6 file2\n")
160 | in.WriteString("hello\n")
161 | in.Write(NULL)
162 | in.WriteString("E\n")
163 | in.WriteString("E\n")
164 |
165 | session := setup(t, nil, h)
166 | session.Stdin = &in
167 | _, err := session.CombinedOutput("scp -r -t .")
168 | is.NoErr(err)
169 |
170 | mtime := int64(1183832947)
171 |
172 | stat, err := os.Stat(filepath.Join(dir, "folder1"))
173 | is.NoErr(err)
174 | is.True(stat.IsDir())
175 | // TODO: check how scp behaves
176 | is.True(stat.ModTime().Unix() != mtime) // should be different because the folder was later modified again
177 |
178 | stat, err = os.Stat(filepath.Join(dir, "folder1/file1"))
179 | is.NoErr(err)
180 | is.True(stat.ModTime().Unix() != mtime)
181 |
182 | stat, err = os.Stat(filepath.Join(dir, "folder1/folder2/file2"))
183 | is.NoErr(err)
184 | is.Equal(stat.ModTime().Unix(), mtime)
185 | })
186 | })
187 |
188 | t.Run("errors", func(t *testing.T) {
189 | t.Run("chtimes", func(t *testing.T) {
190 | h := &fileSystemHandler{t.TempDir()}
191 | is.New(t).True(h.chtimes("nope", 1212212, 323232) != nil) // should err
192 | })
193 |
194 | t.Run("glob", func(t *testing.T) {
195 | t.Run("invalid glob", func(t *testing.T) {
196 | is := is.New(t)
197 | h := &fileSystemHandler{t.TempDir()}
198 | matches, err := h.Glob(nil, "[asda")
199 | is.True(err != nil) // should err
200 | is.Equal([]string{}, matches)
201 | })
202 | })
203 |
204 | t.Run("NewDirEntry", func(t *testing.T) {
205 | t.Run("do not exist", func(t *testing.T) {
206 | is := is.New(t)
207 | h := &fileSystemHandler{t.TempDir()}
208 | _, err := h.NewDirEntry(nil, "foo")
209 | is.True(err != nil) // should err
210 | })
211 | })
212 |
213 | t.Run("NewFileEntry", func(t *testing.T) {
214 | t.Run("do not exist", func(t *testing.T) {
215 | is := is.New(t)
216 | h := &fileSystemHandler{t.TempDir()}
217 | _, _, err := h.NewFileEntry(nil, "foo")
218 | is.True(err != nil) // should err
219 | })
220 | })
221 |
222 | t.Run("Mkdir", func(t *testing.T) {
223 | t.Run("parent do not exist", func(t *testing.T) {
224 | is := is.New(t)
225 | h := &fileSystemHandler{t.TempDir()}
226 | err := h.Mkdir(nil, &DirEntry{
227 | Name: "foo",
228 | Filepath: "foo/bar/baz",
229 | Mode: 0o755,
230 | })
231 | is.True(err != nil) // should err
232 | })
233 | })
234 |
235 | t.Run("Write", func(t *testing.T) {
236 | t.Run("parent do not exist", func(t *testing.T) {
237 | is := is.New(t)
238 | h := &fileSystemHandler{t.TempDir()}
239 | _, err := h.Write(nil, &FileEntry{
240 | Name: "foo.txt",
241 | Filepath: "baz/foo.txt",
242 | Mode: 0o644,
243 | Size: 10,
244 | })
245 | is.True(err != nil) // should err
246 | })
247 |
248 | t.Run("reader fails", func(t *testing.T) {
249 | is := is.New(t)
250 | h := &fileSystemHandler{t.TempDir()}
251 | _, err := h.Write(nil, &FileEntry{
252 | Name: "foo.txt",
253 | Filepath: "foo.txt",
254 | Mode: 0o644,
255 | Size: 10,
256 | Reader: iotest.ErrReader(fmt.Errorf("fake err")),
257 | })
258 | is.True(err != nil) // should err
259 | })
260 | })
261 | })
262 | }
263 |
264 | func chtimesTree(tb testing.TB, dir string, atime, mtime time.Time) {
265 | is.New(tb).NoErr(filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error {
266 | if err != nil {
267 | return err
268 | }
269 | return os.Chtimes(path, atime, mtime)
270 | }))
271 | }
272 |
--------------------------------------------------------------------------------
/scp/fs.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "fmt"
5 | "io/fs"
6 |
7 | "github.com/charmbracelet/ssh"
8 | )
9 |
10 | type fsHandler struct{ fsys fs.FS }
11 |
12 | var _ CopyToClientHandler = &fsHandler{}
13 |
14 | // NewFSReadHandler returns a read-only CopyToClientHandler that accepts any
15 | // fs.FS as input.
16 | func NewFSReadHandler(fsys fs.FS) CopyToClientHandler {
17 | return &fsHandler{fsys: fsys}
18 | }
19 |
20 | func (h *fsHandler) Glob(_ ssh.Session, s string) ([]string, error) {
21 | return fs.Glob(h.fsys, s)
22 | }
23 |
24 | func (h *fsHandler) WalkDir(_ ssh.Session, path string, fn fs.WalkDirFunc) error {
25 | return fs.WalkDir(h.fsys, path, fn)
26 | }
27 |
28 | func (h *fsHandler) NewDirEntry(_ ssh.Session, path string) (*DirEntry, error) {
29 | path = normalizePath(path)
30 | info, err := fs.Stat(h.fsys, path)
31 | if err != nil {
32 | return nil, fmt.Errorf("failed to open dir: %q: %w", path, err)
33 | }
34 | return &DirEntry{
35 | Children: []Entry{},
36 | Name: info.Name(),
37 | Filepath: path,
38 | Mode: info.Mode(),
39 | Mtime: info.ModTime().Unix(),
40 | Atime: info.ModTime().Unix(),
41 | }, nil
42 | }
43 |
44 | func (h *fsHandler) NewFileEntry(_ ssh.Session, path string) (*FileEntry, func() error, error) {
45 | info, err := fs.Stat(h.fsys, path)
46 | if err != nil {
47 | return nil, nil, fmt.Errorf("failed to stat %q: %w", path, err)
48 | }
49 | f, err := h.fsys.Open(path)
50 | if err != nil {
51 | return nil, nil, fmt.Errorf("failed to open %q: %w", path, err)
52 | }
53 | return &FileEntry{
54 | Name: info.Name(),
55 | Filepath: path,
56 | Mode: info.Mode(),
57 | Size: info.Size(),
58 | Mtime: info.ModTime().Unix(),
59 | Atime: info.ModTime().Unix(),
60 | Reader: f,
61 | }, f.Close, nil
62 | }
63 |
--------------------------------------------------------------------------------
/scp/fs_test.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "os"
5 | "path/filepath"
6 | "testing"
7 | "time"
8 |
9 | "github.com/matryer/is"
10 | )
11 |
12 | func TestFS(t *testing.T) {
13 | mtime := time.Unix(1323853868, 0)
14 | atime := time.Unix(1380425711, 0)
15 |
16 | t.Run("file", func(t *testing.T) {
17 | is := is.New(t)
18 |
19 | dir := t.TempDir()
20 | h := NewFSReadHandler(os.DirFS(dir))
21 |
22 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644))
23 | chtimesTree(t, dir, atime, mtime)
24 |
25 | session := setup(t, h, nil)
26 | bts, err := session.CombinedOutput("scp -f a.txt")
27 | is.NoErr(err)
28 | requireEqualGolden(t, bts)
29 | })
30 |
31 | t.Run("glob", func(t *testing.T) {
32 | is := is.New(t)
33 |
34 | dir := t.TempDir()
35 | h := NewFSReadHandler(os.DirFS(dir))
36 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644))
37 | is.NoErr(os.WriteFile(filepath.Join(dir, "b.txt"), []byte("another text file"), 0o644))
38 | chtimesTree(t, dir, atime, mtime)
39 |
40 | session := setup(t, h, nil)
41 | bts, err := session.CombinedOutput("scp -f *.txt")
42 | is.NoErr(err)
43 | requireEqualGolden(t, bts)
44 | })
45 |
46 | t.Run("invalid file", func(t *testing.T) {
47 | is := is.New(t)
48 |
49 | dir := t.TempDir()
50 | h := NewFSReadHandler(os.DirFS(dir))
51 |
52 | session := setup(t, h, nil)
53 | _, err := session.CombinedOutput("scp -f a.txt")
54 | is.True(err != nil)
55 | })
56 |
57 | t.Run("recursive", func(t *testing.T) {
58 | is := is.New(t)
59 |
60 | dir := t.TempDir()
61 | h := NewFSReadHandler(os.DirFS(dir))
62 |
63 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755))
64 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644))
65 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644))
66 | chtimesTree(t, dir, atime, mtime)
67 |
68 | session := setup(t, h, nil)
69 | bts, err := session.CombinedOutput("scp -r -f a")
70 | is.NoErr(err)
71 | requireEqualGolden(t, bts)
72 | })
73 |
74 | t.Run("recursive glob", func(t *testing.T) {
75 | is := is.New(t)
76 |
77 | dir := t.TempDir()
78 | h := NewFSReadHandler(os.DirFS(dir))
79 |
80 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755))
81 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644))
82 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644))
83 | chtimesTree(t, dir, atime, mtime)
84 |
85 | session := setup(t, h, nil)
86 | bts, err := session.CombinedOutput("scp -r -f a/*")
87 | is.NoErr(err)
88 | requireEqualGolden(t, bts)
89 | })
90 |
91 | t.Run("recursive folder", func(t *testing.T) {
92 | is := is.New(t)
93 |
94 | dir := t.TempDir()
95 | h := NewFileSystemHandler(dir)
96 |
97 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755))
98 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644))
99 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644))
100 | chtimesTree(t, dir, atime, mtime)
101 |
102 | session := setup(t, h, nil)
103 | bts, err := session.CombinedOutput("scp -r -f /")
104 | is.NoErr(err)
105 | requireEqualGolden(t, bts)
106 | })
107 |
108 | t.Run("recursive invalid file", func(t *testing.T) {
109 | is := is.New(t)
110 |
111 | dir := t.TempDir()
112 | h := NewFSReadHandler(os.DirFS(dir))
113 |
114 | session := setup(t, h, nil)
115 | _, err := session.CombinedOutput("scp -r -f a")
116 | is.True(err != nil)
117 | })
118 |
119 | t.Run("errors", func(t *testing.T) {
120 | t.Run("glob", func(t *testing.T) {
121 | t.Run("invalid glob", func(t *testing.T) {
122 | is := is.New(t)
123 | h := &fsHandler{os.DirFS(t.TempDir())}
124 | matches, err := h.Glob(nil, "[asda")
125 | is.True(err != nil) // should err
126 | is.Equal(nil, matches)
127 | })
128 | })
129 |
130 | t.Run("NewDirEntry", func(t *testing.T) {
131 | t.Run("do not exist", func(t *testing.T) {
132 | is := is.New(t)
133 | h := &fsHandler{os.DirFS(t.TempDir())}
134 | _, err := h.NewDirEntry(nil, "foo")
135 | is.True(err != nil) // should err
136 | })
137 | })
138 |
139 | t.Run("NewFileEntry", func(t *testing.T) {
140 | t.Run("do not exist", func(t *testing.T) {
141 | is := is.New(t)
142 | h := &fsHandler{os.DirFS(t.TempDir())}
143 | _, _, err := h.NewFileEntry(nil, "foo")
144 | is.True(err != nil) // should err
145 | })
146 | })
147 | })
148 | }
149 |
--------------------------------------------------------------------------------
/scp/limit_reader.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "io"
5 | "sync"
6 | )
7 |
8 | func newLimitReader(r io.Reader, limit int) io.Reader {
9 | return &limitReader{
10 | r: r,
11 | left: limit,
12 | }
13 | }
14 |
15 | type limitReader struct {
16 | r io.Reader
17 |
18 | lock sync.Mutex
19 | left int
20 | }
21 |
22 | func (r *limitReader) Read(b []byte) (int, error) {
23 | r.lock.Lock()
24 | defer r.lock.Unlock()
25 |
26 | if r.left <= 0 {
27 | return 0, io.EOF
28 | }
29 | if len(b) > r.left {
30 | b = b[0:r.left]
31 | }
32 | n, err := r.r.Read(b)
33 | r.left -= n
34 | return n, err
35 | }
36 |
--------------------------------------------------------------------------------
/scp/limit_reader_test.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "testing"
7 |
8 | "github.com/matryer/is"
9 | )
10 |
11 | func TestLimitedReader(t *testing.T) {
12 | t.Run("partial", func(t *testing.T) {
13 | is := is.New(t)
14 | var b bytes.Buffer
15 | b.WriteString("writing some bytes")
16 | r := newLimitReader(&b, 7)
17 |
18 | bts, err := io.ReadAll(r)
19 | is.NoErr(err)
20 | is.Equal("writing", string(bts))
21 | })
22 |
23 | t.Run("full", func(t *testing.T) {
24 | is := is.New(t)
25 | var b bytes.Buffer
26 | b.WriteString("some text")
27 | r := newLimitReader(&b, b.Len())
28 |
29 | bts, err := io.ReadAll(r)
30 | is.NoErr(err)
31 | is.Equal("some text", string(bts))
32 | })
33 |
34 | t.Run("pass limit", func(t *testing.T) {
35 | is := is.New(t)
36 | var b bytes.Buffer
37 | b.WriteString("another text")
38 | r := newLimitReader(&b, b.Len()+10)
39 |
40 | bts, err := io.ReadAll(r)
41 | is.NoErr(err)
42 | is.Equal("another text", string(bts))
43 | })
44 | }
45 |
--------------------------------------------------------------------------------
/scp/scp.go:
--------------------------------------------------------------------------------
1 | // Package scp provides a SCP middleware for wish.
2 | package scp
3 |
4 | import (
5 | "fmt"
6 | "io"
7 | "io/fs"
8 | "path/filepath"
9 | "runtime"
10 | "strconv"
11 | "strings"
12 |
13 | "github.com/charmbracelet/ssh"
14 | "github.com/charmbracelet/wish"
15 | )
16 |
17 | // CopyToClientHandler is a handler that can be implemented to handle files
18 | // being copied from the server to the client.
19 | type CopyToClientHandler interface {
20 | // Glob should be implemented if you want to provide server-side globbing
21 | // support.
22 | //
23 | // A minimal implementation to disable it is to return `[]string{s}, nil`.
24 | //
25 | // Note: if your other functions expect a relative path, make sure that
26 | // your Glob implementation returns relative paths as well.
27 | Glob(ssh.Session, string) ([]string, error)
28 |
29 | // WalkDir must be implemented if you want to allow recursive copies.
30 | WalkDir(ssh.Session, string, fs.WalkDirFunc) error
31 |
32 | // NewDirEntry should provide a *DirEntry for the given path.
33 | NewDirEntry(ssh.Session, string) (*DirEntry, error)
34 |
35 | // NewFileEntry should provide a *FileEntry for the given path.
36 | // Users may also provide a closing function.
37 | NewFileEntry(ssh.Session, string) (*FileEntry, func() error, error)
38 | }
39 |
40 | // CopyFromClientHandler is a handler that can be implemented to handle files
41 | // being copied from the client to the server.
42 | type CopyFromClientHandler interface {
43 | // Mkdir should created the given dir.
44 | // Note that this usually shouldn't use os.MkdirAll and the like.
45 | Mkdir(ssh.Session, *DirEntry) error
46 |
47 | // Write should write the given file.
48 | Write(ssh.Session, *FileEntry) (int64, error)
49 | }
50 |
51 | // Handler is a interface that can be implemented to handle both SCP
52 | // directions.
53 | type Handler interface {
54 | CopyFromClientHandler
55 | CopyToClientHandler
56 | }
57 |
58 | // Middleware provides a wish middleware using the given CopyToClientHandler
59 | // and CopyFromClientHandler.
60 | func Middleware(rh CopyToClientHandler, wh CopyFromClientHandler) wish.Middleware {
61 | return func(sh ssh.Handler) ssh.Handler {
62 | return func(s ssh.Session) {
63 | info := GetInfo(s.Command())
64 | if !info.Ok {
65 | sh(s)
66 | return
67 | }
68 |
69 | var err error
70 | switch info.Op {
71 | case OpCopyToClient:
72 | if rh == nil {
73 | err = fmt.Errorf("no handler provided for scp -f")
74 | break
75 | }
76 | err = copyToClient(s, info, rh)
77 | case OpCopyFromClient:
78 | if wh == nil {
79 | err = fmt.Errorf("no handler provided for scp -t")
80 | break
81 | }
82 | err = copyFromClient(s, info, wh)
83 | }
84 | if err != nil {
85 | wish.Fatal(s, err)
86 | return
87 | }
88 | }
89 | }
90 | }
91 |
92 | // NULL is an array with a single NULL byte.
93 | var NULL = []byte{'\x00'}
94 |
95 | // Entry defines something that knows how to write itself and its path.
96 | type Entry interface {
97 | // Write the current entry in SCP format.
98 | Write(io.Writer) error
99 |
100 | path() string
101 | }
102 |
103 | // AppendableEntry defines a special kind of Entry, which can contain
104 | // children.
105 | type AppendableEntry interface {
106 | // Write the current entry in SCP format.
107 | Write(io.Writer) error
108 |
109 | // Append another entry to the current entry.
110 | Append(entry Entry)
111 | }
112 |
113 | // FileEntry is an Entry that reads from a Reader, defining a file and
114 | // its contents.
115 | type FileEntry struct {
116 | Name string
117 | Filepath string
118 | Mode fs.FileMode
119 | Size int64
120 | Reader io.Reader
121 | Atime int64
122 | Mtime int64
123 | }
124 |
125 | func (e *FileEntry) path() string { return e.Filepath }
126 |
127 | // Write a file to the given writer.
128 | func (e *FileEntry) Write(w io.Writer) error {
129 | if e.Mtime > 0 && e.Atime > 0 {
130 | if _, err := fmt.Fprintf(w, "T%d 0 %d 0\n", e.Mtime, e.Atime); err != nil {
131 | return fmt.Errorf("failed to write file: %q: %w", e.Filepath, err)
132 | }
133 | }
134 | if _, err := fmt.Fprintf(w, "C%s %d %s\n", octalPerms(e.Mode), e.Size, e.Name); err != nil {
135 | return fmt.Errorf("failed to write file: %q: %w", e.Filepath, err)
136 | }
137 |
138 | if _, err := io.Copy(w, e.Reader); err != nil {
139 | return fmt.Errorf("failed to read file: %q: %w", e.Filepath, err)
140 | }
141 |
142 | if _, err := w.Write(NULL); err != nil {
143 | return fmt.Errorf("failed to write file: %q: %w", e.Filepath, err)
144 | }
145 | return nil
146 | }
147 |
148 | // RootEntry is a root entry that can only have children.
149 | type RootEntry []Entry
150 |
151 | // Appennd the given entry to a child directory, or the the itself if
152 | // none matches.
153 | func (e *RootEntry) Append(entry Entry) {
154 | parent := normalizePath(filepath.Dir(entry.path()))
155 |
156 | for _, child := range *e {
157 | switch dir := child.(type) {
158 | case *DirEntry:
159 | if child.path() == parent {
160 | dir.Children = append(dir.Children, entry)
161 | return
162 | }
163 | if strings.HasPrefix(parent, normalizePath(dir.Filepath)) {
164 | dir.Append(entry)
165 | return
166 | }
167 | default:
168 | continue
169 | }
170 | }
171 |
172 | *e = append(*e, entry)
173 | }
174 |
175 | // Write recursively writes all the children to the given writer.
176 | func (e *RootEntry) Write(w io.Writer) error {
177 | for _, child := range *e {
178 | if err := child.Write(w); err != nil {
179 | return err
180 | }
181 | }
182 | return nil
183 | }
184 |
185 | // DirEntry is an Entry with mode, possibly children, and possibly a
186 | // parent.
187 | type DirEntry struct {
188 | Children []Entry
189 | Name string
190 | Filepath string
191 | Mode fs.FileMode
192 | Atime int64
193 | Mtime int64
194 | }
195 |
196 | func (e *DirEntry) path() string { return e.Filepath }
197 |
198 | // Write the current dir entry, all its contents (recursively), and the
199 | // dir closing to the given writer.
200 | func (e *DirEntry) Write(w io.Writer) error {
201 | if e.Mtime > 0 && e.Atime > 0 {
202 | if _, err := fmt.Fprintf(w, "T%d 0 %d 0\n", e.Mtime, e.Atime); err != nil {
203 | return fmt.Errorf("failed to write dir: %q: %w", e.Filepath, err)
204 | }
205 | }
206 | if _, err := fmt.Fprintf(w, "D%s 0 %s\n", octalPerms(e.Mode), e.Name); err != nil {
207 | return fmt.Errorf("failed to write dir: %q: %w", e.Filepath, err)
208 | }
209 |
210 | for _, child := range e.Children {
211 | if err := child.Write(w); err != nil {
212 | return err
213 | }
214 | }
215 |
216 | if _, err := fmt.Fprint(w, "E\n"); err != nil {
217 | return fmt.Errorf("failed to write dir: %q: %w", e.Filepath, err)
218 | }
219 | return nil
220 | }
221 |
222 | // Appends an entry to the folder or their children.
223 | func (e *DirEntry) Append(entry Entry) {
224 | parent := normalizePath(filepath.Dir(entry.path()))
225 |
226 | for _, child := range e.Children {
227 | switch dir := child.(type) {
228 | case *DirEntry:
229 | if child.path() == parent {
230 | dir.Children = append(dir.Children, entry)
231 | return
232 | }
233 | if strings.HasPrefix(parent, normalizePath(dir.path())) {
234 | dir.Append(entry)
235 | return
236 | }
237 | default:
238 | continue
239 | }
240 | }
241 |
242 | e.Children = append(e.Children, entry)
243 | }
244 |
245 | // Op defines which kind of SCP Operation is going on.
246 | type Op byte
247 |
248 | const (
249 | // OpCopyToClient is when a file is being copied from the server to the client.
250 | OpCopyToClient Op = 'f'
251 |
252 | // OpCopyFromClient is when a file is being copied from the client into the server.
253 | OpCopyFromClient Op = 't'
254 | )
255 |
256 | // Info provides some information about the current SCP Operation.
257 | type Info struct {
258 | // Ok is true if the current session is a SCP.
259 | Ok bool
260 |
261 | // Recursive is true if its a recursive SCP.
262 | Recursive bool
263 |
264 | // Path is the server path of the scp operation.
265 | Path string
266 |
267 | // Op is the SCP operation kind.
268 | Op Op
269 | }
270 |
271 | // GetInfo return information about the given command.
272 | func GetInfo(cmd []string) Info {
273 | info := Info{}
274 | if len(cmd) == 0 || cmd[0] != "scp" {
275 | return info
276 | }
277 |
278 | for i, p := range cmd {
279 | switch p {
280 | case "-r":
281 | info.Recursive = true
282 | case "-f":
283 | info.Op = OpCopyToClient
284 | info.Path = cmd[i+1]
285 | case "-t":
286 | info.Op = OpCopyFromClient
287 | info.Path = cmd[i+1]
288 | }
289 | }
290 |
291 | info.Ok = true
292 | return info
293 | }
294 |
295 | func octalPerms(info fs.FileMode) string {
296 | return "0" + strconv.FormatUint(uint64(info.Perm()), 8)
297 | }
298 |
299 | func normalizePath(p string) string {
300 | p = filepath.Clean(p)
301 | if runtime.GOOS == "windows" {
302 | return strings.ReplaceAll(p, "\\", "/")
303 | }
304 | return p
305 | }
306 |
--------------------------------------------------------------------------------
/scp/scp_test.go:
--------------------------------------------------------------------------------
1 | package scp
2 |
3 | import (
4 | "bytes"
5 | "os"
6 | "path/filepath"
7 | "runtime"
8 | "testing"
9 |
10 | "github.com/charmbracelet/ssh"
11 | "github.com/charmbracelet/wish/testsession"
12 | "github.com/google/go-cmp/cmp"
13 | "github.com/matryer/is"
14 | gossh "golang.org/x/crypto/ssh"
15 | )
16 |
17 | func TestGetInfo(t *testing.T) {
18 | t.Run("no exec", func(t *testing.T) {
19 | is := is.New(t)
20 | info := GetInfo([]string{})
21 | is.Equal(info.Ok, false)
22 | })
23 |
24 | t.Run("exec is not scp", func(t *testing.T) {
25 | is := is.New(t)
26 | info := GetInfo([]string{"not-scp"})
27 | is.Equal(info.Ok, false)
28 | })
29 |
30 | t.Run("scp no recursive", func(t *testing.T) {
31 | is := is.New(t)
32 | info := GetInfo([]string{"scp", "-f", "file"})
33 | is.True(info.Ok)
34 | is.Equal(info.Recursive, false)
35 | is.Equal("file", info.Path)
36 | is.Equal(info.Op, OpCopyToClient)
37 | })
38 |
39 | t.Run("scp recursive", func(t *testing.T) {
40 | is := is.New(t)
41 | info := GetInfo([]string{"scp", "-r", "--some-ignored-flag", "-f", "file", "ignored-arg"})
42 | is.True(info.Ok)
43 | is.True(info.Recursive)
44 | is.Equal("file", info.Path)
45 | is.Equal(info.Op, OpCopyToClient)
46 | })
47 |
48 | t.Run("scp op copy from client", func(t *testing.T) {
49 | is := is.New(t)
50 | info := GetInfo([]string{"scp", "-t", "file"})
51 | is.True(info.Ok)
52 | is.Equal(info.Op, OpCopyFromClient)
53 | is.Equal("file", info.Path)
54 | })
55 | }
56 |
57 | func TestNoDirRootEntry(t *testing.T) {
58 | is := is.New(t)
59 | root := RootEntry{}
60 |
61 | var f1 bytes.Buffer
62 | f1.WriteString("hello from file f1\n")
63 |
64 | var f2 bytes.Buffer
65 | f2.WriteString("hello from file f2\nwith multiple lines :)\n")
66 |
67 | dir := &DirEntry{
68 | Children: []Entry{},
69 | Name: "dir1",
70 | Filepath: "dir1",
71 | Mode: 0o755,
72 | }
73 |
74 | dir.Append(&FileEntry{
75 | Name: "f2",
76 | Filepath: "f2",
77 | Mode: 0o600,
78 | Size: int64(f2.Len()),
79 | Reader: &f2,
80 | })
81 |
82 | root.Append(&FileEntry{
83 | Name: "f1",
84 | Filepath: "f1",
85 | Mode: 0o644,
86 | Size: int64(f1.Len()),
87 | Reader: &f1,
88 | })
89 |
90 | root.Append(dir)
91 |
92 | var out bytes.Buffer
93 | is.NoErr(root.Write(&out))
94 |
95 | requireEqualGolden(t, out.Bytes())
96 | }
97 |
98 | func TestInvalidOps(t *testing.T) {
99 | t.Run("not scp", func(t *testing.T) {
100 | _, err := setup(t, nil, nil).CombinedOutput("not-scp ign")
101 | is.New(t).NoErr(err)
102 | })
103 |
104 | t.Run("copy to client", func(t *testing.T) {
105 | _, err := setup(t, nil, nil).CombinedOutput("scp -t .")
106 | is.New(t).True(err != nil)
107 | })
108 |
109 | t.Run("copy from client", func(t *testing.T) {
110 | _, err := setup(t, nil, nil).CombinedOutput("scp -f .")
111 | is.New(t).True(err != nil)
112 | })
113 | }
114 |
115 | func setup(tb testing.TB, rh CopyToClientHandler, wh CopyFromClientHandler) *gossh.Session {
116 | tb.Helper()
117 | return testsession.New(tb, &ssh.Server{
118 | Handler: Middleware(rh, wh)(func(s ssh.Session) {
119 | s.Exit(0)
120 | }),
121 | }, nil)
122 | }
123 |
124 | func requireEqualGolden(tb testing.TB, out []byte) {
125 | tb.Helper()
126 | is := is.New(tb)
127 |
128 | fixOutput := func(bts []byte) []byte {
129 | bts = bytes.ReplaceAll(bts, []byte("\r"), []byte(""))
130 | if runtime.GOOS == "windows" {
131 | // perms always come different on Windows because, well, its Windows.
132 | bts = bytes.ReplaceAll(bts, []byte("0666"), []byte("0644"))
133 | bts = bytes.ReplaceAll(bts, []byte("0777"), []byte("0755"))
134 | }
135 | return bytes.ReplaceAll(bts, NULL, []byte(""))
136 | }
137 |
138 | out = fixOutput(out)
139 | golden := "testdata/" + tb.Name() + ".test"
140 | if os.Getenv("UPDATE") != "" {
141 | is.NoErr(os.MkdirAll(filepath.Dir(golden), 0o755))
142 | is.NoErr(os.WriteFile(golden, out, 0o655))
143 | }
144 |
145 | gbts, err := os.ReadFile(golden)
146 | is.NoErr(err)
147 | gbts = fixOutput(gbts)
148 |
149 | if diff := cmp.Diff(string(gbts), string(out)); diff != "" {
150 | tb.Fatal("files do not match:", diff)
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/scp/testdata/TestFS/file.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | C0644 11 a.txt
3 | a text file
--------------------------------------------------------------------------------
/scp/testdata/TestFS/glob.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | C0644 11 a.txt
3 | a text fileT1323853868 0 1323853868 0
4 | C0644 17 b.txt
5 | another text file
--------------------------------------------------------------------------------
/scp/testdata/TestFS/recursive.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | D0755 0 a
3 | T1323853868 0 1323853868 0
4 | D0755 0 b
5 | T1323853868 0 1323853868 0
6 | D0755 0 c
7 | T1323853868 0 1323853868 0
8 | D0755 0 d
9 | T1323853868 0 1323853868 0
10 | D0755 0 e
11 | T1323853868 0 1323853868 0
12 | C0644 11 e.txt
13 | e text fileE
14 | E
15 | E
16 | T1323853868 0 1323853868 0
17 | C0644 11 c.txt
18 | c text fileE
19 | E
20 |
--------------------------------------------------------------------------------
/scp/testdata/TestFS/recursive_folder.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | D0755 0 a
3 | T1323853868 0 1323853868 0
4 | D0755 0 b
5 | T1323853868 0 1323853868 0
6 | D0755 0 c
7 | T1323853868 0 1323853868 0
8 | D0755 0 d
9 | T1323853868 0 1323853868 0
10 | D0755 0 e
11 | T1323853868 0 1323853868 0
12 | C0644 11 e.txt
13 | e text fileE
14 | E
15 | E
16 | T1323853868 0 1323853868 0
17 | C0644 11 c.txt
18 | c text fileE
19 | E
20 |
--------------------------------------------------------------------------------
/scp/testdata/TestFS/recursive_glob.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | D0755 0 b
3 | T1323853868 0 1323853868 0
4 | D0755 0 c
5 | T1323853868 0 1323853868 0
6 | D0755 0 d
7 | T1323853868 0 1323853868 0
8 | D0755 0 e
9 | T1323853868 0 1323853868 0
10 | C0644 11 e.txt
11 | e text fileE
12 | E
13 | E
14 | T1323853868 0 1323853868 0
15 | C0644 11 c.txt
16 | c text fileE
17 |
--------------------------------------------------------------------------------
/scp/testdata/TestFilesystem/scp_-f/file.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | C0644 11 a.txt
3 | a text file
--------------------------------------------------------------------------------
/scp/testdata/TestFilesystem/scp_-f/glob.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | C0644 11 a.txt
3 | a text fileT1323853868 0 1323853868 0
4 | C0644 17 b.txt
5 | another text file
--------------------------------------------------------------------------------
/scp/testdata/TestFilesystem/scp_-f/recursive.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | D0755 0 a
3 | T1323853868 0 1323853868 0
4 | D0755 0 b
5 | T1323853868 0 1323853868 0
6 | D0755 0 c
7 | T1323853868 0 1323853868 0
8 | D0755 0 d
9 | T1323853868 0 1323853868 0
10 | D0755 0 e
11 | T1323853868 0 1323853868 0
12 | C0644 11 e.txt
13 | e text fileE
14 | E
15 | E
16 | T1323853868 0 1323853868 0
17 | C0644 11 c.txt
18 | c text fileE
19 | E
20 |
--------------------------------------------------------------------------------
/scp/testdata/TestFilesystem/scp_-f/recursive_folder.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | D0755 0 a
3 | T1323853868 0 1323853868 0
4 | D0755 0 b
5 | T1323853868 0 1323853868 0
6 | D0755 0 c
7 | T1323853868 0 1323853868 0
8 | D0755 0 d
9 | T1323853868 0 1323853868 0
10 | D0755 0 e
11 | T1323853868 0 1323853868 0
12 | C0644 11 e.txt
13 | e text fileE
14 | E
15 | E
16 | T1323853868 0 1323853868 0
17 | C0644 11 c.txt
18 | c text fileE
19 | E
20 |
--------------------------------------------------------------------------------
/scp/testdata/TestFilesystem/scp_-f/recursive_glob.test:
--------------------------------------------------------------------------------
1 | T1323853868 0 1323853868 0
2 | D0755 0 b
3 | T1323853868 0 1323853868 0
4 | D0755 0 c
5 | T1323853868 0 1323853868 0
6 | D0755 0 d
7 | T1323853868 0 1323853868 0
8 | D0755 0 e
9 | T1323853868 0 1323853868 0
10 | C0644 11 e.txt
11 | e text fileE
12 | E
13 | E
14 | T1323853868 0 1323853868 0
15 | C0644 11 c.txt
16 | c text fileE
17 |
--------------------------------------------------------------------------------
/scp/testdata/TestNoDirRootEntry.test:
--------------------------------------------------------------------------------
1 | C0644 19 f1
2 | hello from file f1
3 | D0755 0 dir1
4 | C0600 42 f2
5 | hello from file f2
6 | with multiple lines :)
7 | E
8 |
--------------------------------------------------------------------------------
/testdata/another-ca-cert.pub:
--------------------------------------------------------------------------------
1 | ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAIJ9gNYnZ0HxMWDQ0Mu6vQiygdMYBKPN8DXJej/8Xi/h/AAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toAAAAAAAAAAAAAAAABAAAABGlkLTIAAAAHAAAAA2ZvbwAAAABib9ZIAAAAAGJv1roAAAAAAAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACBKEE2aUq61ZKQ3sFwf2aax1XW1lSg2VHiLX/WBSIuYyQAAAFMAAAALc3NoLWVkMjU1MTkAAABAuQW2puuCo9hZ4+NlGynMJ0rLcdznY2wIeLdr9ZyAncI0bRm5oU9aaWkUNYFjZXxytqO+Tpr8IIOTDj95VYeQAQ== carlos@darkstar
2 |
--------------------------------------------------------------------------------
/testdata/authorized_keys:
--------------------------------------------------------------------------------
1 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIM6LN9MSONoI2Dak7GSAy1vTY92NcioIuZqBnk0xmYR2 k1@test
2 | ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQChxV3pJRnXP7crH+4xxH8skCF/Bs8JX8VTjlS4dpLYzXMUcr0ls0DVwgIkIHvXQtqhR4ymgzciUNTTTYGPLAsda47MVqChO2Kxb+I215ApOmMt11lLX6l1Mp7xO35BYR+jC+s4H8VcespUQbWvASHKGZvhD1cri/FttjdCVs7Gqz7U5Cpo+Ym7UZ6TSiBmEd7zQkg4gR1uR4K8/5oJGpaQDDZr/QZJDGat//qvMAKtPkxomYVzHPnflFdsUIwMJHVver+JqKTMEZm2aDrOji4KpHosvfcbmIlx04N99TdT/0oNIQR1tpUsT/kdc44AqKsKUt7Os6kwYiDrjQlVIjpXPTCrdddgnl+/otH7pFsgVUjgCXz6lvhmV5HYhbJM45UEeDrmfxB0wLhC5J4fmvu4EcJtvO7vgg5PD51NpN4iFdUSj91fjskrLsbI+Do+KjcxhOdQxglZ6JwUV3ljFpPjINyawCveba9DRXO7CL/gLgwEwuIU8HXrgKFSXATSWsk= k4@test
3 | ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdJkpQAr3zhC+grKMexj8zgJIuAQ/2LR59RvXemEAovd671Et356cmHnCDmUvUlH/70xQdyL3n68tzu2ZEzKheQP5vz05CAFXTi7rlMvhtz632mLMPlU3lGuP+A6rzqNSnTtrIa2Q3Fe2ir6N+ad782J8g6frGJaVfA/G7j/M1JwyDJWzUS3HvDHDO+qFze71h0/o9W1+VoRaSfD67BzPQumkEkt/CilSPU8VKRP3q/FIeIrgTBhNh17SX/qlnyrJipDTF1QtXUOK4H5TsEE0S13z8a4Wo37kRWQPxdjWyfX9tBjsN86n+R7OGSXXdi10n9THrisdgx2GKsk1HjY+u5YlDpDysFLBs6j4nWeTxnrjgx6HUqvMk3mdqrAKHTglt34OUQtB463GMgCW85w+ni8ebPKlt5YQsXalilcoI4K7fakyXe+o9Y0sCwE3SLXEJhtd/Esz1pVzvMBCshpRknBPFh/gs/i1YuL0SJqI2BGBFs0d/ARwqUQSoXXBTJPc= k5@test
4 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJ7InQIj/ROngoWWb6kXTcTJd8+u5skDfGm8JJxRugMB k2@test
5 | ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDgteu96TZLd3iG11D5NqBsQRvhW2I6iD/ycwOiWFjFyv4MAHaDFiIazbeQSbi++1+5vspeNuv9AKJFgG0SpnjMLQM0rJb5DIsuRxGOAS/oh82yNCxYcW2+eXcqUDL4V+fZ6eIqtSIBrPQY89/CbZ4nFtw7+941gmFa2+7Wj9vLk4GTiyu/jQsbGnAZUCMvce1jFZ9XDMYSYzXEtkqhBT6eYDd7xMQejovszJfPqlKDxpMZxpaDsQGf+00IJPZUUxkX62eAmrlX1q4XO+m2zIjGpf/gdNKHEMXQrvBWdvwg0rat2i+PCW4Rbwx7wHBBWPRqEPjcVTfwvOWoZGGU3TSX8M7Gcj+ZvAD/uV7DWcNi61Obtw/6PXYvKFZWcZ1sHxUTI84CUcVLSL55hOtJqCuJXmUdKdBcJLyo3NValIjIZn+ljn6biVAr00nGo06nO4j2eTE2ZLOZFEB6rHuf1iaT18EiJgnJGrB7HY4+KUoUIzmvzQrKxxLbIe957hnx+TE= k6@test
6 |
7 | # a commented line, and the previous line was empty
8 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMJlb/qf2B2kMNdBxfpCQqI2ctPcsOkdZGVh5zTRhKtH k3@test
9 |
--------------------------------------------------------------------------------
/testdata/ca:
--------------------------------------------------------------------------------
1 | -----BEGIN OPENSSH PRIVATE KEY-----
2 | b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
3 | QyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcwAAAJgZFHXsGRR1
4 | 7AAAAAtzc2gtZWQyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcw
5 | AAAEDYnnbXzbn0SnOKCf0ByXm1+FLnqJC+ZErxo2SgaFeVCWbxxtpfzUaQA5IdOuI+QU6U
6 | UMTxj2es4IJiKnXFVhVzAAAAD2Nhcmxvc0BkYXJrc3RhcgECAwQFBg==
7 | -----END OPENSSH PRIVATE KEY-----
8 |
--------------------------------------------------------------------------------
/testdata/ca.pub:
--------------------------------------------------------------------------------
1 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGbxxtpfzUaQA5IdOuI+QU6UUMTxj2es4IJiKnXFVhVz carlos@darkstar
2 |
--------------------------------------------------------------------------------
/testdata/expired-cert.pub:
--------------------------------------------------------------------------------
1 | ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAIB7uQi6D84xl48p0tQjh7sXk9cpzjEcqsu/ZYjT4k3j7AAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toAAAAAAAAAAAAAAAABAAAABGlkLTIAAAAHAAAAA2ZvbwAAAABib8u8AAAAAGJvzAQAAAAAAAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcwAAAFMAAAALc3NoLWVkMjU1MTkAAABA5vfSWp5U1jsyM9u7jO8JmhMQqsxakkUUCwOVYqSVe65h06ANCmrYowmMuYVowvTaXkbdF5RhNjQxqE2fRNa4Bw== carlos@darkstar
2 |
--------------------------------------------------------------------------------
/testdata/foo:
--------------------------------------------------------------------------------
1 | -----BEGIN OPENSSH PRIVATE KEY-----
2 | b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
3 | QyNTUxOQAAACBpREs3uF5U48eRJovplGwla5BPyZBlBV/Apz2eq+LaAAAAAJj5F0Fq+RdB
4 | agAAAAtzc2gtZWQyNTUxOQAAACBpREs3uF5U48eRJovplGwla5BPyZBlBV/Apz2eq+LaAA
5 | AAAEBYbFEA6Ad/SafFoBHD9RM97SS79vZLrZ0p9DftTk4DbmlESze4XlTjx5Emi+mUbCVr
6 | kE/JkGUFX8CnPZ6r4toAAAAAD2Nhcmxvc0BkYXJrc3RhcgECAwQFBg==
7 | -----END OPENSSH PRIVATE KEY-----
8 |
--------------------------------------------------------------------------------
/testdata/foo.pub:
--------------------------------------------------------------------------------
1 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toA carlos@darkstar
2 |
--------------------------------------------------------------------------------
/testdata/invalid_authorized_keys:
--------------------------------------------------------------------------------
1 | ssh-nope nopenopenope k1@test
2 |
--------------------------------------------------------------------------------
/testdata/valid-cert.pub:
--------------------------------------------------------------------------------
1 | ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAILHcoWg/uAZrEroYaEwxBkUCNH95Iz+ycr41K6KdnL0ZAAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toAAAAAAAAAAAAAAAABAAAABGlkLTEAAAAHAAAAA2ZvbwAAAAAAAAAA//////////8AAAAAAAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcwAAAFMAAAALc3NoLWVkMjU1MTkAAABAUlsnmwWKhk2S0tAT/woHwnT2H0V3sn/PEtruHEzbU+uriVHffG046MHbvudLCZ2H86HoeZsu1N7GJuqWE+P6Bg== carlos@darkstar
2 |
--------------------------------------------------------------------------------
/testsession/testsession.go:
--------------------------------------------------------------------------------
1 | // Package testsession provides utilities to test SSH sessions.
2 | //
3 | // more or less copied from charmbracelet/ssh tests
4 | package testsession
5 |
6 | import (
7 | "net"
8 | "testing"
9 |
10 | "github.com/charmbracelet/ssh"
11 | gossh "golang.org/x/crypto/ssh"
12 | )
13 |
14 | // New starts a local SSH server with the given config and returns a client session.
15 | // It automatically closes everything afterwards.
16 | func New(tb testing.TB, srv *ssh.Server, cfg *gossh.ClientConfig) *gossh.Session {
17 | tb.Helper()
18 | sess, err := NewClientSession(tb, Listen(tb, srv), cfg)
19 | if err != nil {
20 | tb.Fatal(err)
21 | }
22 | return sess
23 | }
24 |
25 | // Listen starts a test server.
26 | func Listen(tb testing.TB, srv *ssh.Server) string {
27 | tb.Helper()
28 | l := newLocalListener(tb)
29 | go func() {
30 | if err := srv.Serve(l); err != nil && err != ssh.ErrServerClosed {
31 | tb.Fatalf("failed to serve: %v", err)
32 | }
33 | }()
34 | tb.Cleanup(func() {
35 | _ = srv.Close()
36 | })
37 | return l.Addr().String()
38 | }
39 |
40 | func newLocalListener(tb testing.TB) net.Listener {
41 | tb.Helper()
42 | l, err := net.Listen("tcp", "127.0.0.1:0")
43 | if err != nil {
44 | if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
45 | tb.Fatalf("failed to listen on a port: %v", err)
46 | }
47 | }
48 |
49 | tb.Cleanup(func() { _ = l.Close() })
50 | return l
51 | }
52 |
53 | // NewClientSession creates a new client session to the given address.
54 | func NewClientSession(tb testing.TB, addr string, config *gossh.ClientConfig) (*gossh.Session, error) {
55 | tb.Helper()
56 | if config == nil {
57 | config = &gossh.ClientConfig{
58 | User: "testuser",
59 | Auth: []gossh.AuthMethod{
60 | gossh.Password("testpass"),
61 | },
62 | }
63 | }
64 | if config.HostKeyCallback == nil {
65 | config.HostKeyCallback = gossh.InsecureIgnoreHostKey() // nolint: gosec
66 | }
67 | client, err := gossh.Dial("tcp", addr, config)
68 | if err != nil {
69 | return nil, err
70 | }
71 | session, err := client.NewSession()
72 | if err != nil {
73 | return nil, err
74 | }
75 | tb.Cleanup(func() {
76 | _ = session.Close()
77 | _ = client.Close()
78 | })
79 | return session, nil
80 | }
81 |
--------------------------------------------------------------------------------
/testsession/testsession_test.go:
--------------------------------------------------------------------------------
1 | package testsession
2 |
3 | import (
4 | "fmt"
5 | "testing"
6 |
7 | "github.com/charmbracelet/ssh"
8 | )
9 |
10 | func TestSession(t *testing.T) {
11 | const out = "hello world"
12 | session := New(t, &ssh.Server{
13 | Handler: func(s ssh.Session) {
14 | _, _ = fmt.Fprint(s, out)
15 | },
16 | }, nil)
17 | result, err := session.Output("")
18 | if err != nil {
19 | t.Errorf("expected no error, got %v", err)
20 | }
21 | if string(result) != out {
22 | t.Errorf("expected %q, got %q", out, string(result))
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/wish.go:
--------------------------------------------------------------------------------
1 | package wish
2 |
3 | import (
4 | "fmt"
5 | "io"
6 |
7 | "github.com/charmbracelet/keygen"
8 | "github.com/charmbracelet/ssh"
9 | )
10 |
11 | // Middleware is a function that takes an ssh.Handler and returns an
12 | // ssh.Handler. Implementations should call the provided handler argument.
13 | type Middleware func(next ssh.Handler) ssh.Handler
14 |
15 | // NewServer is returns a default SSH server with the provided Middleware. A
16 | // new SSH key pair of type ed25519 will be created if one does not exist. By
17 | // default this server will accept all incoming connections, password and
18 | // public key.
19 | func NewServer(ops ...ssh.Option) (*ssh.Server, error) {
20 | s := &ssh.Server{}
21 | for _, op := range ops {
22 | if err := s.SetOption(op); err != nil {
23 | return nil, err
24 | }
25 | }
26 | if len(s.HostSigners) == 0 {
27 | k, err := keygen.New("id_ed25519", keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
28 | if err != nil {
29 | return nil, err
30 | }
31 | err = s.SetOption(WithHostKeyPEM(k.RawPrivateKey()))
32 | if err != nil {
33 | return nil, err
34 | }
35 | }
36 | return s, nil
37 | }
38 |
39 | // Fatal prints to the given session's STDERR and exits 1.
40 | func Fatal(s ssh.Session, v ...interface{}) {
41 | Error(s, v...)
42 | _ = s.Exit(1)
43 | _ = s.Close()
44 | }
45 |
46 | // Fatalf formats according to the given format, prints to the session's STDERR
47 | // followed by an exit 1.
48 | //
49 | // Notice that this might cause formatting issues if you don't add a \r\n in the end of your string.
50 | func Fatalf(s ssh.Session, f string, v ...interface{}) {
51 | Errorf(s, f, v...)
52 | _ = s.Exit(1)
53 | _ = s.Close()
54 | }
55 |
56 | // Fatalln formats according to the default format, prints to the session's
57 | // STDERR, followed by a new line and an exit 1.
58 | func Fatalln(s ssh.Session, v ...interface{}) {
59 | Errorln(s, v...)
60 | Errorf(s, "\r")
61 | _ = s.Exit(1)
62 | _ = s.Close()
63 | }
64 |
65 | // Error prints the given error the the session's STDERR.
66 | func Error(s ssh.Session, v ...interface{}) {
67 | _, _ = fmt.Fprint(s.Stderr(), v...)
68 | }
69 |
70 | // Errorf formats according to the given format and prints to the session's STDERR.
71 | func Errorf(s ssh.Session, f string, v ...interface{}) {
72 | _, _ = fmt.Fprintf(s.Stderr(), f, v...)
73 | }
74 |
75 | // Errorf formats according to the default format and prints to the session's STDERR.
76 | func Errorln(s ssh.Session, v ...interface{}) {
77 | _, _ = fmt.Fprintln(s.Stderr(), v...)
78 | }
79 |
80 | // Print writes to the session's STDOUT followed.
81 | func Print(s ssh.Session, v ...interface{}) {
82 | _, _ = fmt.Fprint(s, v...)
83 | }
84 |
85 | // Printf formats according to the given format and writes to the session's STDOUT.
86 | func Printf(s ssh.Session, f string, v ...interface{}) {
87 | _, _ = fmt.Fprintf(s, f, v...)
88 | }
89 |
90 | // Println formats according to the default format and writes to the session's STDOUT.
91 | func Println(s ssh.Session, v ...interface{}) {
92 | _, _ = fmt.Fprintln(s, v...)
93 | }
94 |
95 | // WriteString writes the given string to the session's STDOUT.
96 | func WriteString(s ssh.Session, v string) (int, error) {
97 | return io.WriteString(s, v)
98 | }
99 |
--------------------------------------------------------------------------------
/wish_test.go:
--------------------------------------------------------------------------------
1 | // go:generate mockgen -package mocks -destination mocks/session.go github.com/charmbracelet/ssh Session
2 | package wish
3 |
4 | import (
5 | "bytes"
6 | "errors"
7 | "path/filepath"
8 | "strings"
9 | "testing"
10 | "time"
11 |
12 | "github.com/charmbracelet/ssh"
13 | "github.com/charmbracelet/wish/testsession"
14 | )
15 |
16 | func TestNewServer(t *testing.T) {
17 | fp := filepath.Join(t.TempDir(), "id_ed25519")
18 | _, err := NewServer(WithHostKeyPath(fp))
19 | if err != nil {
20 | t.Fatal(err)
21 | }
22 | }
23 |
24 | func TestNewServerWithOptions(t *testing.T) {
25 | fp := filepath.Join(t.TempDir(), "id_ed25519")
26 | if _, err := NewServer(
27 | WithHostKeyPath(fp),
28 | WithMaxTimeout(time.Second),
29 | WithBanner("welcome"),
30 | WithAddress(":2222"),
31 | ); err != nil {
32 | t.Fatal(err)
33 | }
34 | }
35 |
36 | func TestError(t *testing.T) {
37 | eerr := errors.New("foo err")
38 | sess := testsession.New(t, &ssh.Server{
39 | Handler: func(s ssh.Session) {
40 | Error(s, eerr)
41 | },
42 | }, nil)
43 | var out bytes.Buffer
44 | sess.Stderr = &out
45 | if err := sess.Run(""); err != nil {
46 | t.Errorf("expected no error, got %s", err)
47 | }
48 | if s := strings.TrimSpace(out.String()); s != eerr.Error() {
49 | t.Errorf("expected %s, got %s", s, eerr)
50 | }
51 | }
52 |
53 | func TestFatal(t *testing.T) {
54 | err := errors.New("foo err")
55 | sess := testsession.New(t, &ssh.Server{
56 | Handler: func(s ssh.Session) {
57 | Fatal(s, err)
58 | },
59 | }, nil)
60 | var out bytes.Buffer
61 | sess.Stderr = &out
62 | if err := sess.Run(""); err == nil {
63 | t.Error("expected an error, got nil")
64 | }
65 | if s := strings.TrimSpace(out.String()); s != err.Error() {
66 | t.Errorf("expected %s, got %s", s, err)
67 | }
68 | }
69 |
--------------------------------------------------------------------------------