├── .gitignore ├── .golangci.yml ├── LICENSE ├── README.md ├── cmd.go ├── communicator.go ├── communicator_acc_test.go ├── config.go ├── dial.go ├── doc.go ├── go.mod ├── go.sum ├── httpshell ├── conn.go ├── dialer.go ├── doc.go ├── listener.go └── listener_test.go └── keepalive.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.dll 4 | *.so 5 | *.dylib 6 | 7 | # Development binary, built with makefile 8 | *.dev 9 | 10 | # Test binary, built with `go test -c` 11 | *.test 12 | 13 | # Output of the go coverage tool, specifically when used with LiteIDE 14 | *.out 15 | 16 | # IDE 17 | .idea/* 18 | .vscode/* 19 | .history 20 | 21 | # Go mod 22 | vendor 23 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | deadline: 5m 3 | tests: false 4 | 5 | linters-settings: 6 | errcheck: 7 | check-blank: true 8 | lll: 9 | line-length: 180 10 | 11 | linters: 12 | enable-all: true 13 | disable: 14 | - gas 15 | - gochecknoglobals 16 | - gochecknoinits 17 | - interfacer 18 | - maligned 19 | - prealloc 20 | 21 | issues: 22 | exclude-use-default: false 23 | exclude: 24 | - composite literal uses unkeyed fields 25 | - exported function `New.+` should have comment or be unexported 26 | - exported method `Listener\..+` should have comment or be unexported 27 | - Error return value of `.+\.Close` is not checked 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scylla SSH tools 2 | 3 | Package sshtools provides a wrapper around SSH client with the following features: 4 | 5 | * Context aware (dial and execution), 6 | * Keepalive enabled, 7 | * Can copy files using SCP. 8 | 9 | ## License 10 | 11 | Copyright (C) 2017 ScyllaDB 12 | 13 | This project is distributed under the Apache 2.0 license. See the [LICENSE](https://github.com/scylladb/gocqlx/blob/master/LICENSE) file for details. 14 | It contains software from: 15 | 16 | * [Hashicorp Terraform](https://github.com/hashicorp/terraform), licensed under the MPL-2 license 17 | -------------------------------------------------------------------------------- /cmd.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package sshtools 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "io" 9 | ) 10 | 11 | // Cmd represents a remote command being prepared or run. 12 | type Cmd struct { 13 | // Command is the command to run remotely. This is executed as if 14 | // it were a shell command, so you are expected to do any shell escaping 15 | // necessary. 16 | Command string 17 | 18 | // Stdin specifies the process's standard input. If Stdin is nil, 19 | // the process reads from an empty bytes.Buffer. 20 | Stdin io.Reader 21 | 22 | // Stdout and Stderr represent the process's standard output and error. 23 | // 24 | // If either is nil, it will be set to ioutil.Discard. 25 | Stdout io.Writer 26 | Stderr io.Writer 27 | 28 | // Internal fields 29 | ctx context.Context 30 | closer io.Closer 31 | 32 | exitStatus int 33 | err error 34 | exitCh chan struct{} // protects exitStatus and err 35 | } 36 | 37 | // Init must be called by the Communicator before executing the command. 38 | func (c *Cmd) init(ctx context.Context, closer io.Closer) { 39 | c.ctx = ctx 40 | c.closer = closer 41 | c.exitCh = make(chan struct{}) 42 | } 43 | 44 | // setExitStatus stores the exit status of the remote command as well as any 45 | // communicator related error. SetExitStatus then unblocks any pending calls 46 | // to Wait. 47 | // This should only be called by communicators executing the remote.Cmd. 48 | func (c *Cmd) setExitStatus(status int, err error) { 49 | c.exitStatus = status 50 | c.err = err 51 | 52 | close(c.exitCh) 53 | } 54 | 55 | // Wait waits for the remote command completion or cancellation. 56 | // Wait may return an error from the communicator, or an ExitError if the 57 | // process exits with a non-zero exit status. 58 | func (c *Cmd) Wait() error { 59 | select { 60 | case <-c.ctx.Done(): 61 | c.closer.Close() 62 | return c.ctx.Err() 63 | case <-c.exitCh: 64 | // continue 65 | } 66 | 67 | if c.err != nil || c.exitStatus != 0 { 68 | return &ExitError{ 69 | Command: c.Command, 70 | ExitStatus: c.exitStatus, 71 | Err: c.err, 72 | } 73 | } 74 | 75 | return nil 76 | } 77 | 78 | // ExitError is returned by Wait to indicate an error while executing the remote 79 | // command, or a non-zero exit status. 80 | type ExitError struct { 81 | Command string 82 | ExitStatus int 83 | Err error 84 | } 85 | 86 | func (e *ExitError) Error() string { 87 | if e.Err != nil { 88 | return fmt.Sprintf("error executing %q: %v", e.Command, e.Err) 89 | } 90 | return fmt.Sprintf("%q exit status: %d", e.Command, e.ExitStatus) 91 | } 92 | -------------------------------------------------------------------------------- /communicator.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package sshtools 4 | 5 | import ( 6 | "bufio" 7 | "context" 8 | "fmt" 9 | "io" 10 | "io/ioutil" 11 | "net" 12 | "os" 13 | "path/filepath" 14 | "strings" 15 | 16 | "github.com/pkg/errors" 17 | "golang.org/x/crypto/ssh" 18 | ) 19 | 20 | // Logger is the minimal interface Communicator needs for logging. Note that 21 | // log.Logger from the standard library implements this interface, and it is 22 | // easy to implement by custom loggers, if they don't do so already anyway. 23 | type Logger interface { 24 | Println(v ...interface{}) 25 | } 26 | 27 | // Communicator allows for executing commands on a remote host over SSH, it is 28 | // not thread safe. New communicator is not connected by default, however, 29 | // calling Start or Upload on not connected communicator would try to establish 30 | // SSH connection before executing. 31 | type Communicator struct { 32 | host string 33 | config Config 34 | dial DialContextFunc 35 | logger Logger 36 | 37 | // OnDial is a listener that may be set to track openning SSH connection to 38 | // the remote host. It is called for both successful and failed trials. 39 | OnDial func(host string, err error) 40 | // OnConnClose is a listener that may be set to track closing of SSH 41 | // connection. 42 | OnConnClose func(host string) 43 | 44 | client *ssh.Client 45 | keepaliveDone chan struct{} 46 | } 47 | 48 | func NewCommunicator(host string, config Config, dial DialContextFunc, logger Logger) *Communicator { 49 | return &Communicator{ 50 | host: host, 51 | config: config, 52 | dial: dial, 53 | logger: logger, 54 | } 55 | } 56 | 57 | // Connect must be called to connect the communicator to remote host. It can 58 | // be called multiple times, in that case the current SSH connection is closed 59 | // and a new connection is established. 60 | func (c *Communicator) Connect(ctx context.Context) (err error) { 61 | c.logger.Println("Connecting to remote host", "host", c.host) 62 | 63 | defer func() { 64 | if c.OnDial != nil { 65 | c.OnDial(c.host, err) 66 | } 67 | }() 68 | 69 | c.reset() 70 | 71 | client, err := c.dial(ctx, "tcp", net.JoinHostPort(c.host, fmt.Sprint(c.config.Port)), &c.config.ClientConfig) 72 | if err != nil { 73 | return errors.Wrap(err, "ssh: dial failed") 74 | } 75 | c.client = client 76 | 77 | c.logger.Println("Connected!", "host", c.host) 78 | 79 | if c.config.KeepaliveEnabled() { 80 | c.logger.Println("Starting ssh KeepAlives", "host", c.host) 81 | c.keepaliveDone = make(chan struct{}) 82 | go StartKeepalive(client, c.config.ServerAliveInterval, c.config.ServerAliveCountMax, c.keepaliveDone) 83 | } 84 | 85 | return nil 86 | } 87 | 88 | // Disconnect closes the current SSH connection. 89 | func (c *Communicator) Disconnect() { 90 | c.reset() 91 | } 92 | 93 | func (c *Communicator) reset() { 94 | if c.keepaliveDone != nil { 95 | close(c.keepaliveDone) 96 | } 97 | c.keepaliveDone = nil 98 | 99 | if c.client != nil { 100 | c.client.Close() 101 | if c.OnConnClose != nil { 102 | c.OnConnClose(c.host) 103 | } 104 | } 105 | c.client = nil 106 | } 107 | 108 | // Start starts the specified command but does not wait for it to complete. 109 | // Each command is executed in a new SSH session. If context is canceled 110 | // the session is immediately closed and error is returned. 111 | // 112 | // The cmd Wait method will return the exit code and release associated 113 | // resources once the command exits. 114 | func (c *Communicator) Start(ctx context.Context, cmd *Cmd) error { 115 | session, err := c.newSession(ctx) 116 | if err != nil { 117 | return err 118 | } 119 | 120 | // Setup command 121 | cmd.init(ctx, session) 122 | 123 | // Setup session 124 | session.Stdin = cmd.Stdin 125 | session.Stdout = cmd.Stdout 126 | session.Stderr = cmd.Stderr 127 | 128 | if c.config.Pty { 129 | // request a PTY 130 | termModes := ssh.TerminalModes{ 131 | ssh.ECHO: 0, // do not echo 132 | ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 133 | ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 134 | } 135 | 136 | if err := session.RequestPty("xterm", 80, 40, termModes); err != nil { 137 | return err 138 | } 139 | } 140 | 141 | c.logger.Println("Starting remote command", 142 | "host", c.host, 143 | "cmd", cmd.Command, 144 | ) 145 | err = session.Start(strings.TrimSpace(cmd.Command) + "\n") 146 | if err != nil { 147 | return err 148 | } 149 | 150 | // Start a goroutine to wait for the session to end 151 | go func() { 152 | defer session.Close() 153 | 154 | err := session.Wait() 155 | exitStatus := 0 156 | if err != nil { 157 | if exitErr, ok := err.(*ssh.ExitError); ok { 158 | exitStatus = exitErr.ExitStatus() 159 | } else { 160 | exitStatus = -1 161 | } 162 | } 163 | 164 | if err != nil { 165 | c.logger.Println("Remote command exited with error", 166 | "host", c.host, 167 | "cmd", cmd.Command, 168 | "status", exitStatus, 169 | "error", err, 170 | ) 171 | } else { 172 | c.logger.Println("Remote command exited", 173 | "host", c.host, 174 | "cmd", cmd.Command, 175 | "status", exitStatus, 176 | ) 177 | } 178 | 179 | cmd.setExitStatus(exitStatus, err) 180 | }() 181 | 182 | return nil 183 | } 184 | 185 | func (c *Communicator) newSession(ctx context.Context) (session *ssh.Session, err error) { 186 | c.logger.Println("Opening new ssh session", "host", c.host) 187 | if c.client == nil { 188 | err = errors.New("ssh client is not connected") 189 | } else { 190 | session, err = c.client.NewSession() 191 | } 192 | 193 | if err != nil { 194 | c.logger.Println("ssh session open error", "host", c.host, "error", err) 195 | if err := c.Connect(ctx); err != nil { 196 | return nil, err 197 | } 198 | 199 | return c.client.NewSession() 200 | } 201 | 202 | return session, nil 203 | } 204 | 205 | // Upload creates a file with a given path and permissions and content on 206 | // a remote host. If context is canceled the upload is interrupted, file is not 207 | // saved and error is returned. 208 | func (c *Communicator) Upload(ctx context.Context, path string, perm os.FileMode, src io.Reader) error { 209 | // The target directory and file for talking the SCP protocol 210 | targetDir := filepath.Dir(path) 211 | targetFile := filepath.Base(path) 212 | 213 | // On windows, filepath.Dir uses backslash separators (ie. "\tmp"). 214 | // This does not work when the target host is unix. Switch to forward slash 215 | // which works for unix and windows 216 | targetDir = filepath.ToSlash(targetDir) 217 | 218 | // Skip copying if we can get the file size directly from common io.Readers 219 | size := int64(0) 220 | 221 | switch s := src.(type) { 222 | case interface { 223 | Stat() (os.FileInfo, error) 224 | }: 225 | fi, err := s.Stat() 226 | if err == nil { 227 | size = fi.Size() 228 | } 229 | case interface { 230 | Len() int 231 | }: 232 | size = int64(s.Len()) 233 | } 234 | 235 | c.logger.Println("Uploading file", 236 | "host", c.host, 237 | "path", path, 238 | "perm", perm.Perm(), 239 | ) 240 | 241 | scpFunc := func(w io.Writer, stdoutR *bufio.Reader) error { 242 | return scpUploadFile(w, src, stdoutR, targetFile, perm, size) 243 | } 244 | err := c.scpSession(ctx, "scp -vt "+targetDir, scpFunc) 245 | 246 | if err != nil { 247 | c.logger.Println("Uploading file ended with error", 248 | "host", c.host, 249 | "path", path, 250 | "perm", perm.Perm(), 251 | "error", err, 252 | ) 253 | } else { 254 | c.logger.Println("Uploading file ended", 255 | "host", c.host, 256 | "path", path, 257 | "perm", perm.Perm(), 258 | ) 259 | } 260 | 261 | return err 262 | } 263 | 264 | func (c *Communicator) scpSession(ctx context.Context, scpCommand string, f func(io.Writer, *bufio.Reader) error) error { 265 | session, err := c.newSession(ctx) 266 | if err != nil { 267 | return err 268 | } 269 | defer session.Close() 270 | 271 | // Get a pipe to stdin so that we can send data down 272 | stdinW, err := session.StdinPipe() 273 | if err != nil { 274 | return err 275 | } 276 | 277 | // We only want to close once, so we nil w after we close it, 278 | // and only close in the defer if it hasn't been closed already. 279 | defer func() { 280 | if stdinW != nil { 281 | stdinW.Close() 282 | } 283 | }() 284 | 285 | // Get a pipe to stdout so that we can get responses back 286 | stdoutPipe, err := session.StdoutPipe() 287 | if err != nil { 288 | return err 289 | } 290 | stdoutR := bufio.NewReader(stdoutPipe) 291 | 292 | // Start the sink mode on the other side 293 | if err := session.Start(scpCommand); err != nil { 294 | return err 295 | } 296 | 297 | // Call our callback that executes in the context of SCP. We ignore 298 | // EOF errors if they occur because it usually means that SCP prematurely 299 | // ended on the other side. 300 | if err := f(stdinW, stdoutR); err != nil && err != io.EOF { 301 | return err 302 | } 303 | 304 | // Close the stdin, which sends an EOF, and then set w to nil so that 305 | // our defer func doesn't close it again since that is unsafe with 306 | // the Go SSH package. 307 | stdinW.Close() 308 | stdinW = nil 309 | 310 | // Wait for the SCP connection to close, meaning it has consumed all 311 | // our data and has completed. Or has errored. 312 | exitCh := make(chan struct{}) 313 | go func() { 314 | // Ignore result if context was cancelled 315 | if ctx.Err() != nil { 316 | return 317 | } 318 | err = session.Wait() 319 | close(exitCh) 320 | }() 321 | 322 | select { 323 | case <-ctx.Done(): 324 | err = ctx.Err() 325 | case <-exitCh: 326 | // continue 327 | } 328 | 329 | if err != nil { 330 | if exitErr, ok := err.(*ssh.ExitError); ok { 331 | // Otherwise, we have an ExitErorr, meaning we can just read 332 | // the exit status 333 | c.logger.Println("scp error", "host", c.host, "error", exitErr) 334 | 335 | // If we exited with status 127, it means SCP isn't available. 336 | // Return a more descriptive error for that. 337 | if exitErr.ExitStatus() == 127 { 338 | return errors.New("SCP failed to start, this usually means that SCP is not properly installed on the remote system") 339 | } 340 | } 341 | 342 | return err 343 | } 344 | 345 | return nil 346 | } 347 | 348 | // checkSCPStatus checks that a prior command sent to SCP completed 349 | // successfully. If it did not complete successfully, an error will 350 | // be returned. 351 | func checkSCPStatus(r *bufio.Reader) error { 352 | code, err := r.ReadByte() 353 | if err != nil { 354 | return err 355 | } 356 | 357 | if code != 0 { 358 | // Treat any non-zero (really 1 and 2) as fatal errors 359 | message, _, err := r.ReadLine() 360 | if err != nil { 361 | return errors.Wrapf(err, "error reading error message") 362 | } 363 | 364 | return errors.New(string(message)) 365 | } 366 | 367 | return nil 368 | } 369 | 370 | func scpUploadFile(dst io.Writer, src io.Reader, stdout *bufio.Reader, file string, perm os.FileMode, size int64) error { 371 | if size == 0 { 372 | // Create a temporary file where we can copy the contents of the src 373 | // so that we can determine the length, since SCP is length-prefixed. 374 | tf, err := ioutil.TempFile("", "scylla-manager-upload") 375 | if err != nil { 376 | return errors.Wrapf(err, "error creating temporary file for upload") 377 | } 378 | defer os.Remove(tf.Name()) // nolint: errcheck 379 | defer tf.Close() 380 | 381 | if _, err := io.Copy(tf, src); err != nil { 382 | return err 383 | } 384 | 385 | // Sync the file so that the contents are definitely on disk, then 386 | // read the length of it. 387 | if err := tf.Sync(); err != nil { 388 | return errors.Wrapf(err, "error creating temporary file for upload") 389 | } 390 | 391 | // Seek the file to the beginning so we can re-read all of it 392 | if _, err := tf.Seek(0, 0); err != nil { 393 | return errors.Wrapf(err, "error creating temporary file for upload") 394 | } 395 | 396 | fi, err := tf.Stat() 397 | if err != nil { 398 | return errors.Wrapf(err, "error creating temporary file for upload") 399 | } 400 | 401 | src = tf 402 | size = fi.Size() 403 | } 404 | 405 | // Start the protocol 406 | mode := fmt.Sprintf("C%04o", uint32(perm.Perm())) 407 | 408 | fmt.Fprintln(dst, mode, size, file) 409 | if err := checkSCPStatus(stdout); err != nil { 410 | return err 411 | } 412 | 413 | if _, err := io.Copy(dst, src); err != nil { 414 | return err 415 | } 416 | 417 | fmt.Fprint(dst, "\x00") 418 | if err := checkSCPStatus(stdout); err != nil { 419 | return err 420 | } 421 | 422 | return nil 423 | } 424 | -------------------------------------------------------------------------------- /communicator_acc_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | // +build acc 4 | 5 | package sshtools 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "fmt" 11 | "io" 12 | "log" 13 | "math/rand" 14 | "net" 15 | "os" 16 | "testing" 17 | "time" 18 | 19 | "github.com/pkg/errors" 20 | ) 21 | 22 | func newTestCommunicator(host, user, pass string) *Communicator { 23 | var ( 24 | config = DefaultConfig().WithPasswordAuth(user, pass) 25 | dial = ContextDialer(&net.Dialer{}) 26 | logger = log.New(os.Stdout, "", log.LstdFlags) 27 | ) 28 | return NewCommunicator(host, config, dial, logger) 29 | } 30 | 31 | func newTestCommunicatorFromEnv(t *testing.T) *Communicator { 32 | t.Helper() 33 | 34 | host := os.Getenv("TEST_SSH_HOST") 35 | user := os.Getenv("TEST_SSH_USER") 36 | pass := os.Getenv("TEST_SSH_PASS") 37 | 38 | if host == "" || user == "" || pass == "" { 39 | t.Skip("Missing environment variables TEST_SSH_HOST, TEST_SSH_USER, TEST_SSH_PASS") 40 | } 41 | 42 | return newTestCommunicator(host, user, pass) 43 | } 44 | 45 | func TestAccStart(t *testing.T) { 46 | c := newTestCommunicatorFromEnv(t) 47 | 48 | if err := c.Connect(context.Background()); err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | t.Run("command success", func(t *testing.T) { 53 | var cmd Cmd 54 | stdout := new(bytes.Buffer) 55 | cmd.Command = "echo foo" 56 | cmd.Stdout = stdout 57 | 58 | ctx := context.Background() 59 | if err := c.Start(ctx, &cmd); err != nil { 60 | t.Fatalf("error executing remote command: %s", err) 61 | } 62 | 63 | if err := cmd.Wait(); err != nil { 64 | t.Fatal("command failed", err) 65 | } 66 | 67 | if stdout.String() != "foo\n" { 68 | t.Fatal("expected", "foo", "got", stdout.String()) 69 | } 70 | }) 71 | 72 | t.Run("command failure", func(t *testing.T) { 73 | var cmd Cmd 74 | cmd.Command = "false" 75 | 76 | ctx := context.Background() 77 | if err := c.Start(ctx, &cmd); err != nil { 78 | t.Fatalf("error executing remote command: %s", err) 79 | } 80 | 81 | err := cmd.Wait() 82 | if err == nil { 83 | t.Fatal("expected communicator error") 84 | } 85 | _, ok := err.(*ExitError) 86 | if !ok { 87 | t.Fatal("expected exit error") 88 | } 89 | }) 90 | 91 | t.Run("context canceled", func(t *testing.T) { 92 | var cmd Cmd 93 | cmd.Command = "sleep 5" 94 | 95 | ctx, cancel := context.WithCancel(context.Background()) 96 | if err := c.Start(ctx, &cmd); err != nil { 97 | t.Fatalf("error executing remote command: %s", err) 98 | } 99 | 100 | // Cancel the context, to cause the command to fail 101 | go func() { 102 | time.Sleep(100 * time.Millisecond) 103 | cancel() 104 | }() 105 | 106 | if err := cmd.Wait(); errors.Cause(err) != context.Canceled { 107 | t.Fatal("expected context.Canceled", "got", err) 108 | } 109 | }) 110 | } 111 | 112 | func TestAccLostConnection(t *testing.T) { 113 | c := newTestCommunicatorFromEnv(t) 114 | 115 | var cmd Cmd 116 | cmd.Command = "sleep 5" 117 | 118 | ctx := context.Background() 119 | if err := c.Start(ctx, &cmd); err != nil { 120 | t.Fatalf("error executing remote command: %s", err) 121 | } 122 | 123 | // Disconnect the communicator transport, to cause the command to fail 124 | go func() { 125 | time.Sleep(100 * time.Millisecond) 126 | c.Disconnect() 127 | }() 128 | 129 | if err := cmd.Wait(); err == nil { 130 | t.Fatal("expected communicator error") 131 | } 132 | } 133 | 134 | func TestAccUploadFile(t *testing.T) { 135 | c := newTestCommunicatorFromEnv(t) 136 | 137 | const MB = int64(1 << 20) 138 | 139 | r := rand.NewSource(time.Now().Unix()) 140 | TempFile := func() string { 141 | return fmt.Sprint("/tmp/ssh-upload-test-", r.Int63()) 142 | } 143 | 144 | t.Run("small file", func(t *testing.T) { 145 | var ( 146 | path = TempFile() 147 | perm = os.FileMode(0700) 148 | content = []byte("this is the file content") 149 | ) 150 | 151 | ctx := context.Background() 152 | if err := c.Upload(ctx, path, perm, bytes.NewReader(content)); err != nil { 153 | t.Fatal("error uploading file", err) 154 | } 155 | 156 | assertStdout(t, c, "stat -c '%A' "+path, perm.String()) 157 | assertStdout(t, c, "stat -c '%s' "+path, fmt.Sprint(len(content))) 158 | }) 159 | 160 | t.Run("big file", func(t *testing.T) { 161 | var ( 162 | path = TempFile() 163 | perm = os.FileMode(0644) 164 | size = 100 * MB 165 | ) 166 | 167 | ctx := context.Background() 168 | if err := c.Upload(ctx, path, perm, io.LimitReader(rand.New(r), size)); err != nil { 169 | t.Fatal("error uploading file", err) 170 | } 171 | 172 | assertStdout(t, c, "stat -c '%A' "+path, perm.String()) 173 | assertStdout(t, c, "stat -c '%s' "+path, fmt.Sprint(size)) 174 | }) 175 | 176 | t.Run("context cancelled", func(t *testing.T) { 177 | var ( 178 | path = TempFile() 179 | perm = os.FileMode(0420) 180 | size = 100 * MB 181 | ) 182 | 183 | ctx, cancel := context.WithCancel(context.Background()) 184 | source := &contextCancellingReader{ 185 | limit: MB, 186 | cancel: cancel, 187 | inner: rand.New(r), 188 | } 189 | 190 | if err := c.Upload(ctx, path, perm, io.LimitReader(source, size)); err != context.Canceled { 191 | t.Fatal("expected context.Canceled", "got", err) 192 | } 193 | 194 | var cmd Cmd 195 | cmd.Command = "stat -c '%A' " + path 196 | 197 | if err := c.Start(ctx, &cmd); err != nil { 198 | t.Fatalf("error executing remote command: %s", err) 199 | } 200 | if err := cmd.Wait(); err == nil { 201 | t.Fatal("expected file does not exist") 202 | } 203 | }) 204 | } 205 | 206 | type contextCancellingReader struct { 207 | size int 208 | limit int64 209 | cancel context.CancelFunc 210 | inner io.Reader 211 | 212 | read int64 213 | } 214 | 215 | func (r *contextCancellingReader) Read(p []byte) (n int, err error) { 216 | n, err = r.inner.Read(p) 217 | 218 | r.read += int64(n) 219 | if r.cancel != nil && r.read >= r.limit { 220 | r.cancel() 221 | r.cancel = nil 222 | } 223 | return 224 | } 225 | 226 | func (r *contextCancellingReader) Len() int { 227 | return r.size 228 | } 229 | 230 | func assertStdout(t *testing.T, c *Communicator, command string, expected string) { 231 | t.Helper() 232 | 233 | var cmd Cmd 234 | stdout := new(bytes.Buffer) 235 | cmd.Command = command 236 | cmd.Stdout = stdout 237 | 238 | ctx := context.Background() 239 | if err := c.Start(ctx, &cmd); err != nil { 240 | t.Fatalf("error executing remote command: %s", err) 241 | } 242 | if err := cmd.Wait(); err != nil { 243 | t.Fatal("command failed", err) 244 | } 245 | if stdout.String() != expected+"\n" { 246 | t.Fatal("expected", "foo", "got", stdout.String()) 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package sshtools 4 | 5 | import ( 6 | "time" 7 | 8 | "github.com/pkg/errors" 9 | "go.uber.org/multierr" 10 | "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | // Config specifies SSH configuration. 14 | type Config struct { 15 | ssh.ClientConfig `json:"-" yaml:"-"` 16 | // Port specifies the port number to connect on the remote host. 17 | Port int `yaml:"port"` 18 | // ServerAliveInterval specifies an interval to send keepalive message 19 | // through the encrypted channel and request a response from the server. 20 | ServerAliveInterval time.Duration `yaml:"server_alive_interval"` 21 | // ServerAliveCountMax specifies the number of server keepalive messages 22 | // which may be sent without receiving any messages back from the server. 23 | // If this threshold is reached while server keepalive messages are being sent, 24 | // ssh will disconnect from the server, terminating the session. 25 | ServerAliveCountMax int `yaml:"server_alive_count_max"` 26 | // Pty specifies if a pty should be associated with sessions on remote 27 | // hosts. Enabling pty would make Scylla banner to be printed to commands' 28 | // stdout. 29 | Pty bool `yaml:"pty"` 30 | } 31 | 32 | // DefaultConfig returns a Config initialized with default values. 33 | func DefaultConfig() Config { 34 | return Config{ 35 | Port: 22, 36 | ServerAliveInterval: 15 * time.Second, 37 | ServerAliveCountMax: 3, 38 | } 39 | } 40 | 41 | // Validate checks if all the fields are properly set. 42 | func (c Config) Validate() (err error) { 43 | if c.Port <= 0 { 44 | err = multierr.Append(err, errors.New("invalid port, must be > 0")) 45 | } 46 | 47 | if c.ServerAliveInterval < 0 { 48 | err = multierr.Append(err, errors.New("invalid server_alive_interval, must be >= 0")) 49 | } 50 | 51 | if c.ServerAliveCountMax < 0 { 52 | err = multierr.Append(err, errors.New("invalid server_alive_count_max, must be >= 0")) 53 | } 54 | 55 | return 56 | } 57 | 58 | // WithIdentityFileAuth returns a copy of c with added user and identity file 59 | // authentication method. 60 | func (c Config) WithIdentityFileAuth(user string, identityFile []byte) (Config, error) { 61 | if user == "" { 62 | return Config{}, errors.New("missing user") 63 | } 64 | 65 | auth, err := keyPairAuthMethod(identityFile) 66 | if err != nil { 67 | return Config{}, errors.Wrap(err, "failed to parse identity file") 68 | } 69 | 70 | config := c 71 | config.User = user 72 | config.Auth = []ssh.AuthMethod{auth} 73 | config.HostKeyCallback = ssh.InsecureIgnoreHostKey() 74 | 75 | return config, nil 76 | } 77 | 78 | func keyPairAuthMethod(pemBytes []byte) (ssh.AuthMethod, error) { 79 | signer, err := ssh.ParsePrivateKey(pemBytes) 80 | if err != nil { 81 | return nil, err 82 | } 83 | 84 | return ssh.PublicKeys(signer), nil 85 | } 86 | 87 | // WithPasswordAuth returns a copy of c with added user and password 88 | // authentication method. 89 | func (c Config) WithPasswordAuth(user, passwd string) Config { 90 | config := c 91 | config.User = user 92 | config.Auth = []ssh.AuthMethod{ssh.Password(passwd)} 93 | config.HostKeyCallback = ssh.InsecureIgnoreHostKey() 94 | return config 95 | } 96 | 97 | // KeepaliveEnabled returns true if SSH keepalive should be enabled. 98 | func (c Config) KeepaliveEnabled() bool { 99 | return c.ServerAliveInterval > 0 && c.ServerAliveCountMax > 0 100 | } 101 | -------------------------------------------------------------------------------- /dial.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package sshtools 4 | 5 | import ( 6 | "context" 7 | "net" 8 | 9 | "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | // DialContextFunc creates SSH connection to host with a given address. 13 | type DialContextFunc func(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) 14 | 15 | // ContextDialer returns DialContextFunc based on dialer to make net connections. 16 | func ContextDialer(dialer *net.Dialer) DialContextFunc { 17 | return contextDialer{dialer}.DialContext 18 | } 19 | 20 | type contextDialer struct { 21 | dialer *net.Dialer 22 | } 23 | 24 | func (d contextDialer) DialContext(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { 25 | conn, err := d.dialer.DialContext(ctx, network, addr) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | type dialRes struct { 31 | client *ssh.Client 32 | err error 33 | } 34 | dialc := make(chan dialRes, 1) 35 | 36 | go func() { 37 | sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) 38 | if err != nil { 39 | dialc <- dialRes{err: err} 40 | } else { 41 | dialc <- dialRes{client: ssh.NewClient(sshConn, chans, reqs)} 42 | } 43 | }() 44 | 45 | select { 46 | case v := <-dialc: 47 | // Our dial finished 48 | if v.client != nil { 49 | return v.client, nil 50 | } 51 | // Our dial failed 52 | conn.Close() 53 | // It wasn't an error due to cancellation, so 54 | // return the original error message: 55 | return nil, v.err 56 | case <-ctx.Done(): 57 | conn.Close() 58 | return nil, ctx.Err() 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package sshtools 4 | 5 | // Package sshtools provides a wrapper around SSH client with the following features: 6 | // 7 | // * Context aware (dial and execution), 8 | // * Keepalive enabled, 9 | // * Can copy files using SCP. 10 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/scylladb/go-sshtools 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/pkg/errors v0.8.1 7 | github.com/stretchr/testify v1.3.0 // indirect 8 | go.uber.org/atomic v1.4.0 // indirect 9 | go.uber.org/multierr v1.1.0 10 | golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56 11 | ) 12 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 4 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 8 | github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= 9 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 10 | go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= 11 | go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= 12 | go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= 13 | go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= 14 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 15 | golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56 h1:ZpKuNIejY8P0ExLOVyKhb0WsgG8UdvHXe6TWjY7eL6k= 16 | golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 17 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 18 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 19 | golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= 20 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 21 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 22 | -------------------------------------------------------------------------------- /httpshell/conn.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package httpshell 4 | 5 | import ( 6 | "io" 7 | "net" 8 | "time" 9 | 10 | "github.com/pkg/errors" 11 | "go.uber.org/multierr" 12 | "golang.org/x/crypto/ssh" 13 | ) 14 | 15 | // proxyConn is a net.Conn that writes to the SSH shell stdin and reads from 16 | // the SSH shell stdout. 17 | type proxyConn struct { 18 | client *ssh.Client 19 | session *ssh.Session 20 | stdin io.WriteCloser 21 | stdout io.Reader 22 | 23 | free func() 24 | } 25 | 26 | // newProxyConn opens a new session and start the shell. When the connection is 27 | // closed the client is closed and the free function is called. 28 | func newProxyConn(client *ssh.Client, stderr io.Writer, free func()) (*proxyConn, error) { 29 | // Open new session to the agent 30 | session, err := client.NewSession() 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | // Get a pipe to stdin so that we can send data down 36 | stdin, err := session.StdinPipe() 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | // Get a pipe to stdout so that we can get responses back 42 | stdout, err := session.StdoutPipe() 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | // Set stderr 48 | session.Stderr = stderr 49 | 50 | // Start the shell on the other side 51 | if err := session.Shell(); err != nil { 52 | return nil, err 53 | } 54 | 55 | return &proxyConn{ 56 | client: client, 57 | session: session, 58 | stdin: stdin, 59 | stdout: stdout, 60 | free: free, 61 | }, nil 62 | } 63 | 64 | func (conn *proxyConn) Read(b []byte) (n int, err error) { 65 | return conn.stdout.Read(b) 66 | } 67 | 68 | func (conn *proxyConn) Write(b []byte) (n int, err error) { 69 | return conn.stdin.Write(b) 70 | } 71 | 72 | func (conn *proxyConn) Close() error { 73 | var err error 74 | err = multierr.Append(err, conn.session.Close()) 75 | err = multierr.Append(err, conn.client.Close()) 76 | if conn.free != nil { 77 | conn.free() 78 | } 79 | return err 80 | } 81 | 82 | func (conn *proxyConn) LocalAddr() net.Addr { 83 | return conn.client.LocalAddr() 84 | } 85 | 86 | func (conn *proxyConn) RemoteAddr() net.Addr { 87 | return conn.client.RemoteAddr() 88 | } 89 | 90 | func (*proxyConn) SetDeadline(t time.Time) error { 91 | return errors.New("ssh: deadline not supported") 92 | } 93 | 94 | func (*proxyConn) SetReadDeadline(t time.Time) error { 95 | return errors.New("ssh: deadline not supported") 96 | } 97 | 98 | func (*proxyConn) SetWriteDeadline(t time.Time) error { 99 | return errors.New("ssh: deadline not supported") 100 | } 101 | -------------------------------------------------------------------------------- /httpshell/dialer.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package httpshell 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net" 9 | 10 | "github.com/pkg/errors" 11 | "github.com/scylladb/go-sshtools" 12 | ) 13 | 14 | // Dialer allows for proxying connections over SSH. It can be used with HTTP 15 | // client to allow communication with a HTTP shell using Listener and serving 16 | // HTTP request over stdin and stdout. 17 | type Dialer struct { 18 | config sshtools.Config 19 | dial sshtools.DialContextFunc 20 | logger sshtools.Logger 21 | 22 | // OnDial is a listener that may be set to track openning SSH connection to 23 | // the remote host. It is called for both successful and failed trials. 24 | OnDial func(host string, err error) 25 | // OnConnClose is a listener that may be set to track closing of SSH 26 | // connection. 27 | OnConnClose func(host string) 28 | } 29 | 30 | func NewDialer(config sshtools.Config, dial sshtools.DialContextFunc, logger sshtools.Logger) *Dialer { 31 | return &Dialer{ 32 | config: config, 33 | dial: dial, 34 | logger: logger, 35 | } 36 | } 37 | 38 | // DialContext to addr HOST:PORT establishes an SSH connection to HOST and then 39 | // sends request to the SSH shell. 40 | func (p *Dialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { 41 | host, _, _ := net.SplitHostPort(addr) // nolint: errcheck 42 | 43 | defer func() { 44 | if p.OnDial != nil { 45 | p.OnDial(host, err) 46 | } 47 | }() 48 | 49 | p.logger.Println("Connecting to remote host...", "host", host) 50 | 51 | client, err := p.dial(ctx, network, net.JoinHostPort(host, fmt.Sprint(p.config.Port)), &p.config.ClientConfig) 52 | if err != nil { 53 | return nil, errors.Wrap(err, "ssh: dial failed") 54 | } 55 | 56 | p.logger.Println("Starting session", "host", host) 57 | 58 | keepaliveDone := make(chan struct{}) 59 | free := func() { 60 | close(keepaliveDone) 61 | if p.OnConnClose != nil { 62 | p.OnConnClose(host) 63 | } 64 | p.logger.Println("Connection closed", "host", host) 65 | } 66 | 67 | pconn, err := newProxyConn(client, &logStderr{host: host, logger: p.logger}, free) 68 | if err != nil { 69 | client.Close() 70 | return nil, errors.Wrap(err, "ssh: failed to connect") 71 | } 72 | 73 | p.logger.Println("Connected!", "host", host) 74 | 75 | // Init SSH keepalive if needed 76 | if p.config.KeepaliveEnabled() { 77 | p.logger.Println("Starting ssh KeepAlives", "host", host) 78 | go sshtools.StartKeepalive(client, p.config.ServerAliveInterval, p.config.ServerAliveCountMax, keepaliveDone) 79 | } 80 | 81 | return pconn, nil 82 | } 83 | 84 | type logStderr struct { 85 | host string 86 | logger sshtools.Logger 87 | } 88 | 89 | func (w *logStderr) Write(p []byte) (n int, err error) { 90 | w.logger.Println("host", w.host, "stderr", string(p)) 91 | return len(p), nil 92 | } 93 | -------------------------------------------------------------------------------- /httpshell/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package httpshell 4 | 5 | // Package httpshell provides tools for building system login shells that you 6 | // can communicate with using HTTP protocol. This is useful for providing a fine 7 | // grained control what can be run. 8 | -------------------------------------------------------------------------------- /httpshell/listener.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package httpshell 4 | 5 | import ( 6 | "io" 7 | "net" 8 | "time" 9 | 10 | "github.com/pkg/errors" 11 | ) 12 | 13 | // Listener is a net.Listener that accepts only a single connection that uses 14 | // the given reader and writer. It's intended to be used with http.Server, 15 | // it can then expose http.Handlers over all sorts of transports. 16 | // 17 | // After accepting the first connection any calls to Accept will block until 18 | // the connection is closed, then they will end immediately with io.EOF error. 19 | // This is needed to block http.Server main loop and avoid termination of 20 | // the golden connection. 21 | type Listener struct { 22 | w io.Writer 23 | r io.ReadCloser 24 | done chan struct{} 25 | } 26 | 27 | func NewListener(w io.Writer, r io.ReadCloser) *Listener { 28 | return &Listener{w: w, r: r} 29 | } 30 | 31 | func (l *Listener) Accept() (net.Conn, error) { 32 | if l.done != nil { 33 | // Block the http.Server main loop and wait for the connection to end 34 | <-l.done 35 | return nil, io.EOF 36 | } 37 | 38 | // Return the connection consuming the reader and writer. 39 | l.done = make(chan struct{}) 40 | return &conn{ 41 | w: l.w, 42 | r: l.r, 43 | done: l.done, 44 | }, nil 45 | } 46 | 47 | func (l *Listener) Close() error { 48 | return errors.New("agent: closing Listener is not supported") 49 | } 50 | 51 | func (l *Listener) Addr() net.Addr { 52 | return nilAddr 53 | } 54 | 55 | // conn is a net.Conn that uses given reader and writer. 56 | // It should be only used by Listener. 57 | type conn struct { 58 | w io.Writer 59 | r io.ReadCloser 60 | done chan struct{} 61 | } 62 | 63 | func (c *conn) Read(b []byte) (n int, err error) { 64 | return c.r.Read(b) 65 | } 66 | 67 | func (c *conn) Write(b []byte) (n int, err error) { 68 | return c.w.Write(b) 69 | } 70 | 71 | func (c *conn) Close() error { 72 | defer func() { 73 | if c.done != nil { 74 | close(c.done) 75 | c.done = nil 76 | } 77 | }() 78 | return c.r.Close() 79 | } 80 | 81 | func (*conn) LocalAddr() net.Addr { 82 | return nilAddr 83 | } 84 | 85 | func (*conn) RemoteAddr() net.Addr { 86 | return nilAddr 87 | } 88 | 89 | func (*conn) SetDeadline(t time.Time) error { 90 | return errors.New("agent: deadline not supported") 91 | } 92 | 93 | func (*conn) SetReadDeadline(t time.Time) error { 94 | return errors.New("agent: deadline not supported") 95 | } 96 | 97 | func (*conn) SetWriteDeadline(t time.Time) error { 98 | return errors.New("agent: deadline not supported") 99 | } 100 | 101 | // addr is a mock net.Addr. 102 | type addr string 103 | 104 | func (addr) Network() string { 105 | return "tcp" 106 | } 107 | 108 | func (a addr) String() string { 109 | return string(a) 110 | } 111 | 112 | var nilAddr = addr("127.0.0.1") 113 | -------------------------------------------------------------------------------- /httpshell/listener_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package httpshell 4 | 5 | import ( 6 | "bytes" 7 | "io" 8 | "io/ioutil" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestListenerAcceptOnce(t *testing.T) { 14 | w := &bytes.Buffer{} 15 | r := &bytes.Buffer{} 16 | 17 | l := NewListener(w, ioutil.NopCloser(r)) 18 | conn, err := l.Accept() 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | time.AfterFunc(50*time.Millisecond, func() { 23 | conn.Close() 24 | }) 25 | _, err = l.Accept() 26 | if err != io.EOF { 27 | t.Fatal("expected io.EOF got", err) 28 | } 29 | } 30 | 31 | func TestListenerConn(t *testing.T) { 32 | const payload = "Foo" 33 | 34 | w := &bytes.Buffer{} 35 | r := bytes.NewBufferString(payload) 36 | 37 | l := NewListener(w, ioutil.NopCloser(r)) 38 | conn, err := l.Accept() 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | 43 | t.Run("read", func(t *testing.T) { 44 | b := make([]byte, 5) 45 | n, err := conn.Read(b) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | if n != len(payload) { 50 | t.Fatal("expected", len(payload), "got", n) 51 | } 52 | if string(b[:n]) != payload { 53 | t.Fatal("expected", payload, "got", string(b)) 54 | } 55 | }) 56 | 57 | t.Run("write", func(t *testing.T) { 58 | b := []byte(payload) 59 | n, err := conn.Write(b) 60 | if err != nil { 61 | t.Fatal(err) 62 | } 63 | if n != len(payload) { 64 | t.Fatal("expected", len(payload), "got", n) 65 | } 66 | if w.String() != payload { 67 | t.Fatal("expected", payload, "got", w.String()) 68 | } 69 | }) 70 | } 71 | -------------------------------------------------------------------------------- /keepalive.go: -------------------------------------------------------------------------------- 1 | // Copyright (C) 2017 ScyllaDB 2 | 3 | package sshtools 4 | 5 | import ( 6 | "time" 7 | 8 | "golang.org/x/crypto/ssh" 9 | ) 10 | 11 | // StartKeepalive starts sending server keepalive messages until done channel 12 | // is closed. 13 | func StartKeepalive(client *ssh.Client, interval time.Duration, countMax int, done <-chan struct{}) { 14 | t := time.NewTicker(interval) 15 | defer t.Stop() 16 | 17 | n := 0 18 | for { 19 | select { 20 | case <-t.C: 21 | if err := serverAliveCheck(client); err != nil { 22 | n++ 23 | if n >= countMax { 24 | client.Close() 25 | return 26 | } 27 | } else { 28 | n = 0 29 | } 30 | case <-done: 31 | return 32 | } 33 | } 34 | } 35 | 36 | func serverAliveCheck(client *ssh.Client) (err error) { 37 | // This is ported version of Open SSH client server_alive_check function 38 | // see: https://github.com/openssh/openssh-portable/blob/b5e412a8993ad17b9e1141c78408df15d3d987e1/clientloop.c#L482 39 | _, _, err = client.SendRequest("keepalive@openssh.com", true, nil) 40 | return 41 | } 42 | --------------------------------------------------------------------------------