├── .github ├── CODEOWNERS ├── dependabot.yml └── workflows │ ├── build.yml │ ├── dependabot-sync.yml │ ├── examples.yml │ ├── goreleaser.yml │ └── lint.yml ├── .gitignore ├── .golangci.yml ├── .goreleaser.yml ├── LICENSE ├── README.md ├── accesscontrol ├── accesscontrol.go └── accesscontrol_test.go ├── activeterm ├── activeterm.go └── activeterm_test.go ├── bubbletea ├── query.go ├── tea.go ├── tea_other.go └── tea_unix.go ├── cmd.go ├── cmd_test.go ├── cmd_unix.go ├── cmd_windows.go ├── comment ├── comment.go └── comment_test.go ├── elapsed ├── elapsed.go └── elapsed_test.go ├── examples ├── .gitignore ├── README.md ├── banner │ ├── banner.txt │ └── main.go ├── bubbletea-exec │ └── main.go ├── bubbletea │ └── main.go ├── bubbleteaprogram │ └── main.go ├── cobra │ └── main.go ├── exec │ ├── example.sh │ └── main.go ├── forward │ └── main.go ├── git │ └── main.go ├── go.mod ├── go.sum ├── graceful-shutdown │ └── main.go ├── identity │ └── main.go ├── multi-auth │ └── main.go ├── multichat │ └── main.go ├── pty │ └── main.go ├── scp │ ├── main.go │ └── testdata │ │ └── .gitkeep └── simple │ └── main.go ├── git ├── git.go └── git_test.go ├── go.mod ├── go.sum ├── logging ├── logging.go └── logging_test.go ├── options.go ├── options_test.go ├── ratelimiter ├── ratelimiter.go └── ratelimiter_test.go ├── recover ├── recover.go └── recover_test.go ├── scp ├── copy_from_client.go ├── copy_to_client.go ├── filesystem.go ├── filesystem_test.go ├── fs.go ├── fs_test.go ├── limit_reader.go ├── limit_reader_test.go ├── scp.go ├── scp_test.go └── testdata │ ├── TestFS │ ├── file.test │ ├── glob.test │ ├── recursive.test │ ├── recursive_folder.test │ └── recursive_glob.test │ ├── TestFilesystem │ └── scp_-f │ │ ├── file.test │ │ ├── glob.test │ │ ├── recursive.test │ │ ├── recursive_folder.test │ │ └── recursive_glob.test │ └── TestNoDirRootEntry.test ├── testdata ├── another-ca-cert.pub ├── authorized_keys ├── ca ├── ca.pub ├── expired-cert.pub ├── foo ├── foo.pub ├── invalid_authorized_keys └── valid-cert.pub ├── testsession ├── testsession.go └── testsession_test.go ├── wish.go └── wish_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @charmbracelet/everyone 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: "gomod" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | day: "monday" 9 | time: "05:00" 10 | timezone: "America/New_York" 11 | labels: 12 | - "dependencies" 13 | commit-message: 14 | prefix: "chore" 15 | include: "scope" 16 | 17 | - package-ecosystem: "github-actions" 18 | directory: "/" 19 | schedule: 20 | interval: "weekly" 21 | day: "monday" 22 | time: "05:00" 23 | timezone: "America/New_York" 24 | labels: 25 | - "dependencies" 26 | commit-message: 27 | prefix: "chore" 28 | include: "scope" 29 | 30 | - package-ecosystem: "docker" 31 | directory: "/" 32 | schedule: 33 | interval: "weekly" 34 | day: "monday" 35 | time: "05:00" 36 | timezone: "America/New_York" 37 | labels: 38 | - "dependencies" 39 | commit-message: 40 | prefix: "chore" 41 | include: "scope" 42 | 43 | - package-ecosystem: "gomod" 44 | directory: "/examples" 45 | schedule: 46 | interval: "weekly" 47 | day: "monday" 48 | time: "05:00" 49 | timezone: "America/New_York" 50 | labels: 51 | - "dependencies" 52 | commit-message: 53 | prefix: "chore" 54 | include: "scope" 55 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | build: 9 | uses: charmbracelet/meta/.github/workflows/build.yml@main 10 | 11 | codecov: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-go@v5 16 | with: 17 | go-version: "stable" 18 | cache: true 19 | - run: go test -failfast -race -coverpkg=./... -covermode=atomic -coverprofile=coverage.txt ./... -timeout 5m 20 | - uses: codecov/codecov-action@v5 21 | with: 22 | file: ./coverage.txt 23 | -------------------------------------------------------------------------------- /.github/workflows/dependabot-sync.yml: -------------------------------------------------------------------------------- 1 | name: dependabot-sync 2 | on: 3 | schedule: 4 | - cron: "0 0 * * 0" # every Sunday at midnight 5 | workflow_dispatch: # allows manual triggering 6 | 7 | permissions: 8 | contents: write 9 | pull-requests: write 10 | 11 | jobs: 12 | dependabot-sync: 13 | uses: charmbracelet/meta/.github/workflows/dependabot-sync.yml@main 14 | with: 15 | repo_name: ${{ github.event.repository.name }} 16 | secrets: 17 | gh_token: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 18 | -------------------------------------------------------------------------------- /.github/workflows/examples.yml: -------------------------------------------------------------------------------- 1 | name: examples 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | paths: 8 | - '.github/workflows/examples.yml' 9 | - './examples/go.mod' 10 | - './examples/go.sum' 11 | - './go.mod' 12 | - './go.sum' 13 | workflow_dispatch: {} 14 | 15 | jobs: 16 | tidy: 17 | permissions: 18 | contents: write 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-go@v5 23 | with: 24 | go-version: '^1' 25 | cache: true 26 | - shell: bash 27 | run: | 28 | (cd ./examples && go mod tidy) 29 | - uses: stefanzweifel/git-auto-commit-action@v5 30 | with: 31 | commit_message: "chore: go mod tidy examples" 32 | branch: main 33 | commit_user_name: actions-user 34 | commit_user_email: actions@github.com 35 | 36 | 37 | -------------------------------------------------------------------------------- /.github/workflows/goreleaser.yml: -------------------------------------------------------------------------------- 1 | name: goreleaser 2 | 3 | on: 4 | push: 5 | tags: 6 | - v*.*.* 7 | 8 | concurrency: 9 | group: goreleaser 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | goreleaser: 14 | uses: charmbracelet/meta/.github/workflows/goreleaser.yml@main 15 | secrets: 16 | docker_username: ${{ secrets.DOCKERHUB_USERNAME }} 17 | docker_token: ${{ secrets.DOCKERHUB_TOKEN }} 18 | gh_pat: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 19 | goreleaser_key: ${{ secrets.GORELEASER_KEY }} 20 | twitter_consumer_key: ${{ secrets.TWITTER_CONSUMER_KEY }} 21 | twitter_consumer_secret: ${{ secrets.TWITTER_CONSUMER_SECRET }} 22 | twitter_access_token: ${{ secrets.TWITTER_ACCESS_TOKEN }} 23 | twitter_access_token_secret: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }} 24 | mastodon_client_id: ${{ secrets.MASTODON_CLIENT_ID }} 25 | mastodon_client_secret: ${{ secrets.MASTODON_CLIENT_SECRET }} 26 | mastodon_access_token: ${{ secrets.MASTODON_ACCESS_TOKEN }} 27 | discord_webhook_id: ${{ secrets.DISCORD_WEBHOOK_ID }} 28 | discord_webhook_token: ${{ secrets.DISCORD_WEBHOOK_TOKEN }} 29 | 30 | # yaml-language-server: $schema=https://json.schemastore.org/github-workflow.json 31 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | on: 3 | push: 4 | pull_request: 5 | 6 | jobs: 7 | golangci: 8 | name: lint 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-go@v5 13 | with: 14 | go-version: ^1 15 | cache: true 16 | - uses: golangci/golangci-lint-action@v8 17 | with: 18 | # Optional: golangci-lint command line arguments. 19 | args: --issues-exit-code=0 20 | # Optional: show only new issues if it's a pull request. The default value is `false`. 21 | only-new-issues: true 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | examples/bubbletea/bubbletea 2 | examples/bubbletea/.ssh 3 | examples/git/git 4 | examples/git/.ssh 5 | examples/git/.repos 6 | .repos 7 | .ssh 8 | coverage.txt 9 | id_ed25519 10 | id_ed25519.pub 11 | 12 | # MacOS specific 13 | .DS_Store 14 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | tests: false 3 | 4 | issues: 5 | include: 6 | - EXC0001 7 | - EXC0005 8 | - EXC0011 9 | - EXC0012 10 | - EXC0013 11 | 12 | max-issues-per-linter: 0 13 | max-same-issues: 0 14 | 15 | linters: 16 | enable: 17 | - bodyclose 18 | - depguard 19 | - dupl 20 | - goconst 21 | - godot 22 | - godox 23 | - gofumpt 24 | - goimports 25 | - goprintffuncname 26 | - gosec 27 | - misspell 28 | - prealloc 29 | - revive 30 | - rowserrcheck 31 | - sqlclosecheck 32 | - unconvert 33 | - unparam 34 | - whitespace 35 | 36 | linters-settings: 37 | depguard: 38 | rules: 39 | main: 40 | deny: 41 | - pkg: "github.com/gliderlabs/ssh" 42 | desc: "use github.com/charmbracelet/ssh instead" 43 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://goreleaser.com/static/schema-pro.json 2 | version: 2 3 | includes: 4 | - from_url: 5 | url: charmbracelet/meta/main/goreleaser-lib.yaml 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019-2023 Charmbracelet, Inc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wish 2 | 3 |

4 | A nice rendering of a star, anthropomorphized somewhat by means of a smile, with the words ‘Charm Wish’ next to it 5 |
6 | Latest Release 7 | GoDoc 8 | Build Status 9 | Codecov branch 10 | Go Report Card 11 |

12 | 13 | 14 | Make SSH apps, just like that! 💫 15 | 16 | SSH is an excellent platform for building remotely accessible applications. It 17 | offers: 18 | * secure communication without the hassle of HTTPS certificates 19 | * user identification with SSH keys 20 | * access from any terminal 21 | 22 | Powerful protocols like Git work over SSH and you can even render TUIs directly over an SSH connection. 23 | 24 | Wish is an SSH server with sensible defaults and a collection of middlewares that 25 | makes building SSH apps really easy. Wish is built on [gliderlabs/ssh][gliderlabs/ssh] 26 | and should be easy to integrate into any existing projects. 27 | 28 | ## What are SSH Apps? 29 | 30 | Usually, when we think about SSH, we think about remote shell access into servers, 31 | most commonly through `openssh-server`. 32 | 33 | That's a perfectly valid (and probably the most common) use of SSH, but it can do so much more than that. 34 | Just like HTTP, SMTP, FTP and others, SSH is a protocol! 35 | It is a cryptographic network protocol for operating network services securely over an unsecured network. [^1] 36 | 37 | [^1]: https://en.wikipedia.org/wiki/Secure_Shell 38 | 39 | That means, among other things, that we can write custom SSH servers without touching `openssh-server`, 40 | so we can securely do more things than just providing a shell. 41 | 42 | Wish is a library that helps writing these kind of apps using Go. 43 | 44 | ## Middleware 45 | 46 | Wish middlewares are analogous to those in several HTTP frameworks. 47 | They are essentially SSH handlers that you can use to do specific tasks, 48 | and then call the next middleware. 49 | 50 | Notice that middlewares are composed from first to last, 51 | which means the last one is executed first. 52 | 53 | ### Bubble Tea 54 | 55 | The [`bubbletea`](bubbletea) middleware makes it easy to serve any 56 | [Bubble Tea][bubbletea] application over SSH. Each SSH session will get their own 57 | `tea.Program` with the SSH pty input and output connected. Client window 58 | dimension and resize messages are also natively handled by the `tea.Program`. 59 | 60 | You can see a demo of the Wish middleware in action at: `ssh git.charm.sh` 61 | 62 | ### Git 63 | 64 | The [`git`](git) middleware adds `git` server functionality to any ssh server. 65 | It supports repo creation on initial push and custom public key based auth. 66 | 67 | This middleware requires that `git` is installed on the server. 68 | 69 | ### Logging 70 | 71 | The [`logging`](logging) middleware provides basic connection logging. Connects 72 | are logged with the remote address, invoked command, TERM setting, window 73 | dimensions and if the auth was public key based. Disconnect will log the remote 74 | address and connection duration. 75 | 76 | ### Access Control 77 | 78 | Not all applications will support general SSH connections. To restrict access 79 | to supported methods, you can use the [`activeterm`](activeterm) middleware to 80 | only allow connections with active terminals connected and the 81 | [`accesscontrol`](accesscontrol) middleware that lets you specify allowed 82 | commands. 83 | 84 | ## Default Server 85 | 86 | Wish includes the ability to easily create an always authenticating default SSH 87 | server with automatic server key generation. 88 | 89 | ## Examples 90 | 91 | There are examples for a standalone [Bubble Tea application](examples/bubbletea) 92 | and [Git server](examples/git) in the [examples](examples) folder. 93 | 94 | ## Apps Built With Wish 95 | 96 | * [Soft Serve](https://github.com/charmbracelet/soft-serve) 97 | * [Wishlist](https://github.com/charmbracelet/wishlist) 98 | * [SSHWordle](https://github.com/davidcroda/sshwordle) 99 | * [clidle](https://github.com/ajeetdsouza/clidle) 100 | * [ssh-warm-welcome](https://git.coopcloud.tech/decentral1se/ssh-warm-welcome) 101 | 102 | [bubbletea]: https://github.com/charmbracelet/bubbletea 103 | [gliderlabs/ssh]: https://github.com/gliderlabs/ssh 104 | 105 | ## Pro tip 106 | 107 | When building various Wish applications locally you can add the following to 108 | your `~/.ssh/config` to avoid having to clear out `localhost` entries in your 109 | `~/.ssh/known_hosts` file: 110 | 111 | ``` 112 | Host localhost 113 | UserKnownHostsFile /dev/null 114 | ``` 115 | 116 | ## How it works? 117 | 118 | Wish uses [gliderlabs/ssh][gliderlabs/ssh] to implement its SSH server, and 119 | OpenSSH is never used nor needed — you can even uninstall it if you want to. 120 | 121 | Incidentally, there's no risk of accidentally sharing a shell because there's no 122 | default behavior that does that on Wish. 123 | 124 | ## Running with SystemD 125 | 126 | If you want to run a Wish app with `systemd`, you can create an unit like so: 127 | 128 | `/etc/systemd/system/myapp.service`: 129 | ```service 130 | [Unit] 131 | Description=My App 132 | After=network.target 133 | 134 | [Service] 135 | Type=simple 136 | User=myapp 137 | Group=myapp 138 | WorkingDirectory=/home/myapp/ 139 | ExecStart=/usr/bin/myapp 140 | Restart=on-failure 141 | 142 | [Install] 143 | WantedBy=multi-user.target 144 | ``` 145 | 146 | You can tune the values below, and once you're happy with them, you can run: 147 | 148 | ```bash 149 | # need to run this every time you change the unit file 150 | sudo systemctl daemon-reload 151 | 152 | # start/restart/stop/etc: 153 | sudo systemctl start myapp 154 | ``` 155 | 156 | If you use a new user for each app (which is good), you'll need to create them 157 | first: 158 | 159 | ```bash 160 | useradd --system --user-group --create-home myapp 161 | ``` 162 | 163 | That should do it. 164 | 165 | ### 166 | 167 | ## Feedback 168 | 169 | We’d love to hear your thoughts on this project. Feel free to drop us a note! 170 | 171 | * [Twitter](https://twitter.com/charmcli) 172 | * [The Fediverse](https://mastodon.social/@charmcli) 173 | * [Discord](https://charm.sh/chat) 174 | 175 | ## License 176 | 177 | [MIT](https://github.com/charmbracelet/wish/raw/main/LICENSE) 178 | 179 | *** 180 | 181 | Part of [Charm](https://charm.sh). 182 | 183 | The Charm logo 184 | 185 | Charm热爱开源 • Charm loves open source 186 | -------------------------------------------------------------------------------- /accesscontrol/accesscontrol.go: -------------------------------------------------------------------------------- 1 | // Package accesscontrol provides a middleware that allows you to restrict the commands the user can execute. 2 | package accesscontrol 3 | 4 | import ( 5 | "fmt" 6 | 7 | "github.com/charmbracelet/ssh" 8 | "github.com/charmbracelet/wish" 9 | ) 10 | 11 | // Middleware will exit 1 connections trying to execute commands that are not allowed. 12 | // If no allowed commands are provided, no commands will be allowed. 13 | func Middleware(cmds ...string) wish.Middleware { 14 | return func(sh ssh.Handler) ssh.Handler { 15 | return func(s ssh.Session) { 16 | if len(s.Command()) == 0 { 17 | sh(s) 18 | return 19 | } 20 | for _, cmd := range cmds { 21 | if s.Command()[0] == cmd { 22 | sh(s) 23 | return 24 | } 25 | } 26 | _, _ = fmt.Fprintln(s, "Command is not allowed: "+s.Command()[0]) 27 | s.Exit(1) // nolint: errcheck 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /accesscontrol/accesscontrol_test.go: -------------------------------------------------------------------------------- 1 | package accesscontrol_test 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/charmbracelet/ssh" 8 | "github.com/charmbracelet/wish/accesscontrol" 9 | "github.com/charmbracelet/wish/testsession" 10 | gossh "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | const out = "hello world" 14 | 15 | func TestMiddleware(t *testing.T) { 16 | requireDenied := func(tb testing.TB, s, cmd string) { 17 | tb.Helper() 18 | expected := fmt.Sprintf("Command is not allowed: %s\n", cmd) 19 | if s != expected { 20 | t.Errorf("expected %q, got %q", expected, s) 21 | } 22 | } 23 | 24 | requireOutput := func(tb testing.TB, s string) { 25 | tb.Helper() 26 | if out != s { 27 | t.Errorf("expected %q, got %q", out, s) 28 | } 29 | } 30 | 31 | t.Run("no allowed cmds no cmd", func(t *testing.T) { 32 | out, err := setup(t).Output("") 33 | if err != nil { 34 | t.Error(err) 35 | } 36 | requireOutput(t, string(out)) 37 | }) 38 | 39 | t.Run("no allowed cmds with cmd", func(t *testing.T) { 40 | out, err := setup(t).Output("echo") 41 | if err == nil { 42 | t.Errorf("should have errored") 43 | } 44 | requireDenied(t, string(out), "echo") 45 | }) 46 | 47 | t.Run("allowed cmds no cmd", func(t *testing.T) { 48 | out, err := setup(t, "echo").Output("") 49 | if err != nil { 50 | t.Error(err) 51 | } 52 | requireOutput(t, string(out)) 53 | }) 54 | 55 | t.Run("allowed cmds with allowed cmd", func(t *testing.T) { 56 | out, err := setup(t, "echo").Output("echo") 57 | if err != nil { 58 | t.Error(err) 59 | } 60 | requireOutput(t, string(out)) 61 | }) 62 | 63 | t.Run("allowed cmds with disallowed cmd", func(t *testing.T) { 64 | out, err := setup(t, "echo").Output("cat") 65 | if err == nil { 66 | t.Error(err) 67 | } 68 | requireDenied(t, string(out), "cat") 69 | }) 70 | 71 | t.Run("allowed cmds with allowed cmd followed disallowed cmd", func(t *testing.T) { 72 | out, err := setup(t, "echo").Output("cat echo") 73 | if err == nil { 74 | t.Error(err) 75 | } 76 | requireDenied(t, string(out), "cat") 77 | }) 78 | } 79 | 80 | func setup(tb testing.TB, allowedCmds ...string) *gossh.Session { 81 | tb.Helper() 82 | return testsession.New(tb, &ssh.Server{ 83 | Handler: accesscontrol.Middleware(allowedCmds...)(func(s ssh.Session) { 84 | s.Write([]byte(out)) 85 | }), 86 | }, nil) 87 | } 88 | -------------------------------------------------------------------------------- /activeterm/activeterm.go: -------------------------------------------------------------------------------- 1 | // Package activeterm provides a middleware to block inactive PTYs. 2 | package activeterm 3 | 4 | import ( 5 | "github.com/charmbracelet/ssh" 6 | "github.com/charmbracelet/wish" 7 | ) 8 | 9 | // Middleware will exit 1 connections trying with no active terminals. 10 | func Middleware() wish.Middleware { 11 | return func(next ssh.Handler) ssh.Handler { 12 | return func(sess ssh.Session) { 13 | _, _, active := sess.Pty() 14 | if active { 15 | next(sess) 16 | return 17 | } 18 | wish.Println(sess, "Requires an active PTY") 19 | _ = sess.Exit(1) 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /activeterm/activeterm_test.go: -------------------------------------------------------------------------------- 1 | package activeterm_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/charmbracelet/ssh" 7 | "github.com/charmbracelet/wish/activeterm" 8 | "github.com/charmbracelet/wish/testsession" 9 | gossh "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | func TestMiddleware(t *testing.T) { 13 | t.Run("inactive term", func(t *testing.T) { 14 | out, err := setup(t).Output("") 15 | if err == nil { 16 | t.Errorf("tests should be an inactive pty") 17 | } 18 | if string(out) != "Requires an active PTY\n" { 19 | t.Errorf("invalid output: %q", string(out)) 20 | } 21 | }) 22 | } 23 | 24 | func setup(tb testing.TB) *gossh.Session { 25 | tb.Helper() 26 | return testsession.New(tb, &ssh.Server{ 27 | Handler: activeterm.Middleware()(func(s ssh.Session) { 28 | s.Write([]byte("hello")) 29 | }), 30 | }, nil) 31 | } 32 | -------------------------------------------------------------------------------- /bubbletea/query.go: -------------------------------------------------------------------------------- 1 | package bubbletea 2 | 3 | import ( 4 | "image/color" 5 | "io" 6 | "time" 7 | 8 | "github.com/charmbracelet/x/ansi" 9 | "github.com/charmbracelet/x/input" 10 | ) 11 | 12 | // queryBackgroundColor queries the terminal for the background color. 13 | // If the terminal does not support querying the background color, nil is 14 | // returned. 15 | // 16 | // Note: you will need to set the input to raw mode before calling this 17 | // function. 18 | // 19 | // state, _ := term.MakeRaw(in.Fd()) 20 | // defer term.Restore(in.Fd(), state) 21 | // 22 | // copied from x/term@v0.1.3. 23 | func queryBackgroundColor(in io.Reader, out io.Writer) (c color.Color, err error) { 24 | // nolint: errcheck 25 | err = queryTerminal(in, out, defaultQueryTimeout, 26 | func(events []input.Event) bool { 27 | for _, e := range events { 28 | switch e := e.(type) { 29 | case input.BackgroundColorEvent: 30 | c = e.Color 31 | continue // we need to consume the next DA1 event 32 | case input.PrimaryDeviceAttributesEvent: 33 | return false 34 | } 35 | } 36 | return true 37 | }, ansi.RequestBackgroundColor+ansi.RequestPrimaryDeviceAttributes) 38 | return 39 | } 40 | 41 | const defaultQueryTimeout = time.Second * 2 42 | 43 | // QueryTerminalFilter is a function that filters input events using a type 44 | // switch. If false is returned, the QueryTerminal function will stop reading 45 | // input. 46 | type QueryTerminalFilter func(events []input.Event) bool 47 | 48 | // queryTerminal queries the terminal for support of various features and 49 | // returns a list of response events. 50 | // Most of the time, you will need to set stdin to raw mode before calling this 51 | // function. 52 | // Note: This function will block until the terminal responds or the timeout 53 | // is reached. 54 | // copied from x/term@v0.1.3. 55 | func queryTerminal( 56 | in io.Reader, 57 | out io.Writer, 58 | timeout time.Duration, 59 | filter QueryTerminalFilter, 60 | query string, 61 | ) error { 62 | rd, err := input.NewReader(in, "", 0) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | defer rd.Close() // nolint: errcheck 68 | 69 | done := make(chan struct{}, 1) 70 | defer close(done) 71 | go func() { 72 | select { 73 | case <-done: 74 | case <-time.After(timeout): 75 | rd.Cancel() 76 | } 77 | }() 78 | 79 | if _, err := io.WriteString(out, query); err != nil { 80 | return err 81 | } 82 | 83 | for { 84 | events, err := rd.ReadEvents() 85 | if err != nil { 86 | return err 87 | } 88 | 89 | if !filter(events) { 90 | break 91 | } 92 | } 93 | 94 | return nil 95 | } 96 | -------------------------------------------------------------------------------- /bubbletea/tea.go: -------------------------------------------------------------------------------- 1 | // Package bubbletea provides middleware for serving bubbletea apps over SSH. 2 | package bubbletea 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "strings" 8 | 9 | tea "github.com/charmbracelet/bubbletea" 10 | "github.com/charmbracelet/lipgloss" 11 | "github.com/charmbracelet/log" 12 | "github.com/charmbracelet/ssh" 13 | "github.com/charmbracelet/wish" 14 | "github.com/muesli/termenv" 15 | ) 16 | 17 | // BubbleTeaHandler is the function Bubble Tea apps implement to hook into the 18 | // SSH Middleware. This will create a new tea.Program for every connection and 19 | // start it with the tea.ProgramOptions returned. 20 | // 21 | // Deprecated: use Handler instead. 22 | type BubbleTeaHandler = Handler // nolint: revive 23 | 24 | // Handler is the function Bubble Tea apps implement to hook into the 25 | // SSH Middleware. This will create a new tea.Program for every connection and 26 | // start it with the tea.ProgramOptions returned. 27 | type Handler func(sess ssh.Session) (tea.Model, []tea.ProgramOption) 28 | 29 | // ProgramHandler is the function Bubble Tea apps implement to hook into the SSH 30 | // Middleware. This should return a new tea.Program. This handler is different 31 | // from the default handler in that it returns a tea.Program instead of 32 | // (tea.Model, tea.ProgramOptions). 33 | // 34 | // Make sure to set the tea.WithInput and tea.WithOutput to the ssh.Session 35 | // otherwise the program will not function properly. 36 | type ProgramHandler func(sess ssh.Session) *tea.Program 37 | 38 | // Middleware takes a Handler and hooks the input and output for the 39 | // ssh.Session into the tea.Program. 40 | // 41 | // It also captures window resize events and sends them to the tea.Program 42 | // as tea.WindowSizeMsgs. 43 | func Middleware(handler Handler) wish.Middleware { 44 | return MiddlewareWithProgramHandler(newDefaultProgramHandler(handler), termenv.Ascii) 45 | } 46 | 47 | // MiddlewareWithColorProfile allows you to specify the minimum number of colors 48 | // this program needs to work properly. 49 | // 50 | // If the client's color profile has less colors than p, p will be forced. 51 | // Use with caution. 52 | func MiddlewareWithColorProfile(handler Handler, profile termenv.Profile) wish.Middleware { 53 | return MiddlewareWithProgramHandler(newDefaultProgramHandler(handler), profile) 54 | } 55 | 56 | // MiddlewareWithProgramHandler allows you to specify the ProgramHandler to be 57 | // able to access the underlying tea.Program, and the minimum supported color 58 | // profile. 59 | // 60 | // This is useful for creating custom middlewares that need access to 61 | // tea.Program for instance to use p.Send() to send messages to tea.Program. 62 | // 63 | // Make sure to set the tea.WithInput and tea.WithOutput to the ssh.Session 64 | // otherwise the program will not function properly. The recommended way 65 | // of doing so is by using MakeOptions. 66 | // 67 | // If the client's color profile has less colors than p, p will be forced. 68 | // Use with caution. 69 | func MiddlewareWithProgramHandler(handler ProgramHandler, profile termenv.Profile) wish.Middleware { 70 | return func(next ssh.Handler) ssh.Handler { 71 | return func(sess ssh.Session) { 72 | sess.Context().SetValue(minColorProfileKey, profile) 73 | program := handler(sess) 74 | if program == nil { 75 | next(sess) 76 | return 77 | } 78 | _, windowChanges, ok := sess.Pty() 79 | if !ok { 80 | wish.Fatalln(sess, "no active terminal, skipping") 81 | return 82 | } 83 | ctx, cancel := context.WithCancel(sess.Context()) 84 | go func() { 85 | for { 86 | select { 87 | case <-ctx.Done(): 88 | program.Quit() 89 | return 90 | case w := <-windowChanges: 91 | program.Send(tea.WindowSizeMsg{Width: w.Width, Height: w.Height}) 92 | } 93 | } 94 | }() 95 | if _, err := program.Run(); err != nil { 96 | log.Error("app exit with error", "error", err) 97 | } 98 | // p.Kill() will force kill the program if it's still running, 99 | // and restore the terminal to its original state in case of a 100 | // tui crash 101 | program.Kill() 102 | cancel() 103 | next(sess) 104 | } 105 | } 106 | } 107 | 108 | var minColorProfileKey struct{} 109 | 110 | var profileNames = [4]string{"TrueColor", "ANSI256", "ANSI", "Ascii"} 111 | 112 | // MakeRenderer returns a lipgloss renderer for the current session. 113 | // This function handle PTYs as well, and should be used to style your application. 114 | func MakeRenderer(sess ssh.Session) *lipgloss.Renderer { 115 | cp, ok := sess.Context().Value(minColorProfileKey).(termenv.Profile) 116 | if !ok { 117 | cp = termenv.Ascii 118 | } 119 | 120 | r := newRenderer(sess) 121 | 122 | // We only force the color profile if the requested session is a PTY. 123 | _, _, ok = sess.Pty() 124 | if !ok { 125 | return r 126 | } 127 | 128 | if r.ColorProfile() > cp { 129 | _, _ = fmt.Fprintf(sess.Stderr(), "Warning: Client's terminal is %q, forcing %q\r\n", 130 | profileNames[r.ColorProfile()], profileNames[cp]) 131 | r.SetColorProfile(cp) 132 | } 133 | return r 134 | } 135 | 136 | // MakeOptions returns the tea.WithInput and tea.WithOutput program options 137 | // taking into account possible Emulated or Allocated PTYs. 138 | func MakeOptions(sess ssh.Session) []tea.ProgramOption { 139 | return makeOpts(sess) 140 | } 141 | 142 | type sshEnviron []string 143 | 144 | var _ termenv.Environ = sshEnviron(nil) 145 | 146 | // Environ implements termenv.Environ. 147 | func (e sshEnviron) Environ() []string { 148 | return e 149 | } 150 | 151 | // Getenv implements termenv.Environ. 152 | func (e sshEnviron) Getenv(k string) string { 153 | for _, v := range e { 154 | if strings.HasPrefix(v, k+"=") { 155 | return v[len(k)+1:] 156 | } 157 | } 158 | return "" 159 | } 160 | 161 | func newDefaultProgramHandler(handler Handler) ProgramHandler { 162 | return func(s ssh.Session) *tea.Program { 163 | m, opts := handler(s) 164 | if m == nil { 165 | return nil 166 | } 167 | return tea.NewProgram(m, append(opts, makeOpts(s)...)...) 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /bubbletea/tea_other.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !darwin && !freebsd && !dragonfly && !netbsd && !openbsd && !solaris 2 | // +build !linux,!darwin,!freebsd,!dragonfly,!netbsd,!openbsd,!solaris 3 | 4 | package bubbletea 5 | 6 | import ( 7 | tea "github.com/charmbracelet/bubbletea" 8 | "github.com/charmbracelet/lipgloss" 9 | "github.com/charmbracelet/ssh" 10 | "github.com/muesli/termenv" 11 | ) 12 | 13 | func makeOpts(s ssh.Session) []tea.ProgramOption { 14 | return []tea.ProgramOption{ 15 | tea.WithInput(s), 16 | tea.WithOutput(s), 17 | } 18 | } 19 | 20 | func newRenderer(s ssh.Session) *lipgloss.Renderer { 21 | pty, _, _ := s.Pty() 22 | env := sshEnviron(append(s.Environ(), "TERM="+pty.Term)) 23 | return lipgloss.NewRenderer(s, termenv.WithEnvironment(env), termenv.WithUnsafe(), termenv.WithColorCache(true)) 24 | } 25 | -------------------------------------------------------------------------------- /bubbletea/tea_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris 2 | // +build darwin dragonfly freebsd linux netbsd openbsd solaris 3 | 4 | package bubbletea 5 | 6 | import ( 7 | "image/color" 8 | "time" 9 | 10 | tea "github.com/charmbracelet/bubbletea" 11 | "github.com/charmbracelet/lipgloss" 12 | "github.com/charmbracelet/ssh" 13 | "github.com/charmbracelet/x/ansi" 14 | "github.com/charmbracelet/x/input" 15 | "github.com/charmbracelet/x/term" 16 | "github.com/lucasb-eyer/go-colorful" 17 | "github.com/muesli/termenv" 18 | ) 19 | 20 | func makeOpts(s ssh.Session) []tea.ProgramOption { 21 | pty, _, ok := s.Pty() 22 | if !ok || s.EmulatedPty() { 23 | return []tea.ProgramOption{ 24 | tea.WithInput(s), 25 | tea.WithOutput(s), 26 | } 27 | } 28 | 29 | return []tea.ProgramOption{ 30 | tea.WithInput(pty.Slave), 31 | tea.WithOutput(pty.Slave), 32 | } 33 | } 34 | 35 | func newRenderer(s ssh.Session) *lipgloss.Renderer { 36 | pty, _, ok := s.Pty() 37 | if !ok || pty.Term == "" || pty.Term == "dumb" { 38 | return lipgloss.NewRenderer(s, termenv.WithProfile(termenv.Ascii)) 39 | } 40 | env := sshEnviron(append(s.Environ(), "TERM="+pty.Term)) 41 | var r *lipgloss.Renderer 42 | var bg color.Color 43 | if ok && pty.Slave != nil { 44 | r = lipgloss.NewRenderer( 45 | pty.Slave, 46 | termenv.WithEnvironment(env), 47 | termenv.WithColorCache(true), 48 | ) 49 | state, err := term.MakeRaw(pty.Slave.Fd()) 50 | if err == nil { 51 | bg, _ = queryBackgroundColor(pty.Slave, pty.Slave) 52 | _ = term.Restore(pty.Slave.Fd(), state) 53 | } 54 | } else { 55 | r = lipgloss.NewRenderer( 56 | s, 57 | termenv.WithEnvironment(env), 58 | termenv.WithUnsafe(), 59 | termenv.WithColorCache(true), 60 | ) 61 | bg = querySessionBackgroundColor(s) 62 | } 63 | if bg != nil { 64 | c, ok := colorful.MakeColor(bg) 65 | if ok { 66 | _, _, l := c.Hsl() 67 | r.SetHasDarkBackground(l < 0.5) 68 | } 69 | } 70 | return r 71 | } 72 | 73 | // copied from x/term@v0.1.3. 74 | func querySessionBackgroundColor(s ssh.Session) (bg color.Color) { 75 | _ = queryTerminal(s, s, time.Second, func(events []input.Event) bool { 76 | for _, e := range events { 77 | switch e := e.(type) { 78 | case input.BackgroundColorEvent: 79 | bg = e.Color 80 | continue // we need to consume the next DA1 event 81 | case input.PrimaryDeviceAttributesEvent: 82 | return false 83 | } 84 | } 85 | return true 86 | }, ansi.RequestBackgroundColor+ansi.RequestPrimaryDeviceAttributes) 87 | return 88 | } 89 | -------------------------------------------------------------------------------- /cmd.go: -------------------------------------------------------------------------------- 1 | package wish 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "os/exec" 7 | 8 | tea "github.com/charmbracelet/bubbletea" 9 | "github.com/charmbracelet/ssh" 10 | ) 11 | 12 | // CommandContext is like Command but includes a context. 13 | // 14 | // If the current session does not have a PTY, it sets them to the session 15 | // itself. 16 | func CommandContext(ctx context.Context, s ssh.Session, name string, args ...string) *Cmd { 17 | cmd := exec.CommandContext(ctx, name, args...) 18 | return &Cmd{s, cmd} 19 | } 20 | 21 | // Command sets stdin, stdout, and stderr to the current session's PTY. 22 | // 23 | // If the current session does not have a PTY, it sets them to the session 24 | // itself. 25 | // 26 | // This will use the session's context as the context for exec.Command. 27 | func Command(s ssh.Session, name string, args ...string) *Cmd { 28 | return CommandContext(s.Context(), s, name, args...) 29 | } 30 | 31 | // Cmd wraps a *exec.Cmd and a ssh.Pty so a command can be properly run. 32 | type Cmd struct { 33 | sess ssh.Session 34 | cmd *exec.Cmd 35 | } 36 | 37 | // SetDir set the underlying exec.Cmd env. 38 | func (c *Cmd) SetEnv(env []string) { 39 | c.cmd.Env = env 40 | } 41 | 42 | // Environ returns the underlying exec.Cmd environment. 43 | func (c *Cmd) Environ() []string { 44 | return c.cmd.Environ() 45 | } 46 | 47 | // SetDir set the underlying exec.Cmd dir. 48 | func (c *Cmd) SetDir(dir string) { 49 | c.cmd.Dir = dir 50 | } 51 | 52 | // Run runs the program and waits for it to finish. 53 | func (c *Cmd) Run() error { 54 | ppty, winCh, ok := c.sess.Pty() 55 | if !ok { 56 | c.cmd.Stdin, c.cmd.Stdout, c.cmd.Stderr = c.sess, c.sess, c.sess.Stderr() 57 | return c.cmd.Run() 58 | } 59 | return c.doRun(ppty, winCh) 60 | } 61 | 62 | var _ tea.ExecCommand = &Cmd{} 63 | 64 | // SetStderr conforms with tea.ExecCommand. 65 | func (*Cmd) SetStderr(io.Writer) {} 66 | 67 | // SetStdin conforms with tea.ExecCommand. 68 | func (*Cmd) SetStdin(io.Reader) {} 69 | 70 | // SetStdout conforms with tea.ExecCommand. 71 | func (*Cmd) SetStdout(io.Writer) {} 72 | -------------------------------------------------------------------------------- /cmd_test.go: -------------------------------------------------------------------------------- 1 | package wish 2 | 3 | import ( 4 | "bytes" 5 | "runtime" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/charmbracelet/ssh" 11 | "github.com/charmbracelet/wish/testsession" 12 | ) 13 | 14 | func TestCommandNoPty(t *testing.T) { 15 | tmp := t.TempDir() 16 | sess := testsession.New(t, &ssh.Server{ 17 | Handler: func(s ssh.Session) { 18 | runEcho(s, "hello") 19 | runEnv(s, []string{"HELLO=world"}) 20 | runPwd(s, tmp) 21 | }, 22 | }, nil) 23 | var stdout bytes.Buffer 24 | var stderr bytes.Buffer 25 | sess.Stdout = &stdout 26 | sess.Stderr = &stderr 27 | if err := sess.Run(""); err != nil { 28 | t.Errorf("expected no error, got %v: %s", err, stderr.String()) 29 | } 30 | out := stdout.String() 31 | expectContains(t, out, "hello") 32 | expectContains(t, out, "HELLO=world") 33 | expectContains(t, out, tmp) 34 | } 35 | 36 | func TestCommandPty(t *testing.T) { 37 | tmp := t.TempDir() 38 | srv := &ssh.Server{ 39 | Handler: func(s ssh.Session) { 40 | runEcho(s, "hello") 41 | runEnv(s, []string{"HELLO=world"}) 42 | runPwd(s, tmp) 43 | // for some reason sometimes on macos github action runners, 44 | // it cuts parts of the output. 45 | time.Sleep(100 * time.Millisecond) 46 | }, 47 | } 48 | if err := ssh.AllocatePty()(srv); err != nil { 49 | t.Fatalf("expected no error, got %v", err) 50 | } 51 | 52 | sess := testsession.New(t, srv, nil) 53 | if err := sess.RequestPty("xterm", 500, 200, nil); err != nil { 54 | t.Fatalf("expected no error, got %v", err) 55 | } 56 | 57 | var stdout bytes.Buffer 58 | var stderr bytes.Buffer 59 | sess.Stdout = &stdout 60 | sess.Stderr = &stderr 61 | if err := sess.Run(""); err != nil { 62 | t.Errorf("expected no error, got %v: %s", err, stderr.String()) 63 | } 64 | out := stdout.String() 65 | expectContains(t, out, "hello") 66 | expectContains(t, out, "HELLO=world") 67 | expectContains(t, out, tmp) 68 | } 69 | 70 | func TestCommandPtyError(t *testing.T) { 71 | if runtime.GOOS == "windows" { 72 | t.Skip() 73 | } 74 | srv := &ssh.Server{ 75 | Handler: func(s ssh.Session) { 76 | if err := Command(s, "nopenopenope").Run(); err != nil { 77 | Fatal(s, err) 78 | } 79 | }, 80 | } 81 | if err := ssh.AllocatePty()(srv); err != nil { 82 | t.Fatalf("expected no error, got %v", err) 83 | } 84 | 85 | sess := testsession.New(t, srv, nil) 86 | if err := sess.RequestPty("xterm", 500, 200, nil); err != nil { 87 | t.Fatalf("expected no error, got %v", err) 88 | } 89 | 90 | var stderr bytes.Buffer 91 | sess.Stderr = &stderr 92 | if err := sess.Run(""); err == nil { 93 | t.Errorf("expected an error, got nil") 94 | } 95 | expect := `exec: "nopenopenope"` 96 | if s := stderr.String(); !strings.Contains(s, expect) { 97 | t.Errorf("expected output to contain %q, got %q", expect, s) 98 | } 99 | } 100 | 101 | func runEcho(s ssh.Session, str string) { 102 | cmd := Command(s, "echo", str) 103 | if runtime.GOOS == "windows" { 104 | cmd = Command(s, "cmd", "/C", "echo", str) 105 | } 106 | // these should do nothing... 107 | cmd.SetStderr(nil) 108 | cmd.SetStdin(nil) 109 | cmd.SetStdout(nil) 110 | if err := cmd.Run(); err != nil { 111 | Fatal(s, err) 112 | } 113 | } 114 | 115 | func runEnv(s ssh.Session, env []string) { 116 | cmd := Command(s, "env") 117 | if runtime.GOOS == "windows" { 118 | cmd = Command(s, "cmd", "/C", "set") 119 | } 120 | cmd.SetEnv(env) 121 | if err := cmd.Run(); err != nil { 122 | Fatal(s, err) 123 | } 124 | if len(cmd.Environ()) == 0 { 125 | Fatal(s, "cmd.Environ() should not be empty") 126 | } 127 | } 128 | 129 | func runPwd(s ssh.Session, dir string) { 130 | cmd := Command(s, "pwd") 131 | if runtime.GOOS == "windows" { 132 | cmd = Command(s, "cmd", "/C", "cd") 133 | } 134 | cmd.SetDir(dir) 135 | if err := cmd.Run(); err != nil { 136 | Fatal(s, err) 137 | } 138 | } 139 | 140 | func expectContains(tb testing.TB, s, substr string) { 141 | if !strings.Contains(s, substr) { 142 | tb.Errorf("expected output %q to contain %q", s, substr) 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /cmd_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris 2 | // +build darwin dragonfly freebsd linux netbsd openbsd solaris 3 | 4 | package wish 5 | 6 | import "github.com/charmbracelet/ssh" 7 | 8 | func (c *Cmd) doRun(ppty ssh.Pty, _ <-chan ssh.Window) error { 9 | if err := ppty.Start(c.cmd); err != nil { 10 | return err 11 | } 12 | return c.cmd.Wait() 13 | } 14 | -------------------------------------------------------------------------------- /cmd_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package wish 5 | 6 | import ( 7 | "fmt" 8 | "time" 9 | 10 | "github.com/charmbracelet/ssh" 11 | ) 12 | 13 | func (c *Cmd) doRun(ppty ssh.Pty, _ <-chan ssh.Window) error { 14 | if err := ppty.Start(c.cmd); err != nil { 15 | return err 16 | } 17 | 18 | start := time.Now() 19 | for c.cmd.ProcessState == nil { 20 | if time.Since(start) > time.Second*10 { 21 | return fmt.Errorf("could not start process") 22 | } 23 | time.Sleep(100 * time.Millisecond) 24 | } 25 | if !c.cmd.ProcessState.Success() { 26 | return fmt.Errorf("process failed: exit %d", c.cmd.ProcessState.ExitCode()) 27 | } 28 | return nil 29 | } 30 | -------------------------------------------------------------------------------- /comment/comment.go: -------------------------------------------------------------------------------- 1 | package comment 2 | 3 | import ( 4 | "github.com/charmbracelet/ssh" 5 | "github.com/charmbracelet/wish" 6 | ) 7 | 8 | // Middleware prints a comment at the end of the session. 9 | func Middleware(comment string) wish.Middleware { 10 | return func(sh ssh.Handler) ssh.Handler { 11 | return func(s ssh.Session) { 12 | sh(s) 13 | wish.Println(s, comment) 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /comment/comment_test.go: -------------------------------------------------------------------------------- 1 | package comment 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/charmbracelet/ssh" 7 | "github.com/charmbracelet/wish/testsession" 8 | gossh "golang.org/x/crypto/ssh" 9 | ) 10 | 11 | func TestMiddleware(t *testing.T) { 12 | t.Run("recover session", func(t *testing.T) { 13 | b, err := setup(t).Output("") 14 | requireNoError(t, err) 15 | if string(b) != "test\n" { 16 | t.Errorf("expected comment to be 'test', got %q", string(b)) 17 | } 18 | }) 19 | } 20 | 21 | func setup(tb testing.TB) *gossh.Session { 22 | tb.Helper() 23 | return testsession.New(tb, &ssh.Server{ 24 | Handler: Middleware("test")(func(s ssh.Session) {}), 25 | }, nil) 26 | } 27 | 28 | func requireNoError(t *testing.T, err error) { 29 | t.Helper() 30 | 31 | if err != nil { 32 | t.Fatalf("expected no error, got %q", err.Error()) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /elapsed/elapsed.go: -------------------------------------------------------------------------------- 1 | package elapsed 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/charmbracelet/ssh" 7 | "github.com/charmbracelet/wish" 8 | ) 9 | 10 | // MiddlewareWithFormat returns a middleware that logs the elapsed time of the 11 | // session. It accepts a format string to print the elapsed time. 12 | // 13 | // In order to provide an accurate elapsed time for the entire session, 14 | // this must be called as the last middleware in the chain. 15 | func MiddlewareWithFormat(format string) wish.Middleware { 16 | return func(sh ssh.Handler) ssh.Handler { 17 | return func(s ssh.Session) { 18 | now := time.Now() 19 | sh(s) 20 | wish.Printf(s, format, time.Since(now)) 21 | } 22 | } 23 | } 24 | 25 | // Middleware returns a middleware that logs the elapsed time of the session. 26 | // 27 | // In order to provide an accurate elapsed time for the entire session, 28 | // this must be called as the last middleware in the chain. 29 | func Middleware() wish.Middleware { 30 | return MiddlewareWithFormat("elapsed time: %v\n") 31 | } 32 | -------------------------------------------------------------------------------- /elapsed/elapsed_test.go: -------------------------------------------------------------------------------- 1 | package elapsed 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/charmbracelet/ssh" 8 | "github.com/charmbracelet/wish/testsession" 9 | gossh "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | var waitDuration = time.Second 13 | 14 | func TestMiddleware(t *testing.T) { 15 | t.Run("recover session", func(t *testing.T) { 16 | b, err := setup(t).Output("") 17 | requireNoError(t, err) 18 | dur, err := time.ParseDuration(string(b)) 19 | requireNoError(t, err) 20 | if dur < waitDuration { 21 | t.Errorf("expected elapsed time to be at least 1s, got %v", dur) 22 | } 23 | }) 24 | } 25 | 26 | func setup(tb testing.TB) *gossh.Session { 27 | tb.Helper() 28 | return testsession.New(tb, &ssh.Server{ 29 | Handler: MiddlewareWithFormat("%v")(func(s ssh.Session) { 30 | time.Sleep(waitDuration) 31 | }), 32 | }, nil) 33 | } 34 | 35 | func requireNoError(t *testing.T, err error) { 36 | t.Helper() 37 | 38 | if err != nil { 39 | t.Fatalf("expected no error, got %q", err.Error()) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | id_ed25519* 2 | file.txt 3 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Wish Examples 2 | 3 | We recommend you follow the examples in the following order: 4 | 5 | ## Basics 6 | 7 | 1. [Simple](./simple) 8 | 1. [Graceful Shutdown](./graceful-shutdown) 9 | 1. [Server banner and middleware](./banner) 10 | 1. [Identifying Users](./identity) 11 | 1. [Multiple authentication types](./multi-auth) 12 | 13 | ## Making SSH apps 14 | 15 | 1. [Using spf13/cobra](./cobra) 16 | 1. [Serving Bubble Tea apps](./bubbletea) 17 | 1. [Serving Bubble Tea programs](./bubbleteaprogram) 18 | 1. [Reverse Port Forwarding](./forward) 19 | 1. [Multichat](./multichat) 20 | 21 | ## SCP, SFTP, and Git 22 | 23 | 1. [Serving a Git repository](./git) 24 | 1. [SCP and SFTP](./scp) 25 | 26 | ## Pseudo Terminals 27 | 28 | 1. [Allocate a PTY](./pty) 29 | 1. [Running Bubble Tea, and executing another program on an allocated PTY](./bubbletea-exec) 30 | -------------------------------------------------------------------------------- /examples/banner/banner.txt: -------------------------------------------------------------------------------- 1 | 2 | Hello %s, welcome to a SSH server powered by 3 | __ ___ _ _ _ _ 4 | \ \ / (_)___| |__ | | | | 5 | \ \ /\ / /| / __| '_ \| | | | 6 | \ V V / | \__ \ | | |_|_|_| 7 | \_/\_/ |_|___/_| |_(_|_|_) 8 | 9 | 10 | PS: The password is "asd123". 11 | PPS: Visit https://charm.sh to learn more! 12 | 13 | --- 14 | 15 | -------------------------------------------------------------------------------- /examples/banner/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | "time" 12 | 13 | _ "embed" 14 | 15 | "github.com/charmbracelet/log" 16 | "github.com/charmbracelet/ssh" 17 | "github.com/charmbracelet/wish" 18 | "github.com/charmbracelet/wish/elapsed" 19 | "github.com/charmbracelet/wish/logging" 20 | ) 21 | 22 | const ( 23 | host = "localhost" 24 | port = "23234" 25 | ) 26 | 27 | //go:embed banner.txt 28 | var banner string 29 | 30 | func main() { 31 | s, err := wish.NewServer( 32 | wish.WithAddress(net.JoinHostPort(host, port)), 33 | wish.WithHostKeyPath(".ssh/id_ed25519"), 34 | // A banner is always shown, even before authentication. 35 | wish.WithBannerHandler(func(ctx ssh.Context) string { 36 | return fmt.Sprintf(banner, ctx.User()) 37 | }), 38 | wish.WithPasswordAuth(func(ctx ssh.Context, password string) bool { 39 | return password == "asd123" 40 | }), 41 | wish.WithMiddleware( 42 | func(next ssh.Handler) ssh.Handler { 43 | return func(sess ssh.Session) { 44 | wish.Println(sess, fmt.Sprintf("Hello, %s!", sess.User())) 45 | next(sess) 46 | } 47 | }, 48 | logging.Middleware(), 49 | // This middleware prints the session duration before disconnecting. 50 | elapsed.Middleware(), 51 | ), 52 | ) 53 | if err != nil { 54 | log.Error("Could not start server", "error", err) 55 | } 56 | 57 | done := make(chan os.Signal, 1) 58 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 59 | log.Info("Starting SSH server", "host", host, "port", port) 60 | go func() { 61 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 62 | log.Error("Could not start server", "error", err) 63 | done <- nil 64 | } 65 | }() 66 | 67 | <-done 68 | log.Info("Stopping SSH server") 69 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 70 | defer func() { cancel() }() 71 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 72 | log.Error("Could not stop server", "error", err) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /examples/bubbletea-exec/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "runtime" 10 | "syscall" 11 | "time" 12 | 13 | tea "github.com/charmbracelet/bubbletea" 14 | "github.com/charmbracelet/lipgloss" 15 | "github.com/charmbracelet/log" 16 | "github.com/charmbracelet/ssh" 17 | "github.com/charmbracelet/wish" 18 | "github.com/charmbracelet/wish/activeterm" 19 | "github.com/charmbracelet/wish/bubbletea" 20 | "github.com/charmbracelet/wish/logging" 21 | "github.com/charmbracelet/x/editor" 22 | ) 23 | 24 | const ( 25 | host = "localhost" 26 | port = "23234" 27 | ) 28 | 29 | func main() { 30 | s, err := wish.NewServer( 31 | wish.WithAddress(net.JoinHostPort(host, port)), 32 | 33 | // Allocate a pty. 34 | // This creates a pseudoconsole on windows, compatibility is limited in 35 | // that case, see the open issues for more details. 36 | ssh.AllocatePty(), 37 | wish.WithMiddleware( 38 | // run our Bubble Tea handler 39 | bubbletea.Middleware(teaHandler), 40 | 41 | // ensure the user has requested a tty 42 | activeterm.Middleware(), 43 | logging.Middleware(), 44 | ), 45 | ) 46 | if err != nil { 47 | log.Error("Could not start server", "error", err) 48 | } 49 | 50 | done := make(chan os.Signal, 1) 51 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 52 | log.Info("Starting SSH server", "host", host, "port", port) 53 | go func() { 54 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 55 | log.Error("Could not start server", "error", err) 56 | done <- nil 57 | } 58 | }() 59 | 60 | <-done 61 | log.Info("Stopping SSH server") 62 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 63 | defer func() { cancel() }() 64 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 65 | log.Error("Could not stop server", "error", err) 66 | } 67 | } 68 | 69 | func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) { 70 | // Create a lipgloss.Renderer for the session 71 | renderer := bubbletea.MakeRenderer(s) 72 | // Set up the model with the current session and styles. 73 | // We'll use the session to call wish.Command, which makes it compatible 74 | // with tea.Command. 75 | m := model{ 76 | sess: s, 77 | style: renderer.NewStyle().Foreground(lipgloss.Color("8")), 78 | errStyle: renderer.NewStyle().Foreground(lipgloss.Color("3")), 79 | } 80 | return m, []tea.ProgramOption{tea.WithAltScreen()} 81 | } 82 | 83 | type model struct { 84 | err error 85 | sess ssh.Session 86 | style lipgloss.Style 87 | errStyle lipgloss.Style 88 | } 89 | 90 | func (m model) Init() tea.Cmd { 91 | return nil 92 | } 93 | 94 | type cmdFinishedMsg struct{ err error } 95 | 96 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 97 | switch msg := msg.(type) { 98 | case tea.KeyMsg: 99 | switch msg.String() { 100 | case "e": 101 | // Open file.txt in the default editor. 102 | edit, err := editor.Cmd("wish", "file.txt") 103 | if err != nil { 104 | m.err = err 105 | return m, nil 106 | } 107 | // Creates a wish.Cmd from the exec.Cmd 108 | wishCmd := wish.Command(m.sess, edit.Path, edit.Args...) 109 | // Runs the cmd through Bubble Tea. 110 | // Bubble Tea should handle the IO to the program, and get it back 111 | // once the program quits. 112 | cmd := tea.Exec(wishCmd, func(err error) tea.Msg { 113 | if err != nil { 114 | log.Error("editor finished", "error", err) 115 | } 116 | return cmdFinishedMsg{err: err} 117 | }) 118 | return m, cmd 119 | case "s": 120 | // We can also execute a shell and give it over to the user. 121 | // Note that this session won't have control, so it can't run tasks 122 | // in background, suspend, etc. 123 | c := wish.Command(m.sess, "bash", "-im") 124 | if runtime.GOOS == "windows" { 125 | c = wish.Command(m.sess, "powershell") 126 | } 127 | cmd := tea.Exec(c, func(err error) tea.Msg { 128 | if err != nil { 129 | log.Error("shell finished", "error", err) 130 | } 131 | return cmdFinishedMsg{err: err} 132 | }) 133 | return m, cmd 134 | case "q", "ctrl+c": 135 | return m, tea.Quit 136 | } 137 | case cmdFinishedMsg: 138 | m.err = msg.err 139 | return m, nil 140 | } 141 | 142 | return m, nil 143 | } 144 | 145 | func (m model) View() string { 146 | if m.err != nil { 147 | return m.errStyle.Render(m.err.Error() + "\n") 148 | } 149 | return m.style.Render("Press 'e' to edit, 's' to hop into a shell, or 'q' to quit...\n") 150 | } 151 | -------------------------------------------------------------------------------- /examples/bubbletea/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // An example Bubble Tea server. This will put an ssh session into alt screen 4 | // and continually print up to date terminal information. 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "net" 11 | "os" 12 | "os/signal" 13 | "syscall" 14 | "time" 15 | 16 | tea "github.com/charmbracelet/bubbletea" 17 | "github.com/charmbracelet/lipgloss" 18 | "github.com/charmbracelet/log" 19 | "github.com/charmbracelet/ssh" 20 | "github.com/charmbracelet/wish" 21 | "github.com/charmbracelet/wish/activeterm" 22 | "github.com/charmbracelet/wish/bubbletea" 23 | "github.com/charmbracelet/wish/logging" 24 | ) 25 | 26 | const ( 27 | host = "localhost" 28 | port = "23234" 29 | ) 30 | 31 | func main() { 32 | s, err := wish.NewServer( 33 | wish.WithAddress(net.JoinHostPort(host, port)), 34 | wish.WithHostKeyPath(".ssh/id_ed25519"), 35 | wish.WithMiddleware( 36 | bubbletea.Middleware(teaHandler), 37 | activeterm.Middleware(), // Bubble Tea apps usually require a PTY. 38 | logging.Middleware(), 39 | ), 40 | ) 41 | if err != nil { 42 | log.Error("Could not start server", "error", err) 43 | } 44 | 45 | done := make(chan os.Signal, 1) 46 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 47 | log.Info("Starting SSH server", "host", host, "port", port) 48 | go func() { 49 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 50 | log.Error("Could not start server", "error", err) 51 | done <- nil 52 | } 53 | }() 54 | 55 | <-done 56 | log.Info("Stopping SSH server") 57 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 58 | defer func() { cancel() }() 59 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 60 | log.Error("Could not stop server", "error", err) 61 | } 62 | } 63 | 64 | // You can wire any Bubble Tea model up to the middleware with a function that 65 | // handles the incoming ssh.Session. Here we just grab the terminal info and 66 | // pass it to the new model. You can also return tea.ProgramOptions (such as 67 | // tea.WithAltScreen) on a session by session basis. 68 | func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) { 69 | // This should never fail, as we are using the activeterm middleware. 70 | pty, _, _ := s.Pty() 71 | 72 | // When running a Bubble Tea app over SSH, you shouldn't use the default 73 | // lipgloss.NewStyle function. 74 | // That function will use the color profile from the os.Stdin, which is the 75 | // server, not the client. 76 | // We provide a MakeRenderer function in the bubbletea middleware package, 77 | // so you can easily get the correct renderer for the current session, and 78 | // use it to create the styles. 79 | // The recommended way to use these styles is to then pass them down to 80 | // your Bubble Tea model. 81 | renderer := bubbletea.MakeRenderer(s) 82 | txtStyle := renderer.NewStyle().Foreground(lipgloss.Color("10")) 83 | quitStyle := renderer.NewStyle().Foreground(lipgloss.Color("8")) 84 | 85 | bg := "light" 86 | if renderer.HasDarkBackground() { 87 | bg = "dark" 88 | } 89 | 90 | m := model{ 91 | term: pty.Term, 92 | profile: renderer.ColorProfile().Name(), 93 | width: pty.Window.Width, 94 | height: pty.Window.Height, 95 | bg: bg, 96 | txtStyle: txtStyle, 97 | quitStyle: quitStyle, 98 | } 99 | return m, []tea.ProgramOption{tea.WithAltScreen()} 100 | } 101 | 102 | // Just a generic tea.Model to demo terminal information of ssh. 103 | type model struct { 104 | term string 105 | profile string 106 | width int 107 | height int 108 | bg string 109 | txtStyle lipgloss.Style 110 | quitStyle lipgloss.Style 111 | } 112 | 113 | func (m model) Init() tea.Cmd { 114 | return nil 115 | } 116 | 117 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 118 | switch msg := msg.(type) { 119 | case tea.WindowSizeMsg: 120 | m.height = msg.Height 121 | m.width = msg.Width 122 | case tea.KeyMsg: 123 | switch msg.String() { 124 | case "q", "ctrl+c": 125 | return m, tea.Quit 126 | } 127 | } 128 | return m, nil 129 | } 130 | 131 | func (m model) View() string { 132 | s := fmt.Sprintf("Your term is %s\nYour window size is %dx%d\nBackground: %s\nColor Profile: %s", m.term, m.width, m.height, m.bg, m.profile) 133 | return m.txtStyle.Render(s) + "\n\n" + m.quitStyle.Render("Press 'q' to quit\n") 134 | } 135 | -------------------------------------------------------------------------------- /examples/bubbleteaprogram/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // An example Bubble Tea server. This will put an ssh session into alt screen 4 | // and continually print up to date terminal information. 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "net" 11 | "os" 12 | "os/signal" 13 | "syscall" 14 | "time" 15 | 16 | tea "github.com/charmbracelet/bubbletea" 17 | "github.com/charmbracelet/log" 18 | "github.com/charmbracelet/ssh" 19 | "github.com/charmbracelet/wish" 20 | "github.com/charmbracelet/wish/bubbletea" 21 | "github.com/charmbracelet/wish/logging" 22 | "github.com/muesli/termenv" 23 | ) 24 | 25 | const ( 26 | host = "localhost" 27 | port = "23234" 28 | ) 29 | 30 | func main() { 31 | s, err := wish.NewServer( 32 | wish.WithAddress(net.JoinHostPort(host, port)), 33 | wish.WithHostKeyPath(".ssh/id_ed25519"), 34 | wish.WithMiddleware( 35 | myCustomBubbleteaMiddleware(), 36 | logging.Middleware(), 37 | ), 38 | ) 39 | if err != nil { 40 | log.Error("Could not start server", "error", err) 41 | } 42 | 43 | done := make(chan os.Signal, 1) 44 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 45 | log.Info("Starting SSH server", "host", host, "port", port) 46 | go func() { 47 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 48 | log.Error("Could not start server", "error", err) 49 | done <- nil 50 | } 51 | }() 52 | 53 | <-done 54 | log.Info("Stopping SSH server") 55 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 56 | defer func() { cancel() }() 57 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 58 | log.Error("Could not stop server", "error", err) 59 | } 60 | } 61 | 62 | // You can write your own custom bubbletea middleware that wraps tea.Program. 63 | // Make sure you set the program input and output to ssh.Session. 64 | func myCustomBubbleteaMiddleware() wish.Middleware { 65 | newProg := func(m tea.Model, opts ...tea.ProgramOption) *tea.Program { 66 | p := tea.NewProgram(m, opts...) 67 | go func() { 68 | for { 69 | <-time.After(1 * time.Second) 70 | p.Send(timeMsg(time.Now())) 71 | } 72 | }() 73 | return p 74 | } 75 | teaHandler := func(s ssh.Session) *tea.Program { 76 | pty, _, active := s.Pty() 77 | if !active { 78 | wish.Fatalln(s, "no active terminal, skipping") 79 | return nil 80 | } 81 | m := model{ 82 | term: pty.Term, 83 | width: pty.Window.Width, 84 | height: pty.Window.Height, 85 | time: time.Now(), 86 | } 87 | return newProg(m, append(bubbletea.MakeOptions(s), tea.WithAltScreen())...) 88 | } 89 | return bubbletea.MiddlewareWithProgramHandler(teaHandler, termenv.ANSI256) 90 | } 91 | 92 | // Just a generic tea.Model to demo terminal information of ssh. 93 | type model struct { 94 | term string 95 | width int 96 | height int 97 | time time.Time 98 | } 99 | 100 | type timeMsg time.Time 101 | 102 | func (m model) Init() tea.Cmd { 103 | return nil 104 | } 105 | 106 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 107 | switch msg := msg.(type) { 108 | case timeMsg: 109 | m.time = time.Time(msg) 110 | case tea.WindowSizeMsg: 111 | m.height = msg.Height 112 | m.width = msg.Width 113 | case tea.KeyMsg: 114 | switch msg.String() { 115 | case "q", "ctrl+c": 116 | return m, tea.Quit 117 | } 118 | } 119 | return m, nil 120 | } 121 | 122 | func (m model) View() string { 123 | s := "Your term is %s\n" 124 | s += "Your window size is x: %d y: %d\n" 125 | s += "Time: " + m.time.Format(time.RFC1123) + "\n\n" 126 | s += "Press 'q' to quit\n" 127 | return fmt.Sprintf(s, m.term, m.width, m.height) 128 | } 129 | -------------------------------------------------------------------------------- /examples/cobra/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/charmbracelet/log" 13 | "github.com/charmbracelet/ssh" 14 | "github.com/charmbracelet/wish" 15 | "github.com/charmbracelet/wish/logging" 16 | "github.com/spf13/cobra" 17 | ) 18 | 19 | const ( 20 | host = "localhost" 21 | port = "23235" 22 | ) 23 | 24 | func cmd() *cobra.Command { 25 | var reverse bool 26 | cmd := &cobra.Command{ 27 | Use: "echo [string]", 28 | Args: cobra.ExactArgs(1), 29 | RunE: func(cmd *cobra.Command, args []string) error { 30 | s := args[0] 31 | if reverse { 32 | ss := make([]byte, 0, len(s)) 33 | for i := len(s) - 1; i >= 0; i-- { 34 | ss = append(ss, s[i]) 35 | } 36 | s = string(ss) 37 | } 38 | cmd.Println(s) 39 | return nil 40 | }, 41 | } 42 | 43 | cmd.PersistentFlags().BoolVarP(&reverse, "reverse", "r", false, "Reverse string on echo") 44 | return cmd 45 | } 46 | 47 | func main() { 48 | s, err := wish.NewServer( 49 | wish.WithAddress(net.JoinHostPort(host, port)), 50 | wish.WithHostKeyPath(".ssh/id_ed25519"), 51 | wish.WithMiddleware( 52 | func(next ssh.Handler) ssh.Handler { 53 | return func(sess ssh.Session) { 54 | // Here we wire our command's args and IO to the user 55 | // session's 56 | rootCmd := cmd() 57 | rootCmd.SetArgs(sess.Command()) 58 | rootCmd.SetIn(sess) 59 | rootCmd.SetOut(sess) 60 | rootCmd.SetErr(sess.Stderr()) 61 | rootCmd.CompletionOptions.DisableDefaultCmd = true 62 | if err := rootCmd.Execute(); err != nil { 63 | _ = sess.Exit(1) 64 | return 65 | } 66 | 67 | next(sess) 68 | } 69 | }, 70 | logging.Middleware(), 71 | ), 72 | ) 73 | if err != nil { 74 | log.Error("Could not start server", "error", err) 75 | } 76 | 77 | done := make(chan os.Signal, 1) 78 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 79 | log.Info("Starting SSH server", "host", host, "port", port) 80 | go func() { 81 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 82 | log.Error("Could not start server", "error", err) 83 | done <- nil 84 | } 85 | }() 86 | 87 | <-done 88 | log.Info("Stopping SSH server") 89 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 90 | defer func() { cancel() }() 91 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 92 | log.Error("Could not stop server", "error", err) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /examples/exec/example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gum choose a b c d 4 | -------------------------------------------------------------------------------- /examples/exec/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/charmbracelet/log" 13 | "github.com/charmbracelet/ssh" 14 | "github.com/charmbracelet/wish" 15 | "github.com/charmbracelet/wish/activeterm" 16 | "github.com/charmbracelet/wish/logging" 17 | ) 18 | 19 | const ( 20 | host = "localhost" 21 | port = "23234" 22 | ) 23 | 24 | func main() { 25 | s, err := wish.NewServer( 26 | wish.WithAddress(net.JoinHostPort(host, port)), 27 | 28 | // Allocate a pty. 29 | // This creates a pseudoconsole on windows, compatibility is limited in 30 | // that case, see the open issues for more details. 31 | ssh.AllocatePty(), 32 | wish.WithMiddleware( 33 | func(next ssh.Handler) ssh.Handler { 34 | return func(s ssh.Session) { 35 | cmd := wish.Command(s, "bash", "example.sh") 36 | if err := cmd.Run(); err != nil { 37 | wish.Fatalln(s, err) 38 | } 39 | next(s) 40 | } 41 | }, 42 | // ensure the user has requested a tty 43 | activeterm.Middleware(), 44 | logging.Middleware(), 45 | ), 46 | ) 47 | if err != nil { 48 | log.Error("Could not start server", "error", err) 49 | } 50 | 51 | done := make(chan os.Signal, 1) 52 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 53 | log.Info("Starting SSH server", "host", host, "port", port) 54 | go func() { 55 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 56 | log.Error("Could not start server", "error", err) 57 | done <- nil 58 | } 59 | }() 60 | 61 | <-done 62 | log.Info("Stopping SSH server") 63 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 64 | defer func() { cancel() }() 65 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 66 | log.Error("Could not stop server", "error", err) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /examples/forward/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/charmbracelet/log" 13 | "github.com/charmbracelet/ssh" 14 | "github.com/charmbracelet/wish" 15 | "github.com/charmbracelet/wish/logging" 16 | ) 17 | 18 | const ( 19 | host = "localhost" 20 | port = "23234" 21 | ) 22 | 23 | // example usage: ssh -N -R 23236:localhost:23235 -p 23234 localhost 24 | 25 | func main() { 26 | // Create a new SSH ForwardedTCPHandler. 27 | forwardHandler := &ssh.ForwardedTCPHandler{} 28 | s, err := wish.NewServer( 29 | wish.WithAddress(net.JoinHostPort(host, port)), 30 | wish.WithHostKeyPath(".ssh/id_ed25519"), 31 | func(s *ssh.Server) error { 32 | // Set the Reverse TCP Handler up: 33 | s.ReversePortForwardingCallback = func(_ ssh.Context, bindHost string, bindPort uint32) bool { 34 | log.Info("reverse port forwarding allowed", "host", bindHost, "port", bindPort) 35 | return true 36 | } 37 | s.RequestHandlers = map[string]ssh.RequestHandler{ 38 | "tcpip-forward": forwardHandler.HandleSSHRequest, 39 | "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, 40 | } 41 | return nil 42 | }, 43 | wish.WithMiddleware( 44 | func(h ssh.Handler) ssh.Handler { 45 | return func(s ssh.Session) { 46 | wish.Println(s, "Remote port forwarding available!") 47 | wish.Println(s, "Try it with:") 48 | wish.Println(s, " ssh -N -R 23236:localhost:23235 -p 23234 localhost") 49 | h(s) 50 | } 51 | }, 52 | logging.Middleware(), 53 | ), 54 | ) 55 | if err != nil { 56 | log.Error("Could not start server", "error", err) 57 | } 58 | 59 | done := make(chan os.Signal, 1) 60 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 61 | log.Info("Starting SSH server", "host", host, "port", port) 62 | go func() { 63 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 64 | log.Error("Could not start server", "error", err) 65 | done <- nil 66 | } 67 | }() 68 | 69 | <-done 70 | log.Info("Stopping SSH server") 71 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 72 | defer func() { cancel() }() 73 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 74 | log.Error("Could not stop server", "error", err) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /examples/git/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // An example git server. This will list all available repos if you ssh 4 | // directly to the server. To test `ssh -p 23233 localhost` once it's running. 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "io/fs" 11 | "net" 12 | "os" 13 | "os/signal" 14 | "syscall" 15 | "time" 16 | 17 | "github.com/charmbracelet/log" 18 | "github.com/charmbracelet/ssh" 19 | "github.com/charmbracelet/wish" 20 | "github.com/charmbracelet/wish/git" 21 | "github.com/charmbracelet/wish/logging" 22 | ) 23 | 24 | const ( 25 | port = "23233" 26 | host = "localhost" 27 | repoDir = ".repos" 28 | ) 29 | 30 | type app struct { 31 | access git.AccessLevel 32 | } 33 | 34 | func (a app) AuthRepo(string, ssh.PublicKey) git.AccessLevel { 35 | return a.access 36 | } 37 | 38 | func (a app) Push(repo string, _ ssh.PublicKey) { 39 | log.Info("push", "repo", repo) 40 | } 41 | 42 | func (a app) Fetch(repo string, _ ssh.PublicKey) { 43 | log.Info("fetch", "repo", repo) 44 | } 45 | 46 | func main() { 47 | // A simple GitHooks implementation to allow global read write access. 48 | a := app{git.ReadWriteAccess} 49 | 50 | s, err := wish.NewServer( 51 | wish.WithAddress(net.JoinHostPort(host, port)), 52 | wish.WithHostKeyPath(".ssh/id_ed25519"), 53 | // Accept any public key. 54 | ssh.PublicKeyAuth(func(ssh.Context, ssh.PublicKey) bool { return true }), 55 | // Do not accept password auth. 56 | ssh.PasswordAuth(func(ssh.Context, string) bool { return false }), 57 | wish.WithMiddleware( 58 | // Setup the git middleware. 59 | git.Middleware(repoDir, a), 60 | // Adds a middleware to list all available repositories to the user. 61 | gitListMiddleware, 62 | logging.Middleware(), 63 | ), 64 | ) 65 | if err != nil { 66 | log.Error("Could not start server", "error", err) 67 | } 68 | 69 | done := make(chan os.Signal, 1) 70 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 71 | log.Info("Starting SSH server", "host", host, "port", port) 72 | go func() { 73 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 74 | log.Error("Could not start server", "error", err) 75 | done <- nil 76 | } 77 | }() 78 | 79 | <-done 80 | log.Info("Stopping SSH server") 81 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 82 | defer func() { cancel() }() 83 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 84 | log.Error("Could not stop server", "error", err) 85 | } 86 | } 87 | 88 | // Normally we would use a Bubble Tea program for the TUI but for simplicity, 89 | // we'll just write a list of the pushed repos to the terminal and exit the ssh 90 | // session. 91 | func gitListMiddleware(next ssh.Handler) ssh.Handler { 92 | return func(sess ssh.Session) { 93 | // Git will have a command included so only run this if there are no 94 | // commands passed to ssh. 95 | if len(sess.Command()) != 0 { 96 | next(sess) 97 | return 98 | } 99 | 100 | dest, err := os.ReadDir(repoDir) 101 | if err != nil && err != fs.ErrNotExist { 102 | log.Error("Invalid repository", "error", err) 103 | } 104 | if len(dest) > 0 { 105 | fmt.Fprintf(sess, "\n### Repo Menu ###\n\n") 106 | } 107 | for _, dir := range dest { 108 | wish.Println(sess, fmt.Sprintf("• %s - ", dir.Name())) 109 | wish.Println(sess, fmt.Sprintf("git clone ssh://%s/%s", net.JoinHostPort(host, port), dir.Name())) 110 | } 111 | wish.Printf(sess, "\n\n### Add some repos! ###\n\n") 112 | wish.Printf(sess, "> cd some_repo\n") 113 | wish.Printf(sess, "> git remote add wish_test ssh://%s/some_repo\n", net.JoinHostPort(host, port)) 114 | wish.Printf(sess, "> git push wish_test\n\n\n") 115 | next(sess) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /examples/go.mod: -------------------------------------------------------------------------------- 1 | module examples 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | require ( 8 | github.com/charmbracelet/bubbles v0.21.0 9 | github.com/charmbracelet/bubbletea v1.3.5 10 | github.com/charmbracelet/lipgloss v1.1.0 11 | github.com/charmbracelet/log v0.4.2 12 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894 13 | github.com/charmbracelet/wish v0.5.0 14 | github.com/charmbracelet/x/editor v0.1.0 15 | github.com/muesli/termenv v0.16.0 16 | github.com/pkg/sftp v1.13.9 17 | github.com/spf13/cobra v1.9.1 18 | golang.org/x/crypto v0.38.0 19 | ) 20 | 21 | require ( 22 | dario.cat/mergo v1.0.0 // indirect 23 | github.com/Microsoft/go-winio v0.6.2 // indirect 24 | github.com/ProtonMail/go-crypto v1.1.6 // indirect 25 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect 26 | github.com/atotto/clipboard v0.1.4 // indirect 27 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 28 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect 29 | github.com/charmbracelet/keygen v0.5.3 // indirect 30 | github.com/charmbracelet/x/ansi v0.9.2 // indirect 31 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect 32 | github.com/charmbracelet/x/conpty v0.1.0 // indirect 33 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 // indirect 34 | github.com/charmbracelet/x/input v0.3.4 // indirect 35 | github.com/charmbracelet/x/term v0.2.1 // indirect 36 | github.com/charmbracelet/x/termios v0.1.0 // indirect 37 | github.com/charmbracelet/x/windows v0.2.0 // indirect 38 | github.com/cloudflare/circl v1.6.1 // indirect 39 | github.com/creack/pty v1.1.21 // indirect 40 | github.com/cyphar/filepath-securejoin v0.4.1 // indirect 41 | github.com/emirpasic/gods v1.18.1 // indirect 42 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect 43 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect 44 | github.com/go-git/go-billy/v5 v5.6.2 // indirect 45 | github.com/go-git/go-git/v5 v5.16.0 // indirect 46 | github.com/go-logfmt/logfmt v0.6.0 // indirect 47 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect 48 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 49 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect 50 | github.com/kevinburke/ssh_config v1.2.0 // indirect 51 | github.com/kr/fs v0.1.0 // indirect 52 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 53 | github.com/mattn/go-isatty v0.0.20 // indirect 54 | github.com/mattn/go-localereader v0.0.1 // indirect 55 | github.com/mattn/go-runewidth v0.0.16 // indirect 56 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect 57 | github.com/muesli/cancelreader v0.2.2 // indirect 58 | github.com/pjbgf/sha1cd v0.3.2 // indirect 59 | github.com/rivo/uniseg v0.4.7 // indirect 60 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect 61 | github.com/skeema/knownhosts v1.3.1 // indirect 62 | github.com/spf13/pflag v1.0.6 // indirect 63 | github.com/xanzy/ssh-agent v0.3.3 // indirect 64 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect 65 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect 66 | golang.org/x/net v0.39.0 // indirect 67 | golang.org/x/sync v0.14.0 // indirect 68 | golang.org/x/sys v0.33.0 // indirect 69 | golang.org/x/text v0.25.0 // indirect 70 | gopkg.in/warnings.v0 v0.1.2 // indirect 71 | ) 72 | 73 | replace github.com/charmbracelet/wish => ../ 74 | -------------------------------------------------------------------------------- /examples/graceful-shutdown/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/charmbracelet/log" 13 | "github.com/charmbracelet/ssh" 14 | "github.com/charmbracelet/wish" 15 | "github.com/charmbracelet/wish/logging" 16 | ) 17 | 18 | const ( 19 | host = "localhost" 20 | port = "23234" 21 | ) 22 | 23 | func main() { 24 | srv, err := wish.NewServer( 25 | wish.WithAddress(net.JoinHostPort(host, port)), 26 | wish.WithHostKeyPath(".ssh/id_ed25519"), 27 | wish.WithMiddleware( 28 | func(next ssh.Handler) ssh.Handler { 29 | return func(sess ssh.Session) { 30 | wish.Println(sess, "Hello, world!") 31 | next(sess) 32 | } 33 | }, 34 | logging.Middleware(), 35 | ), 36 | ) 37 | if err != nil { 38 | log.Error("Could not start server", "error", err) 39 | } 40 | 41 | // Before starting our server, we create a channel and listen for some 42 | // common interrupt signals. 43 | done := make(chan os.Signal, 1) 44 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 45 | 46 | // We then start the server in a goroutine, as we'll listen for the done 47 | // signal later. 48 | go func() { 49 | log.Info("Starting SSH server", "host", host, "port", port) 50 | if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 51 | // We ignore ErrServerClosed because it is expected. 52 | log.Error("Could not start server", "error", err) 53 | done <- nil 54 | } 55 | }() 56 | 57 | // Here we wait for the done signal: this can be either an interrupt, or 58 | // the server shutting down for any other reason. 59 | <-done 60 | 61 | // When it arrives, we create a context with a timeout. 62 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 63 | defer func() { cancel() }() 64 | 65 | // When we start the shutdown, the server will no longer accept new 66 | // connections, but will wait as much as the given context allows for the 67 | // active connections to finish. 68 | // After the timeout, it shuts down anyway. 69 | log.Info("Stopping SSH server") 70 | if err := srv.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 71 | log.Error("Could not stop server", "error", err) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /examples/identity/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/charmbracelet/log" 14 | "github.com/charmbracelet/ssh" 15 | "github.com/charmbracelet/wish" 16 | "github.com/charmbracelet/wish/logging" 17 | ) 18 | 19 | const ( 20 | host = "localhost" 21 | port = "23234" 22 | ) 23 | 24 | var users = map[string]string{ 25 | "Carlos": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILxWe2rXKoiO6W14LYPVfJKzRfJ1f3Jhzxrgjc/D4tU7", 26 | // You can add add your name and public key here :) 27 | } 28 | 29 | func main() { 30 | s, err := wish.NewServer( 31 | wish.WithAddress(net.JoinHostPort(host, port)), 32 | wish.WithHostKeyPath(".ssh/id_ed25519"), 33 | // This will allow anyone to log in, as long as they have given an 34 | // ed25519 public key. 35 | // You can test this by doing something like: 36 | // ssh -i ~/.ssh/id_ed25519 -p 23234 localhost 37 | // ssh -i ~/.ssh/id_rsa -p 23234 localhost 38 | // ssh -o PreferredAuthentications=password -p 23234 localhost 39 | wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { 40 | return key.Type() == "ssh-ed25519" 41 | }), 42 | wish.WithMiddleware( 43 | func(next ssh.Handler) ssh.Handler { 44 | return func(sess ssh.Session) { 45 | // if the current session's user public key is one of the 46 | // known users, we greet them and return. 47 | for name, pubkey := range users { 48 | parsed, _, _, _, _ := ssh.ParseAuthorizedKey( 49 | []byte(pubkey), 50 | ) 51 | if ssh.KeysEqual(sess.PublicKey(), parsed) { 52 | wish.Println(sess, fmt.Sprintf("Hey %s!", name)) 53 | next(sess) 54 | return 55 | } 56 | } 57 | wish.Println(sess, "Hey, I don't know who you are!") 58 | next(sess) 59 | } 60 | }, 61 | logging.Middleware(), 62 | ), 63 | ) 64 | if err != nil { 65 | log.Error("Could not start server", "error", err) 66 | } 67 | 68 | done := make(chan os.Signal, 1) 69 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 70 | log.Info("Starting SSH server", "host", host, "port", port) 71 | go func() { 72 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 73 | log.Error("Could not start server", "error", err) 74 | done <- nil 75 | } 76 | }() 77 | 78 | <-done 79 | log.Info("Stopping SSH server") 80 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 81 | defer func() { cancel() }() 82 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 83 | log.Error("Could not stop server", "error", err) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /examples/multi-auth/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/charmbracelet/log" 13 | "github.com/charmbracelet/ssh" 14 | "github.com/charmbracelet/wish" 15 | "github.com/charmbracelet/wish/logging" 16 | gossh "golang.org/x/crypto/ssh" 17 | ) 18 | 19 | const ( 20 | host = "localhost" 21 | port = "23234" 22 | validPassword = "asd123" 23 | ) 24 | 25 | var users = map[string]string{ 26 | "Carlos": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILxWe2rXKoiO6W14LYPVfJKzRfJ1f3Jhzxrgjc/D4tU7", 27 | // You can add add your name and public key here :) 28 | } 29 | 30 | func main() { 31 | s, err := wish.NewServer( 32 | wish.WithAddress(net.JoinHostPort(host, port)), 33 | wish.WithHostKeyPath(".ssh/id_ed25519"), 34 | 35 | // In this example, we'll have multiple possible authentication methods. 36 | // The order of preference is defined by the user (via 37 | // PreferredAuthentications), and if all of them fails, they aren't 38 | // allowed in. 39 | // 40 | // You can SSH into the server like so: 41 | // ssh -o PreferredAuthentications=none -p 23234 localhost 42 | // ssh -o PreferredAuthentications=password -p 23234 localhost 43 | // ssh -o PreferredAuthentications=publickey -p 23234 localhost 44 | // ssh -o PreferredAuthentications=keyboard-interactive -p 23234 localhost 45 | 46 | // First, public-key authentication: 47 | wish.WithPublicKeyAuth(func(_ ssh.Context, key ssh.PublicKey) bool { 48 | log.Info("publickey") 49 | for _, pubkey := range users { 50 | parsed, _, _, _, _ := ssh.ParseAuthorizedKey( 51 | []byte(pubkey), 52 | ) 53 | if ssh.KeysEqual(key, parsed) { 54 | return true 55 | } 56 | } 57 | return false 58 | }), 59 | 60 | // Then, password. 61 | wish.WithPasswordAuth(func(_ ssh.Context, password string) bool { 62 | log.Info("password") 63 | return password == validPassword 64 | }), 65 | 66 | // Finally, keyboard-interactive, which you can use to ask the user to 67 | // answer a challenge: 68 | wish.WithKeyboardInteractiveAuth(func(_ ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool { 69 | log.Info("keyboard-interactive") 70 | answers, err := challenger( 71 | "", "", 72 | []string{ 73 | "♦ How much is 2+3: ", 74 | "♦ Which editor is best, vim or emacs? ", 75 | }, 76 | []bool{true, true}, 77 | ) 78 | if err != nil { 79 | return false 80 | } 81 | // here we check for the correct answers: 82 | return len(answers) == 2 && answers[0] == "5" && answers[1] == "vim" 83 | }), 84 | 85 | wish.WithMiddleware( 86 | logging.Middleware(), 87 | func(next ssh.Handler) ssh.Handler { 88 | return func(sess ssh.Session) { 89 | wish.Println(sess, "Authorized!") 90 | wish.Println(sess, sess.PublicKey()) 91 | } 92 | }, 93 | ), 94 | ) 95 | if err != nil { 96 | log.Error("Could not start server", "error", err) 97 | } 98 | 99 | done := make(chan os.Signal, 1) 100 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 101 | log.Info("Starting SSH server", "host", host, "port", port) 102 | go func() { 103 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 104 | log.Error("Could not start server", "error", err) 105 | done <- nil 106 | } 107 | }() 108 | 109 | <-done 110 | log.Info("Stopping SSH server") 111 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 112 | defer func() { cancel() }() 113 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 114 | log.Error("Could not stop server", "error", err) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /examples/multichat/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "os" 8 | "os/signal" 9 | "strings" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/charmbracelet/bubbles/textarea" 14 | "github.com/charmbracelet/bubbles/viewport" 15 | tea "github.com/charmbracelet/bubbletea" 16 | "github.com/charmbracelet/lipgloss" 17 | "github.com/charmbracelet/log" 18 | "github.com/charmbracelet/ssh" 19 | "github.com/charmbracelet/wish" 20 | "github.com/charmbracelet/wish/activeterm" 21 | "github.com/charmbracelet/wish/bubbletea" 22 | "github.com/charmbracelet/wish/logging" 23 | "github.com/muesli/termenv" 24 | ) 25 | 26 | const ( 27 | host = "localhost" 28 | port = "23234" 29 | ) 30 | 31 | // app contains a wish server and the list of running programs. 32 | type app struct { 33 | *ssh.Server 34 | progs []*tea.Program 35 | } 36 | 37 | // send dispatches a message to all running programs. 38 | func (a *app) send(msg tea.Msg) { 39 | for _, p := range a.progs { 40 | go p.Send(msg) 41 | } 42 | } 43 | 44 | func newApp() *app { 45 | a := new(app) 46 | s, err := wish.NewServer( 47 | wish.WithAddress(net.JoinHostPort(host, port)), 48 | wish.WithHostKeyPath(".ssh/id_ed25519"), 49 | wish.WithMiddleware( 50 | bubbletea.MiddlewareWithProgramHandler(a.ProgramHandler, termenv.ANSI256), 51 | activeterm.Middleware(), 52 | logging.Middleware(), 53 | ), 54 | ) 55 | if err != nil { 56 | log.Error("Could not start server", "error", err) 57 | } 58 | 59 | a.Server = s 60 | return a 61 | } 62 | 63 | func (a *app) Start() { 64 | var err error 65 | done := make(chan os.Signal, 1) 66 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 67 | log.Info("Starting SSH server", "host", host, "port", port) 68 | go func() { 69 | if err = a.ListenAndServe(); err != nil { 70 | log.Error("Could not start server", "error", err) 71 | done <- nil 72 | } 73 | }() 74 | 75 | <-done 76 | log.Info("Stopping SSH server") 77 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 78 | defer func() { cancel() }() 79 | if err := a.Shutdown(ctx); err != nil { 80 | log.Error("Could not stop server", "error", err) 81 | } 82 | } 83 | 84 | func (a *app) ProgramHandler(s ssh.Session) *tea.Program { 85 | model := initialModel() 86 | model.app = a 87 | model.id = s.User() 88 | 89 | p := tea.NewProgram(model, bubbletea.MakeOptions(s)...) 90 | a.progs = append(a.progs, p) 91 | 92 | return p 93 | } 94 | 95 | func main() { 96 | app := newApp() 97 | app.Start() 98 | } 99 | 100 | type ( 101 | errMsg error 102 | chatMsg struct { 103 | id string 104 | text string 105 | } 106 | ) 107 | 108 | type model struct { 109 | *app 110 | viewport viewport.Model 111 | messages []string 112 | id string 113 | textarea textarea.Model 114 | senderStyle lipgloss.Style 115 | err error 116 | } 117 | 118 | func initialModel() model { 119 | ta := textarea.New() 120 | ta.Placeholder = "Send a message..." 121 | ta.Focus() 122 | 123 | ta.Prompt = "┃ " 124 | ta.CharLimit = 280 125 | 126 | ta.SetWidth(30) 127 | ta.SetHeight(3) 128 | 129 | // Remove cursor line styling 130 | ta.FocusedStyle.CursorLine = lipgloss.NewStyle() 131 | 132 | ta.ShowLineNumbers = false 133 | 134 | vp := viewport.New(30, 5) 135 | vp.SetContent(`Welcome to the chat room! 136 | Type a message and press Enter to send.`) 137 | 138 | ta.KeyMap.InsertNewline.SetEnabled(false) 139 | 140 | return model{ 141 | textarea: ta, 142 | messages: []string{}, 143 | viewport: vp, 144 | senderStyle: lipgloss.NewStyle().Foreground(lipgloss.Color("5")), 145 | err: nil, 146 | } 147 | } 148 | 149 | func (m model) Init() tea.Cmd { 150 | return textarea.Blink 151 | } 152 | 153 | func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { 154 | var ( 155 | tiCmd tea.Cmd 156 | vpCmd tea.Cmd 157 | ) 158 | 159 | m.textarea, tiCmd = m.textarea.Update(msg) 160 | m.viewport, vpCmd = m.viewport.Update(msg) 161 | 162 | switch msg := msg.(type) { 163 | case tea.KeyMsg: 164 | switch msg.Type { 165 | case tea.KeyCtrlC, tea.KeyEsc: 166 | return m, tea.Quit 167 | case tea.KeyEnter: 168 | m.app.send(chatMsg{ 169 | id: m.id, 170 | text: m.textarea.Value(), 171 | }) 172 | m.textarea.Reset() 173 | } 174 | 175 | case chatMsg: 176 | m.messages = append(m.messages, m.senderStyle.Render(msg.id)+": "+msg.text) 177 | m.viewport.SetContent(strings.Join(m.messages, "\n")) 178 | m.viewport.GotoBottom() 179 | 180 | // We handle errors just like any other message 181 | case errMsg: 182 | m.err = msg 183 | return m, nil 184 | } 185 | 186 | return m, tea.Batch(tiCmd, vpCmd) 187 | } 188 | 189 | func (m model) View() string { 190 | return fmt.Sprintf( 191 | "%s\n\n%s", 192 | m.viewport.View(), 193 | m.textarea.View(), 194 | ) + "\n\n" 195 | } 196 | -------------------------------------------------------------------------------- /examples/pty/main.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package main 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "net" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/charmbracelet/log" 15 | "github.com/charmbracelet/ssh" 16 | "github.com/charmbracelet/wish" 17 | "github.com/charmbracelet/wish/activeterm" 18 | "github.com/charmbracelet/wish/bubbletea" 19 | "github.com/charmbracelet/wish/logging" 20 | ) 21 | 22 | const ( 23 | host = "localhost" 24 | port = "23234" 25 | ) 26 | 27 | func main() { 28 | srv, err := wish.NewServer( 29 | wish.WithAddress(net.JoinHostPort(host, port)), 30 | wish.WithHostKeyPath(".ssh/id_ed25519"), 31 | 32 | // Wish can allocate a PTY per user session. 33 | ssh.AllocatePty(), 34 | 35 | wish.WithMiddleware( 36 | func(next ssh.Handler) ssh.Handler { 37 | return func(sess ssh.Session) { 38 | pty, _, _ := sess.Pty() 39 | renderer := bubbletea.MakeRenderer(sess) 40 | 41 | bg := "light" 42 | if renderer.HasDarkBackground() { 43 | bg = "dark" 44 | } 45 | 46 | wish.Printf(sess, "Hello, world!\r\n") 47 | wish.Printf(sess, "Term: %s\r\n", pty.Term) 48 | wish.Printf(sess, "PTY: %s\r\n", pty.Slave.Name()) 49 | wish.Printf(sess, "FD: %d\r\n", pty.Slave.Fd()) 50 | wish.Printf(sess, "Background: %v\r\n", bg) 51 | next(sess) 52 | } 53 | }, 54 | 55 | activeterm.Middleware(), 56 | logging.Middleware(), 57 | ), 58 | ) 59 | if err != nil { 60 | log.Error("Could not start server", "error", err) 61 | } 62 | 63 | done := make(chan os.Signal, 1) 64 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 65 | 66 | go func() { 67 | log.Info("Starting SSH server", "host", host, "port", port) 68 | if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 69 | log.Error("Could not start server", "error", err) 70 | done <- nil 71 | } 72 | }() 73 | 74 | <-done 75 | 76 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 77 | defer func() { cancel() }() 78 | log.Info("Stopping SSH server") 79 | if err := srv.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 80 | log.Error("Could not stop server", "error", err) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /examples/scp/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // An example SCP server. This will serve files from and to ./examples/scp/testdata. 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "io/fs" 11 | "net" 12 | "os" 13 | "os/signal" 14 | "path/filepath" 15 | "syscall" 16 | "time" 17 | 18 | "github.com/charmbracelet/log" 19 | "github.com/charmbracelet/ssh" 20 | "github.com/charmbracelet/wish" 21 | "github.com/charmbracelet/wish/scp" 22 | "github.com/pkg/sftp" 23 | ) 24 | 25 | const ( 26 | host = "localhost" 27 | port = "23235" 28 | ) 29 | 30 | func main() { 31 | root, _ := filepath.Abs("./examples/scp/testdata") 32 | handler := scp.NewFileSystemHandler(root) 33 | s, err := wish.NewServer( 34 | wish.WithAddress(net.JoinHostPort(host, port)), 35 | wish.WithHostKeyPath(".ssh/id_ed25519"), 36 | 37 | // setup the sftp subsystem 38 | wish.WithSubsystem("sftp", sftpSubsystem(root)), 39 | wish.WithMiddleware( 40 | // setup the scp middleware 41 | scp.Middleware(handler, handler), 42 | ), 43 | ) 44 | if err != nil { 45 | log.Error("Could not start server", "error", err) 46 | } 47 | 48 | done := make(chan os.Signal, 1) 49 | signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) 50 | log.Info("Starting SSH server", "host", host, "port", port) 51 | go func() { 52 | if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 53 | log.Error("Could not start server", "error", err) 54 | done <- nil 55 | } 56 | }() 57 | 58 | <-done 59 | log.Info("Stopping SSH server") 60 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 61 | defer func() { cancel() }() 62 | if err := s.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 63 | log.Error("Could not stop server", "error", err) 64 | } 65 | } 66 | 67 | func sftpSubsystem(root string) ssh.SubsystemHandler { 68 | return func(s ssh.Session) { 69 | log.Info("sftp", "root", root) 70 | fs := &sftpHandler{root} 71 | srv := sftp.NewRequestServer(s, sftp.Handlers{ 72 | FileList: fs, 73 | FileGet: fs, 74 | }) 75 | if err := srv.Serve(); err == io.EOF { 76 | if err := srv.Close(); err != nil { 77 | wish.Fatalln(s, "sftp:", err) 78 | } 79 | } else if err != nil { 80 | wish.Fatalln(s, "sftp:", err) 81 | } 82 | } 83 | } 84 | 85 | // Example readonly handler implementation for sftp. 86 | // 87 | // Other example implementations: 88 | // - https://github.com/gravitational/teleport/blob/f57dc2fe2a9900ec198779aae747ac4f833b278d/tool/teleport/common/sftp.go 89 | // - https://github.com/minio/minio/blob/c66c5828eacb4a7fa9a49b4c890c77dd8684b171/cmd/sftp-server.go 90 | type sftpHandler struct { 91 | root string 92 | } 93 | 94 | var ( 95 | _ sftp.FileLister = &sftpHandler{} 96 | _ sftp.FileReader = &sftpHandler{} 97 | ) 98 | 99 | type listerAt []fs.FileInfo 100 | 101 | func (l listerAt) ListAt(ls []fs.FileInfo, offset int64) (int, error) { 102 | if offset >= int64(len(l)) { 103 | return 0, io.EOF 104 | } 105 | n := copy(ls, l[offset:]) 106 | if n < len(ls) { 107 | return n, io.EOF 108 | } 109 | return n, nil 110 | } 111 | 112 | // Fileread implements sftp.FileReader. 113 | func (s *sftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) { 114 | var flags int 115 | pflags := r.Pflags() 116 | if pflags.Append { 117 | flags |= os.O_APPEND 118 | } 119 | if pflags.Creat { 120 | flags |= os.O_CREATE 121 | } 122 | if pflags.Excl { 123 | flags |= os.O_EXCL 124 | } 125 | if pflags.Trunc { 126 | flags |= os.O_TRUNC 127 | } 128 | 129 | if pflags.Read && pflags.Write { 130 | flags |= os.O_RDWR 131 | } else if pflags.Read { 132 | flags |= os.O_RDONLY 133 | } else if pflags.Write { 134 | flags |= os.O_WRONLY 135 | } 136 | 137 | f, err := os.OpenFile(filepath.Join(s.root, r.Filepath), flags, 0600) 138 | if err != nil { 139 | return nil, err 140 | } 141 | 142 | return f, nil 143 | } 144 | 145 | // Filelist implements sftp.FileLister. 146 | func (s *sftpHandler) Filelist(r *sftp.Request) (sftp.ListerAt, error) { 147 | switch r.Method { 148 | case "List": 149 | entries, err := os.ReadDir(filepath.Join(s.root, r.Filepath)) 150 | if err != nil { 151 | return nil, fmt.Errorf("sftp: %w", err) 152 | } 153 | infos := make([]fs.FileInfo, len(entries)) 154 | for i, entry := range entries { 155 | info, err := entry.Info() 156 | if err != nil { 157 | return nil, err 158 | } 159 | infos[i] = info 160 | } 161 | return listerAt(infos), nil 162 | case "Stat": 163 | fi, err := os.Stat(filepath.Join(s.root, r.Filepath)) 164 | if err != nil { 165 | return nil, err 166 | } 167 | return listerAt{fi}, nil 168 | default: 169 | return nil, sftp.ErrSSHFxOpUnsupported 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /examples/scp/testdata/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charmbracelet/wish/5097be3e4161e3035a74a0d53d89d3adfe662320/examples/scp/testdata/.gitkeep -------------------------------------------------------------------------------- /examples/simple/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | 7 | "github.com/charmbracelet/log" 8 | "github.com/charmbracelet/ssh" 9 | "github.com/charmbracelet/wish" 10 | "github.com/charmbracelet/wish/logging" 11 | ) 12 | 13 | const ( 14 | host = "localhost" 15 | port = "23234" 16 | ) 17 | 18 | func main() { 19 | srv, err := wish.NewServer( 20 | // The address the server will listen to. 21 | wish.WithAddress(net.JoinHostPort(host, port)), 22 | 23 | // The SSH server need its own keys, this will create a keypair in the 24 | // given path if it doesn't exist yet. 25 | // By default, it will create an ED25519 key. 26 | wish.WithHostKeyPath(".ssh/id_ed25519"), 27 | 28 | // Middlewares do something on a ssh.Session, and then call the next 29 | // middleware in the stack. 30 | wish.WithMiddleware( 31 | func(next ssh.Handler) ssh.Handler { 32 | return func(sess ssh.Session) { 33 | wish.Println(sess, "Hello, world!") 34 | next(sess) 35 | } 36 | }, 37 | 38 | // The last item in the chain is the first to be called. 39 | logging.Middleware(), 40 | ), 41 | ) 42 | if err != nil { 43 | log.Error("Could not start server", "error", err) 44 | } 45 | 46 | log.Info("Starting SSH server", "host", host, "port", port) 47 | if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { 48 | // We ignore ErrServerClosed because it is expected. 49 | log.Error("Could not start server", "error", err) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /git/git.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "path/filepath" 9 | "strings" 10 | 11 | "github.com/charmbracelet/log" 12 | "github.com/charmbracelet/ssh" 13 | "github.com/charmbracelet/wish" 14 | "github.com/go-git/go-git/v5" 15 | "github.com/go-git/go-git/v5/plumbing" 16 | ) 17 | 18 | // ErrNotAuthed represents unauthorized access. 19 | var ErrNotAuthed = errors.New("you are not authorized to do this") 20 | 21 | // ErrSystemMalfunction represents a general system error returned to clients. 22 | var ErrSystemMalfunction = errors.New("something went wrong") 23 | 24 | // ErrInvalidRepo represents an attempt to access a non-existent repo. 25 | var ErrInvalidRepo = errors.New("invalid repo") 26 | 27 | // AccessLevel is the level of access allowed to a repo. 28 | type AccessLevel int 29 | 30 | const ( 31 | // NoAccess does not allow access to the repo. 32 | NoAccess AccessLevel = iota 33 | 34 | // ReadOnlyAccess allows read-only access to the repo. 35 | ReadOnlyAccess 36 | 37 | // ReadWriteAccess allows read and write access to the repo. 38 | ReadWriteAccess 39 | 40 | // AdminAccess allows read, write, and admin access to the repo. 41 | AdminAccess 42 | ) 43 | 44 | // GitHooks is an interface that allows for custom authorization 45 | // implementations and post push/fetch notifications. Prior to git access, 46 | // AuthRepo will be called with the ssh.Session public key and the repo name. 47 | // Implementers return the appropriate AccessLevel. 48 | // 49 | // Deprecated: use Hooks instead. 50 | type GitHooks = Hooks // nolint: revive 51 | 52 | // Hooks is an interface that allows for custom authorization 53 | // implementations and post push/fetch notifications. Prior to git access, 54 | // AuthRepo will be called with the ssh.Session public key and the repo name. 55 | // Implementers return the appropriate AccessLevel. 56 | type Hooks interface { 57 | AuthRepo(string, ssh.PublicKey) AccessLevel 58 | Push(string, ssh.PublicKey) 59 | Fetch(string, ssh.PublicKey) 60 | } 61 | 62 | // Middleware adds Git server functionality to the ssh.Server. Repos are stored 63 | // in the specified repo directory. The provided Hooks implementation will be 64 | // checked for access on a per repo basis for a ssh.Session public key. 65 | // Hooks.Push and Hooks.Fetch will be called on successful completion of 66 | // their commands. 67 | func Middleware(repoDir string, gh Hooks) wish.Middleware { 68 | return func(sh ssh.Handler) ssh.Handler { 69 | return func(s ssh.Session) { 70 | cmd := s.Command() 71 | if len(cmd) == 2 { 72 | gc := cmd[0] 73 | // repo should be in the form of "repo.git" or "user/repo.git" 74 | repo := strings.TrimSuffix(strings.TrimPrefix(cmd[1], "/"), "/") 75 | repo = filepath.Clean(repo) 76 | if n := strings.Count(repo, "/"); n > 1 { 77 | Fatal(s, ErrInvalidRepo) 78 | return 79 | } 80 | pk := s.PublicKey() 81 | access := gh.AuthRepo(repo, pk) 82 | switch gc { 83 | case "git-receive-pack": 84 | switch access { 85 | case ReadWriteAccess, AdminAccess: 86 | err := gitPack(s, gc, repoDir, repo) 87 | if err != nil { 88 | Fatal(s, ErrSystemMalfunction) 89 | } else { 90 | gh.Push(repo, pk) 91 | } 92 | default: 93 | Fatal(s, ErrNotAuthed) 94 | } 95 | return 96 | case "git-upload-archive", "git-upload-pack": 97 | switch access { 98 | case ReadOnlyAccess, ReadWriteAccess, AdminAccess: 99 | err := gitPack(s, gc, repoDir, repo) 100 | switch err { 101 | case ErrInvalidRepo: 102 | Fatal(s, ErrInvalidRepo) 103 | case nil: 104 | gh.Fetch(repo, pk) 105 | default: 106 | log.Error("unknown git error", "error", err) 107 | Fatal(s, ErrSystemMalfunction) 108 | } 109 | default: 110 | Fatal(s, ErrNotAuthed) 111 | } 112 | return 113 | } 114 | } 115 | sh(s) 116 | } 117 | } 118 | } 119 | 120 | func gitPack(s ssh.Session, gitCmd string, repoDir string, repo string) error { 121 | cmd := strings.TrimPrefix(gitCmd, "git-") 122 | rp := filepath.Join(repoDir, repo) 123 | switch gitCmd { 124 | case "git-upload-archive", "git-upload-pack": 125 | exists, err := fileExists(rp) 126 | if !exists { 127 | return ErrInvalidRepo 128 | } 129 | if err != nil { 130 | return err 131 | } 132 | return runGit(s, "", cmd, rp) 133 | case "git-receive-pack": 134 | err := EnsureRepo(repoDir, repo) 135 | if err != nil { 136 | return err 137 | } 138 | err = runGit(s, "", cmd, rp) 139 | if err != nil { 140 | return err 141 | } 142 | err = ensureDefaultBranch(s, rp) 143 | if err != nil { 144 | return err 145 | } 146 | // Needed for git dumb http server 147 | return runGit(s, rp, "update-server-info") 148 | default: 149 | return fmt.Errorf("unknown git command: %s", gitCmd) 150 | } 151 | } 152 | 153 | func fileExists(path string) (bool, error) { 154 | _, err := os.Stat(path) 155 | if err == nil { 156 | return true, nil 157 | } 158 | if os.IsNotExist(err) { 159 | return false, nil 160 | } 161 | return true, err 162 | } 163 | 164 | // Fatal prints to the session's STDOUT as a git response and exit 1. 165 | func Fatal(s ssh.Session, v ...interface{}) { 166 | msg := fmt.Sprint(v...) 167 | // hex length includes 4 byte length prefix and ending newline 168 | pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg) 169 | _, _ = wish.WriteString(s, pktLine) 170 | s.Exit(1) // nolint: errcheck 171 | } 172 | 173 | // EnsureRepo makes sure the given repo exists within the given dir, and that 174 | // it is git repository. 175 | // 176 | // If path does not exist, it'll be created. 177 | // If the path is not a git repo, it will be git init-ed as a bare repository. 178 | func EnsureRepo(dir, repo string) error { 179 | exists, err := fileExists(dir) 180 | if err != nil { 181 | return err 182 | } 183 | if !exists { 184 | err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0o700)) 185 | if err != nil { 186 | return err 187 | } 188 | } 189 | rp := filepath.Join(dir, repo) 190 | exists, err = fileExists(rp) 191 | if err != nil { 192 | return err 193 | } 194 | if !exists { 195 | _, err := git.PlainInit(rp, true) 196 | if err != nil { 197 | return err 198 | } 199 | } 200 | return nil 201 | } 202 | 203 | func runGit(s ssh.Session, dir string, args ...string) error { 204 | usi := exec.CommandContext(s.Context(), "git", args...) 205 | usi.Dir = dir 206 | usi.Stdout = s 207 | usi.Stdin = s 208 | if err := usi.Run(); err != nil { 209 | return err 210 | } 211 | return nil 212 | } 213 | 214 | func ensureDefaultBranch(s ssh.Session, repoPath string) error { 215 | r, err := git.PlainOpen(repoPath) 216 | if err != nil { 217 | return err 218 | } 219 | brs, err := r.Branches() 220 | if err != nil { 221 | return err 222 | } 223 | defer brs.Close() 224 | fb, err := brs.Next() 225 | if err != nil { 226 | return err 227 | } 228 | // Rename the default branch to the first branch available 229 | _, err = r.Head() 230 | if err == plumbing.ErrReferenceNotFound { 231 | err = runGit(s, repoPath, "branch", "-M", fb.Name().Short()) 232 | if err != nil { 233 | return err 234 | } 235 | } 236 | if err != nil && err != plumbing.ErrReferenceNotFound { 237 | return err 238 | } 239 | return nil 240 | } 241 | -------------------------------------------------------------------------------- /git/git_test.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os/exec" 7 | "path/filepath" 8 | "runtime" 9 | "sync" 10 | "testing" 11 | 12 | "github.com/charmbracelet/keygen" 13 | "github.com/charmbracelet/ssh" 14 | "github.com/charmbracelet/wish" 15 | ) 16 | 17 | func TestGitMiddleware(t *testing.T) { 18 | pubkey, pkPath := createKeyPair(t) 19 | hkPath := filepath.Join(t.TempDir(), "id_ed25519") 20 | 21 | l, err := net.Listen("tcp", "127.0.0.1:0") 22 | requireNoError(t, err) 23 | remote := "ssh://" + l.Addr().String() 24 | 25 | repoDir := t.TempDir() 26 | hooks := &testHooks{ 27 | pushes: []action{}, 28 | fetches: []action{}, 29 | access: []accessDetails{ 30 | {pubkey, "repo1", AdminAccess}, 31 | {pubkey, "repo2", AdminAccess}, 32 | {pubkey, "repo3", AdminAccess}, 33 | {pubkey, "repo4", AdminAccess}, 34 | {pubkey, "repo5", NoAccess}, 35 | {pubkey, "repo6", ReadOnlyAccess}, 36 | {pubkey, "repo7", AdminAccess}, 37 | {pubkey, "abc/repo1", AdminAccess}, 38 | {pubkey, "abc/def/repo1", AdminAccess}, 39 | }, 40 | } 41 | srv, err := wish.NewServer( 42 | wish.WithHostKeyPath(hkPath), 43 | wish.WithMiddleware(Middleware(repoDir, hooks)), 44 | wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { 45 | return true 46 | }), 47 | ) 48 | requireNoError(t, err) 49 | go func() { srv.Serve(l) }() 50 | t.Cleanup(func() { _ = srv.Close() }) 51 | 52 | t.Run("create repo on master", func(t *testing.T) { 53 | cwd := t.TempDir() 54 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "master")) 55 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo1")) 56 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 57 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "master")) 58 | requireHasAction(t, hooks.pushes, pubkey, "repo1") 59 | }) 60 | 61 | t.Run("create repo on main", func(t *testing.T) { 62 | cwd := t.TempDir() 63 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main")) 64 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo2")) 65 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 66 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main")) 67 | requireHasAction(t, hooks.pushes, pubkey, "repo2") 68 | }) 69 | 70 | t.Run("create repo in subdir", func(t *testing.T) { 71 | if runtime.GOOS == "windows" { 72 | t.Skip("permission issues") 73 | } 74 | cwd := t.TempDir() 75 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main")) 76 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/abc/repo1")) 77 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 78 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main")) 79 | requireHasAction(t, hooks.pushes, pubkey, "abc/repo1") 80 | }) 81 | 82 | t.Run("create wrong repo", func(t *testing.T) { 83 | cwd := t.TempDir() 84 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main")) 85 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"//../../repo1")) 86 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 87 | requireError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main")) 88 | }) 89 | 90 | t.Run("create wrong repo in subdir", func(t *testing.T) { 91 | cwd := t.TempDir() 92 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main")) 93 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/abc/def/repo1")) 94 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 95 | requireError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main")) 96 | }) 97 | 98 | t.Run("create and clone repo", func(t *testing.T) { 99 | cwd := t.TempDir() 100 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main")) 101 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo3")) 102 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 103 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main")) 104 | 105 | cwd = t.TempDir() 106 | requireNoError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo3")) 107 | 108 | requireHasAction(t, hooks.pushes, pubkey, "repo3") 109 | requireHasAction(t, hooks.fetches, pubkey, "repo3") 110 | }) 111 | 112 | t.Run("clone repo that doesn't exist", func(t *testing.T) { 113 | cwd := t.TempDir() 114 | requireError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo4")) 115 | }) 116 | 117 | t.Run("clone repo with no access", func(t *testing.T) { 118 | cwd := t.TempDir() 119 | requireError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo5")) 120 | }) 121 | 122 | t.Run("push repo with with readonly", func(t *testing.T) { 123 | cwd := t.TempDir() 124 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "main")) 125 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo6")) 126 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 127 | requireError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "main")) 128 | }) 129 | 130 | t.Run("create and clone repo on weird branch", func(t *testing.T) { 131 | cwd := t.TempDir() 132 | requireNoError(t, runGitHelper(t, pkPath, cwd, "init", "-b", "a-weird-branch-name")) 133 | requireNoError(t, runGitHelper(t, pkPath, cwd, "remote", "add", "origin", remote+"/repo7")) 134 | requireNoError(t, runGitHelper(t, pkPath, cwd, "commit", "--allow-empty", "-m", "initial commit")) 135 | requireNoError(t, runGitHelper(t, pkPath, cwd, "push", "origin", "a-weird-branch-name")) 136 | 137 | cwd = t.TempDir() 138 | requireNoError(t, runGitHelper(t, pkPath, cwd, "clone", remote+"/repo7")) 139 | 140 | requireHasAction(t, hooks.pushes, pubkey, "repo7") 141 | requireHasAction(t, hooks.fetches, pubkey, "repo7") 142 | }) 143 | } 144 | 145 | func runGitHelper(t *testing.T, pk, cwd string, args ...string) error { 146 | t.Helper() 147 | 148 | allArgs := []string{ 149 | "-c", "user.name='wish'", 150 | "-c", "user.email='test@wish'", 151 | "-c", "commit.gpgSign=false", 152 | "-c", "tag.gpgSign=false", 153 | "-c", "log.showSignature=false", 154 | "-c", "ssh.variant=ssh", 155 | } 156 | allArgs = append(allArgs, args...) 157 | 158 | cmd := exec.Command("git", allArgs...) 159 | cmd.Dir = cwd 160 | cmd.Env = []string{fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -i "%s" -F /dev/null`, pk)} 161 | out, err := cmd.CombinedOutput() 162 | if err != nil { 163 | t.Log("git out:", string(out)) 164 | } 165 | return err 166 | } 167 | 168 | func requireNoError(t *testing.T, err error) { 169 | t.Helper() 170 | 171 | if err != nil { 172 | t.Fatalf("expected no error, got %q", err.Error()) 173 | } 174 | } 175 | 176 | func requireError(t *testing.T, err error) { 177 | t.Helper() 178 | 179 | if err == nil { 180 | t.Fatalf("expected an error, got nil") 181 | } 182 | } 183 | 184 | func requireHasAction(t *testing.T, actions []action, key ssh.PublicKey, repo string) { 185 | t.Helper() 186 | 187 | for _, action := range actions { 188 | if repo == action.repo && ssh.KeysEqual(key, action.key) { 189 | return 190 | } 191 | } 192 | t.Fatalf("expected action for %q, got none", repo) 193 | } 194 | 195 | func createKeyPair(t *testing.T) (ssh.PublicKey, string) { 196 | t.Helper() 197 | 198 | pk := filepath.Join(t.TempDir(), "id_ed25519") 199 | kp, err := keygen.New(pk, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite()) 200 | requireNoError(t, err) 201 | return kp.PublicKey(), pk 202 | } 203 | 204 | type accessDetails struct { 205 | key ssh.PublicKey 206 | repo string 207 | level AccessLevel 208 | } 209 | 210 | type action struct { 211 | key ssh.PublicKey 212 | repo string 213 | } 214 | 215 | type testHooks struct { 216 | sync.Mutex 217 | pushes []action 218 | fetches []action 219 | access []accessDetails 220 | } 221 | 222 | func (h *testHooks) AuthRepo(repo string, key ssh.PublicKey) AccessLevel { 223 | for _, dets := range h.access { 224 | if dets.repo == repo && ssh.KeysEqual(key, dets.key) { 225 | return dets.level 226 | } 227 | } 228 | return NoAccess 229 | } 230 | 231 | func (h *testHooks) Push(repo string, key ssh.PublicKey) { 232 | h.Lock() 233 | defer h.Unlock() 234 | 235 | h.pushes = append(h.pushes, action{key, repo}) 236 | } 237 | 238 | func (h *testHooks) Fetch(repo string, key ssh.PublicKey) { 239 | h.Lock() 240 | defer h.Unlock() 241 | 242 | h.fetches = append(h.fetches, action{key, repo}) 243 | } 244 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/charmbracelet/wish 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.1 6 | 7 | require ( 8 | github.com/charmbracelet/bubbletea v1.3.5 9 | github.com/charmbracelet/keygen v0.5.3 10 | github.com/charmbracelet/lipgloss v1.1.0 11 | github.com/charmbracelet/log v0.4.2 12 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894 13 | github.com/charmbracelet/x/ansi v0.9.2 14 | github.com/charmbracelet/x/input v0.3.4 15 | github.com/charmbracelet/x/term v0.2.1 16 | github.com/go-git/go-git/v5 v5.16.0 17 | github.com/google/go-cmp v0.7.0 18 | github.com/hashicorp/golang-lru/v2 v2.0.7 19 | github.com/lucasb-eyer/go-colorful v1.2.0 20 | github.com/matryer/is v1.4.1 21 | github.com/muesli/termenv v0.16.0 22 | golang.org/x/crypto v0.38.0 23 | golang.org/x/sync v0.14.0 24 | golang.org/x/time v0.11.0 25 | ) 26 | 27 | require ( 28 | dario.cat/mergo v1.0.0 // indirect 29 | github.com/Microsoft/go-winio v0.6.2 // indirect 30 | github.com/ProtonMail/go-crypto v1.1.6 // indirect 31 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect 32 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 33 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect 34 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect 35 | github.com/charmbracelet/x/conpty v0.1.0 // indirect 36 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 // indirect 37 | github.com/charmbracelet/x/termios v0.1.0 // indirect 38 | github.com/charmbracelet/x/windows v0.2.0 // indirect 39 | github.com/cloudflare/circl v1.6.1 // indirect 40 | github.com/creack/pty v1.1.21 // indirect 41 | github.com/cyphar/filepath-securejoin v0.4.1 // indirect 42 | github.com/emirpasic/gods v1.18.1 // indirect 43 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect 44 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect 45 | github.com/go-git/go-billy/v5 v5.6.2 // indirect 46 | github.com/go-logfmt/logfmt v0.6.0 // indirect 47 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect 48 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect 49 | github.com/kevinburke/ssh_config v1.2.0 // indirect 50 | github.com/mattn/go-isatty v0.0.20 // indirect 51 | github.com/mattn/go-localereader v0.0.1 // indirect 52 | github.com/mattn/go-runewidth v0.0.16 // indirect 53 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect 54 | github.com/muesli/cancelreader v0.2.2 // indirect 55 | github.com/pjbgf/sha1cd v0.3.2 // indirect 56 | github.com/rivo/uniseg v0.4.7 // indirect 57 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect 58 | github.com/skeema/knownhosts v1.3.1 // indirect 59 | github.com/xanzy/ssh-agent v0.3.3 // indirect 60 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect 61 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect 62 | golang.org/x/net v0.39.0 // indirect 63 | golang.org/x/sys v0.33.0 // indirect 64 | golang.org/x/text v0.25.0 // indirect 65 | gopkg.in/warnings.v0 v0.1.2 // indirect 66 | ) 67 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= 2 | dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= 3 | github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= 4 | github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= 5 | github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= 6 | github.com/ProtonMail/go-crypto v1.1.6 h1:ZcV+Ropw6Qn0AX9brlQLAUXfqLBc7Bl+f/DmNxpLfdw= 7 | github.com/ProtonMail/go-crypto v1.1.6/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= 8 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= 9 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= 10 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= 11 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= 12 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= 13 | github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= 14 | github.com/charmbracelet/bubbletea v1.3.5 h1:JAMNLTbqMOhSwoELIr0qyP4VidFq72/6E9j7HHmRKQc= 15 | github.com/charmbracelet/bubbletea v1.3.5/go.mod h1:TkCnmH+aBd4LrXhXcqrKiYwRs7qyQx5rBgH5fVY3v54= 16 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= 17 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= 18 | github.com/charmbracelet/keygen v0.5.3 h1:2MSDC62OUbDy6VmjIE2jM24LuXUvKywLCmaJDmr/Z/4= 19 | github.com/charmbracelet/keygen v0.5.3/go.mod h1:TcpNoMAO5GSmhx3SgcEMqCrtn8BahKhB8AlwnLjRUpk= 20 | github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= 21 | github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= 22 | github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig= 23 | github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw= 24 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894 h1:Ffon9TbltLGBsT6XE//YvNuu4OAaThXioqalhH11xEw= 25 | github.com/charmbracelet/ssh v0.0.0-20250128164007-98fd5ae11894/go.mod h1:hg+I6gvlMl16nS9ZzQNgBIrrCasGwEw0QiLsDcP01Ko= 26 | github.com/charmbracelet/x/ansi v0.9.2 h1:92AGsQmNTRMzuzHEYfCdjQeUzTrgE1vfO5/7fEVoXdY= 27 | github.com/charmbracelet/x/ansi v0.9.2/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= 28 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= 29 | github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= 30 | github.com/charmbracelet/x/conpty v0.1.0 h1:4zc8KaIcbiL4mghEON8D72agYtSeIgq8FSThSPQIb+U= 31 | github.com/charmbracelet/x/conpty v0.1.0/go.mod h1:rMFsDJoDwVmiYM10aD4bH2XiRgwI7NYJtQgl5yskjEQ= 32 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA= 33 | github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0= 34 | github.com/charmbracelet/x/input v0.3.4 h1:Mujmnv/4DaitU0p+kIsrlfZl/UlmeLKw1wAP3e1fMN0= 35 | github.com/charmbracelet/x/input v0.3.4/go.mod h1:JI8RcvdZWQIhn09VzeK3hdp4lTz7+yhiEdpEQtZN+2c= 36 | github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= 37 | github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= 38 | github.com/charmbracelet/x/termios v0.1.0 h1:y4rjAHeFksBAfGbkRDmVinMg7x7DELIGAFbdNvxg97k= 39 | github.com/charmbracelet/x/termios v0.1.0/go.mod h1:H/EVv/KRnrYjz+fCYa9bsKdqF3S8ouDK0AZEbG7r+/U= 40 | github.com/charmbracelet/x/windows v0.2.0 h1:ilXA1GJjTNkgOm94CLPeSz7rar54jtFatdmoiONPuEw= 41 | github.com/charmbracelet/x/windows v0.2.0/go.mod h1:ZibNFR49ZFqCXgP76sYanisxRyC+EYrBE7TTknD8s1s= 42 | github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= 43 | github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= 44 | github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= 45 | github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= 46 | github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= 47 | github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= 48 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 49 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 50 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 51 | github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= 52 | github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= 53 | github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= 54 | github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= 55 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= 56 | github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= 57 | github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= 58 | github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= 59 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= 60 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= 61 | github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= 62 | github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= 63 | github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= 64 | github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= 65 | github.com/go-git/go-git/v5 v5.16.0 h1:k3kuOEpkc0DeY7xlL6NaaNg39xdgQbtH5mwCafHO9AQ= 66 | github.com/go-git/go-git/v5 v5.16.0/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8= 67 | github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= 68 | github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= 69 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= 70 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= 71 | github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 72 | github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 73 | github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= 74 | github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= 75 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= 76 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= 77 | github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= 78 | github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= 79 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 80 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 81 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 82 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 83 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 84 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 85 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 86 | github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= 87 | github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= 88 | github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ= 89 | github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= 90 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 91 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 92 | github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= 93 | github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= 94 | github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= 95 | github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= 96 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= 97 | github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= 98 | github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= 99 | github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= 100 | github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= 101 | github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= 102 | github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= 103 | github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= 104 | github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= 105 | github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= 106 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 107 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 108 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 109 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 110 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 111 | github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= 112 | github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= 113 | github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= 114 | github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= 115 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= 116 | github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= 117 | github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= 118 | github.com/skeema/knownhosts v1.3.1 h1:X2osQ+RAjK76shCbvhHHHVl3ZlgDm8apHEHFqRjnBY8= 119 | github.com/skeema/knownhosts v1.3.1/go.mod h1:r7KTdC8l4uxWRyK2TpQZ/1o5HaSzh06ePQNxPwTcfiY= 120 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 121 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 122 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 123 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 124 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 125 | github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= 126 | github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= 127 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= 128 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= 129 | golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 130 | golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= 131 | golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= 132 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= 133 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= 134 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 135 | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= 136 | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= 137 | golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= 138 | golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 139 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 140 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 141 | golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 142 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 143 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 144 | golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 145 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 146 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 147 | golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= 148 | golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 149 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 150 | golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= 151 | golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= 152 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 153 | golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= 154 | golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= 155 | golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= 156 | golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= 157 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 158 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 159 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 160 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 161 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 162 | gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= 163 | gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= 164 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 165 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 166 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 167 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 168 | -------------------------------------------------------------------------------- /logging/logging.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/charmbracelet/log" 7 | "github.com/charmbracelet/ssh" 8 | "github.com/charmbracelet/wish" 9 | ) 10 | 11 | // Middleware provides basic connection logging. 12 | // Connects are logged with the remote address, invoked command, TERM setting, 13 | // window dimensions, client version, and if the auth was public key based. 14 | // Disconnect will log the remote address and connection duration. 15 | // 16 | // It will use charmbracelet/log.StandardLog() by default. 17 | func Middleware() wish.Middleware { 18 | return MiddlewareWithLogger(log.StandardLog()) 19 | } 20 | 21 | // Logger is the interface that wraps the basic Log method. 22 | type Logger interface { 23 | Printf(format string, v ...interface{}) 24 | } 25 | 26 | // MiddlewareWithLogger provides basic connection logging. 27 | // Connects are logged with the remote address, invoked command, TERM setting, 28 | // window dimensions, client version, and if the auth was public key based. 29 | // Disconnect will log the remote address and connection duration. 30 | func MiddlewareWithLogger(logger Logger) wish.Middleware { 31 | return func(next ssh.Handler) ssh.Handler { 32 | return func(sess ssh.Session) { 33 | ct := time.Now() 34 | hpk := sess.PublicKey() != nil 35 | pty, _, _ := sess.Pty() 36 | logger.Printf( 37 | "%s connect %s %v %v %s %v %v %v", 38 | sess.User(), 39 | sess.RemoteAddr().String(), 40 | hpk, 41 | sess.Command(), 42 | pty.Term, 43 | pty.Window.Width, 44 | pty.Window.Height, 45 | sess.Context().ClientVersion(), 46 | ) 47 | next(sess) 48 | logger.Printf( 49 | "%s disconnect %s\n", 50 | sess.RemoteAddr().String(), 51 | time.Since(ct), 52 | ) 53 | } 54 | } 55 | } 56 | 57 | // StructuredMiddleware provides basic connection logging in a structured form. 58 | // Connects are logged with the remote address, invoked command, TERM setting, 59 | // window dimensions, client version, and if the auth was public key based. 60 | // Disconnect will log the remote address and connection duration. 61 | // 62 | // It will use the charmbracelet/log.Default() and Info level by default. 63 | func StructuredMiddleware() wish.Middleware { 64 | return StructuredMiddlewareWithLogger(log.Default(), log.InfoLevel) 65 | } 66 | 67 | // StructuredMiddlewareWithLogger provides basic connection logging in a structured form. 68 | // Connects are logged with the remote address, invoked command, TERM setting, 69 | // window dimensions, client version, and if the auth was public key based. 70 | // Disconnect will log the remote address and connection duration. 71 | func StructuredMiddlewareWithLogger(logger *log.Logger, level log.Level) wish.Middleware { 72 | return func(next ssh.Handler) ssh.Handler { 73 | return func(sess ssh.Session) { 74 | ct := time.Now() 75 | hpk := sess.PublicKey() != nil 76 | pty, _, _ := sess.Pty() 77 | logger.Log( 78 | level, 79 | "connect", 80 | "user", sess.User(), 81 | "remote-addr", sess.RemoteAddr().String(), 82 | "public-key", hpk, 83 | "command", sess.Command(), 84 | "term", pty.Term, 85 | "width", pty.Window.Width, 86 | "height", pty.Window.Height, 87 | "client-version", sess.Context().ClientVersion(), 88 | ) 89 | next(sess) 90 | logger.Log( 91 | level, 92 | "disconnect", 93 | "user", sess.User(), 94 | "remote-addr", sess.RemoteAddr().String(), 95 | "duration", time.Since(ct), 96 | ) 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /logging/logging_test.go: -------------------------------------------------------------------------------- 1 | package logging_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/charmbracelet/ssh" 7 | "github.com/charmbracelet/wish" 8 | "github.com/charmbracelet/wish/logging" 9 | "github.com/charmbracelet/wish/testsession" 10 | gossh "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | func TestMiddleware(t *testing.T) { 14 | t.Run("inactive term", func(t *testing.T) { 15 | if err := setup(t, logging.Middleware()).Run(""); err != nil { 16 | t.Error(err) 17 | } 18 | }) 19 | } 20 | 21 | func TestStructuredMiddleware(t *testing.T) { 22 | t.Run("inactive term", func(t *testing.T) { 23 | if err := setup(t, logging.StructuredMiddleware()).Run(""); err != nil { 24 | t.Error(err) 25 | } 26 | }) 27 | } 28 | 29 | func setup(tb testing.TB, middleware wish.Middleware) *gossh.Session { 30 | tb.Helper() 31 | return testsession.New(tb, &ssh.Server{ 32 | Handler: middleware(func(s ssh.Session) { 33 | s.Write([]byte("hello")) 34 | }), 35 | }, nil) 36 | } 37 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package wish 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "errors" 7 | "io" 8 | "os" 9 | "strings" 10 | "time" 11 | 12 | "github.com/charmbracelet/keygen" 13 | "github.com/charmbracelet/log" 14 | "github.com/charmbracelet/ssh" 15 | gossh "golang.org/x/crypto/ssh" 16 | ) 17 | 18 | // WithAddress returns an ssh.Option that sets the address to listen on. 19 | func WithAddress(addr string) ssh.Option { 20 | return func(s *ssh.Server) error { 21 | s.Addr = addr 22 | return nil 23 | } 24 | } 25 | 26 | // WithVersion returns an ssh.Option that sets the server version. 27 | func WithVersion(version string) ssh.Option { 28 | return func(s *ssh.Server) error { 29 | s.Version = version 30 | return nil 31 | } 32 | } 33 | 34 | // WithBanner return an ssh.Option that sets the server banner. 35 | func WithBanner(banner string) ssh.Option { 36 | return func(s *ssh.Server) error { 37 | s.Banner = banner 38 | return nil 39 | } 40 | } 41 | 42 | // WithBannerHandler return an ssh.Option that sets the server banner handler, 43 | // overriding WithBanner. 44 | func WithBannerHandler(h ssh.BannerHandler) ssh.Option { 45 | return func(s *ssh.Server) error { 46 | s.BannerHandler = h 47 | return nil 48 | } 49 | } 50 | 51 | // WithMiddleware composes the provided Middleware and returns an ssh.Option. 52 | // This is useful if you manually create an ssh.Server and want to set the 53 | // Server.Handler. 54 | // 55 | // Notice that middlewares are composed from first to last, which means the last one is executed first. 56 | func WithMiddleware(mw ...Middleware) ssh.Option { 57 | return func(s *ssh.Server) error { 58 | h := func(ssh.Session) {} 59 | for _, m := range mw { 60 | h = m(h) 61 | } 62 | s.Handler = h 63 | return nil 64 | } 65 | } 66 | 67 | // WithHostKeyFile returns an ssh.Option that sets the path to the private key. 68 | func WithHostKeyPath(path string) ssh.Option { 69 | if _, err := os.Stat(path); os.IsNotExist(err) { 70 | _, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite()) 71 | if err != nil { 72 | return func(*ssh.Server) error { 73 | return err 74 | } 75 | } 76 | } 77 | return ssh.HostKeyFile(path) 78 | } 79 | 80 | // WithHostKeyPEM returns an ssh.Option that sets the host key from a PEM block. 81 | func WithHostKeyPEM(pem []byte) ssh.Option { 82 | return ssh.HostKeyPEM(pem) 83 | } 84 | 85 | // WithAuthorizedKeys allows the use of an SSH authorized_keys file to allowlist users. 86 | func WithAuthorizedKeys(path string) ssh.Option { 87 | return func(s *ssh.Server) error { 88 | if _, err := os.Stat(path); err != nil { 89 | return err 90 | } 91 | return WithPublicKeyAuth(func(_ ssh.Context, key ssh.PublicKey) bool { 92 | return isAuthorized(path, func(k ssh.PublicKey) bool { 93 | return ssh.KeysEqual(key, k) 94 | }) 95 | })(s) 96 | } 97 | } 98 | 99 | // WithTrustedUserCAKeys authorize certificates that are signed with the given 100 | // Certificate Authority public key, and are valid. 101 | // Analogous to the TrustedUserCAKeys OpenSSH option. 102 | func WithTrustedUserCAKeys(path string) ssh.Option { 103 | return func(s *ssh.Server) error { 104 | if _, err := os.Stat(path); err != nil { 105 | return err 106 | } 107 | return WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { 108 | cert, ok := key.(*gossh.Certificate) 109 | if !ok { 110 | // not a certificate... 111 | return false 112 | } 113 | 114 | return isAuthorized(path, func(k ssh.PublicKey) bool { 115 | checker := &gossh.CertChecker{ 116 | IsUserAuthority: func(auth gossh.PublicKey) bool { 117 | // its a cert signed by one of the CAs 118 | return bytes.Equal(auth.Marshal(), k.Marshal()) 119 | }, 120 | } 121 | 122 | if !checker.IsUserAuthority(cert.SignatureKey) { 123 | return false 124 | } 125 | 126 | if err := checker.CheckCert(ctx.User(), cert); err != nil { 127 | return false 128 | } 129 | 130 | return true 131 | }) 132 | })(s) 133 | } 134 | } 135 | 136 | func isAuthorized(path string, checker func(k ssh.PublicKey) bool) bool { 137 | f, err := os.Open(path) 138 | if err != nil { 139 | log.Warn("failed to parse", "path", path, "error", err) 140 | return false 141 | } 142 | defer f.Close() // nolint: errcheck 143 | 144 | rd := bufio.NewReader(f) 145 | for { 146 | line, _, err := rd.ReadLine() 147 | if err != nil { 148 | if errors.Is(err, io.EOF) { 149 | break 150 | } 151 | log.Warn("failed to parse", "path", path, "error", err) 152 | return false 153 | } 154 | if strings.TrimSpace(string(line)) == "" { 155 | continue 156 | } 157 | if bytes.HasPrefix(line, []byte{'#'}) { 158 | continue 159 | } 160 | upk, _, _, _, err := ssh.ParseAuthorizedKey(line) 161 | if err != nil { 162 | log.Warn("failed to parse", "path", path, "error", err) 163 | return false 164 | } 165 | if checker(upk) { 166 | return true 167 | } 168 | } 169 | return false 170 | } 171 | 172 | // WithPublicKeyAuth returns an ssh.Option that sets the public key auth handler. 173 | func WithPublicKeyAuth(h ssh.PublicKeyHandler) ssh.Option { 174 | return ssh.PublicKeyAuth(h) 175 | } 176 | 177 | // WithPasswordAuth returns an ssh.Option that sets the password auth handler. 178 | func WithPasswordAuth(p ssh.PasswordHandler) ssh.Option { 179 | return ssh.PasswordAuth(p) 180 | } 181 | 182 | // WithKeyboardInteractiveAuth returns an ssh.Option that sets the keyboard interactive auth handler. 183 | func WithKeyboardInteractiveAuth(h ssh.KeyboardInteractiveHandler) ssh.Option { 184 | return ssh.KeyboardInteractiveAuth(h) 185 | } 186 | 187 | // WithIdleTimeout returns an ssh.Option that sets the connection's idle timeout. 188 | func WithIdleTimeout(d time.Duration) ssh.Option { 189 | return func(s *ssh.Server) error { 190 | s.IdleTimeout = d 191 | return nil 192 | } 193 | } 194 | 195 | // WithMaxTimeout returns an ssh.Option that sets the connection's absolute timeout. 196 | func WithMaxTimeout(d time.Duration) ssh.Option { 197 | return func(s *ssh.Server) error { 198 | s.MaxTimeout = d 199 | return nil 200 | } 201 | } 202 | 203 | // WithSubsystem returns an ssh.Option that sets the subsystem 204 | // handler for a given protocol. 205 | func WithSubsystem(key string, h ssh.SubsystemHandler) ssh.Option { 206 | return func(s *ssh.Server) error { 207 | if s.SubsystemHandlers == nil { 208 | s.SubsystemHandlers = map[string]ssh.SubsystemHandler{} 209 | } 210 | s.SubsystemHandlers[key] = h 211 | return nil 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /options_test.go: -------------------------------------------------------------------------------- 1 | package wish 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "os" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/charmbracelet/ssh" 12 | "github.com/charmbracelet/wish/testsession" 13 | gossh "golang.org/x/crypto/ssh" 14 | ) 15 | 16 | func TestWithSubsystem(t *testing.T) { 17 | srv := &ssh.Server{ 18 | Handler: func(s ssh.Session) {}, 19 | } 20 | requireNoError(t, WithSubsystem("foo", func(s ssh.Session) {})(srv)) 21 | if srv.SubsystemHandlers == nil { 22 | t.Fatalf("should not have been nil") 23 | } 24 | if _, ok := srv.SubsystemHandlers["foo"]; !ok { 25 | t.Fatalf("should have set the foo subsystem handler") 26 | } 27 | } 28 | 29 | func TestWithBanner(t *testing.T) { 30 | const banner = "a banner" 31 | var got string 32 | 33 | srv := &ssh.Server{ 34 | Handler: func(s ssh.Session) {}, 35 | } 36 | requireNoError(t, WithBanner(banner)(srv)) 37 | 38 | requireNoError(t, testsession.New(t, srv, &gossh.ClientConfig{ 39 | BannerCallback: func(message string) error { 40 | got = message 41 | return nil 42 | }, 43 | }).Run("")) 44 | requireEqual(t, banner, got) 45 | } 46 | 47 | func TestWithBannerHandler(t *testing.T) { 48 | var got string 49 | 50 | srv := &ssh.Server{ 51 | Handler: func(s ssh.Session) {}, 52 | } 53 | requireNoError(t, WithBannerHandler(func(ctx ssh.Context) string { 54 | return fmt.Sprintf("banner for %s", ctx.User()) 55 | })(srv)) 56 | 57 | requireNoError(t, testsession.New(t, srv, &gossh.ClientConfig{ 58 | User: "fulano", 59 | BannerCallback: func(message string) error { 60 | got = message 61 | return nil 62 | }, 63 | }).Run("")) 64 | requireEqual(t, "banner for fulano", got) 65 | } 66 | 67 | func TestWithIdleTimeout(t *testing.T) { 68 | s := ssh.Server{} 69 | requireNoError(t, WithIdleTimeout(time.Second)(&s)) 70 | requireEqual(t, time.Second, s.IdleTimeout) 71 | } 72 | 73 | func TestWithMaxTimeout(t *testing.T) { 74 | s := ssh.Server{} 75 | requireNoError(t, WithMaxTimeout(time.Second)(&s)) 76 | requireEqual(t, time.Second, s.MaxTimeout) 77 | } 78 | 79 | func TestIsAuthorized(t *testing.T) { 80 | t.Run("valid", func(t *testing.T) { 81 | requireEqual(t, true, isAuthorized("testdata/authorized_keys", func(k ssh.PublicKey) bool { return true })) 82 | }) 83 | 84 | t.Run("invalid", func(t *testing.T) { 85 | requireEqual(t, false, isAuthorized("testdata/invalid_authorized_keys", func(k ssh.PublicKey) bool { return true })) 86 | }) 87 | 88 | t.Run("file not found", func(t *testing.T) { 89 | requireEqual(t, false, isAuthorized("testdata/nope_authorized_keys", func(k ssh.PublicKey) bool { return true })) 90 | }) 91 | } 92 | 93 | func TestWithAuthorizedKeys(t *testing.T) { 94 | t.Run("valid", func(t *testing.T) { 95 | s := ssh.Server{} 96 | requireNoError(t, WithAuthorizedKeys("testdata/authorized_keys")(&s)) 97 | 98 | for key, authorize := range map[string]bool{ 99 | `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMJlb/qf2B2kMNdBxfpCQqI2ctPcsOkdZGVh5zTRhKtH k3@test`: true, 100 | `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOhsthN+zSFSJF7V2HFSO4+2OJYRghuAA43CIbVyvzF8 k7@test`: false, 101 | } { 102 | parts := strings.Fields(key) 103 | t.Run(parts[len(parts)-1], func(t *testing.T) { 104 | key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key)) 105 | requireNoError(t, err) 106 | requireEqual(t, authorize, s.PublicKeyHandler(nil, key)) 107 | }) 108 | } 109 | }) 110 | 111 | t.Run("invalid", func(t *testing.T) { 112 | s := ssh.Server{} 113 | requireNoError( 114 | t, 115 | WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s), 116 | ) 117 | }) 118 | 119 | t.Run("file not found", func(t *testing.T) { 120 | s := ssh.Server{} 121 | if err := WithAuthorizedKeys("testdata/nope_authorized_keys")(&s); err == nil { 122 | t.Fatal("expected an error, got nil") 123 | } 124 | }) 125 | } 126 | 127 | func TestWithTrustedUserCAKeys(t *testing.T) { 128 | setup := func(tb testing.TB, certPath string) (*ssh.Server, *gossh.ClientConfig) { 129 | tb.Helper() 130 | s := &ssh.Server{ 131 | Handler: func(s ssh.Session) { 132 | cert, ok := s.PublicKey().(*gossh.Certificate) 133 | fmt.Fprintf(s, "cert? %v - principals: %v - type: %v", ok, cert.ValidPrincipals, cert.CertType) 134 | }, 135 | } 136 | requireNoError(tb, WithTrustedUserCAKeys("testdata/ca.pub")(s)) 137 | 138 | signer, err := gossh.ParsePrivateKey(getBytes(tb, "testdata/foo")) 139 | requireNoError(tb, err) 140 | 141 | cert, _, _, _, err := gossh.ParseAuthorizedKey(getBytes(tb, certPath)) 142 | requireNoError(tb, err) 143 | 144 | certSigner, err := gossh.NewCertSigner(cert.(*gossh.Certificate), signer) 145 | requireNoError(tb, err) 146 | return s, &gossh.ClientConfig{ 147 | User: "foo", 148 | Auth: []gossh.AuthMethod{ 149 | gossh.PublicKeys(certSigner), 150 | }, 151 | } 152 | } 153 | 154 | t.Run("invalid ca key", func(t *testing.T) { 155 | s := &ssh.Server{} 156 | if err := WithTrustedUserCAKeys("testdata/invalid-path")(s); err == nil { 157 | t.Fatal("expected an error, got nil") 158 | } 159 | }) 160 | 161 | t.Run("valid", func(t *testing.T) { 162 | s, cc := setup(t, "testdata/valid-cert.pub") 163 | sess := testsession.New(t, s, cc) 164 | var b bytes.Buffer 165 | sess.Stdout = &b 166 | requireNoError(t, sess.Run("")) 167 | requireEqual(t, "cert? true - principals: [foo] - type: 1", b.String()) 168 | }) 169 | 170 | t.Run("valid wrong principal", func(t *testing.T) { 171 | s, cc := setup(t, "testdata/valid-cert.pub") 172 | cc.User = "not-foo" 173 | _, err := testsession.NewClientSession(t, testsession.Listen(t, s), cc) 174 | requireAuthError(t, err) 175 | }) 176 | 177 | t.Run("expired", func(t *testing.T) { 178 | s, cc := setup(t, "testdata/expired-cert.pub") 179 | _, err := testsession.NewClientSession(t, testsession.Listen(t, s), cc) 180 | requireAuthError(t, err) 181 | }) 182 | 183 | t.Run("signed by another ca", func(t *testing.T) { 184 | s, cc := setup(t, "testdata/another-ca-cert.pub") 185 | _, err := testsession.NewClientSession(t, testsession.Listen(t, s), cc) 186 | requireAuthError(t, err) 187 | }) 188 | 189 | t.Run("not a cert", func(t *testing.T) { 190 | s := &ssh.Server{ 191 | Handler: func(s ssh.Session) { 192 | fmt.Fprintln(s, "hello") 193 | }, 194 | } 195 | requireNoError(t, WithTrustedUserCAKeys("testdata/ca.pub")(s)) 196 | 197 | signer, err := gossh.ParsePrivateKey(getBytes(t, "testdata/foo")) 198 | requireNoError(t, err) 199 | 200 | _, err = testsession.NewClientSession(t, testsession.Listen(t, s), &gossh.ClientConfig{ 201 | User: "foo", 202 | Auth: []gossh.AuthMethod{ 203 | gossh.PublicKeys(signer), 204 | }, 205 | }) 206 | requireAuthError(t, err) 207 | }) 208 | } 209 | 210 | func getBytes(tb testing.TB, path string) []byte { 211 | tb.Helper() 212 | bts, err := os.ReadFile(path) 213 | requireNoError(tb, err) 214 | return bts 215 | } 216 | 217 | func requireEqual(tb testing.TB, a, b interface{}) { 218 | tb.Helper() 219 | if a != b { 220 | tb.Fatalf("expected %v, got %v", a, b) 221 | } 222 | } 223 | 224 | func requireNoError(tb testing.TB, err error) { 225 | tb.Helper() 226 | if err != nil { 227 | tb.Fatalf("expected no error, got %v", err) 228 | } 229 | } 230 | 231 | func requireAuthError(tb testing.TB, err error) { 232 | if err == nil { 233 | tb.Fatal("required an error, got nil") 234 | } 235 | requireEqual(tb, "ssh: handshake failed: ssh: unable to authenticate, attempted methods [none publickey], no supported methods remain", err.Error()) 236 | } 237 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter.go: -------------------------------------------------------------------------------- 1 | // Package ratelimiter provides basic rate limiting functionality as a with middleware. 2 | // 3 | // It limits the amount of connections a source can make in a specified amount of time. 4 | package ratelimiter 5 | 6 | import ( 7 | "errors" 8 | "net" 9 | 10 | "github.com/charmbracelet/log" 11 | "github.com/charmbracelet/ssh" 12 | "github.com/charmbracelet/wish" 13 | lru "github.com/hashicorp/golang-lru/v2" 14 | "golang.org/x/time/rate" 15 | ) 16 | 17 | // ErrRateLimitExceeded happens when the connection was denied due to the rate limit being exceeded. 18 | var ErrRateLimitExceeded = errors.New("rate limit exceeded, please try again later") 19 | 20 | // RateLimiter implementations should check if a given session is allowed to 21 | // proceed or not, returning an error if they aren't. 22 | // Its up to the implementation to handle what identifies an session as well 23 | // as the implementation details of these limits. 24 | type RateLimiter interface { 25 | Allow(s ssh.Session) error 26 | } 27 | 28 | // Middleware provides a new rate limiting Middleware. 29 | func Middleware(limiter RateLimiter) wish.Middleware { 30 | return func(sh ssh.Handler) ssh.Handler { 31 | return func(s ssh.Session) { 32 | if err := limiter.Allow(s); err != nil { 33 | wish.Fatal(s, err) 34 | return 35 | } 36 | 37 | sh(s) 38 | } 39 | } 40 | } 41 | 42 | // NewRateLimiter returns a new RateLimiter that allows events up to rate rate, 43 | // permits bursts of at most burst tokens and keeps a cache of maxEntries 44 | // limiters. 45 | // 46 | // Internally, it creates a LRU Cache of *rate.Limiter, in which the key is 47 | // the remote IP address. 48 | func NewRateLimiter(r rate.Limit, burst int, maxEntries int) RateLimiter { 49 | if maxEntries <= 0 { 50 | maxEntries = 1 51 | } 52 | // only possible error is if maxEntries is <= 0, which is prevented above. 53 | cache, _ := lru.New[string, *rate.Limiter](maxEntries) 54 | return &limiters{ 55 | rate: r, 56 | burst: burst, 57 | cache: cache, 58 | } 59 | } 60 | 61 | type limiters struct { 62 | cache *lru.Cache[string, *rate.Limiter] 63 | rate rate.Limit 64 | burst int 65 | } 66 | 67 | func (r *limiters) Allow(s ssh.Session) error { 68 | var key string 69 | switch addr := s.RemoteAddr().(type) { 70 | case *net.TCPAddr: 71 | key = addr.IP.String() 72 | default: 73 | key = addr.String() 74 | } 75 | 76 | var allowed bool 77 | limiter, ok := r.cache.Get(key) 78 | if ok { 79 | allowed = limiter.Allow() 80 | } else { 81 | limiter := rate.NewLimiter(r.rate, r.burst) 82 | allowed = limiter.Allow() 83 | r.cache.Add(key, limiter) 84 | } 85 | 86 | log.Debug("rate limiter key", "key", key, "allowed", allowed) 87 | if allowed { 88 | return nil 89 | } 90 | return ErrRateLimitExceeded 91 | } 92 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | package ratelimiter 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/charmbracelet/ssh" 8 | "github.com/charmbracelet/wish/testsession" 9 | "golang.org/x/sync/errgroup" 10 | "golang.org/x/time/rate" 11 | ) 12 | 13 | func TestRateLimiterNoLimit(t *testing.T) { 14 | s := &ssh.Server{ 15 | Handler: Middleware(NewRateLimiter(rate.Limit(0), 0, 5))(func(s ssh.Session) { 16 | s.Write([]byte("hello")) 17 | }), 18 | } 19 | 20 | sess := testsession.New(t, s, nil) 21 | if err := sess.Run(""); err == nil { 22 | t.Fatal("expected an error, got nil") 23 | } 24 | } 25 | 26 | func TestRateLimiterZeroedMaxEntried(t *testing.T) { 27 | s := &ssh.Server{ 28 | Handler: Middleware(NewRateLimiter(rate.Limit(1), 1, 0))(func(s ssh.Session) { 29 | s.Write([]byte("hello")) 30 | }), 31 | } 32 | 33 | sess := testsession.New(t, s, nil) 34 | if err := sess.Run(""); err != nil { 35 | t.Fatalf("expected no error, got %v", err) 36 | } 37 | } 38 | 39 | func TestRateLimiter(t *testing.T) { 40 | s := &ssh.Server{ 41 | Handler: Middleware(NewRateLimiter(rate.Limit(10), 4, 1))(func(s ssh.Session) { 42 | // noop 43 | }), 44 | } 45 | 46 | addr := testsession.Listen(t, s) 47 | 48 | g := errgroup.Group{} 49 | for i := 0; i < 10; i++ { 50 | g.Go(func() error { 51 | sess, err := testsession.NewClientSession(t, addr, nil) 52 | if err != nil { 53 | t.Fatalf("expected no errors, got %v", err) 54 | } 55 | if err := sess.Run(""); err != nil { 56 | return err 57 | } 58 | return nil 59 | }) 60 | } 61 | 62 | if err := g.Wait(); err == nil { 63 | t.Fatal("expected error, got nil") 64 | } 65 | 66 | // after some time, it should reset and pass again 67 | time.Sleep(100 * time.Millisecond) 68 | sess, err := testsession.NewClientSession(t, addr, nil) 69 | if err != nil { 70 | t.Fatalf("expected no errors, got %v", err) 71 | } 72 | if err := sess.Run(""); err != nil { 73 | t.Fatalf("expected no errors, got %v", err) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /recover/recover.go: -------------------------------------------------------------------------------- 1 | package recover 2 | 3 | import ( 4 | "runtime/debug" 5 | 6 | "github.com/charmbracelet/log" 7 | "github.com/charmbracelet/ssh" 8 | "github.com/charmbracelet/wish" 9 | ) 10 | 11 | // Middleware is a wish middleware that recovers from panics and log to stderr. 12 | func Middleware(mw ...wish.Middleware) wish.Middleware { 13 | return MiddlewareWithLogger(nil, mw...) 14 | } 15 | 16 | // Logger is the interface that wraps the basic Log method. 17 | type Logger interface { 18 | Printf(format string, v ...interface{}) 19 | } 20 | 21 | // MiddlewareWithLogger is a wish middleware that recovers from panics and log to 22 | // the provided logger. 23 | func MiddlewareWithLogger(logger Logger, mw ...wish.Middleware) wish.Middleware { 24 | if logger == nil { 25 | logger = log.StandardLog() 26 | } 27 | h := func(ssh.Session) {} 28 | for _, m := range mw { 29 | h = m(h) 30 | } 31 | return func(sh ssh.Handler) ssh.Handler { 32 | return func(s ssh.Session) { 33 | func() { 34 | defer func() { 35 | if r := recover(); r != nil { 36 | logger.Printf( 37 | "panic: %v\n%s", 38 | r, 39 | string(debug.Stack()), 40 | ) 41 | } 42 | }() 43 | h(s) 44 | }() 45 | sh(s) 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /recover/recover_test.go: -------------------------------------------------------------------------------- 1 | package recover 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/charmbracelet/ssh" 7 | "github.com/charmbracelet/wish/testsession" 8 | gossh "golang.org/x/crypto/ssh" 9 | ) 10 | 11 | func TestMiddleware(t *testing.T) { 12 | t.Run("recover session", func(t *testing.T) { 13 | _, err := setup(t).Output("") 14 | requireNoError(t, err) 15 | }) 16 | } 17 | 18 | func setup(tb testing.TB) *gossh.Session { 19 | tb.Helper() 20 | return testsession.New(tb, &ssh.Server{ 21 | Handler: Middleware(func(h ssh.Handler) ssh.Handler { 22 | return func(s ssh.Session) { 23 | panic("hello") 24 | } 25 | })(func(s ssh.Session) {}), 26 | }, nil) 27 | } 28 | 29 | func requireNoError(t *testing.T, err error) { 30 | t.Helper() 31 | 32 | if err != nil { 33 | t.Fatalf("expected no error, got %q", err.Error()) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /scp/copy_from_client.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "io/fs" 9 | "path/filepath" 10 | "regexp" 11 | "strconv" 12 | "strings" 13 | 14 | "github.com/charmbracelet/ssh" 15 | ) 16 | 17 | var ( 18 | reTimestamp = regexp.MustCompile(`^T(\d{10}) 0 (\d{10}) 0$`) 19 | reNewFolder = regexp.MustCompile(`^D(\d{4}) 0 (.*)$`) 20 | reNewFile = regexp.MustCompile(`^C(\d{4}) (\d+) (.*)$`) 21 | ) 22 | 23 | type parseError struct { 24 | subject string 25 | } 26 | 27 | func (e parseError) Error() string { 28 | return fmt.Sprintf("failed to parse: %q", e.subject) 29 | } 30 | 31 | func copyFromClient(s ssh.Session, info Info, handler CopyFromClientHandler) error { 32 | // accepts the request 33 | _, _ = s.Write(NULL) 34 | 35 | var ( 36 | path = info.Path 37 | r = bufio.NewReader(s) 38 | mtime int64 39 | atime int64 40 | ) 41 | 42 | for { 43 | line, err := r.ReadString('\n') 44 | if err != nil { 45 | if errors.Is(err, io.EOF) { 46 | break 47 | } 48 | return fmt.Errorf("failed to read line: %w", err) 49 | } 50 | line = strings.TrimSuffix(line, "\n") 51 | 52 | if matches := reTimestamp.FindAllStringSubmatch(line, 2); matches != nil { 53 | mtime, err = strconv.ParseInt(matches[0][1], 10, 64) 54 | if err != nil { 55 | return parseError{line} 56 | } 57 | atime, err = strconv.ParseInt(matches[0][2], 10, 64) 58 | if err != nil { 59 | return parseError{line} 60 | } 61 | 62 | // accepts the header 63 | _, _ = s.Write(NULL) 64 | continue 65 | } 66 | 67 | if matches := reNewFile.FindAllStringSubmatch(line, 3); matches != nil { 68 | if len(matches) != 1 || len(matches[0]) != 4 { 69 | return parseError{line} 70 | } 71 | 72 | mode, err := strconv.ParseUint(matches[0][1], 8, 32) 73 | if err != nil { 74 | return parseError{line} 75 | } 76 | 77 | size, err := strconv.ParseInt(matches[0][2], 10, 64) 78 | if err != nil { 79 | return parseError{line} 80 | } 81 | name := matches[0][3] 82 | 83 | // accepts the header 84 | _, _ = s.Write(NULL) 85 | 86 | written, err := handler.Write(s, &FileEntry{ 87 | Name: name, 88 | Filepath: filepath.Join(path, name), 89 | Mode: fs.FileMode(mode), //nolint:gosec 90 | Size: size, 91 | Mtime: mtime, 92 | Atime: atime, 93 | Reader: newLimitReader(r, int(size)), 94 | }) 95 | if err != nil { 96 | return fmt.Errorf("failed to write file: %q: %w", name, err) 97 | } 98 | if written != size { 99 | return fmt.Errorf("failed to write the file: %q: written %d out of %d bytes", name, written, size) 100 | } 101 | 102 | // read the trailing nil char 103 | _, _ = r.ReadByte() 104 | 105 | mtime = 0 106 | atime = 0 107 | // says 'hey im done' 108 | _, _ = s.Write(NULL) 109 | continue 110 | } 111 | 112 | if matches := reNewFolder.FindAllStringSubmatch(line, 2); matches != nil { 113 | if len(matches) != 1 || len(matches[0]) != 3 { 114 | return parseError{line} 115 | } 116 | 117 | mode, err := strconv.ParseUint(matches[0][1], 8, 32) 118 | if err != nil { 119 | return parseError{line} 120 | } 121 | name := matches[0][2] 122 | 123 | path = filepath.Join(path, name) 124 | if err := handler.Mkdir(s, &DirEntry{ 125 | Name: name, 126 | Filepath: path, 127 | Mode: fs.FileMode(mode), //nolint:gosec 128 | Mtime: mtime, 129 | Atime: atime, 130 | }); err != nil { 131 | return fmt.Errorf("failed to create dir: %q: %w", name, err) 132 | } 133 | 134 | mtime = 0 135 | atime = 0 136 | // says 'hey im done' 137 | _, _ = s.Write(NULL) 138 | continue 139 | } 140 | 141 | if line == "E" { 142 | path = filepath.Dir(path) 143 | 144 | // says 'hey im done' 145 | _, _ = s.Write(NULL) 146 | continue 147 | } 148 | 149 | return fmt.Errorf("unhandled input: %q", line) 150 | } 151 | 152 | _, _ = s.Write(NULL) 153 | return nil 154 | } 155 | -------------------------------------------------------------------------------- /scp/copy_to_client.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "fmt" 5 | "io/fs" 6 | 7 | "github.com/charmbracelet/ssh" 8 | ) 9 | 10 | func copyToClient(s ssh.Session, info Info, handler CopyToClientHandler) error { 11 | matches, err := handler.Glob(s, info.Path) 12 | if err != nil { 13 | return err 14 | } 15 | if len(matches) == 0 { 16 | return fmt.Errorf("no files matching %q", info.Path) 17 | } 18 | 19 | rootEntry := &RootEntry{} 20 | var closers []func() error 21 | defer func() { 22 | closeAll(closers) 23 | }() 24 | 25 | for _, match := range matches { 26 | if !info.Recursive { 27 | entry, closer, err := handler.NewFileEntry(s, match) 28 | closers = append(closers, closer) 29 | if err != nil { 30 | return err 31 | } 32 | rootEntry.Append(entry) 33 | continue 34 | } 35 | 36 | if err := handler.WalkDir(s, match, func(path string, d fs.DirEntry, err error) error { 37 | if err != nil { 38 | return err 39 | } 40 | 41 | if d.IsDir() { 42 | entry, err := handler.NewDirEntry(s, path) 43 | if err != nil { 44 | return err 45 | } 46 | rootEntry.Append(entry) 47 | } else { 48 | entry, closer, err := handler.NewFileEntry(s, path) 49 | if err != nil { 50 | return err 51 | } 52 | closers = append(closers, closer) 53 | rootEntry.Append(entry) 54 | } 55 | 56 | return nil 57 | }); err != nil { 58 | return err 59 | } 60 | } 61 | 62 | return rootEntry.Write(s) 63 | } 64 | 65 | func closeAll(closers []func() error) { 66 | for _, closer := range closers { 67 | if closer != nil { 68 | _ = closer() 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /scp/filesystem.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "io/fs" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | "time" 11 | 12 | "github.com/charmbracelet/ssh" 13 | ) 14 | 15 | // fileSystemHandler is a Handler implementation for a given root path. 16 | type fileSystemHandler struct{ root string } 17 | 18 | var _ Handler = &fileSystemHandler{} 19 | 20 | // NewFileSystemHandler return a Handler based on the given dir. 21 | func NewFileSystemHandler(root string) Handler { 22 | return &fileSystemHandler{ 23 | root: filepath.Clean(root), 24 | } 25 | } 26 | 27 | func (h *fileSystemHandler) chtimes(path string, mtime, atime int64) error { 28 | if mtime == 0 || atime == 0 { 29 | return nil 30 | } 31 | if err := os.Chtimes( 32 | h.prefixed(path), 33 | time.Unix(atime, 0), 34 | time.Unix(mtime, 0), 35 | ); err != nil { 36 | return fmt.Errorf("failed to chtimes: %q: %w", path, err) 37 | } 38 | return nil 39 | } 40 | 41 | func (h *fileSystemHandler) prefixed(path string) string { 42 | path = filepath.Clean(path) 43 | if strings.HasPrefix(path, h.root) { 44 | return path 45 | } 46 | return filepath.Join(h.root, path) 47 | } 48 | 49 | func (h *fileSystemHandler) Glob(_ ssh.Session, s string) ([]string, error) { 50 | matches, err := filepath.Glob(h.prefixed(s)) 51 | if err != nil { 52 | return []string{}, err 53 | } 54 | 55 | for i, match := range matches { 56 | matches[i], err = filepath.Rel(h.root, match) 57 | if err != nil { 58 | return []string{}, err 59 | } 60 | } 61 | return matches, nil 62 | } 63 | 64 | func (h *fileSystemHandler) WalkDir(_ ssh.Session, path string, fn fs.WalkDirFunc) error { 65 | return filepath.WalkDir(h.prefixed(path), func(path string, d fs.DirEntry, err error) error { 66 | // if h.root is ./foo/bar, we don't want to server `bar` as the root, 67 | // but instead its contents. 68 | if path == h.root { 69 | return err 70 | } 71 | return fn(path, d, err) 72 | }) 73 | } 74 | 75 | func (h *fileSystemHandler) NewDirEntry(_ ssh.Session, name string) (*DirEntry, error) { 76 | path := h.prefixed(name) 77 | info, err := os.Stat(path) 78 | if err != nil { 79 | return nil, fmt.Errorf("failed to open dir: %q: %w", path, err) 80 | } 81 | return &DirEntry{ 82 | Children: []Entry{}, 83 | Name: info.Name(), 84 | Filepath: path, 85 | Mode: info.Mode(), 86 | Mtime: info.ModTime().Unix(), 87 | Atime: info.ModTime().Unix(), 88 | }, nil 89 | } 90 | 91 | func (h *fileSystemHandler) NewFileEntry(_ ssh.Session, name string) (*FileEntry, func() error, error) { 92 | path := h.prefixed(name) 93 | info, err := os.Stat(path) 94 | if err != nil { 95 | return nil, nil, fmt.Errorf("failed to stat %q: %w", path, err) 96 | } 97 | f, err := os.Open(path) 98 | if err != nil { 99 | return nil, nil, fmt.Errorf("failed to open %q: %w", path, err) 100 | } 101 | return &FileEntry{ 102 | Name: info.Name(), 103 | Filepath: path, 104 | Mode: info.Mode(), 105 | Size: info.Size(), 106 | Mtime: info.ModTime().Unix(), 107 | Atime: info.ModTime().Unix(), 108 | Reader: f, 109 | }, f.Close, nil 110 | } 111 | 112 | func (h *fileSystemHandler) Mkdir(_ ssh.Session, entry *DirEntry) error { 113 | if err := os.Mkdir(h.prefixed(entry.Filepath), entry.Mode); err != nil { 114 | return fmt.Errorf("failed to create dir: %q: %w", entry.Filepath, err) 115 | } 116 | return h.chtimes(entry.Filepath, entry.Mtime, entry.Atime) 117 | } 118 | 119 | func (h *fileSystemHandler) Write(_ ssh.Session, entry *FileEntry) (int64, error) { 120 | f, err := os.OpenFile(h.prefixed(entry.Filepath), os.O_TRUNC|os.O_RDWR|os.O_CREATE, entry.Mode) 121 | if err != nil { 122 | return 0, fmt.Errorf("failed to open file: %q: %w", entry.Filepath, err) 123 | } 124 | defer f.Close() //nolint:errcheck 125 | written, err := io.Copy(f, entry.Reader) 126 | if err != nil { 127 | return 0, fmt.Errorf("failed to write file: %q: %w", entry.Filepath, err) 128 | } 129 | if err := f.Close(); err != nil { 130 | return 0, fmt.Errorf("failed to close file: %q: %w", entry.Filepath, err) 131 | } 132 | return written, h.chtimes(entry.Filepath, entry.Mtime, entry.Atime) 133 | } 134 | -------------------------------------------------------------------------------- /scp/filesystem_test.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/fs" 7 | "os" 8 | "path/filepath" 9 | "testing" 10 | "testing/iotest" 11 | "time" 12 | 13 | "github.com/matryer/is" 14 | ) 15 | 16 | func TestFilesystem(t *testing.T) { 17 | mtime := time.Unix(1323853868, 0) 18 | atime := time.Unix(1380425711, 0) 19 | 20 | t.Run("scp -f", func(t *testing.T) { 21 | t.Run("file", func(t *testing.T) { 22 | is := is.New(t) 23 | 24 | dir := t.TempDir() 25 | h := NewFileSystemHandler(dir) 26 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644)) 27 | chtimesTree(t, dir, atime, mtime) 28 | 29 | session := setup(t, h, nil) 30 | bts, err := session.CombinedOutput("scp -f a.txt") 31 | is.NoErr(err) 32 | requireEqualGolden(t, bts) 33 | }) 34 | 35 | t.Run("glob", func(t *testing.T) { 36 | is := is.New(t) 37 | 38 | dir := t.TempDir() 39 | h := NewFileSystemHandler(dir) 40 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644)) 41 | is.NoErr(os.WriteFile(filepath.Join(dir, "b.txt"), []byte("another text file"), 0o644)) 42 | chtimesTree(t, dir, atime, mtime) 43 | 44 | session := setup(t, h, nil) 45 | bts, err := session.CombinedOutput("scp -f *.txt") 46 | is.NoErr(err) 47 | requireEqualGolden(t, bts) 48 | }) 49 | 50 | t.Run("invalid file", func(t *testing.T) { 51 | is := is.New(t) 52 | 53 | dir := t.TempDir() 54 | h := NewFileSystemHandler(dir) 55 | 56 | session := setup(t, h, nil) 57 | _, err := session.CombinedOutput("scp -f a.txt") 58 | is.True(err != nil) 59 | }) 60 | 61 | t.Run("recursive", func(t *testing.T) { 62 | is := is.New(t) 63 | 64 | dir := t.TempDir() 65 | h := NewFileSystemHandler(dir) 66 | 67 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755)) 68 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644)) 69 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644)) 70 | chtimesTree(t, dir, atime, mtime) 71 | 72 | session := setup(t, h, nil) 73 | bts, err := session.CombinedOutput("scp -r -f a") 74 | is.NoErr(err) 75 | requireEqualGolden(t, bts) 76 | }) 77 | 78 | t.Run("recursive glob", func(t *testing.T) { 79 | is := is.New(t) 80 | 81 | dir := t.TempDir() 82 | h := NewFileSystemHandler(dir) 83 | 84 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755)) 85 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644)) 86 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644)) 87 | chtimesTree(t, dir, atime, mtime) 88 | 89 | session := setup(t, h, nil) 90 | bts, err := session.CombinedOutput("scp -r -f a/*") 91 | is.NoErr(err) 92 | requireEqualGolden(t, bts) 93 | }) 94 | 95 | t.Run("recursive invalid file", func(t *testing.T) { 96 | is := is.New(t) 97 | 98 | dir := t.TempDir() 99 | h := NewFileSystemHandler(dir) 100 | 101 | session := setup(t, h, nil) 102 | _, err := session.CombinedOutput("scp -r -f a") 103 | is.True(err != nil) 104 | }) 105 | 106 | t.Run("recursive folder", func(t *testing.T) { 107 | is := is.New(t) 108 | 109 | dir := t.TempDir() 110 | h := NewFileSystemHandler(dir) 111 | 112 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755)) 113 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644)) 114 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644)) 115 | chtimesTree(t, dir, atime, mtime) 116 | 117 | session := setup(t, h, nil) 118 | bts, err := session.CombinedOutput("scp -r -f /") 119 | is.NoErr(err) 120 | requireEqualGolden(t, bts) 121 | }) 122 | }) 123 | 124 | t.Run("scp -t", func(t *testing.T) { 125 | t.Run("file", func(t *testing.T) { 126 | is := is.New(t) 127 | dir := t.TempDir() 128 | h := NewFileSystemHandler(dir) 129 | session := setup(t, nil, h) 130 | 131 | var in bytes.Buffer 132 | in.WriteString("T1183832947 0 1183833773 0\n") 133 | in.WriteString("C0644 6 a.txt\n") 134 | in.WriteString("hello\n") 135 | in.Write(NULL) 136 | session.Stdin = &in 137 | 138 | _, err := session.CombinedOutput("scp -t .") 139 | is.NoErr(err) 140 | 141 | bts, err := os.ReadFile(filepath.Join(dir, "a.txt")) 142 | is.NoErr(err) 143 | is.Equal("hello\n", string(bts)) 144 | }) 145 | 146 | t.Run("recursive", func(t *testing.T) { 147 | is := is.New(t) 148 | dir := t.TempDir() 149 | h := NewFileSystemHandler(dir) 150 | 151 | var in bytes.Buffer 152 | in.WriteString("T1183832947 0 1183833773 0\n") 153 | in.WriteString("D0755 0 folder1\n") 154 | in.WriteString("C0644 6 file1\n") 155 | in.WriteString("hello\n") 156 | in.Write(NULL) 157 | in.WriteString("D0755 0 folder2\n") 158 | in.WriteString("T1183832947 0 1183833773 0\n") 159 | in.WriteString("C0644 6 file2\n") 160 | in.WriteString("hello\n") 161 | in.Write(NULL) 162 | in.WriteString("E\n") 163 | in.WriteString("E\n") 164 | 165 | session := setup(t, nil, h) 166 | session.Stdin = &in 167 | _, err := session.CombinedOutput("scp -r -t .") 168 | is.NoErr(err) 169 | 170 | mtime := int64(1183832947) 171 | 172 | stat, err := os.Stat(filepath.Join(dir, "folder1")) 173 | is.NoErr(err) 174 | is.True(stat.IsDir()) 175 | // TODO: check how scp behaves 176 | is.True(stat.ModTime().Unix() != mtime) // should be different because the folder was later modified again 177 | 178 | stat, err = os.Stat(filepath.Join(dir, "folder1/file1")) 179 | is.NoErr(err) 180 | is.True(stat.ModTime().Unix() != mtime) 181 | 182 | stat, err = os.Stat(filepath.Join(dir, "folder1/folder2/file2")) 183 | is.NoErr(err) 184 | is.Equal(stat.ModTime().Unix(), mtime) 185 | }) 186 | }) 187 | 188 | t.Run("errors", func(t *testing.T) { 189 | t.Run("chtimes", func(t *testing.T) { 190 | h := &fileSystemHandler{t.TempDir()} 191 | is.New(t).True(h.chtimes("nope", 1212212, 323232) != nil) // should err 192 | }) 193 | 194 | t.Run("glob", func(t *testing.T) { 195 | t.Run("invalid glob", func(t *testing.T) { 196 | is := is.New(t) 197 | h := &fileSystemHandler{t.TempDir()} 198 | matches, err := h.Glob(nil, "[asda") 199 | is.True(err != nil) // should err 200 | is.Equal([]string{}, matches) 201 | }) 202 | }) 203 | 204 | t.Run("NewDirEntry", func(t *testing.T) { 205 | t.Run("do not exist", func(t *testing.T) { 206 | is := is.New(t) 207 | h := &fileSystemHandler{t.TempDir()} 208 | _, err := h.NewDirEntry(nil, "foo") 209 | is.True(err != nil) // should err 210 | }) 211 | }) 212 | 213 | t.Run("NewFileEntry", func(t *testing.T) { 214 | t.Run("do not exist", func(t *testing.T) { 215 | is := is.New(t) 216 | h := &fileSystemHandler{t.TempDir()} 217 | _, _, err := h.NewFileEntry(nil, "foo") 218 | is.True(err != nil) // should err 219 | }) 220 | }) 221 | 222 | t.Run("Mkdir", func(t *testing.T) { 223 | t.Run("parent do not exist", func(t *testing.T) { 224 | is := is.New(t) 225 | h := &fileSystemHandler{t.TempDir()} 226 | err := h.Mkdir(nil, &DirEntry{ 227 | Name: "foo", 228 | Filepath: "foo/bar/baz", 229 | Mode: 0o755, 230 | }) 231 | is.True(err != nil) // should err 232 | }) 233 | }) 234 | 235 | t.Run("Write", func(t *testing.T) { 236 | t.Run("parent do not exist", func(t *testing.T) { 237 | is := is.New(t) 238 | h := &fileSystemHandler{t.TempDir()} 239 | _, err := h.Write(nil, &FileEntry{ 240 | Name: "foo.txt", 241 | Filepath: "baz/foo.txt", 242 | Mode: 0o644, 243 | Size: 10, 244 | }) 245 | is.True(err != nil) // should err 246 | }) 247 | 248 | t.Run("reader fails", func(t *testing.T) { 249 | is := is.New(t) 250 | h := &fileSystemHandler{t.TempDir()} 251 | _, err := h.Write(nil, &FileEntry{ 252 | Name: "foo.txt", 253 | Filepath: "foo.txt", 254 | Mode: 0o644, 255 | Size: 10, 256 | Reader: iotest.ErrReader(fmt.Errorf("fake err")), 257 | }) 258 | is.True(err != nil) // should err 259 | }) 260 | }) 261 | }) 262 | } 263 | 264 | func chtimesTree(tb testing.TB, dir string, atime, mtime time.Time) { 265 | is.New(tb).NoErr(filepath.WalkDir(dir, func(path string, _ fs.DirEntry, err error) error { 266 | if err != nil { 267 | return err 268 | } 269 | return os.Chtimes(path, atime, mtime) 270 | })) 271 | } 272 | -------------------------------------------------------------------------------- /scp/fs.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "fmt" 5 | "io/fs" 6 | 7 | "github.com/charmbracelet/ssh" 8 | ) 9 | 10 | type fsHandler struct{ fsys fs.FS } 11 | 12 | var _ CopyToClientHandler = &fsHandler{} 13 | 14 | // NewFSReadHandler returns a read-only CopyToClientHandler that accepts any 15 | // fs.FS as input. 16 | func NewFSReadHandler(fsys fs.FS) CopyToClientHandler { 17 | return &fsHandler{fsys: fsys} 18 | } 19 | 20 | func (h *fsHandler) Glob(_ ssh.Session, s string) ([]string, error) { 21 | return fs.Glob(h.fsys, s) 22 | } 23 | 24 | func (h *fsHandler) WalkDir(_ ssh.Session, path string, fn fs.WalkDirFunc) error { 25 | return fs.WalkDir(h.fsys, path, fn) 26 | } 27 | 28 | func (h *fsHandler) NewDirEntry(_ ssh.Session, path string) (*DirEntry, error) { 29 | path = normalizePath(path) 30 | info, err := fs.Stat(h.fsys, path) 31 | if err != nil { 32 | return nil, fmt.Errorf("failed to open dir: %q: %w", path, err) 33 | } 34 | return &DirEntry{ 35 | Children: []Entry{}, 36 | Name: info.Name(), 37 | Filepath: path, 38 | Mode: info.Mode(), 39 | Mtime: info.ModTime().Unix(), 40 | Atime: info.ModTime().Unix(), 41 | }, nil 42 | } 43 | 44 | func (h *fsHandler) NewFileEntry(_ ssh.Session, path string) (*FileEntry, func() error, error) { 45 | info, err := fs.Stat(h.fsys, path) 46 | if err != nil { 47 | return nil, nil, fmt.Errorf("failed to stat %q: %w", path, err) 48 | } 49 | f, err := h.fsys.Open(path) 50 | if err != nil { 51 | return nil, nil, fmt.Errorf("failed to open %q: %w", path, err) 52 | } 53 | return &FileEntry{ 54 | Name: info.Name(), 55 | Filepath: path, 56 | Mode: info.Mode(), 57 | Size: info.Size(), 58 | Mtime: info.ModTime().Unix(), 59 | Atime: info.ModTime().Unix(), 60 | Reader: f, 61 | }, f.Close, nil 62 | } 63 | -------------------------------------------------------------------------------- /scp/fs_test.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | "time" 8 | 9 | "github.com/matryer/is" 10 | ) 11 | 12 | func TestFS(t *testing.T) { 13 | mtime := time.Unix(1323853868, 0) 14 | atime := time.Unix(1380425711, 0) 15 | 16 | t.Run("file", func(t *testing.T) { 17 | is := is.New(t) 18 | 19 | dir := t.TempDir() 20 | h := NewFSReadHandler(os.DirFS(dir)) 21 | 22 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644)) 23 | chtimesTree(t, dir, atime, mtime) 24 | 25 | session := setup(t, h, nil) 26 | bts, err := session.CombinedOutput("scp -f a.txt") 27 | is.NoErr(err) 28 | requireEqualGolden(t, bts) 29 | }) 30 | 31 | t.Run("glob", func(t *testing.T) { 32 | is := is.New(t) 33 | 34 | dir := t.TempDir() 35 | h := NewFSReadHandler(os.DirFS(dir)) 36 | is.NoErr(os.WriteFile(filepath.Join(dir, "a.txt"), []byte("a text file"), 0o644)) 37 | is.NoErr(os.WriteFile(filepath.Join(dir, "b.txt"), []byte("another text file"), 0o644)) 38 | chtimesTree(t, dir, atime, mtime) 39 | 40 | session := setup(t, h, nil) 41 | bts, err := session.CombinedOutput("scp -f *.txt") 42 | is.NoErr(err) 43 | requireEqualGolden(t, bts) 44 | }) 45 | 46 | t.Run("invalid file", func(t *testing.T) { 47 | is := is.New(t) 48 | 49 | dir := t.TempDir() 50 | h := NewFSReadHandler(os.DirFS(dir)) 51 | 52 | session := setup(t, h, nil) 53 | _, err := session.CombinedOutput("scp -f a.txt") 54 | is.True(err != nil) 55 | }) 56 | 57 | t.Run("recursive", func(t *testing.T) { 58 | is := is.New(t) 59 | 60 | dir := t.TempDir() 61 | h := NewFSReadHandler(os.DirFS(dir)) 62 | 63 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755)) 64 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644)) 65 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644)) 66 | chtimesTree(t, dir, atime, mtime) 67 | 68 | session := setup(t, h, nil) 69 | bts, err := session.CombinedOutput("scp -r -f a") 70 | is.NoErr(err) 71 | requireEqualGolden(t, bts) 72 | }) 73 | 74 | t.Run("recursive glob", func(t *testing.T) { 75 | is := is.New(t) 76 | 77 | dir := t.TempDir() 78 | h := NewFSReadHandler(os.DirFS(dir)) 79 | 80 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755)) 81 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644)) 82 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644)) 83 | chtimesTree(t, dir, atime, mtime) 84 | 85 | session := setup(t, h, nil) 86 | bts, err := session.CombinedOutput("scp -r -f a/*") 87 | is.NoErr(err) 88 | requireEqualGolden(t, bts) 89 | }) 90 | 91 | t.Run("recursive folder", func(t *testing.T) { 92 | is := is.New(t) 93 | 94 | dir := t.TempDir() 95 | h := NewFileSystemHandler(dir) 96 | 97 | is.NoErr(os.MkdirAll(filepath.Join(dir, "a/b/c/d/e"), 0o755)) 98 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c.txt"), []byte("c text file"), 0o644)) 99 | is.NoErr(os.WriteFile(filepath.Join(dir, "a/b/c/d/e/e.txt"), []byte("e text file"), 0o644)) 100 | chtimesTree(t, dir, atime, mtime) 101 | 102 | session := setup(t, h, nil) 103 | bts, err := session.CombinedOutput("scp -r -f /") 104 | is.NoErr(err) 105 | requireEqualGolden(t, bts) 106 | }) 107 | 108 | t.Run("recursive invalid file", func(t *testing.T) { 109 | is := is.New(t) 110 | 111 | dir := t.TempDir() 112 | h := NewFSReadHandler(os.DirFS(dir)) 113 | 114 | session := setup(t, h, nil) 115 | _, err := session.CombinedOutput("scp -r -f a") 116 | is.True(err != nil) 117 | }) 118 | 119 | t.Run("errors", func(t *testing.T) { 120 | t.Run("glob", func(t *testing.T) { 121 | t.Run("invalid glob", func(t *testing.T) { 122 | is := is.New(t) 123 | h := &fsHandler{os.DirFS(t.TempDir())} 124 | matches, err := h.Glob(nil, "[asda") 125 | is.True(err != nil) // should err 126 | is.Equal(nil, matches) 127 | }) 128 | }) 129 | 130 | t.Run("NewDirEntry", func(t *testing.T) { 131 | t.Run("do not exist", func(t *testing.T) { 132 | is := is.New(t) 133 | h := &fsHandler{os.DirFS(t.TempDir())} 134 | _, err := h.NewDirEntry(nil, "foo") 135 | is.True(err != nil) // should err 136 | }) 137 | }) 138 | 139 | t.Run("NewFileEntry", func(t *testing.T) { 140 | t.Run("do not exist", func(t *testing.T) { 141 | is := is.New(t) 142 | h := &fsHandler{os.DirFS(t.TempDir())} 143 | _, _, err := h.NewFileEntry(nil, "foo") 144 | is.True(err != nil) // should err 145 | }) 146 | }) 147 | }) 148 | } 149 | -------------------------------------------------------------------------------- /scp/limit_reader.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "io" 5 | "sync" 6 | ) 7 | 8 | func newLimitReader(r io.Reader, limit int) io.Reader { 9 | return &limitReader{ 10 | r: r, 11 | left: limit, 12 | } 13 | } 14 | 15 | type limitReader struct { 16 | r io.Reader 17 | 18 | lock sync.Mutex 19 | left int 20 | } 21 | 22 | func (r *limitReader) Read(b []byte) (int, error) { 23 | r.lock.Lock() 24 | defer r.lock.Unlock() 25 | 26 | if r.left <= 0 { 27 | return 0, io.EOF 28 | } 29 | if len(b) > r.left { 30 | b = b[0:r.left] 31 | } 32 | n, err := r.r.Read(b) 33 | r.left -= n 34 | return n, err 35 | } 36 | -------------------------------------------------------------------------------- /scp/limit_reader_test.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "testing" 7 | 8 | "github.com/matryer/is" 9 | ) 10 | 11 | func TestLimitedReader(t *testing.T) { 12 | t.Run("partial", func(t *testing.T) { 13 | is := is.New(t) 14 | var b bytes.Buffer 15 | b.WriteString("writing some bytes") 16 | r := newLimitReader(&b, 7) 17 | 18 | bts, err := io.ReadAll(r) 19 | is.NoErr(err) 20 | is.Equal("writing", string(bts)) 21 | }) 22 | 23 | t.Run("full", func(t *testing.T) { 24 | is := is.New(t) 25 | var b bytes.Buffer 26 | b.WriteString("some text") 27 | r := newLimitReader(&b, b.Len()) 28 | 29 | bts, err := io.ReadAll(r) 30 | is.NoErr(err) 31 | is.Equal("some text", string(bts)) 32 | }) 33 | 34 | t.Run("pass limit", func(t *testing.T) { 35 | is := is.New(t) 36 | var b bytes.Buffer 37 | b.WriteString("another text") 38 | r := newLimitReader(&b, b.Len()+10) 39 | 40 | bts, err := io.ReadAll(r) 41 | is.NoErr(err) 42 | is.Equal("another text", string(bts)) 43 | }) 44 | } 45 | -------------------------------------------------------------------------------- /scp/scp.go: -------------------------------------------------------------------------------- 1 | // Package scp provides a SCP middleware for wish. 2 | package scp 3 | 4 | import ( 5 | "fmt" 6 | "io" 7 | "io/fs" 8 | "path/filepath" 9 | "runtime" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/charmbracelet/ssh" 14 | "github.com/charmbracelet/wish" 15 | ) 16 | 17 | // CopyToClientHandler is a handler that can be implemented to handle files 18 | // being copied from the server to the client. 19 | type CopyToClientHandler interface { 20 | // Glob should be implemented if you want to provide server-side globbing 21 | // support. 22 | // 23 | // A minimal implementation to disable it is to return `[]string{s}, nil`. 24 | // 25 | // Note: if your other functions expect a relative path, make sure that 26 | // your Glob implementation returns relative paths as well. 27 | Glob(ssh.Session, string) ([]string, error) 28 | 29 | // WalkDir must be implemented if you want to allow recursive copies. 30 | WalkDir(ssh.Session, string, fs.WalkDirFunc) error 31 | 32 | // NewDirEntry should provide a *DirEntry for the given path. 33 | NewDirEntry(ssh.Session, string) (*DirEntry, error) 34 | 35 | // NewFileEntry should provide a *FileEntry for the given path. 36 | // Users may also provide a closing function. 37 | NewFileEntry(ssh.Session, string) (*FileEntry, func() error, error) 38 | } 39 | 40 | // CopyFromClientHandler is a handler that can be implemented to handle files 41 | // being copied from the client to the server. 42 | type CopyFromClientHandler interface { 43 | // Mkdir should created the given dir. 44 | // Note that this usually shouldn't use os.MkdirAll and the like. 45 | Mkdir(ssh.Session, *DirEntry) error 46 | 47 | // Write should write the given file. 48 | Write(ssh.Session, *FileEntry) (int64, error) 49 | } 50 | 51 | // Handler is a interface that can be implemented to handle both SCP 52 | // directions. 53 | type Handler interface { 54 | CopyFromClientHandler 55 | CopyToClientHandler 56 | } 57 | 58 | // Middleware provides a wish middleware using the given CopyToClientHandler 59 | // and CopyFromClientHandler. 60 | func Middleware(rh CopyToClientHandler, wh CopyFromClientHandler) wish.Middleware { 61 | return func(sh ssh.Handler) ssh.Handler { 62 | return func(s ssh.Session) { 63 | info := GetInfo(s.Command()) 64 | if !info.Ok { 65 | sh(s) 66 | return 67 | } 68 | 69 | var err error 70 | switch info.Op { 71 | case OpCopyToClient: 72 | if rh == nil { 73 | err = fmt.Errorf("no handler provided for scp -f") 74 | break 75 | } 76 | err = copyToClient(s, info, rh) 77 | case OpCopyFromClient: 78 | if wh == nil { 79 | err = fmt.Errorf("no handler provided for scp -t") 80 | break 81 | } 82 | err = copyFromClient(s, info, wh) 83 | } 84 | if err != nil { 85 | wish.Fatal(s, err) 86 | return 87 | } 88 | } 89 | } 90 | } 91 | 92 | // NULL is an array with a single NULL byte. 93 | var NULL = []byte{'\x00'} 94 | 95 | // Entry defines something that knows how to write itself and its path. 96 | type Entry interface { 97 | // Write the current entry in SCP format. 98 | Write(io.Writer) error 99 | 100 | path() string 101 | } 102 | 103 | // AppendableEntry defines a special kind of Entry, which can contain 104 | // children. 105 | type AppendableEntry interface { 106 | // Write the current entry in SCP format. 107 | Write(io.Writer) error 108 | 109 | // Append another entry to the current entry. 110 | Append(entry Entry) 111 | } 112 | 113 | // FileEntry is an Entry that reads from a Reader, defining a file and 114 | // its contents. 115 | type FileEntry struct { 116 | Name string 117 | Filepath string 118 | Mode fs.FileMode 119 | Size int64 120 | Reader io.Reader 121 | Atime int64 122 | Mtime int64 123 | } 124 | 125 | func (e *FileEntry) path() string { return e.Filepath } 126 | 127 | // Write a file to the given writer. 128 | func (e *FileEntry) Write(w io.Writer) error { 129 | if e.Mtime > 0 && e.Atime > 0 { 130 | if _, err := fmt.Fprintf(w, "T%d 0 %d 0\n", e.Mtime, e.Atime); err != nil { 131 | return fmt.Errorf("failed to write file: %q: %w", e.Filepath, err) 132 | } 133 | } 134 | if _, err := fmt.Fprintf(w, "C%s %d %s\n", octalPerms(e.Mode), e.Size, e.Name); err != nil { 135 | return fmt.Errorf("failed to write file: %q: %w", e.Filepath, err) 136 | } 137 | 138 | if _, err := io.Copy(w, e.Reader); err != nil { 139 | return fmt.Errorf("failed to read file: %q: %w", e.Filepath, err) 140 | } 141 | 142 | if _, err := w.Write(NULL); err != nil { 143 | return fmt.Errorf("failed to write file: %q: %w", e.Filepath, err) 144 | } 145 | return nil 146 | } 147 | 148 | // RootEntry is a root entry that can only have children. 149 | type RootEntry []Entry 150 | 151 | // Appennd the given entry to a child directory, or the the itself if 152 | // none matches. 153 | func (e *RootEntry) Append(entry Entry) { 154 | parent := normalizePath(filepath.Dir(entry.path())) 155 | 156 | for _, child := range *e { 157 | switch dir := child.(type) { 158 | case *DirEntry: 159 | if child.path() == parent { 160 | dir.Children = append(dir.Children, entry) 161 | return 162 | } 163 | if strings.HasPrefix(parent, normalizePath(dir.Filepath)) { 164 | dir.Append(entry) 165 | return 166 | } 167 | default: 168 | continue 169 | } 170 | } 171 | 172 | *e = append(*e, entry) 173 | } 174 | 175 | // Write recursively writes all the children to the given writer. 176 | func (e *RootEntry) Write(w io.Writer) error { 177 | for _, child := range *e { 178 | if err := child.Write(w); err != nil { 179 | return err 180 | } 181 | } 182 | return nil 183 | } 184 | 185 | // DirEntry is an Entry with mode, possibly children, and possibly a 186 | // parent. 187 | type DirEntry struct { 188 | Children []Entry 189 | Name string 190 | Filepath string 191 | Mode fs.FileMode 192 | Atime int64 193 | Mtime int64 194 | } 195 | 196 | func (e *DirEntry) path() string { return e.Filepath } 197 | 198 | // Write the current dir entry, all its contents (recursively), and the 199 | // dir closing to the given writer. 200 | func (e *DirEntry) Write(w io.Writer) error { 201 | if e.Mtime > 0 && e.Atime > 0 { 202 | if _, err := fmt.Fprintf(w, "T%d 0 %d 0\n", e.Mtime, e.Atime); err != nil { 203 | return fmt.Errorf("failed to write dir: %q: %w", e.Filepath, err) 204 | } 205 | } 206 | if _, err := fmt.Fprintf(w, "D%s 0 %s\n", octalPerms(e.Mode), e.Name); err != nil { 207 | return fmt.Errorf("failed to write dir: %q: %w", e.Filepath, err) 208 | } 209 | 210 | for _, child := range e.Children { 211 | if err := child.Write(w); err != nil { 212 | return err 213 | } 214 | } 215 | 216 | if _, err := fmt.Fprint(w, "E\n"); err != nil { 217 | return fmt.Errorf("failed to write dir: %q: %w", e.Filepath, err) 218 | } 219 | return nil 220 | } 221 | 222 | // Appends an entry to the folder or their children. 223 | func (e *DirEntry) Append(entry Entry) { 224 | parent := normalizePath(filepath.Dir(entry.path())) 225 | 226 | for _, child := range e.Children { 227 | switch dir := child.(type) { 228 | case *DirEntry: 229 | if child.path() == parent { 230 | dir.Children = append(dir.Children, entry) 231 | return 232 | } 233 | if strings.HasPrefix(parent, normalizePath(dir.path())) { 234 | dir.Append(entry) 235 | return 236 | } 237 | default: 238 | continue 239 | } 240 | } 241 | 242 | e.Children = append(e.Children, entry) 243 | } 244 | 245 | // Op defines which kind of SCP Operation is going on. 246 | type Op byte 247 | 248 | const ( 249 | // OpCopyToClient is when a file is being copied from the server to the client. 250 | OpCopyToClient Op = 'f' 251 | 252 | // OpCopyFromClient is when a file is being copied from the client into the server. 253 | OpCopyFromClient Op = 't' 254 | ) 255 | 256 | // Info provides some information about the current SCP Operation. 257 | type Info struct { 258 | // Ok is true if the current session is a SCP. 259 | Ok bool 260 | 261 | // Recursive is true if its a recursive SCP. 262 | Recursive bool 263 | 264 | // Path is the server path of the scp operation. 265 | Path string 266 | 267 | // Op is the SCP operation kind. 268 | Op Op 269 | } 270 | 271 | // GetInfo return information about the given command. 272 | func GetInfo(cmd []string) Info { 273 | info := Info{} 274 | if len(cmd) == 0 || cmd[0] != "scp" { 275 | return info 276 | } 277 | 278 | for i, p := range cmd { 279 | switch p { 280 | case "-r": 281 | info.Recursive = true 282 | case "-f": 283 | info.Op = OpCopyToClient 284 | info.Path = cmd[i+1] 285 | case "-t": 286 | info.Op = OpCopyFromClient 287 | info.Path = cmd[i+1] 288 | } 289 | } 290 | 291 | info.Ok = true 292 | return info 293 | } 294 | 295 | func octalPerms(info fs.FileMode) string { 296 | return "0" + strconv.FormatUint(uint64(info.Perm()), 8) 297 | } 298 | 299 | func normalizePath(p string) string { 300 | p = filepath.Clean(p) 301 | if runtime.GOOS == "windows" { 302 | return strings.ReplaceAll(p, "\\", "/") 303 | } 304 | return p 305 | } 306 | -------------------------------------------------------------------------------- /scp/scp_test.go: -------------------------------------------------------------------------------- 1 | package scp 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "path/filepath" 7 | "runtime" 8 | "testing" 9 | 10 | "github.com/charmbracelet/ssh" 11 | "github.com/charmbracelet/wish/testsession" 12 | "github.com/google/go-cmp/cmp" 13 | "github.com/matryer/is" 14 | gossh "golang.org/x/crypto/ssh" 15 | ) 16 | 17 | func TestGetInfo(t *testing.T) { 18 | t.Run("no exec", func(t *testing.T) { 19 | is := is.New(t) 20 | info := GetInfo([]string{}) 21 | is.Equal(info.Ok, false) 22 | }) 23 | 24 | t.Run("exec is not scp", func(t *testing.T) { 25 | is := is.New(t) 26 | info := GetInfo([]string{"not-scp"}) 27 | is.Equal(info.Ok, false) 28 | }) 29 | 30 | t.Run("scp no recursive", func(t *testing.T) { 31 | is := is.New(t) 32 | info := GetInfo([]string{"scp", "-f", "file"}) 33 | is.True(info.Ok) 34 | is.Equal(info.Recursive, false) 35 | is.Equal("file", info.Path) 36 | is.Equal(info.Op, OpCopyToClient) 37 | }) 38 | 39 | t.Run("scp recursive", func(t *testing.T) { 40 | is := is.New(t) 41 | info := GetInfo([]string{"scp", "-r", "--some-ignored-flag", "-f", "file", "ignored-arg"}) 42 | is.True(info.Ok) 43 | is.True(info.Recursive) 44 | is.Equal("file", info.Path) 45 | is.Equal(info.Op, OpCopyToClient) 46 | }) 47 | 48 | t.Run("scp op copy from client", func(t *testing.T) { 49 | is := is.New(t) 50 | info := GetInfo([]string{"scp", "-t", "file"}) 51 | is.True(info.Ok) 52 | is.Equal(info.Op, OpCopyFromClient) 53 | is.Equal("file", info.Path) 54 | }) 55 | } 56 | 57 | func TestNoDirRootEntry(t *testing.T) { 58 | is := is.New(t) 59 | root := RootEntry{} 60 | 61 | var f1 bytes.Buffer 62 | f1.WriteString("hello from file f1\n") 63 | 64 | var f2 bytes.Buffer 65 | f2.WriteString("hello from file f2\nwith multiple lines :)\n") 66 | 67 | dir := &DirEntry{ 68 | Children: []Entry{}, 69 | Name: "dir1", 70 | Filepath: "dir1", 71 | Mode: 0o755, 72 | } 73 | 74 | dir.Append(&FileEntry{ 75 | Name: "f2", 76 | Filepath: "f2", 77 | Mode: 0o600, 78 | Size: int64(f2.Len()), 79 | Reader: &f2, 80 | }) 81 | 82 | root.Append(&FileEntry{ 83 | Name: "f1", 84 | Filepath: "f1", 85 | Mode: 0o644, 86 | Size: int64(f1.Len()), 87 | Reader: &f1, 88 | }) 89 | 90 | root.Append(dir) 91 | 92 | var out bytes.Buffer 93 | is.NoErr(root.Write(&out)) 94 | 95 | requireEqualGolden(t, out.Bytes()) 96 | } 97 | 98 | func TestInvalidOps(t *testing.T) { 99 | t.Run("not scp", func(t *testing.T) { 100 | _, err := setup(t, nil, nil).CombinedOutput("not-scp ign") 101 | is.New(t).NoErr(err) 102 | }) 103 | 104 | t.Run("copy to client", func(t *testing.T) { 105 | _, err := setup(t, nil, nil).CombinedOutput("scp -t .") 106 | is.New(t).True(err != nil) 107 | }) 108 | 109 | t.Run("copy from client", func(t *testing.T) { 110 | _, err := setup(t, nil, nil).CombinedOutput("scp -f .") 111 | is.New(t).True(err != nil) 112 | }) 113 | } 114 | 115 | func setup(tb testing.TB, rh CopyToClientHandler, wh CopyFromClientHandler) *gossh.Session { 116 | tb.Helper() 117 | return testsession.New(tb, &ssh.Server{ 118 | Handler: Middleware(rh, wh)(func(s ssh.Session) { 119 | s.Exit(0) 120 | }), 121 | }, nil) 122 | } 123 | 124 | func requireEqualGolden(tb testing.TB, out []byte) { 125 | tb.Helper() 126 | is := is.New(tb) 127 | 128 | fixOutput := func(bts []byte) []byte { 129 | bts = bytes.ReplaceAll(bts, []byte("\r"), []byte("")) 130 | if runtime.GOOS == "windows" { 131 | // perms always come different on Windows because, well, its Windows. 132 | bts = bytes.ReplaceAll(bts, []byte("0666"), []byte("0644")) 133 | bts = bytes.ReplaceAll(bts, []byte("0777"), []byte("0755")) 134 | } 135 | return bytes.ReplaceAll(bts, NULL, []byte("")) 136 | } 137 | 138 | out = fixOutput(out) 139 | golden := "testdata/" + tb.Name() + ".test" 140 | if os.Getenv("UPDATE") != "" { 141 | is.NoErr(os.MkdirAll(filepath.Dir(golden), 0o755)) 142 | is.NoErr(os.WriteFile(golden, out, 0o655)) 143 | } 144 | 145 | gbts, err := os.ReadFile(golden) 146 | is.NoErr(err) 147 | gbts = fixOutput(gbts) 148 | 149 | if diff := cmp.Diff(string(gbts), string(out)); diff != "" { 150 | tb.Fatal("files do not match:", diff) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /scp/testdata/TestFS/file.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | C0644 11 a.txt 3 | a text file -------------------------------------------------------------------------------- /scp/testdata/TestFS/glob.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | C0644 11 a.txt 3 | a text fileT1323853868 0 1323853868 0 4 | C0644 17 b.txt 5 | another text file -------------------------------------------------------------------------------- /scp/testdata/TestFS/recursive.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | D0755 0 a 3 | T1323853868 0 1323853868 0 4 | D0755 0 b 5 | T1323853868 0 1323853868 0 6 | D0755 0 c 7 | T1323853868 0 1323853868 0 8 | D0755 0 d 9 | T1323853868 0 1323853868 0 10 | D0755 0 e 11 | T1323853868 0 1323853868 0 12 | C0644 11 e.txt 13 | e text fileE 14 | E 15 | E 16 | T1323853868 0 1323853868 0 17 | C0644 11 c.txt 18 | c text fileE 19 | E 20 | -------------------------------------------------------------------------------- /scp/testdata/TestFS/recursive_folder.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | D0755 0 a 3 | T1323853868 0 1323853868 0 4 | D0755 0 b 5 | T1323853868 0 1323853868 0 6 | D0755 0 c 7 | T1323853868 0 1323853868 0 8 | D0755 0 d 9 | T1323853868 0 1323853868 0 10 | D0755 0 e 11 | T1323853868 0 1323853868 0 12 | C0644 11 e.txt 13 | e text fileE 14 | E 15 | E 16 | T1323853868 0 1323853868 0 17 | C0644 11 c.txt 18 | c text fileE 19 | E 20 | -------------------------------------------------------------------------------- /scp/testdata/TestFS/recursive_glob.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | D0755 0 b 3 | T1323853868 0 1323853868 0 4 | D0755 0 c 5 | T1323853868 0 1323853868 0 6 | D0755 0 d 7 | T1323853868 0 1323853868 0 8 | D0755 0 e 9 | T1323853868 0 1323853868 0 10 | C0644 11 e.txt 11 | e text fileE 12 | E 13 | E 14 | T1323853868 0 1323853868 0 15 | C0644 11 c.txt 16 | c text fileE 17 | -------------------------------------------------------------------------------- /scp/testdata/TestFilesystem/scp_-f/file.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | C0644 11 a.txt 3 | a text file -------------------------------------------------------------------------------- /scp/testdata/TestFilesystem/scp_-f/glob.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | C0644 11 a.txt 3 | a text fileT1323853868 0 1323853868 0 4 | C0644 17 b.txt 5 | another text file -------------------------------------------------------------------------------- /scp/testdata/TestFilesystem/scp_-f/recursive.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | D0755 0 a 3 | T1323853868 0 1323853868 0 4 | D0755 0 b 5 | T1323853868 0 1323853868 0 6 | D0755 0 c 7 | T1323853868 0 1323853868 0 8 | D0755 0 d 9 | T1323853868 0 1323853868 0 10 | D0755 0 e 11 | T1323853868 0 1323853868 0 12 | C0644 11 e.txt 13 | e text fileE 14 | E 15 | E 16 | T1323853868 0 1323853868 0 17 | C0644 11 c.txt 18 | c text fileE 19 | E 20 | -------------------------------------------------------------------------------- /scp/testdata/TestFilesystem/scp_-f/recursive_folder.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | D0755 0 a 3 | T1323853868 0 1323853868 0 4 | D0755 0 b 5 | T1323853868 0 1323853868 0 6 | D0755 0 c 7 | T1323853868 0 1323853868 0 8 | D0755 0 d 9 | T1323853868 0 1323853868 0 10 | D0755 0 e 11 | T1323853868 0 1323853868 0 12 | C0644 11 e.txt 13 | e text fileE 14 | E 15 | E 16 | T1323853868 0 1323853868 0 17 | C0644 11 c.txt 18 | c text fileE 19 | E 20 | -------------------------------------------------------------------------------- /scp/testdata/TestFilesystem/scp_-f/recursive_glob.test: -------------------------------------------------------------------------------- 1 | T1323853868 0 1323853868 0 2 | D0755 0 b 3 | T1323853868 0 1323853868 0 4 | D0755 0 c 5 | T1323853868 0 1323853868 0 6 | D0755 0 d 7 | T1323853868 0 1323853868 0 8 | D0755 0 e 9 | T1323853868 0 1323853868 0 10 | C0644 11 e.txt 11 | e text fileE 12 | E 13 | E 14 | T1323853868 0 1323853868 0 15 | C0644 11 c.txt 16 | c text fileE 17 | -------------------------------------------------------------------------------- /scp/testdata/TestNoDirRootEntry.test: -------------------------------------------------------------------------------- 1 | C0644 19 f1 2 | hello from file f1 3 | D0755 0 dir1 4 | C0600 42 f2 5 | hello from file f2 6 | with multiple lines :) 7 | E 8 | -------------------------------------------------------------------------------- /testdata/another-ca-cert.pub: -------------------------------------------------------------------------------- 1 | ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAIJ9gNYnZ0HxMWDQ0Mu6vQiygdMYBKPN8DXJej/8Xi/h/AAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toAAAAAAAAAAAAAAAABAAAABGlkLTIAAAAHAAAAA2ZvbwAAAABib9ZIAAAAAGJv1roAAAAAAAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACBKEE2aUq61ZKQ3sFwf2aax1XW1lSg2VHiLX/WBSIuYyQAAAFMAAAALc3NoLWVkMjU1MTkAAABAuQW2puuCo9hZ4+NlGynMJ0rLcdznY2wIeLdr9ZyAncI0bRm5oU9aaWkUNYFjZXxytqO+Tpr8IIOTDj95VYeQAQ== carlos@darkstar 2 | -------------------------------------------------------------------------------- /testdata/authorized_keys: -------------------------------------------------------------------------------- 1 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIM6LN9MSONoI2Dak7GSAy1vTY92NcioIuZqBnk0xmYR2 k1@test 2 | ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQChxV3pJRnXP7crH+4xxH8skCF/Bs8JX8VTjlS4dpLYzXMUcr0ls0DVwgIkIHvXQtqhR4ymgzciUNTTTYGPLAsda47MVqChO2Kxb+I215ApOmMt11lLX6l1Mp7xO35BYR+jC+s4H8VcespUQbWvASHKGZvhD1cri/FttjdCVs7Gqz7U5Cpo+Ym7UZ6TSiBmEd7zQkg4gR1uR4K8/5oJGpaQDDZr/QZJDGat//qvMAKtPkxomYVzHPnflFdsUIwMJHVver+JqKTMEZm2aDrOji4KpHosvfcbmIlx04N99TdT/0oNIQR1tpUsT/kdc44AqKsKUt7Os6kwYiDrjQlVIjpXPTCrdddgnl+/otH7pFsgVUjgCXz6lvhmV5HYhbJM45UEeDrmfxB0wLhC5J4fmvu4EcJtvO7vgg5PD51NpN4iFdUSj91fjskrLsbI+Do+KjcxhOdQxglZ6JwUV3ljFpPjINyawCveba9DRXO7CL/gLgwEwuIU8HXrgKFSXATSWsk= k4@test 3 | ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdJkpQAr3zhC+grKMexj8zgJIuAQ/2LR59RvXemEAovd671Et356cmHnCDmUvUlH/70xQdyL3n68tzu2ZEzKheQP5vz05CAFXTi7rlMvhtz632mLMPlU3lGuP+A6rzqNSnTtrIa2Q3Fe2ir6N+ad782J8g6frGJaVfA/G7j/M1JwyDJWzUS3HvDHDO+qFze71h0/o9W1+VoRaSfD67BzPQumkEkt/CilSPU8VKRP3q/FIeIrgTBhNh17SX/qlnyrJipDTF1QtXUOK4H5TsEE0S13z8a4Wo37kRWQPxdjWyfX9tBjsN86n+R7OGSXXdi10n9THrisdgx2GKsk1HjY+u5YlDpDysFLBs6j4nWeTxnrjgx6HUqvMk3mdqrAKHTglt34OUQtB463GMgCW85w+ni8ebPKlt5YQsXalilcoI4K7fakyXe+o9Y0sCwE3SLXEJhtd/Esz1pVzvMBCshpRknBPFh/gs/i1YuL0SJqI2BGBFs0d/ARwqUQSoXXBTJPc= k5@test 4 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJ7InQIj/ROngoWWb6kXTcTJd8+u5skDfGm8JJxRugMB k2@test 5 | ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDgteu96TZLd3iG11D5NqBsQRvhW2I6iD/ycwOiWFjFyv4MAHaDFiIazbeQSbi++1+5vspeNuv9AKJFgG0SpnjMLQM0rJb5DIsuRxGOAS/oh82yNCxYcW2+eXcqUDL4V+fZ6eIqtSIBrPQY89/CbZ4nFtw7+941gmFa2+7Wj9vLk4GTiyu/jQsbGnAZUCMvce1jFZ9XDMYSYzXEtkqhBT6eYDd7xMQejovszJfPqlKDxpMZxpaDsQGf+00IJPZUUxkX62eAmrlX1q4XO+m2zIjGpf/gdNKHEMXQrvBWdvwg0rat2i+PCW4Rbwx7wHBBWPRqEPjcVTfwvOWoZGGU3TSX8M7Gcj+ZvAD/uV7DWcNi61Obtw/6PXYvKFZWcZ1sHxUTI84CUcVLSL55hOtJqCuJXmUdKdBcJLyo3NValIjIZn+ljn6biVAr00nGo06nO4j2eTE2ZLOZFEB6rHuf1iaT18EiJgnJGrB7HY4+KUoUIzmvzQrKxxLbIe957hnx+TE= k6@test 6 | 7 | # a commented line, and the previous line was empty 8 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMJlb/qf2B2kMNdBxfpCQqI2ctPcsOkdZGVh5zTRhKtH k3@test 9 | -------------------------------------------------------------------------------- /testdata/ca: -------------------------------------------------------------------------------- 1 | -----BEGIN OPENSSH PRIVATE KEY----- 2 | b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW 3 | QyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcwAAAJgZFHXsGRR1 4 | 7AAAAAtzc2gtZWQyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcw 5 | AAAEDYnnbXzbn0SnOKCf0ByXm1+FLnqJC+ZErxo2SgaFeVCWbxxtpfzUaQA5IdOuI+QU6U 6 | UMTxj2es4IJiKnXFVhVzAAAAD2Nhcmxvc0BkYXJrc3RhcgECAwQFBg== 7 | -----END OPENSSH PRIVATE KEY----- 8 | -------------------------------------------------------------------------------- /testdata/ca.pub: -------------------------------------------------------------------------------- 1 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGbxxtpfzUaQA5IdOuI+QU6UUMTxj2es4IJiKnXFVhVz carlos@darkstar 2 | -------------------------------------------------------------------------------- /testdata/expired-cert.pub: -------------------------------------------------------------------------------- 1 | ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAIB7uQi6D84xl48p0tQjh7sXk9cpzjEcqsu/ZYjT4k3j7AAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toAAAAAAAAAAAAAAAABAAAABGlkLTIAAAAHAAAAA2ZvbwAAAABib8u8AAAAAGJvzAQAAAAAAAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcwAAAFMAAAALc3NoLWVkMjU1MTkAAABA5vfSWp5U1jsyM9u7jO8JmhMQqsxakkUUCwOVYqSVe65h06ANCmrYowmMuYVowvTaXkbdF5RhNjQxqE2fRNa4Bw== carlos@darkstar 2 | -------------------------------------------------------------------------------- /testdata/foo: -------------------------------------------------------------------------------- 1 | -----BEGIN OPENSSH PRIVATE KEY----- 2 | b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW 3 | QyNTUxOQAAACBpREs3uF5U48eRJovplGwla5BPyZBlBV/Apz2eq+LaAAAAAJj5F0Fq+RdB 4 | agAAAAtzc2gtZWQyNTUxOQAAACBpREs3uF5U48eRJovplGwla5BPyZBlBV/Apz2eq+LaAA 5 | AAAEBYbFEA6Ad/SafFoBHD9RM97SS79vZLrZ0p9DftTk4DbmlESze4XlTjx5Emi+mUbCVr 6 | kE/JkGUFX8CnPZ6r4toAAAAAD2Nhcmxvc0BkYXJrc3RhcgECAwQFBg== 7 | -----END OPENSSH PRIVATE KEY----- 8 | -------------------------------------------------------------------------------- /testdata/foo.pub: -------------------------------------------------------------------------------- 1 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toA carlos@darkstar 2 | -------------------------------------------------------------------------------- /testdata/invalid_authorized_keys: -------------------------------------------------------------------------------- 1 | ssh-nope nopenopenope k1@test 2 | -------------------------------------------------------------------------------- /testdata/valid-cert.pub: -------------------------------------------------------------------------------- 1 | ssh-ed25519-cert-v01@openssh.com AAAAIHNzaC1lZDI1NTE5LWNlcnQtdjAxQG9wZW5zc2guY29tAAAAILHcoWg/uAZrEroYaEwxBkUCNH95Iz+ycr41K6KdnL0ZAAAAIGlESze4XlTjx5Emi+mUbCVrkE/JkGUFX8CnPZ6r4toAAAAAAAAAAAAAAAABAAAABGlkLTEAAAAHAAAAA2ZvbwAAAAAAAAAA//////////8AAAAAAAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAAAMwAAAAtzc2gtZWQyNTUxOQAAACBm8cbaX81GkAOSHTriPkFOlFDE8Y9nrOCCYip1xVYVcwAAAFMAAAALc3NoLWVkMjU1MTkAAABAUlsnmwWKhk2S0tAT/woHwnT2H0V3sn/PEtruHEzbU+uriVHffG046MHbvudLCZ2H86HoeZsu1N7GJuqWE+P6Bg== carlos@darkstar 2 | -------------------------------------------------------------------------------- /testsession/testsession.go: -------------------------------------------------------------------------------- 1 | // Package testsession provides utilities to test SSH sessions. 2 | // 3 | // more or less copied from charmbracelet/ssh tests 4 | package testsession 5 | 6 | import ( 7 | "net" 8 | "testing" 9 | 10 | "github.com/charmbracelet/ssh" 11 | gossh "golang.org/x/crypto/ssh" 12 | ) 13 | 14 | // New starts a local SSH server with the given config and returns a client session. 15 | // It automatically closes everything afterwards. 16 | func New(tb testing.TB, srv *ssh.Server, cfg *gossh.ClientConfig) *gossh.Session { 17 | tb.Helper() 18 | sess, err := NewClientSession(tb, Listen(tb, srv), cfg) 19 | if err != nil { 20 | tb.Fatal(err) 21 | } 22 | return sess 23 | } 24 | 25 | // Listen starts a test server. 26 | func Listen(tb testing.TB, srv *ssh.Server) string { 27 | tb.Helper() 28 | l := newLocalListener(tb) 29 | go func() { 30 | if err := srv.Serve(l); err != nil && err != ssh.ErrServerClosed { 31 | tb.Fatalf("failed to serve: %v", err) 32 | } 33 | }() 34 | tb.Cleanup(func() { 35 | _ = srv.Close() 36 | }) 37 | return l.Addr().String() 38 | } 39 | 40 | func newLocalListener(tb testing.TB) net.Listener { 41 | tb.Helper() 42 | l, err := net.Listen("tcp", "127.0.0.1:0") 43 | if err != nil { 44 | if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { 45 | tb.Fatalf("failed to listen on a port: %v", err) 46 | } 47 | } 48 | 49 | tb.Cleanup(func() { _ = l.Close() }) 50 | return l 51 | } 52 | 53 | // NewClientSession creates a new client session to the given address. 54 | func NewClientSession(tb testing.TB, addr string, config *gossh.ClientConfig) (*gossh.Session, error) { 55 | tb.Helper() 56 | if config == nil { 57 | config = &gossh.ClientConfig{ 58 | User: "testuser", 59 | Auth: []gossh.AuthMethod{ 60 | gossh.Password("testpass"), 61 | }, 62 | } 63 | } 64 | if config.HostKeyCallback == nil { 65 | config.HostKeyCallback = gossh.InsecureIgnoreHostKey() // nolint: gosec 66 | } 67 | client, err := gossh.Dial("tcp", addr, config) 68 | if err != nil { 69 | return nil, err 70 | } 71 | session, err := client.NewSession() 72 | if err != nil { 73 | return nil, err 74 | } 75 | tb.Cleanup(func() { 76 | _ = session.Close() 77 | _ = client.Close() 78 | }) 79 | return session, nil 80 | } 81 | -------------------------------------------------------------------------------- /testsession/testsession_test.go: -------------------------------------------------------------------------------- 1 | package testsession 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/charmbracelet/ssh" 8 | ) 9 | 10 | func TestSession(t *testing.T) { 11 | const out = "hello world" 12 | session := New(t, &ssh.Server{ 13 | Handler: func(s ssh.Session) { 14 | _, _ = fmt.Fprint(s, out) 15 | }, 16 | }, nil) 17 | result, err := session.Output("") 18 | if err != nil { 19 | t.Errorf("expected no error, got %v", err) 20 | } 21 | if string(result) != out { 22 | t.Errorf("expected %q, got %q", out, string(result)) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /wish.go: -------------------------------------------------------------------------------- 1 | package wish 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | "github.com/charmbracelet/keygen" 8 | "github.com/charmbracelet/ssh" 9 | ) 10 | 11 | // Middleware is a function that takes an ssh.Handler and returns an 12 | // ssh.Handler. Implementations should call the provided handler argument. 13 | type Middleware func(next ssh.Handler) ssh.Handler 14 | 15 | // NewServer is returns a default SSH server with the provided Middleware. A 16 | // new SSH key pair of type ed25519 will be created if one does not exist. By 17 | // default this server will accept all incoming connections, password and 18 | // public key. 19 | func NewServer(ops ...ssh.Option) (*ssh.Server, error) { 20 | s := &ssh.Server{} 21 | for _, op := range ops { 22 | if err := s.SetOption(op); err != nil { 23 | return nil, err 24 | } 25 | } 26 | if len(s.HostSigners) == 0 { 27 | k, err := keygen.New("id_ed25519", keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite()) 28 | if err != nil { 29 | return nil, err 30 | } 31 | err = s.SetOption(WithHostKeyPEM(k.RawPrivateKey())) 32 | if err != nil { 33 | return nil, err 34 | } 35 | } 36 | return s, nil 37 | } 38 | 39 | // Fatal prints to the given session's STDERR and exits 1. 40 | func Fatal(s ssh.Session, v ...interface{}) { 41 | Error(s, v...) 42 | _ = s.Exit(1) 43 | _ = s.Close() 44 | } 45 | 46 | // Fatalf formats according to the given format, prints to the session's STDERR 47 | // followed by an exit 1. 48 | // 49 | // Notice that this might cause formatting issues if you don't add a \r\n in the end of your string. 50 | func Fatalf(s ssh.Session, f string, v ...interface{}) { 51 | Errorf(s, f, v...) 52 | _ = s.Exit(1) 53 | _ = s.Close() 54 | } 55 | 56 | // Fatalln formats according to the default format, prints to the session's 57 | // STDERR, followed by a new line and an exit 1. 58 | func Fatalln(s ssh.Session, v ...interface{}) { 59 | Errorln(s, v...) 60 | Errorf(s, "\r") 61 | _ = s.Exit(1) 62 | _ = s.Close() 63 | } 64 | 65 | // Error prints the given error the the session's STDERR. 66 | func Error(s ssh.Session, v ...interface{}) { 67 | _, _ = fmt.Fprint(s.Stderr(), v...) 68 | } 69 | 70 | // Errorf formats according to the given format and prints to the session's STDERR. 71 | func Errorf(s ssh.Session, f string, v ...interface{}) { 72 | _, _ = fmt.Fprintf(s.Stderr(), f, v...) 73 | } 74 | 75 | // Errorf formats according to the default format and prints to the session's STDERR. 76 | func Errorln(s ssh.Session, v ...interface{}) { 77 | _, _ = fmt.Fprintln(s.Stderr(), v...) 78 | } 79 | 80 | // Print writes to the session's STDOUT followed. 81 | func Print(s ssh.Session, v ...interface{}) { 82 | _, _ = fmt.Fprint(s, v...) 83 | } 84 | 85 | // Printf formats according to the given format and writes to the session's STDOUT. 86 | func Printf(s ssh.Session, f string, v ...interface{}) { 87 | _, _ = fmt.Fprintf(s, f, v...) 88 | } 89 | 90 | // Println formats according to the default format and writes to the session's STDOUT. 91 | func Println(s ssh.Session, v ...interface{}) { 92 | _, _ = fmt.Fprintln(s, v...) 93 | } 94 | 95 | // WriteString writes the given string to the session's STDOUT. 96 | func WriteString(s ssh.Session, v string) (int, error) { 97 | return io.WriteString(s, v) 98 | } 99 | -------------------------------------------------------------------------------- /wish_test.go: -------------------------------------------------------------------------------- 1 | // go:generate mockgen -package mocks -destination mocks/session.go github.com/charmbracelet/ssh Session 2 | package wish 3 | 4 | import ( 5 | "bytes" 6 | "errors" 7 | "path/filepath" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/charmbracelet/ssh" 13 | "github.com/charmbracelet/wish/testsession" 14 | ) 15 | 16 | func TestNewServer(t *testing.T) { 17 | fp := filepath.Join(t.TempDir(), "id_ed25519") 18 | _, err := NewServer(WithHostKeyPath(fp)) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | } 23 | 24 | func TestNewServerWithOptions(t *testing.T) { 25 | fp := filepath.Join(t.TempDir(), "id_ed25519") 26 | if _, err := NewServer( 27 | WithHostKeyPath(fp), 28 | WithMaxTimeout(time.Second), 29 | WithBanner("welcome"), 30 | WithAddress(":2222"), 31 | ); err != nil { 32 | t.Fatal(err) 33 | } 34 | } 35 | 36 | func TestError(t *testing.T) { 37 | eerr := errors.New("foo err") 38 | sess := testsession.New(t, &ssh.Server{ 39 | Handler: func(s ssh.Session) { 40 | Error(s, eerr) 41 | }, 42 | }, nil) 43 | var out bytes.Buffer 44 | sess.Stderr = &out 45 | if err := sess.Run(""); err != nil { 46 | t.Errorf("expected no error, got %s", err) 47 | } 48 | if s := strings.TrimSpace(out.String()); s != eerr.Error() { 49 | t.Errorf("expected %s, got %s", s, eerr) 50 | } 51 | } 52 | 53 | func TestFatal(t *testing.T) { 54 | err := errors.New("foo err") 55 | sess := testsession.New(t, &ssh.Server{ 56 | Handler: func(s ssh.Session) { 57 | Fatal(s, err) 58 | }, 59 | }, nil) 60 | var out bytes.Buffer 61 | sess.Stderr = &out 62 | if err := sess.Run(""); err == nil { 63 | t.Error("expected an error, got nil") 64 | } 65 | if s := strings.TrimSpace(out.String()); s != err.Error() { 66 | t.Errorf("expected %s, got %s", s, err) 67 | } 68 | } 69 | --------------------------------------------------------------------------------