├── examples ├── fileupload │ ├── test │ └── filetransfer.go └── localforward.go ├── .travis.yml ├── tunnel.go ├── forward.go ├── client_test.go ├── LICENSE ├── tool_test.go ├── common.go ├── localfs.go ├── README.md ├── tool.go ├── client.go └── session.go /examples/fileupload/test: -------------------------------------------------------------------------------- 1 | hellohellohellohellohellohellohellohello 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | 4 | go: 5 | - 1.4 6 | - 1.5 7 | - 1.6 8 | - tip 9 | -------------------------------------------------------------------------------- /tunnel.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | "golang.org/x/crypto/ssh" 5 | ) 6 | 7 | type Tunnel struct { 8 | Client *ssh.Client 9 | } 10 | -------------------------------------------------------------------------------- /forward.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | type ForwardConfig struct { 4 | LocalBindAddress string 5 | RemoteAddress string 6 | SshServerAddress string 7 | SshUserName string 8 | SshUserPassword string 9 | SshPrivateKey string 10 | } 11 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_Cmd(t *testing.T) { 8 | /* sshconfig := &SSHClientConfig{ 9 | User: "user", 10 | Password: "pwd", 11 | Host: "11.11.22.22", 12 | } 13 | sshclient := NewSSHClient(sshconfig) 14 | t.Log(sshclient.Host) 15 | stdout, stderr,session, err := sshclient.Cmd("pwd",nil,nil,0) 16 | if err != nil { 17 | t.Error(err) 18 | } 19 | t.Log(stdout) 20 | t.Log(stderr)*/ 21 | t.Log("test") 22 | } 23 | -------------------------------------------------------------------------------- /examples/fileupload/filetransfer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | //"github.com/scottkiss/gosshtool" 5 | // "log" 6 | ) 7 | 8 | const ( 9 | HOST = "your host" 10 | USER = "your name" 11 | PASSWORD = "your password" 12 | ) 13 | 14 | func main() { 15 | //config := &gosshtool.SSHClientConfig{ 16 | // User: USER, 17 | // Password: PASSWORD, 18 | // Host: HOST, 19 | //} 20 | // sshclient := gosshtool.NewSSHClient(config) 21 | // sshclient.MaxDataThroughput = 6553600 22 | // stdout, stderr, err := gosshtool.UploadFile(HOST, "./test", "/root/test/test.txt") 23 | // if err != nil { 24 | // log.Panicln(err) 25 | // } 26 | // if stderr != "" { 27 | // log.Panicln(stderr) 28 | // } 29 | // log.Println("upload succeeded " + stdout) 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | (The MIT License) 2 | 3 | Copyright (c) 2016 sk,http://cocosk.com/ ,skkmvp@hotmail.com 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | 'Software'), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /examples/localforward.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | _ "github.com/go-sql-driver/mysql" 5 | "github.com/scottkiss/gomagic/dbmagic" 6 | "github.com/scottkiss/gosshtool" 7 | //"io/ioutil" 8 | "log" 9 | ) 10 | 11 | func dbop() { 12 | ds := new(dbmagic.DataSource) 13 | ds.Charset = "utf8" 14 | ds.Host = "127.0.0.1" 15 | ds.Port = 9999 16 | ds.DatabaseName = "test" 17 | ds.User = "root" 18 | ds.Password = "password" 19 | dbm, err := dbmagic.Open("mysql", ds) 20 | if err != nil { 21 | log.Fatal(err) 22 | } 23 | row := dbm.Db.QueryRow("select name from provinces where id=?", 1) 24 | var name string 25 | err = row.Scan(&name) 26 | if err != nil { 27 | log.Fatal(err) 28 | } 29 | log.Println(name) 30 | dbm.Close() 31 | } 32 | 33 | func main() { 34 | server := new(gosshtool.LocalForwardServer) 35 | server.LocalBindAddress = ":9999" 36 | server.RemoteAddress = "remote.com:3306" 37 | server.SshServerAddress = "112.224.38.111" 38 | server.SshUserPassword = "passwd" 39 | //buf, _ := ioutil.ReadFile("/your/home/path/.ssh/id_rsa") 40 | //server.SshPrivateKey = string(buf) 41 | server.SshUserName = "sirk" 42 | server.Start(dbop) 43 | defer server.Stop() 44 | } 45 | -------------------------------------------------------------------------------- /tool_test.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_newSSHClinet(t *testing.T) { 8 | /* config := &SSHClientConfig{ 9 | User: "s", 10 | Password: "1223", 11 | Host: "127.0.0.1", 12 | } 13 | sshclient := NewSSHClient(config) 14 | stdout, stderr,session, err := sshclient.Cmd("pwd",nil,nil,0) 15 | if err != nil { 16 | t.Error(err) 17 | } 18 | t.Log(stdout) 19 | t.Log(stderr) 20 | stdout, stderr,session, err = sshclient.Cmd("ls",nil,nil,0) 21 | t.Log(stdout) 22 | t.Log("test")*/ 23 | } 24 | 25 | func Test_mutiCmd(t *testing.T) { 26 | /* config := &SSHClientConfig{ 27 | User: "jack", 28 | Password: "assd", 29 | Host: "31.11.11.11", 30 | } 31 | NewSSHClient(config) 32 | 33 | config2 := &SSHClientConfig{ 34 | User: "asd", 35 | Password: "daas", 36 | Host: "8.8.8.8", 37 | } 38 | NewSSHClient(config2) 39 | stdout, _, _,err := ExecuteCmd("pwd", "8.8.8.8") 40 | if err != nil { 41 | t.Error(err) 42 | } 43 | t.Log(stdout) 44 | 45 | stdout, _,_, err = ExecuteCmd("pwd", "114.215.151.48") 46 | if err != nil { 47 | t.Error(err) 48 | } 49 | t.Log(stdout)*/ 50 | t.Log("test") 51 | } 52 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | "golang.org/x/crypto/ssh" 5 | "io" 6 | "log" 7 | ) 8 | 9 | type PtyInfo struct { 10 | Term string 11 | H int 12 | W int 13 | Modes ssh.TerminalModes 14 | } 15 | 16 | type ReadWriteCloser interface { 17 | io.Reader 18 | io.WriteCloser 19 | } 20 | 21 | type SSHClientConfig struct { 22 | Host string 23 | User string 24 | Password string 25 | Privatekey string 26 | DialTimeoutSecond int 27 | MaxDataThroughput uint64 28 | } 29 | 30 | func makeConfig(user string, password string, privateKey string) (config *ssh.ClientConfig) { 31 | 32 | if password == "" && privateKey == "" { 33 | log.Fatal("No password or private key available") 34 | } 35 | if user == "" { 36 | log.Fatal("user is required parameter, not allow empyt!") 37 | } 38 | config = &ssh.ClientConfig{ 39 | User: user, 40 | Auth: []ssh.AuthMethod{ 41 | ssh.Password(password), 42 | }, 43 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), 44 | } 45 | if privateKey != "" { 46 | signer, err := ssh.ParsePrivateKey([]byte(privateKey)) 47 | if err != nil { 48 | log.Fatalf("ssh.ParsePrivateKey error:%v", err) 49 | } 50 | clientkey := ssh.PublicKeys(signer) 51 | config = &ssh.ClientConfig{ 52 | User: user, 53 | Auth: []ssh.AuthMethod{ 54 | clientkey, 55 | }, 56 | HostKeyCallback: ssh.InsecureIgnoreHostKey(), 57 | } 58 | } 59 | return 60 | } 61 | -------------------------------------------------------------------------------- /localfs.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net" 7 | ) 8 | 9 | type LocalForwardServer struct { 10 | ForwardConfig 11 | tunnel *Tunnel 12 | } 13 | 14 | //create tunnel 15 | func (this *LocalForwardServer) createTunnel() { 16 | config := &SSHClientConfig{ 17 | User: this.SshUserName, 18 | Password: this.SshUserPassword, 19 | Host: this.SshServerAddress, 20 | Privatekey: this.SshPrivateKey, 21 | } 22 | sshclient := NewSSHClient(config) 23 | conn, err := sshclient.Connect() 24 | if err != nil { 25 | log.Fatal("Failed to dial: " + err.Error()) 26 | } 27 | log.Println("create ssh client ok") 28 | this.tunnel = &Tunnel{conn} 29 | } 30 | 31 | func (this *LocalForwardServer) handleConnectionAndForward(conn *net.Conn) { 32 | sshConn, err := this.tunnel.Client.Dial("tcp", this.RemoteAddress) 33 | if err != nil { 34 | log.Fatalf("ssh client dial error:%v", err) 35 | } 36 | log.Println("create ssh connection ok") 37 | go localReaderToRemoteWriter(*conn, sshConn) 38 | go remoteReaderToLoacalWriter(sshConn, *conn) 39 | } 40 | 41 | func localReaderToRemoteWriter(localConn net.Conn, sshConn net.Conn) { 42 | _, err := io.Copy(sshConn, localConn) 43 | if err != nil { 44 | log.Fatalf("io copy error:%v", err) 45 | } 46 | } 47 | 48 | func remoteReaderToLoacalWriter(sshConn net.Conn, localConn net.Conn) { 49 | _, err := io.Copy(localConn, sshConn) 50 | if err != nil { 51 | log.Fatalf("io copy error:%v", err) 52 | } 53 | } 54 | 55 | func (this *LocalForwardServer) Start(call func()) { 56 | this.createTunnel() 57 | ln, err := net.Listen("tcp", this.LocalBindAddress) 58 | if err != nil { 59 | log.Fatalf("net listen error :%v", err) 60 | } 61 | defer ln.Close() 62 | var called bool 63 | for { 64 | if !called && call != nil { 65 | go call() 66 | called = true 67 | } 68 | conn, err := ln.Accept() 69 | if err != nil { 70 | log.Println(err) 71 | } 72 | go this.handleConnectionAndForward(&conn) 73 | defer conn.Close() 74 | } 75 | } 76 | 77 | func (this *LocalForwardServer) Stop() { 78 | err := this.tunnel.Client.Close() 79 | if err != nil { 80 | log.Fatalf("ssh client stop error:%v", err) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gosshtool 2 | 3 | ssh tool library for Go,gosshtool provide some useful functions for ssh client in golang.implemented using golang.org/x/crypto/ssh. 4 | 5 | 6 | ## supports 7 | * command execution on multiple servers. 8 | * ssh tunnel local port forwarding. 9 | * ssh authentication using private keys or password. 10 | * ssh session timeout support. 11 | * ssh file upload support. 12 | 13 | ## Installation 14 | ```bash 15 | go get -u github.com/scottkiss/gosshtool 16 | ``` 17 | 18 | ## Examples 19 | 20 | ### command execution on single server 21 | 22 | ```golang 23 | import "github.com/scottkiss/gosshtool" 24 | sshconfig := &gosshtool.SSHClientConfig{ 25 | User: "user", 26 | Password: "pwd", 27 | Host: "11.11.22.22", 28 | } 29 | sshclient := gosshtool.NewSSHClient(sshconfig) 30 | t.Log(sshclient.Host) 31 | stdout, stderr,session, err := sshclient.Cmd("pwd",nil,nil,0) 32 | if err != nil { 33 | t.Error(err) 34 | } 35 | t.Log(stdout) 36 | t.Log(stderr) 37 | ``` 38 | 39 | 40 | ### command execution on multiple servers 41 | 42 | ```golang 43 | import "github.com/scottkiss/gosshtool" 44 | 45 | config := &gosshtool.SSHClientConfig{ 46 | User: "sam", 47 | Password: "123456", 48 | Host: "serverA", //ip:port 49 | } 50 | gosshtool.NewSSHClient(config) 51 | 52 | config2 := &gosshtool.SSHClientConfig{ 53 | User: "sirk", 54 | Privatekey: "sshprivatekey", 55 | Host: "serverB", 56 | } 57 | gosshtool.NewSSHClient(config2) 58 | stdout, _,_, err := gosshtool.ExecuteCmd("pwd", "serverA") 59 | if err != nil { 60 | t.Error(err) 61 | } 62 | t.Log(stdout) 63 | 64 | stdout, _,_, err = gosshtool.ExecuteCmd("pwd", "serverB") 65 | if err != nil { 66 | t.Error(err) 67 | } 68 | t.Log(stdout) 69 | ``` 70 | 71 | ### ssh tunnel port forwarding 72 | ```golang 73 | 74 | package main 75 | 76 | import ( 77 | _ "github.com/go-sql-driver/mysql" 78 | "github.com/scottkiss/gomagic/dbmagic" 79 | "github.com/scottkiss/gosshtool" 80 | //"io/ioutil" 81 | "log" 82 | ) 83 | 84 | func dbop() { 85 | ds := new(dbmagic.DataSource) 86 | ds.Charset = "utf8" 87 | ds.Host = "127.0.0.1" 88 | ds.Port = 9999 89 | ds.DatabaseName = "test" 90 | ds.User = "root" 91 | ds.Password = "password" 92 | dbm, err := dbmagic.Open("mysql", ds) 93 | if err != nil { 94 | log.Fatal(err) 95 | } 96 | row := dbm.Db.QueryRow("select name from provinces where id=?", 1) 97 | var name string 98 | err = row.Scan(&name) 99 | if err != nil { 100 | log.Fatal(err) 101 | } 102 | log.Println(name) 103 | dbm.Close() 104 | } 105 | 106 | func main() { 107 | server := new(gosshtool.LocalForwardServer) 108 | server.LocalBindAddress = ":9999" 109 | server.RemoteAddress = "remote.com:3306" 110 | server.SshServerAddress = "112.224.38.111" 111 | server.SshUserPassword = "passwd" 112 | //buf, _ := ioutil.ReadFile("/your/home/path/.ssh/id_rsa") 113 | //server.SshPrivateKey = string(buf) 114 | server.SshUserName = "sirk" 115 | server.Start(dbop) 116 | defer server.Stop() 117 | } 118 | 119 | ``` 120 | 121 | ## More Examples 122 | * [sshcmd](https://github.com/scottkiss/sshcmd) simple ssh command line client. 123 | * [gooverssh](https://github.com/scottkiss/gooverssh) port forward server over ssh. 124 | 125 | ## License 126 | View the [LICENSE](https://github.com/scottkiss/gosshtool/blob/master/LICENSE) file 127 | 128 | 129 | -------------------------------------------------------------------------------- /tool.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | crand "crypto/rand" 5 | "encoding/hex" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "io/ioutil" 10 | mrand "math/rand" 11 | "os" 12 | "regexp" 13 | "strings" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | var ( 19 | sshClients map[string]*SSHClient 20 | sshClientsMutex sync.RWMutex 21 | ) 22 | 23 | var seeded bool = false 24 | 25 | var syncbufpool *sync.Pool 26 | 27 | var uuidRegex *regexp.Regexp = regexp.MustCompile(`^\{?([a-fA-F0-9]{8})-?([a-fA-F0-9]{4})-?([a-fA-F0-9]{4})-?([a-fA-F0-9]{4})-?([a-fA-F0-9]{12})\}?$`) 28 | 29 | type UUID [16]byte 30 | 31 | // Hex returns a hex string representation of the UUID in xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx format. 32 | func (this UUID) Hex() string { 33 | x := [16]byte(this) 34 | return fmt.Sprintf("%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", 35 | x[0], x[1], x[2], x[3], x[4], 36 | x[5], x[6], 37 | x[7], x[8], 38 | x[9], x[10], x[11], x[12], x[13], x[14], x[15]) 39 | 40 | } 41 | 42 | // Rand generates a new version 4 UUID. 43 | func Rand() UUID { 44 | var x [16]byte 45 | randBytes(x[:]) 46 | x[6] = (x[6] & 0x0F) | 0x40 47 | x[8] = (x[8] & 0x3F) | 0x80 48 | return x 49 | } 50 | func FromStr(s string) (id UUID, err error) { 51 | if s == "" { 52 | err = errors.New("Empty string") 53 | return 54 | } 55 | 56 | parts := uuidRegex.FindStringSubmatch(s) 57 | if parts == nil { 58 | err = errors.New("Invalid string format") 59 | return 60 | } 61 | 62 | var array [16]byte 63 | slice, _ := hex.DecodeString(strings.Join(parts[1:], "")) 64 | copy(array[:], slice) 65 | id = array 66 | return 67 | } 68 | 69 | func MustFromStr(s string) UUID { 70 | id, err := FromStr(s) 71 | if err != nil { 72 | panic(err) 73 | } 74 | return id 75 | } 76 | 77 | func randBytes(x []byte) { 78 | length := len(x) 79 | n, err := crand.Read(x) 80 | 81 | if n != length || err != nil { 82 | if !seeded { 83 | mrand.Seed(time.Now().UnixNano()) 84 | } 85 | for length > 0 { 86 | length-- 87 | x[length] = byte(mrand.Int31n(256)) 88 | } 89 | } 90 | } 91 | 92 | func init() { 93 | sshClients = make(map[string]*SSHClient) 94 | syncbufpool = &sync.Pool{} 95 | syncbufpool.New = func() interface{} { 96 | return make([]byte, 32*1024) 97 | } 98 | } 99 | 100 | func CopyIOAndUpdateSessionDeadline(dst io.Writer, src io.Reader, session *SshSession) (written int64, err error) { 101 | if wt, ok := src.(io.WriterTo); ok { 102 | return wt.WriteTo(dst) 103 | } 104 | if rt, ok := dst.(io.ReaderFrom); ok { 105 | return rt.ReadFrom(src) 106 | } 107 | 108 | buf := syncbufpool.Get().([]byte) 109 | defer syncbufpool.Put(buf) 110 | 111 | for { 112 | nr, er := src.Read(buf) 113 | if nr > 0 { 114 | if session.idleTimeout > 0 { 115 | deadlinenew := time.Now().Add(time.Second * time.Duration(session.idleTimeout)) 116 | session.SetDeadline(&deadlinenew) 117 | } 118 | nw, ew := dst.Write(buf[0:nr]) 119 | if nw > 0 { 120 | written += int64(nw) 121 | } 122 | if ew != nil { 123 | err = ew 124 | break 125 | } 126 | if nr != nw { 127 | err = io.ErrShortWrite 128 | break 129 | } 130 | } 131 | if er == io.EOF { 132 | break 133 | } 134 | if er != nil { 135 | err = er 136 | break 137 | } 138 | } 139 | return written, err 140 | } 141 | 142 | func NewSSHClient(config *SSHClientConfig) (client *SSHClient) { 143 | sshClientsMutex.RLock() 144 | client = sshClients[config.Host] 145 | if client != nil { 146 | return 147 | } 148 | sshClientsMutex.RUnlock() 149 | client = new(SSHClient) 150 | client.Host = config.Host 151 | client.User = config.User 152 | client.Password = config.Password 153 | client.Privatekey = config.Privatekey 154 | client.DialTimeoutSecond = config.DialTimeoutSecond 155 | sshClientsMutex.Lock() 156 | sshClients[config.Host] = client 157 | sshClientsMutex.Unlock() 158 | return client 159 | } 160 | 161 | func getClient(hostname string) (client *SSHClient, err error) { 162 | if hostname == "" { 163 | return nil, errors.New("host name is empty") 164 | } 165 | sshClientsMutex.RLock() 166 | client = sshClients[hostname] 167 | if client != nil { 168 | return client, nil 169 | } 170 | sshClientsMutex.RUnlock() 171 | return nil, errors.New("client not create") 172 | } 173 | 174 | func ExecuteCmd(cmd, hostname string) (output, errput string, currentSession *SshSession, err error) { 175 | client, err := getClient(hostname) 176 | if err != nil { 177 | return 178 | } 179 | return client.Cmd(cmd, nil, nil, 0) 180 | } 181 | 182 | func UploadFile(hostname, sourceFile, targetFile string) (stdout, stderr string, err error) { 183 | client, err := getClient(hostname) 184 | if err != nil { 185 | return 186 | } 187 | f, err := os.Open(sourceFile) 188 | if err != nil { 189 | return 190 | } 191 | defer f.Close() 192 | data, err := ioutil.ReadAll(f) 193 | if err != nil { 194 | return 195 | } 196 | return client.TransferData(targetFile, data) 197 | } 198 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | "bytes" 5 | "golang.org/x/crypto/ssh" 6 | "log" 7 | "net" 8 | "strings" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | const ( 14 | DEFAULT_CHUNK_SIZE = 65536 15 | MIN_CHUNKS = 10 16 | THROUGHPUT_SLEEP_INTERVAL = 100 17 | MIN_THROUGHPUT = DEFAULT_CHUNK_SIZE * MIN_CHUNKS * (1000 / THROUGHPUT_SLEEP_INTERVAL) 18 | ) 19 | 20 | var ( 21 | maxThroughputChan = make(chan bool, MIN_CHUNKS) 22 | maxThroughput uint64 23 | maxThroughputMutex sync.Mutex 24 | ) 25 | 26 | type SSHClient struct { 27 | SSHClientConfig 28 | remoteConn *ssh.Client 29 | isConnected bool 30 | } 31 | 32 | func (c *SSHClient) maxThroughputControl() { 33 | for { 34 | if c.MaxDataThroughput > 0 && c.MaxDataThroughput < MIN_THROUGHPUT { 35 | log.Panicf("Minimal throughput is %d Bps", MIN_THROUGHPUT) 36 | } 37 | maxThroughputMutex.Lock() 38 | throughput := c.MaxDataThroughput 39 | maxThroughputMutex.Unlock() 40 | chunks := throughput / DEFAULT_CHUNK_SIZE * THROUGHPUT_SLEEP_INTERVAL / 1000 41 | if chunks < MIN_CHUNKS { 42 | chunks = MIN_CHUNKS 43 | } 44 | for i := uint64(0); i < chunks; i++ { 45 | maxThroughputChan <- true 46 | } 47 | if throughput > 0 { 48 | time.Sleep(THROUGHPUT_SLEEP_INTERVAL * time.Millisecond) 49 | } 50 | } 51 | } 52 | 53 | func (c *SSHClient) Connect() (conn *ssh.Client, err error) { 54 | if c.remoteConn != nil { 55 | return 56 | } 57 | port := "22" 58 | host := c.Host 59 | hstr := strings.SplitN(host, ":", 2) 60 | if len(hstr) == 2 { 61 | host = hstr[0] 62 | port = hstr[1] 63 | } 64 | 65 | config := makeConfig(c.User, c.Password, c.Privatekey) 66 | 67 | if c.DialTimeoutSecond > 0 { 68 | connNet, err := net.DialTimeout("tcp", host+":"+port, time.Duration(c.DialTimeoutSecond)*time.Second) 69 | if err != nil { 70 | return nil, err 71 | } 72 | sc, chans, reqs, err := ssh.NewClientConn(connNet, host+":"+port, config) 73 | if err != nil { 74 | return nil, err 75 | } 76 | conn = ssh.NewClient(sc, chans, reqs) 77 | } else { 78 | conn, err = ssh.Dial("tcp", host+":"+port, config) 79 | if err != nil { 80 | return 81 | } 82 | } 83 | log.Println("dial ssh success") 84 | c.remoteConn = conn 85 | return 86 | } 87 | 88 | func (c *SSHClient) TransferData(target string, data []byte) (stdout, stderr string, err error) { 89 | go c.maxThroughputControl() 90 | 91 | if c.isConnected == false { 92 | _, err = c.Connect() 93 | if err != nil { 94 | return 95 | } 96 | } 97 | currentSession, err := NewSession(c.remoteConn, nil, 0) 98 | if err != nil { 99 | return 100 | } 101 | defer currentSession.Close() 102 | cmd := "cat >'" + strings.Replace(target, "'", "'\\''", -1) + "'" 103 | stdinPipe, err := currentSession.StdinPipe() 104 | if err != nil { 105 | return 106 | } 107 | var stdoutBuf bytes.Buffer 108 | var stderrBuf bytes.Buffer 109 | currentSession.Stdout = &stdoutBuf 110 | currentSession.Stderr = &stderrBuf 111 | err = currentSession.session.Start(cmd) 112 | if err != nil { 113 | return 114 | } 115 | for start, max := 0, len(data); start < max; start += DEFAULT_CHUNK_SIZE { 116 | <-maxThroughputChan 117 | end := start + DEFAULT_CHUNK_SIZE 118 | if end > max { 119 | end = max 120 | } 121 | _, err = stdinPipe.Write(data[start:end]) 122 | if err != nil { 123 | return 124 | } 125 | } 126 | err = stdinPipe.Close() 127 | if err != nil { 128 | return 129 | } 130 | err = currentSession.Wait() 131 | stdout = stdoutBuf.String() 132 | stderr = stderrBuf.String() 133 | return 134 | } 135 | 136 | func (c *SSHClient) Cmd(cmd string, sn *SshSession, deadline *time.Time, idleTimeout int) (output, errput string, currentSession *SshSession, err error) { 137 | if c.isConnected == false { 138 | _, err = c.Connect() 139 | if err != nil { 140 | return 141 | } 142 | } 143 | if sn == nil { 144 | currentSession, err = NewSession(c.remoteConn, deadline, idleTimeout) 145 | } else { 146 | currentSession = sn 147 | currentSession.SetDeadline(deadline) 148 | } 149 | if err != nil { 150 | return 151 | } 152 | var stdoutBuf bytes.Buffer 153 | var stderrBuf bytes.Buffer 154 | currentSession.Stdout = &stdoutBuf 155 | currentSession.Stderr = &stderrBuf 156 | err = currentSession.Run(cmd) 157 | defer currentSession.Close() 158 | output = stdoutBuf.String() 159 | errput = stderrBuf.String() 160 | return 161 | } 162 | 163 | func (c *SSHClient) Pipe(rw ReadWriteCloser, pty *PtyInfo, deadline *time.Time, idleTimeout int) (currentSession *SshSession, err error) { 164 | if c.isConnected == false { 165 | _, err := c.Connect() 166 | if err != nil { 167 | return nil, err 168 | } 169 | } 170 | currentSession, err = NewSession(c.remoteConn, deadline, idleTimeout) 171 | if err != nil { 172 | return 173 | } 174 | 175 | if err = currentSession.RequestPty(pty.Term, pty.H, pty.W, pty.Modes); err != nil { 176 | return 177 | } 178 | wc, err := currentSession.StdinPipe() 179 | if err != nil { 180 | return 181 | } 182 | 183 | go CopyIOAndUpdateSessionDeadline(wc, rw, currentSession) 184 | 185 | r, err := currentSession.StdoutPipe() 186 | if err != nil { 187 | return 188 | } 189 | go CopyIOAndUpdateSessionDeadline(rw, r, currentSession) 190 | er, err := currentSession.StderrPipe() 191 | if err != nil { 192 | return 193 | } 194 | go CopyIOAndUpdateSessionDeadline(rw, er, currentSession) 195 | err = currentSession.Shell() 196 | if err != nil { 197 | return 198 | } 199 | err = currentSession.Wait() 200 | if err != nil { 201 | return 202 | } 203 | defer currentSession.Close() 204 | return 205 | } 206 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package gosshtool 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "golang.org/x/crypto/ssh" 7 | "io" 8 | "log" 9 | "time" 10 | ) 11 | 12 | // POSIX terminal mode flags as listed in RFC 4254 Section 8. 13 | const ( 14 | tty_OP_END = 0 15 | VINTR = 1 16 | VQUIT = 2 17 | VERASE = 3 18 | VKILL = 4 19 | VEOF = 5 20 | VEOL = 6 21 | VEOL2 = 7 22 | VSTART = 8 23 | VSTOP = 9 24 | VSUSP = 10 25 | VDSUSP = 11 26 | VREPRINT = 12 27 | VWERASE = 13 28 | VLNEXT = 14 29 | VFLUSH = 15 30 | VSWTCH = 16 31 | VSTATUS = 17 32 | VDISCARD = 18 33 | IGNPAR = 30 34 | PARMRK = 31 35 | INPCK = 32 36 | ISTRIP = 33 37 | INLCR = 34 38 | IGNCR = 35 39 | ICRNL = 36 40 | IUCLC = 37 41 | IXON = 38 42 | IXANY = 39 43 | IXOFF = 40 44 | IMAXBEL = 41 45 | ISIG = 50 46 | ICANON = 51 47 | XCASE = 52 48 | ECHO = 53 49 | ECHOE = 54 50 | ECHOK = 55 51 | ECHONL = 56 52 | NOFLSH = 57 53 | TOSTOP = 58 54 | IEXTEN = 59 55 | ECHOCTL = 60 56 | ECHOKE = 61 57 | PENDIN = 62 58 | OPOST = 70 59 | OLCUC = 71 60 | ONLCR = 72 61 | OCRNL = 73 62 | ONOCR = 74 63 | ONLRET = 75 64 | CS7 = 90 65 | CS8 = 91 66 | PARENB = 92 67 | PARODD = 93 68 | TTY_OP_ISPEED = 128 69 | TTY_OP_OSPEED = 129 70 | ) 71 | 72 | type SshSession struct { 73 | id string 74 | session *ssh.Session 75 | Stdout *bytes.Buffer 76 | Stderr *bytes.Buffer 77 | deadline *time.Time 78 | idleTimeout int 79 | ch ssh.Channel 80 | started bool 81 | } 82 | 83 | // RFC 4254 Section 6.2. 84 | type ptyRequestMsg struct { 85 | Term string 86 | Columns uint32 87 | Rows uint32 88 | Width uint32 89 | Height uint32 90 | Modelist string 91 | } 92 | 93 | func (sc *SshSession) start() error { 94 | sc.started = true 95 | return nil 96 | } 97 | 98 | func (sc *SshSession) Run(cmd string) (err error) { 99 | sc.session.Stdout = sc.Stdout 100 | sc.session.Stderr = sc.Stderr 101 | return sc.session.Run(cmd) 102 | } 103 | 104 | func (sc *SshSession) RequestPty(term string, h int, w int, termmodes ssh.TerminalModes) (err error) { 105 | if sc.session != nil { 106 | return sc.session.RequestPty(term, h, w, termmodes) 107 | } else { 108 | var tm []byte 109 | for k, v := range termmodes { 110 | kv := struct { 111 | Key byte 112 | Val uint32 113 | }{k, v} 114 | 115 | tm = append(tm, ssh.Marshal(&kv)...) 116 | } 117 | tm = append(tm, tty_OP_END) 118 | req := ptyRequestMsg{ 119 | Term: term, 120 | Columns: uint32(w), 121 | Rows: uint32(h), 122 | Width: uint32(w * 8), 123 | Height: uint32(h * 8), 124 | Modelist: string(tm), 125 | } 126 | ok, err := sc.ch.SendRequest("pty-req", true, ssh.Marshal(&req)) 127 | if err == nil && !ok { 128 | err = errors.New("ssh: pty-req failed") 129 | } 130 | return err 131 | } 132 | } 133 | 134 | func (sc *SshSession) StdinPipe() (io.WriteCloser, error) { 135 | return sc.session.StdinPipe() 136 | } 137 | 138 | func (sc *SshSession) StdoutPipe() (io.Reader, error) { 139 | return sc.session.StdoutPipe() 140 | } 141 | 142 | func (sc *SshSession) StderrPipe() (io.Reader, error) { 143 | return sc.session.StderrPipe() 144 | } 145 | 146 | func (sc *SshSession) SetDeadline(deadline *time.Time) { 147 | sc.deadline = deadline 148 | } 149 | 150 | func (sc *SshSession) Shell() error { 151 | if sc.session != nil { 152 | return sc.session.Shell() 153 | } else { 154 | if sc.started { 155 | return errors.New("ssh: session already started") 156 | } 157 | ok, err := sc.ch.SendRequest("shell", true, nil) 158 | if err == nil && !ok { 159 | return errors.New("ssh: could not start shell") 160 | } 161 | if err != nil { 162 | return err 163 | } 164 | return sc.start() 165 | } 166 | } 167 | 168 | func (sc *SshSession) Wait() error { 169 | return sc.session.Wait() 170 | } 171 | 172 | func (sc *SshSession) Close() error { 173 | if sc.session != nil { 174 | return sc.session.Close() 175 | } else { 176 | return sc.ch.Close() 177 | } 178 | } 179 | 180 | func NewSession(conn *ssh.Client, deadline *time.Time, idleTimeout int) (ss *SshSession, err error) { 181 | session, err := conn.NewSession() 182 | if err != nil { 183 | return nil, err 184 | } 185 | sshSession := new(SshSession) 186 | sshSession.session = session 187 | sshSession.deadline = deadline 188 | sshSession.idleTimeout = idleTimeout 189 | sshSession.id = Rand().Hex() 190 | //check session timeout 191 | go sshSession.checkSessionTimeout() 192 | return sshSession, nil 193 | } 194 | 195 | func NewSessionWithChannel(conn *ssh.Client, ch ssh.Channel, deadline *time.Time, idleTimeout int) (ss *SshSession, err error) { 196 | sshSession := new(SshSession) 197 | sshSession.deadline = deadline 198 | sshSession.idleTimeout = idleTimeout 199 | sshSession.id = Rand().Hex() 200 | sshSession.ch = ch 201 | //check session timeout 202 | go sshSession.checkSessionTimeout() 203 | return sshSession, nil 204 | } 205 | 206 | func (sc *SshSession) checkSessionTimeout() { 207 | timeout := make(chan bool, 1) 208 | go func() { 209 | t := time.NewTicker(time.Second * 1) 210 | for { 211 | <-t.C 212 | if sc.deadline != nil && time.Now().After(*sc.deadline) { 213 | timeout <- true 214 | } 215 | } 216 | }() 217 | ch := make(chan int) 218 | select { 219 | case <-ch: 220 | case <-timeout: 221 | log.Println("timeout!") 222 | sc.Close() 223 | } 224 | } 225 | --------------------------------------------------------------------------------