├── .dockerignore ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .stignore ├── Dockerfile ├── Makefile ├── README.md ├── cmd └── main.go ├── go.mod ├── go.sum ├── okteto.yml └── pkg ├── os └── os.go └── ssh ├── key.go ├── ssh.go └── ssh_test.go /.dockerignore: -------------------------------------------------------------------------------- 1 | remote -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | pull_request: 4 | branches: main 5 | push: 6 | branches: 7 | - main 8 | release: 9 | types: 10 | - published 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out code 17 | uses: actions/checkout@v2 18 | - name: Install Go 19 | uses: actions/setup-go@v2 20 | with: 21 | go-version: 1.22.x 22 | - name: Test 23 | run: make test 24 | - uses: azure/docker-login@v1 25 | with: 26 | username: '${{ secrets.DOCKER_USER }}' 27 | password: '${{ secrets.DOCKER_PASS }}' 28 | 29 | - name: Set up Docker Buildx 30 | uses: crazy-max/ghaction-docker-buildx@v3 31 | with: 32 | qemu-version: latest 33 | buildx-version: latest 34 | 35 | - name: "Build" 36 | run: | 37 | docker buildx build \ 38 | --platform linux/amd64,linux/arm64,linux/arm/v7 \ 39 | --output "type=image,push=false" \ 40 | --build-arg COMMIT_SHA=${{ github.sha }} -t remote:${{ github.sha }} . 41 | 42 | - id: tag 43 | name: Get the tag 44 | run: echo ::set-output name=TAG::${GITHUB_REF#refs/tags/} 45 | if: ${{ github.event_name == 'release' }} 46 | 47 | - name: "Push" 48 | run: | 49 | docker buildx build \ 50 | --platform linux/amd64,linux/arm64,linux/arm/v7 \ 51 | --output "type=image,push=true" \ 52 | --build-arg COMMIT_SHA=${{ github.sha }} -t okteto/remote:${{ steps.tag.outputs.TAG }} . 53 | if: ${{ github.event_name == 'release' }} 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | remote 2 | coverage.txt 3 | -------------------------------------------------------------------------------- /.stignore: -------------------------------------------------------------------------------- 1 | /.git 2 | /*.exe 3 | /*.exe~ 4 | /*.dll 5 | /*.so 6 | /*.dylib 7 | 8 | # Test binary, built with go test -c 9 | /*.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | /*.out 13 | remote -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.23.8-bookworm AS builder 2 | 3 | WORKDIR /app 4 | 5 | COPY go.mod . 6 | COPY go.sum . 7 | RUN go mod download 8 | 9 | COPY Makefile /app 10 | COPY pkg /app/pkg 11 | COPY cmd /app/cmd 12 | ARG COMMIT_SHA 13 | RUN make 14 | 15 | FROM busybox:1.37.0 16 | 17 | COPY --from=builder /app/remote /usr/local/bin/remote 18 | RUN chmod +x /usr/local/bin/remote 19 | 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | COMMIT_SHA ?= $(shell git rev-parse --short HEAD) 2 | .DEFAULT_GOAL := build 3 | 4 | .PHONY: build test 5 | 6 | build: 7 | CGO=0 go build -o remote -ldflags "-X main.CommitString=${COMMIT_SHA}" -tags "osusergo netgo static_build" cmd/main.go 8 | 9 | test: 10 | go test -p 4 -coverprofile=coverage.txt -covermode=atomic ./... -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Remote 2 | 3 | Minimalistic SSH server compatible with the VS Code Remote-SSH extension -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strconv" 7 | 8 | log "github.com/sirupsen/logrus" 9 | 10 | remoteOS "github.com/okteto/remote/pkg/os" 11 | "github.com/okteto/remote/pkg/ssh" 12 | ) 13 | 14 | // CommitString is the commit used to build the server 15 | var CommitString string 16 | 17 | const ( 18 | authorizedKeysPath = "/var/okteto/remote/authorized_keys" 19 | ) 20 | 21 | func main() { 22 | log.SetOutput(os.Stdout) 23 | shell, err := remoteOS.GetShell() 24 | if err != nil { 25 | log.Fatal(err.Error()) 26 | } 27 | 28 | port := 2222 29 | if p, ok := os.LookupEnv("OKTETO_REMOTE_PORT"); ok { 30 | var err error 31 | port, err = strconv.Atoi(p) 32 | if err != nil { 33 | panic(fmt.Sprintf("%s is not a valid port number", p)) 34 | } 35 | 36 | if port <= 1024 { 37 | panic(fmt.Sprintf("%d is a reserved port", port)) 38 | } 39 | } 40 | 41 | keys, err := ssh.LoadAuthorizedKeys(authorizedKeysPath) 42 | if err != nil { 43 | log.Fatalf("Failed to load authorized_keys: %s", err) 44 | } 45 | 46 | if keys == nil { 47 | log.Warningf("remote server is running without authentication enabled") 48 | } 49 | 50 | srv := ssh.Server{ 51 | Port: port, 52 | Shell: shell, 53 | AuthorizedKeys: keys, 54 | } 55 | 56 | log.Infof("ssh server %s started in 0.0.0.0:%d", CommitString, srv.Port) 57 | log.Fatal(srv.ListenAndServe()) 58 | } 59 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/okteto/remote 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/creack/pty v1.1.11 9 | github.com/gliderlabs/ssh v0.3.1 10 | github.com/google/uuid v1.1.2 11 | github.com/pkg/sftp v1.12.0 12 | github.com/sirupsen/logrus v1.7.0 13 | golang.org/x/crypto v0.35.0 14 | ) 15 | 16 | require ( 17 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect 18 | github.com/kr/fs v0.1.0 // indirect 19 | github.com/pkg/errors v0.9.1 // indirect 20 | golang.org/x/sys v0.30.0 // indirect 21 | ) 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= 2 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= 3 | github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= 4 | github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/gliderlabs/ssh v0.3.1 h1:L6VrMUGZaMlNIMN8Hj+CHh4U9yodJE3FAt/rgvfaKvE= 9 | github.com/gliderlabs/ssh v0.3.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= 10 | github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= 11 | github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 12 | github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= 13 | github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= 14 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 15 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 16 | github.com/pkg/sftp v1.12.0 h1:/f3b24xrDhkhddlaobPe2JgBqfdt+gC/NYl0QY9IOuI= 17 | github.com/pkg/sftp v1.12.0/go.mod h1:fUqqXB5vEgVCZ131L+9say31RAri6aF6KDViawhxKK8= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 20 | github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= 21 | github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= 22 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 23 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 24 | github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= 25 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 26 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 27 | golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 28 | golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= 29 | golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= 30 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 31 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 32 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 33 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 34 | golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= 35 | golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 36 | golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= 37 | golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= 38 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 39 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 40 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 41 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 42 | -------------------------------------------------------------------------------- /okteto.yml: -------------------------------------------------------------------------------- 1 | name: ssh 2 | image: okteto/golang:1 3 | command: 4 | - bash 5 | workdir: /app 6 | forward: 7 | - 2345:2345 8 | - 22000:2222 9 | volumes: 10 | - /go 11 | persistentVolume: 12 | enabled: true 13 | securityContext: 14 | capabilities: 15 | add: ["SYS_PTRACE"] -------------------------------------------------------------------------------- /pkg/os/os.go: -------------------------------------------------------------------------------- 1 | package os 2 | 3 | import ( 4 | "fmt" 5 | "os/exec" 6 | 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | var ( 11 | // ErrNoShell is used when there is no shell available in the $PATH 12 | ErrNoShell = fmt.Errorf("bash or sh needs to be available in the $PATH of your development container") 13 | ) 14 | 15 | // GetShell returns the available shell 16 | func GetShell() (string, error) { 17 | if p, err := exec.LookPath("bash"); err == nil { 18 | log.Printf("bash exists at %s", p) 19 | return "bash", nil 20 | } 21 | 22 | if p, err := exec.LookPath("sh"); err == nil { 23 | log.Printf("sh exists at %s", p) 24 | return "sh", nil 25 | } 26 | 27 | return "", ErrNoShell 28 | } 29 | -------------------------------------------------------------------------------- /pkg/ssh/key.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | const hostKeyBytes = `-----BEGIN EC PRIVATE KEY----- 4 | MHcCAQEEIKaaS8eETpK6OV6HDXmQ1hwpUNSLtDd2gAwafY+8khpUoAoGCCqGSM49 5 | AwEHoUQDQgAEdqpICiM7YTvLv6sO3VA/MrnmIuCeZ4aPbPh8/os1vx/PfD+DaCht 6 | fnfzZ17fCxLPRmkDqWKEGXZ+Tv5qqnD72g== 7 | -----END EC PRIVATE KEY-----` 8 | -------------------------------------------------------------------------------- /pkg/ssh/ssh.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "os" 9 | "os/exec" 10 | "strings" 11 | "sync" 12 | "syscall" 13 | "time" 14 | "unsafe" 15 | 16 | "github.com/creack/pty" 17 | "github.com/gliderlabs/ssh" 18 | "github.com/google/uuid" 19 | "github.com/pkg/sftp" 20 | log "github.com/sirupsen/logrus" 21 | ) 22 | 23 | var ( 24 | // ErrEOF is the error when the terminal exits 25 | ErrEOF = errors.New("EOF") 26 | ) 27 | 28 | // Server holds the ssh server configuration 29 | type Server struct { 30 | Port int 31 | Shell string 32 | AuthorizedKeys []ssh.PublicKey 33 | } 34 | 35 | func getExitStatusFromError(err error) int { 36 | if err == nil { 37 | return 0 38 | } 39 | 40 | exitErr, ok := err.(*exec.ExitError) 41 | if !ok { 42 | return 1 43 | } 44 | 45 | waitStatus, ok := exitErr.Sys().(syscall.WaitStatus) 46 | if !ok { 47 | if exitErr.Success() { 48 | return 0 49 | } 50 | 51 | return 1 52 | } 53 | 54 | return waitStatus.ExitStatus() 55 | } 56 | 57 | func setWinsize(f *os.File, w, h int) { 58 | syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), 59 | uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) 60 | } 61 | 62 | func handlePTY(logger *log.Entry, cmd *exec.Cmd, s ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) error { 63 | if len(ptyReq.Term) > 0 { 64 | cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) 65 | } 66 | 67 | f, err := pty.Start(cmd) 68 | if err != nil { 69 | logger.WithError(err).Error("failed to start pty session") 70 | return err 71 | } 72 | 73 | go func() { 74 | for win := range winCh { 75 | setWinsize(f, win.Width, win.Height) 76 | } 77 | }() 78 | 79 | go func() { 80 | io.Copy(f, s) // stdin 81 | }() 82 | 83 | waitCh := make(chan struct{}) 84 | go func() { 85 | defer close(waitCh) 86 | io.Copy(s, f) // stdout 87 | }() 88 | 89 | if err := cmd.Wait(); err != nil { 90 | logger.WithError(err).Errorf("pty command failed while waiting") 91 | return err 92 | } 93 | 94 | select { 95 | case <-waitCh: 96 | logger.Info("stdout finished") 97 | case <-time.NewTicker(1 * time.Second).C: 98 | logger.Info("stdout didn't finish after 1s") 99 | } 100 | 101 | return nil 102 | } 103 | 104 | func sendErrAndExit(logger *log.Entry, s ssh.Session, err error) { 105 | msg := strings.TrimPrefix(err.Error(), "exec: ") 106 | if _, err := s.Stderr().Write([]byte(msg)); err != nil { 107 | logger.WithError(err).Errorf("failed to write error back to session") 108 | } 109 | 110 | if err := s.Exit(getExitStatusFromError(err)); err != nil { 111 | logger.WithError(err).Errorf("pty session failed to exit") 112 | } 113 | } 114 | 115 | func handleNoTTY(logger *log.Entry, cmd *exec.Cmd, s ssh.Session) error { 116 | stdout, err := cmd.StdoutPipe() 117 | if err != nil { 118 | logger.WithError(err).Errorf("couldn't get StdoutPipe") 119 | return err 120 | } 121 | 122 | stderr, err := cmd.StderrPipe() 123 | if err != nil { 124 | logger.WithError(err).Errorf("couldn't get StderrPipe") 125 | return err 126 | } 127 | 128 | stdin, err := cmd.StdinPipe() 129 | if err != nil { 130 | logger.WithError(err).Errorf("couldn't get StdinPipe") 131 | return err 132 | } 133 | 134 | if err = cmd.Start(); err != nil { 135 | logger.WithError(err).Errorf("couldn't start command '%s'", cmd.String()) 136 | return err 137 | } 138 | 139 | go func() { 140 | defer stdin.Close() 141 | if _, err := io.Copy(stdin, s); err != nil { 142 | logger.WithError(err).Errorf("failed to write session to stdin.") 143 | } 144 | }() 145 | 146 | wg := &sync.WaitGroup{} 147 | wg.Add(1) 148 | go func() { 149 | defer wg.Done() 150 | if _, err := io.Copy(s, stdout); err != nil { 151 | logger.WithError(err).Errorf("failed to write stdout to session.") 152 | } 153 | }() 154 | 155 | wg.Add(1) 156 | go func() { 157 | defer wg.Done() 158 | if _, err := io.Copy(s.Stderr(), stderr); err != nil { 159 | logger.WithError(err).Errorf("failed to write stderr to session.") 160 | } 161 | }() 162 | 163 | wg.Wait() 164 | 165 | if err := cmd.Wait(); err != nil { 166 | logger.WithError(err).Errorf("command failed while waiting") 167 | return err 168 | } 169 | 170 | return nil 171 | } 172 | 173 | func (srv *Server) connectionHandler(s ssh.Session) { 174 | sessionID := uuid.New().String() 175 | logger := log.WithFields(log.Fields{"session.id": sessionID}) 176 | defer func() { 177 | s.Close() 178 | logger.Info("session closed") 179 | }() 180 | 181 | logger.Infof("starting ssh session with command '%+v'", s.RawCommand()) 182 | 183 | cmd := srv.buildCmd(s) 184 | 185 | if ssh.AgentRequested(s) { 186 | logger.Info("agent requested") 187 | l, err := ssh.NewAgentListener() 188 | if err != nil { 189 | logger.WithError(err).Error("failed to start agent") 190 | return 191 | } 192 | 193 | defer l.Close() 194 | go ssh.ForwardAgentConnections(l, s) 195 | cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) 196 | } 197 | 198 | ptyReq, winCh, isPty := s.Pty() 199 | if isPty { 200 | logger.Println("handling PTY session") 201 | if err := handlePTY(logger, cmd, s, ptyReq, winCh); err != nil { 202 | sendErrAndExit(logger, s, err) 203 | return 204 | } 205 | 206 | s.Exit(0) 207 | return 208 | } 209 | 210 | logger.Println("handling non PTY session") 211 | if err := handleNoTTY(logger, cmd, s); err != nil { 212 | sendErrAndExit(logger, s, err) 213 | return 214 | } 215 | 216 | s.Exit(0) 217 | } 218 | 219 | // LoadAuthorizedKeys loads path as an array. 220 | // It will return nil if path doesn't exist. 221 | func LoadAuthorizedKeys(path string) ([]ssh.PublicKey, error) { 222 | authorizedKeysBytes, err := ioutil.ReadFile(path) 223 | if err != nil { 224 | if os.IsNotExist(err) { 225 | return nil, nil 226 | } 227 | 228 | return nil, err 229 | } 230 | 231 | authorizedKeys := []ssh.PublicKey{} 232 | for len(authorizedKeysBytes) > 0 { 233 | pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) 234 | if err != nil { 235 | return nil, err 236 | } 237 | 238 | authorizedKeys = append(authorizedKeys, pubKey) 239 | authorizedKeysBytes = rest 240 | } 241 | 242 | if len(authorizedKeys) == 0 { 243 | return nil, fmt.Errorf("%s was empty", path) 244 | } 245 | 246 | return authorizedKeys, nil 247 | } 248 | 249 | func (srv *Server) authorize(ctx ssh.Context, key ssh.PublicKey) bool { 250 | for _, k := range srv.AuthorizedKeys { 251 | if ssh.KeysEqual(key, k) { 252 | return true 253 | } 254 | } 255 | 256 | log.Println("access denied") 257 | return false 258 | } 259 | 260 | // ListenAndServe starts the SSH server using port 261 | func (srv *Server) ListenAndServe() error { 262 | server := srv.getServer() 263 | return server.ListenAndServe() 264 | } 265 | 266 | func (srv *Server) getServer() *ssh.Server { 267 | forwardHandler := &ssh.ForwardedTCPHandler{} 268 | 269 | server := &ssh.Server{ 270 | Addr: fmt.Sprintf(":%d", srv.Port), 271 | Handler: srv.connectionHandler, 272 | ChannelHandlers: map[string]ssh.ChannelHandler{ 273 | "direct-tcpip": ssh.DirectTCPIPHandler, 274 | "session": ssh.DefaultSessionHandler, 275 | }, 276 | LocalPortForwardingCallback: ssh.LocalPortForwardingCallback(func(ctx ssh.Context, dhost string, dport uint32) bool { 277 | log.Println("Accepted forward", dhost, dport) 278 | return true 279 | }), 280 | ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, host string, port uint32) bool { 281 | log.Println("attempt to bind", host, port, "granted") 282 | return true 283 | }), 284 | RequestHandlers: map[string]ssh.RequestHandler{ 285 | "tcpip-forward": forwardHandler.HandleSSHRequest, 286 | "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, 287 | }, 288 | SubsystemHandlers: map[string]ssh.SubsystemHandler{ 289 | "sftp": sftpHandler, 290 | }, 291 | } 292 | 293 | server.SetOption(ssh.HostKeyPEM([]byte(hostKeyBytes))) 294 | 295 | if srv.AuthorizedKeys != nil { 296 | server.PublicKeyHandler = srv.authorize 297 | } 298 | 299 | return server 300 | } 301 | 302 | func sftpHandler(sess ssh.Session) { 303 | debugStream := ioutil.Discard 304 | serverOptions := []sftp.ServerOption{ 305 | sftp.WithDebug(debugStream), 306 | } 307 | server, err := sftp.NewServer( 308 | sess, 309 | serverOptions..., 310 | ) 311 | if err != nil { 312 | log.Printf("sftp server init error: %s\n", err) 313 | return 314 | } 315 | if err := server.Serve(); err == io.EOF { 316 | server.Close() 317 | log.Println("sftp client exited session.") 318 | } else if err != nil { 319 | log.Println("sftp server completed with error:", err) 320 | } 321 | } 322 | 323 | func (srv Server) buildCmd(s ssh.Session) *exec.Cmd { 324 | var cmd *exec.Cmd 325 | 326 | if len(s.RawCommand()) == 0 { 327 | cmd = exec.Command(srv.Shell) 328 | } else { 329 | args := []string{"-c", s.RawCommand()} 330 | cmd = exec.Command(srv.Shell, args...) 331 | } 332 | 333 | cmd.Env = append(cmd.Env, os.Environ()...) 334 | cmd.Env = append(cmd.Env, s.Environ()...) 335 | 336 | fmt.Println(cmd.String()) 337 | return cmd 338 | } 339 | -------------------------------------------------------------------------------- /pkg/ssh/ssh_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "net" 8 | "os" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/gliderlabs/ssh" 13 | gossh "golang.org/x/crypto/ssh" 14 | ) 15 | 16 | const ( 17 | goodKey = `ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBAbAgOR8lXercwCWLNSjxHe4YUYUGxXSQU9gTb4MCPTJ5cXXhiFMcz84nTM5X5Dx5GshdAGeoXPl8dO/FgO+iFI= test@example.com` 18 | badKey = `ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBGfBJMSXwNdBC5EM2fPThe5BcSMxbzXbaweK3ynOL2aNxUXk+Xe7BhD4F/L7stMpHkriV8hWKhhsb8a9gPfV5UI= test@example.com` 19 | ) 20 | 21 | func Test_loadPrivateKey(t *testing.T) { 22 | _, err := gossh.ParsePrivateKey([]byte(hostKeyBytes)) 23 | if err != nil { 24 | t.Error(err) 25 | } 26 | } 27 | 28 | func TestLoadAuthorizedKeys(t *testing.T) { 29 | // missing file 30 | k, err := LoadAuthorizedKeys("missing") 31 | if err != nil { 32 | t.Error(err) 33 | } 34 | 35 | if k != nil { 36 | t.Errorf("didn't return nil array") 37 | } 38 | 39 | // empty file 40 | path, err := ioutil.TempFile("", "") 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | 45 | defer os.Remove(path.Name()) 46 | 47 | if _, err := LoadAuthorizedKeys(path.Name()); err == nil { 48 | t.Error("empty file didn't fail") 49 | } 50 | 51 | parsed, _, _, _, err := gossh.ParseAuthorizedKey([]byte(goodKey)) 52 | if err != nil { 53 | t.Fatalf("failed to parse key: %s", err) 54 | } 55 | 56 | if _, err := path.WriteString(goodKey); err != nil { 57 | t.Fatal(err) 58 | } 59 | 60 | k, err = LoadAuthorizedKeys(path.Name()) 61 | if err != nil { 62 | t.Error(err) 63 | } 64 | 65 | if len(k) != 1 { 66 | t.Error("loaded more than 1 key") 67 | } 68 | 69 | if !ssh.KeysEqual(k[0], parsed) { 70 | t.Error("loaded key is not the same") 71 | } 72 | 73 | srv := Server{AuthorizedKeys: k} 74 | if !srv.authorize(nil, parsed) { 75 | t.Error("failed to authorize loaded key") 76 | } 77 | 78 | bad, _, _, _, err := gossh.ParseAuthorizedKey([]byte(badKey)) 79 | if err != nil { 80 | t.Fatalf("failed to parse key: %s", err) 81 | } 82 | 83 | if srv.authorize(nil, bad) { 84 | t.Error("authorized bad key") 85 | } 86 | } 87 | 88 | func TestLoadAuthorizedKeys_multiple(t *testing.T) { 89 | // empty file 90 | path, err := ioutil.TempFile("", "") 91 | if err != nil { 92 | t.Fatal(err) 93 | } 94 | 95 | defer os.Remove(path.Name()) 96 | 97 | for i := 0; i < 3; i++ { 98 | if _, err := path.WriteString(goodKey + "\n"); err != nil { 99 | t.Fatal(err) 100 | } 101 | } 102 | 103 | parsed, _, _, _, err := gossh.ParseAuthorizedKey([]byte(goodKey)) 104 | if err != nil { 105 | t.Fatalf("failed to parse key: %s", err) 106 | } 107 | 108 | k, err := LoadAuthorizedKeys(path.Name()) 109 | if err != nil { 110 | t.Error(err) 111 | } 112 | 113 | if len(k) != 3 { 114 | t.Error("didn't load 3 authorized keys") 115 | } 116 | 117 | if !ssh.KeysEqual(k[0], parsed) { 118 | t.Error("loaded key is not the same") 119 | } 120 | 121 | srv := Server{AuthorizedKeys: k} 122 | if !srv.authorize(nil, k[1]) { 123 | t.Error("failed to authorize loaded key") 124 | } 125 | } 126 | 127 | func Test_connectionHandler(t *testing.T) { 128 | 129 | var tests = []struct { 130 | name string 131 | command string 132 | stdout string 133 | stderr string 134 | expectErr bool 135 | }{ 136 | { 137 | name: "basic", 138 | command: "echo hi", 139 | stdout: "hi", 140 | stderr: "", 141 | }, 142 | { 143 | name: "with-shell", 144 | command: `sh -c "echo hi"`, 145 | stdout: "hi", 146 | stderr: "", 147 | }, 148 | { 149 | name: "several-commands", 150 | command: `m=hello; echo $m`, 151 | stdout: "hello", 152 | stderr: "", 153 | }, 154 | { 155 | name: "bad-command", 156 | command: "badcommand", 157 | stdout: "", 158 | //stderr: `"badcommand": executable file not found in $PATH`, 159 | expectErr: true, 160 | }, 161 | { 162 | name: "bad-command-with-shell", 163 | command: `sh -c "badcommand"`, 164 | stdout: "", 165 | // we don't check if it because the output is different between OSes 166 | //stderr: `sh: badcommand: command not found` 167 | expectErr: true, 168 | }, 169 | } 170 | for _, tt := range tests { 171 | t.Run(tt.name, func(t *testing.T) { 172 | s := &Server{} 173 | s.Shell = "sh" 174 | srv := s.getServer() 175 | 176 | session, _, cleanup := newTestSession(t, srv, nil) 177 | defer cleanup() 178 | 179 | var stdout bytes.Buffer 180 | var stderr bytes.Buffer 181 | session.Stderr = &stderr 182 | session.Stdout = &stdout 183 | 184 | if err := session.Run(tt.command); err != nil { 185 | if !tt.expectErr { 186 | t.Fatal(err) 187 | } 188 | } 189 | 190 | out := strings.TrimSuffix(stdout.String(), "\n") 191 | if out != tt.stdout { 192 | t.Errorf("bad stdout. got:\n%s\nexpected:\n%s", out, tt.stdout) 193 | } 194 | 195 | if tt.stderr != "" { 196 | err := strings.TrimSuffix(stderr.String(), "\n") 197 | if err != tt.stderr { 198 | t.Errorf("bad stderr. got:\n'%s'\nexpected\n'%s'", err, tt.stderr) 199 | } 200 | } 201 | }) 202 | } 203 | } 204 | 205 | func serveOnce(srv *ssh.Server, l net.Listener) error { 206 | conn, e := l.Accept() 207 | if e != nil { 208 | return e 209 | } 210 | srv.ChannelHandlers = map[string]ssh.ChannelHandler{ 211 | "session": ssh.DefaultSessionHandler, 212 | "direct-tcpip": ssh.DirectTCPIPHandler, 213 | } 214 | srv.HandleConn(conn) 215 | return nil 216 | } 217 | 218 | func newLocalListener() net.Listener { 219 | l, err := net.Listen("tcp", "127.0.0.1:0") 220 | if err != nil { 221 | if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { 222 | panic(fmt.Sprintf("failed to listen on a port: %v", err)) 223 | } 224 | } 225 | return l 226 | } 227 | 228 | func newTestSession(t *testing.T, srv *ssh.Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { 229 | l := newLocalListener() 230 | go serveOnce(srv, l) 231 | return newClientSession(t, l.Addr().String(), cfg) 232 | } 233 | 234 | func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { 235 | if config == nil { 236 | config = &gossh.ClientConfig{} 237 | } 238 | 239 | if config.HostKeyCallback == nil { 240 | config.HostKeyCallback = gossh.InsecureIgnoreHostKey() 241 | } 242 | 243 | client, err := gossh.Dial("tcp", addr, config) 244 | if err != nil { 245 | t.Fatal(err) 246 | } 247 | 248 | session, err := client.NewSession() 249 | if err != nil { 250 | t.Fatal(err) 251 | } 252 | 253 | return session, client, func() { 254 | session.Close() 255 | client.Close() 256 | } 257 | } 258 | --------------------------------------------------------------------------------