├── .gitignore ├── LICENSE ├── README.md ├── cmd ├── sftp-server │ └── sftp-server.go ├── ssh2incus │ └── ssh2incus.go └── stdio-proxy │ └── stdio-proxy.go ├── go.mod ├── go.sum ├── init.go ├── justfile ├── packaging ├── nfpm.yaml ├── post-install.sh ├── ssh2incus.env └── ssh2incus.service ├── pkg ├── app.go ├── cache │ ├── README.md │ ├── cache.go │ ├── cache_test.go │ ├── sharded.go │ └── sharded_test.go ├── incus │ ├── config.go │ ├── exec.go │ ├── file.go │ ├── incus.go │ ├── instance.go │ ├── proxy-device.go │ └── user.go ├── isatty │ ├── LICENSE │ ├── README.md │ ├── doc.go │ ├── example_test.go │ ├── isatty_bsd.go │ ├── isatty_others.go │ ├── isatty_others_test.go │ ├── isatty_plan9.go │ ├── isatty_solaris.go │ ├── isatty_tcgets.go │ ├── isatty_windows.go │ └── isatty_windows_test.go ├── queue │ └── queue.go ├── shlex │ ├── LICENSE │ ├── README.md │ └── shlex.go ├── ssh │ ├── LICENSE │ ├── README.md │ ├── agent.go │ ├── conn.go │ ├── context.go │ ├── context_test.go │ ├── doc.go │ ├── example_test.go │ ├── options.go │ ├── options_test.go │ ├── server.go │ ├── server_test.go │ ├── session.go │ ├── session_test.go │ ├── ssh.go │ ├── ssh_test.go │ ├── tcpip.go │ ├── tcpip_test.go │ ├── util.go │ └── wrap.go ├── user │ ├── LICENSE │ ├── README.md │ ├── current_cgo.go │ ├── current_default.go │ ├── ds.go │ ├── ds_group.go │ ├── errors.go │ ├── group.go │ ├── id.go │ ├── lookup.go │ ├── lookup_group.go │ ├── lookup_group_windows.go │ ├── lookup_windows.go │ ├── luser.go │ ├── misc.go │ ├── misc_group.go │ ├── nss.go │ └── nss_group.go └── util │ ├── buffer │ ├── bytes.go │ └── output.go │ ├── devicereg │ └── devicereg.go │ ├── dns.go │ ├── goagain │ ├── goagain.go │ └── legacy.go │ ├── gz.go │ ├── io.go │ ├── io │ ├── bytesreadcloser.go │ ├── filesystem.go │ ├── filesystem_unix.go │ ├── filesystem_windows.go │ ├── quotawriter.go │ ├── readseeker.go │ └── writer.go │ ├── ip.go │ ├── md5.go │ ├── port.go │ ├── rand.go │ ├── shadow │ ├── shadow.go │ ├── shadow_test.go │ └── test │ │ └── shadow.txt │ ├── string.go │ ├── structs │ ├── LICENSE │ ├── README.md │ ├── field.go │ ├── structs.go │ └── tags.go │ └── utmp │ ├── utmp_bsd.go │ ├── utmp_darwin.go │ └── utmp_linux.go └── server ├── auth.go ├── banner.go ├── config.go ├── device-registry.go ├── incus.go ├── port-forward.go ├── server.go ├── sftp-server-binary └── sftp-server-binary.go ├── sftp.go ├── shell.go ├── ssh-request.go ├── stdio-proxy-binary └── stdio-proxy-binary.go ├── subsystem.go ├── user.go └── user_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_* 2 | /.idea 3 | /dist 4 | /build 5 | /release 6 | /.go 7 | _examples/ 8 | server/sftp-server-binary/bin/ 9 | server/stdio-proxy-binary/bin/ 10 | 11 | /ansible 12 | Makefile 13 | cmd/helper 14 | cmd/nutshell 15 | /handoff 16 | /shell 17 | /docker* 18 | /scripts 19 | /nutshell 20 | /toybox 21 | 22 | .envrc 23 | -------------------------------------------------------------------------------- /cmd/sftp-server/sftp-server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "os" 9 | "strconv" 10 | "syscall" 11 | 12 | "github.com/pkg/sftp" 13 | ) 14 | 15 | var stderr = io.Discard 16 | 17 | func main() { 18 | var ( 19 | help bool 20 | readOnly bool 21 | debugStderr bool 22 | debugLevel string 23 | startDir string 24 | umask int 25 | options []sftp.ServerOption 26 | ) 27 | 28 | flag.BoolVar(&readOnly, "R", readOnly, "read-only server") 29 | flag.BoolVar(&debugStderr, "e", debugStderr, "debug to stderr") 30 | flag.StringVar(&startDir, "d", startDir, "change root directory") 31 | flag.IntVar(&umask, "u", umask, "explicit umask") 32 | flag.StringVar(&debugLevel, "l", debugLevel, "debug level (ignored)") 33 | flag.BoolVar(&help, "h", help, "print help") 34 | flag.Parse() 35 | 36 | if help { 37 | flag.Usage() 38 | exit(nil) 39 | } 40 | 41 | if debugStderr { 42 | stderr = os.Stderr 43 | } 44 | 45 | if err := syscall.Chroot(startDir); err != nil { 46 | exit(err) 47 | } 48 | 49 | home, ok := os.LookupEnv("HOME") 50 | if !ok { 51 | exit(errors.New("HOME environment variable not set")) 52 | } 53 | 54 | gid, err := toInt(os.LookupEnv("GID")) 55 | if err != nil { 56 | exit(errors.New("GID environment variable not set")) 57 | } 58 | 59 | uid, err := toInt(os.LookupEnv("UID")) 60 | if err != nil { 61 | exit(errors.New("UID environment variable not set")) 62 | } 63 | 64 | if err = syscall.Chdir(home); err != nil { 65 | exit(err) 66 | } 67 | 68 | if err = syscall.Setgid(gid); err != nil { 69 | exit(err) 70 | } 71 | 72 | if err = syscall.Setuid(uid); err != nil { 73 | exit(err) 74 | } 75 | 76 | syscall.Umask(umask) 77 | 78 | options = append(options, sftp.WithDebug(stderr)) 79 | 80 | if readOnly { 81 | options = append(options, sftp.ReadOnly()) 82 | } 83 | 84 | server, err := sftp.NewServer( 85 | struct { 86 | io.Reader 87 | io.WriteCloser 88 | }{ 89 | os.Stdin, 90 | os.Stdout, 91 | }, 92 | options..., 93 | ) 94 | if err != nil { 95 | exit(fmt.Errorf("sftp server could not initialize: %v", err)) 96 | } 97 | 98 | if err = server.Serve(); err != nil { 99 | exit(fmt.Errorf("sftp server completed with error: %v", err)) 100 | } 101 | } 102 | 103 | func exit(err error) { 104 | if err != nil { 105 | fmt.Fprintln(stderr, err) 106 | os.Exit(1) 107 | } 108 | os.Exit(0) 109 | } 110 | 111 | func toInt(s string, ok bool) (int, error) { 112 | i, err := strconv.Atoi(s) 113 | if err != nil { 114 | return 0, err 115 | } 116 | 117 | return i, nil 118 | } 119 | -------------------------------------------------------------------------------- /cmd/ssh2incus/ssh2incus.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | _ "ssh2incus" 5 | ) 6 | 7 | func main() {} 8 | -------------------------------------------------------------------------------- /cmd/stdio-proxy/stdio-proxy.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | "log" 8 | "net" 9 | "os" 10 | "os/signal" 11 | "strings" 12 | "sync" 13 | "syscall" 14 | ) 15 | 16 | var stderr = io.Discard 17 | 18 | func main() { 19 | // Define a verbose flag 20 | verbose := flag.Bool("v", false, "Enable verbose logging") 21 | flag.Parse() 22 | 23 | // If not in verbose mode, discard logs 24 | if *verbose { 25 | stderr = os.Stderr 26 | } 27 | log.SetOutput(stderr) 28 | 29 | // Check if we have the correct number of arguments 30 | if len(os.Args) != 2 { 31 | fmt.Fprintf(os.Stderr, "Usage: %s [protocol]:[host]:[port]\n", os.Args[0]) 32 | fmt.Fprintf(os.Stderr, "Example: %s tcp:example.com:80\n", os.Args[0]) 33 | fmt.Fprintf(os.Stderr, "Example: %s udp:dns.server:53\n", os.Args[0]) 34 | os.Exit(1) 35 | } 36 | 37 | // Parse the argument 38 | parts := strings.SplitN(os.Args[1], ":", 3) 39 | if len(parts) != 3 { 40 | fmt.Fprintf(os.Stderr, "Invalid argument format. Expected [protocol]:[host]:[port]\n") 41 | os.Exit(1) 42 | } 43 | 44 | protocol := parts[0] 45 | host := parts[1] 46 | port := parts[2] 47 | 48 | // Create connection based on the specified protocol 49 | var conn net.Conn 50 | var err error 51 | 52 | address := net.JoinHostPort(host, port) 53 | 54 | switch strings.ToLower(protocol) { 55 | case "tcp": 56 | log.Printf("Connecting to TCP %s...", address) 57 | conn, err = net.Dial("tcp", address) 58 | case "udp": 59 | log.Printf("Connecting to UDP %s...", address) 60 | conn, err = net.Dial("udp", address) 61 | default: 62 | exit(fmt.Errorf("Unsupported protocol: %s. Use 'tcp' or 'udp'.\n", protocol)) 63 | } 64 | 65 | if err != nil { 66 | exit(fmt.Errorf("Error connecting to %s://%s: %v\n", protocol, address, err)) 67 | } 68 | defer conn.Close() 69 | 70 | log.Printf("Connected to %s://%s", protocol, address) 71 | 72 | // Set up channel to handle signals 73 | sigChan := make(chan os.Signal, 1) 74 | signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) 75 | 76 | // Set up wait group for the goroutines 77 | var wg sync.WaitGroup 78 | wg.Add(2) 79 | 80 | // Copy from stdin to connection 81 | go func() { 82 | defer wg.Done() 83 | if _, err := io.Copy(conn, os.Stdin); err != nil { 84 | log.Printf("Error copying from stdin to connection: %v", err) 85 | } 86 | 87 | // If we're done reading from stdin, signal that we're done to the connection 88 | // This is important for TCP connections 89 | if tcpConn, ok := conn.(*net.TCPConn); ok { 90 | tcpConn.CloseWrite() 91 | } 92 | }() 93 | 94 | // Copy from connection to stdout 95 | go func() { 96 | defer wg.Done() 97 | if _, err := io.Copy(os.Stdout, conn); err != nil { 98 | log.Printf("Error copying from connection to stdout: %v", err) 99 | } 100 | }() 101 | 102 | // Wait for either a signal or for both goroutines to finish 103 | go func() { 104 | wg.Wait() 105 | // If both goroutines are done, we can exit 106 | sigChan <- syscall.SIGTERM 107 | }() 108 | 109 | // Wait for termination signal 110 | <-sigChan 111 | log.Printf("Shutting down connection to %s://%s", protocol, address) 112 | } 113 | 114 | func exit(err error) { 115 | if err != nil { 116 | fmt.Fprintln(stderr, err) 117 | os.Exit(1) 118 | } 119 | os.Exit(0) 120 | } 121 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module ssh2incus 2 | 3 | go 1.24.2 4 | 5 | require ( 6 | github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 7 | github.com/creack/pty v1.1.24 8 | github.com/gorilla/websocket v1.5.3 9 | github.com/lxc/incus/v6 v6.11.0 10 | github.com/peterh/liner v1.2.2 11 | github.com/pkg/sftp v1.13.9 12 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 13 | github.com/sirupsen/logrus v1.9.3 14 | github.com/spf13/pflag v1.0.6 15 | github.com/stretchr/testify v1.10.0 16 | golang.org/x/crypto v0.36.0 17 | golang.org/x/sys v0.32.0 18 | gopkg.in/robfig/cron.v2 v2.0.0-20150107220207-be2e0b0deed5 19 | ) 20 | 21 | require ( 22 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 23 | github.com/go-jose/go-jose/v4 v4.0.5 // indirect 24 | github.com/go-logr/logr v1.4.2 // indirect 25 | github.com/go-logr/stdr v1.2.2 // indirect 26 | github.com/google/uuid v1.6.0 // indirect 27 | github.com/gorilla/securecookie v1.1.2 // indirect 28 | github.com/kr/fs v0.1.0 // indirect 29 | github.com/mattn/go-runewidth v0.0.16 // indirect 30 | github.com/muhlemmer/gu v0.3.1 // indirect 31 | github.com/rivo/uniseg v0.4.7 // indirect 32 | github.com/zitadel/logging v0.6.2 // indirect 33 | github.com/zitadel/oidc/v3 v3.37.0 // indirect 34 | github.com/zitadel/schema v1.3.1 // indirect 35 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 36 | go.opentelemetry.io/otel v1.35.0 // indirect 37 | go.opentelemetry.io/otel/metric v1.35.0 // indirect 38 | go.opentelemetry.io/otel/trace v1.35.0 // indirect 39 | golang.org/x/oauth2 v0.29.0 // indirect 40 | golang.org/x/term v0.30.0 // indirect 41 | golang.org/x/text v0.23.0 // indirect 42 | gopkg.in/yaml.v2 v2.4.0 // indirect 43 | gopkg.in/yaml.v3 v3.0.1 // indirect 44 | ) 45 | -------------------------------------------------------------------------------- /packaging/nfpm.yaml: -------------------------------------------------------------------------------- 1 | name: ssh2incus 2 | arch: ${ARCH} 3 | platform: linux 4 | version: ${VERSION} 5 | release: ${RELEASE} 6 | version_schema: none 7 | section: default 8 | priority: "extra" 9 | maintainer: "" 10 | description: | 11 | SSH server for Incus instances 12 | vendor: mobydeck 13 | homepage: https://ssh2incus.com 14 | license: GPL-3.0 15 | 16 | provides: 17 | - ssh2incus 18 | 19 | contents: 20 | - src: ./build/ssh2incus 21 | dst: /bin/ssh2incus 22 | 23 | - src: ./packaging/ssh2incus.env 24 | dst: /etc/default/ssh2incus 25 | type: config|noreplace 26 | 27 | - src: ./packaging/ssh2incus.service 28 | dst: /lib/systemd/system/ssh2incus.service 29 | 30 | - src: ./README.md 31 | dst: /usr/share/doc/ssh2incus/README.md 32 | 33 | - src: ./LICENSE 34 | dst: /usr/share/licenses/ssh2incus/LICENSE 35 | 36 | scripts: 37 | postinstall: ./packaging/post-install.sh 38 | -------------------------------------------------------------------------------- /packaging/post-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | /bin/systemctl daemon-reload -------------------------------------------------------------------------------- /packaging/ssh2incus.env: -------------------------------------------------------------------------------- 1 | # Arguments to pass to ssh2incus daemon 2 | ARGS=-m 3 | 4 | # -b, --banner show banner on login 5 | # -c, --client-cert string client certificate for remote 6 | # -k, --client-key string client key for remote 7 | # -d, --debug enable debug log 8 | # -g, --groups string list of groups members of which allowed to connect (default "incus,incus-admin") 9 | # --healthcheck string enable Incus health check every X minutes, e.g. "5m" 10 | # -h, --help print help 11 | # --inauth enable authentication using instance keys 12 | # -l, --listen string listen on ":port" or "host:port" (default ":2222") 13 | # -m, --master start master process and spawn workers 14 | # --noauth disable SSH authentication completely 15 | # --pprof enable pprof 16 | # --pprof-listen string pprof listen on ":port" or "host:port" (default ":6060") 17 | # -r, --remote string default Incus remote to use 18 | # -t, --server-cert string server certificate for remote 19 | # --shell string shell access command: login, su, sush or user shell 20 | # -s, --socket string Incus socket to connect to (optional, defaults to INCUS_SOCKET env) 21 | # -u, --url string Incus remote url to connect to (should start with https://) 22 | # -v, --version print version 23 | # -w, --welcome show welcome message to users connecting to shell 24 | -------------------------------------------------------------------------------- /packaging/ssh2incus.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=SSH server for Incus instances 3 | After=network.target 4 | 5 | [Service] 6 | EnvironmentFile=-/etc/default/ssh2incus 7 | ExecStart=/bin/ssh2incus $ARGS 8 | KillMode=process 9 | Restart=on-failure 10 | RestartSec=3s 11 | 12 | [Install] 13 | WantedBy=multi-user.target 14 | -------------------------------------------------------------------------------- /pkg/app.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | "time" 8 | 9 | "ssh2incus/pkg/isatty" 10 | ) 11 | 12 | type AppConfig struct { 13 | Name string 14 | Version string 15 | Edition string 16 | GitHash string 17 | BuiltAt string 18 | } 19 | 20 | type App struct { 21 | name string 22 | version string 23 | edition string 24 | gitHash string 25 | longName string 26 | builtAt string 27 | isTTY bool 28 | 29 | startTime time.Time 30 | } 31 | 32 | func NewApp(c AppConfig) *App { 33 | a := &App{ 34 | name: c.Name, 35 | version: c.Version, 36 | edition: c.Edition, 37 | gitHash: c.GitHash, 38 | builtAt: c.BuiltAt, 39 | startTime: time.Now(), 40 | isTTY: isatty.IsTerminal(os.Stdout.Fd()), 41 | } 42 | return a 43 | } 44 | 45 | func (a *App) Name() string { 46 | return a.name 47 | } 48 | 49 | func (a *App) NAME() string { 50 | return strings.ToUpper(a.name) 51 | } 52 | 53 | func (a *App) LongName() string { 54 | name := a.String() 55 | if a.gitHash != "" { 56 | name += fmt.Sprintf(" (%s)", a.gitHash) 57 | } 58 | return name 59 | } 60 | 61 | func (a *App) Version() string { 62 | return a.version 63 | } 64 | 65 | func (a *App) GitHash() string { 66 | return a.gitHash 67 | } 68 | 69 | func (a *App) Commit() string { 70 | return a.gitHash 71 | } 72 | 73 | func (a *App) String() string { 74 | return a.name + " " + a.version 75 | } 76 | 77 | func (a *App) BuiltAt() string { 78 | return a.builtAt 79 | } 80 | 81 | func (a *App) IsTTY() bool { 82 | return a.isTTY 83 | } 84 | 85 | func (a *App) IsTerminal() bool { 86 | return a.isTTY 87 | } 88 | 89 | func (a *App) Uptime() string { 90 | // Calculate the duration since the process started 91 | uptime := time.Since(a.startTime) 92 | 93 | // Format the uptime in a human-readable format 94 | return formatDuration(uptime) 95 | } 96 | 97 | // formatDuration converts a time.Duration to a friendly string format 98 | func formatDuration(d time.Duration) string { 99 | // Round to seconds 100 | d = d.Round(time.Second) 101 | 102 | days := int(d.Hours() / 24) 103 | hours := int(d.Hours()) % 24 104 | minutes := int(d.Minutes()) % 60 105 | seconds := int(d.Seconds()) % 60 106 | 107 | // Build the string representation based on the duration 108 | parts := []string{} 109 | 110 | if days > 0 { 111 | if days == 1 { 112 | parts = append(parts, "1 day") 113 | } else { 114 | parts = append(parts, fmt.Sprintf("%d days", days)) 115 | } 116 | } 117 | 118 | if hours > 0 { 119 | if hours == 1 { 120 | parts = append(parts, "1 hour") 121 | } else { 122 | parts = append(parts, fmt.Sprintf("%d hours", hours)) 123 | } 124 | } 125 | 126 | if minutes > 0 { 127 | if minutes == 1 { 128 | parts = append(parts, "1 minute") 129 | } else { 130 | parts = append(parts, fmt.Sprintf("%d minutes", minutes)) 131 | } 132 | } 133 | 134 | if seconds > 0 || len(parts) == 0 { 135 | if seconds == 1 { 136 | parts = append(parts, "1 second") 137 | } else { 138 | parts = append(parts, fmt.Sprintf("%d seconds", seconds)) 139 | } 140 | } 141 | 142 | return strings.Join(parts, ", ") 143 | } 144 | -------------------------------------------------------------------------------- /pkg/cache/README.md: -------------------------------------------------------------------------------- 1 | # go-cache 2 | 3 | go-cache is an in-memory key:value store/cache similar to memcached that is 4 | suitable for applications running on a single machine. Its major advantage is 5 | that, being essentially a thread-safe `map[string]interface{}` with expiration 6 | times, it doesn't need to serialize or transmit its contents over the network. 7 | 8 | Any object can be stored, for a given duration or forever, and the cache can be 9 | safely used by multiple goroutines. 10 | 11 | Although go-cache isn't meant to be used as a persistent datastore, the entire 12 | cache can be saved to and loaded from a file (using `c.Items()` to retrieve the 13 | items map to serialize, and `NewFrom()` to create a cache from a deserialized 14 | one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.) 15 | 16 | ### Installation 17 | 18 | `go get github.com/patrickmn/go-cache` 19 | 20 | ### Usage 21 | 22 | ```go 23 | import ( 24 | "fmt" 25 | "github.com/patrickmn/go-cache" 26 | "time" 27 | ) 28 | 29 | func main() { 30 | // Create a cache with a default expiration time of 5 minutes, and which 31 | // purges expired items every 10 minutes 32 | c := cache.New(5*time.Minute, 10*time.Minute) 33 | 34 | // Set the value of the key "foo" to "bar", with the default expiration time 35 | c.Set("foo", "bar", cache.DefaultExpiration) 36 | 37 | // Set the value of the key "baz" to 42, with no expiration time 38 | // (the item won't be removed until it is re-set, or removed using 39 | // c.Delete("baz") 40 | c.Set("baz", 42, cache.NoExpiration) 41 | 42 | // Get the string associated with the key "foo" from the cache 43 | foo, found := c.Get("foo") 44 | if found { 45 | fmt.Println(foo) 46 | } 47 | 48 | // Since Go is statically typed, and cache values can be anything, type 49 | // assertion is needed when values are being passed to functions that don't 50 | // take arbitrary types, (i.e. interface{}). The simplest way to do this for 51 | // values which will only be used once--e.g. for passing to another 52 | // function--is: 53 | foo, found := c.Get("foo") 54 | if found { 55 | MyFunction(foo.(string)) 56 | } 57 | 58 | // This gets tedious if the value is used several times in the same function. 59 | // You might do either of the following instead: 60 | if x, found := c.Get("foo"); found { 61 | foo := x.(string) 62 | // ... 63 | } 64 | // or 65 | var foo string 66 | if x, found := c.Get("foo"); found { 67 | foo = x.(string) 68 | } 69 | // ... 70 | // foo can then be passed around freely as a string 71 | 72 | // Want performance? Store pointers! 73 | c.Set("foo", &MyStruct, cache.DefaultExpiration) 74 | if x, found := c.Get("foo"); found { 75 | foo := x.(*MyStruct) 76 | // ... 77 | } 78 | } 79 | ``` 80 | 81 | ### Reference 82 | 83 | `godoc` or [http://godoc.org/github.com/patrickmn/go-cache](http://godoc.org/github.com/patrickmn/go-cache) -------------------------------------------------------------------------------- /pkg/cache/sharded.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "crypto/rand" 5 | "math" 6 | "math/big" 7 | insecurerand "math/rand" 8 | "os" 9 | "runtime" 10 | "time" 11 | ) 12 | 13 | // This is an experimental and unexported (for now) attempt at making a cache 14 | // with better algorithmic complexity than the standard one, namely by 15 | // preventing write locks of the entire cache when an item is added. As of the 16 | // time of writing, the overhead of selecting buckets results in cache 17 | // operations being about twice as slow as for the standard cache with small 18 | // total cache sizes, and faster for larger ones. 19 | // 20 | // See cache_test.go for a few benchmarks. 21 | 22 | type unexportedShardedCache struct { 23 | *shardedCache 24 | } 25 | 26 | type shardedCache struct { 27 | seed uint32 28 | m uint32 29 | cs []*cache 30 | janitor *shardedJanitor 31 | } 32 | 33 | // djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead. 34 | func djb33(seed uint32, k string) uint32 { 35 | var ( 36 | l = uint32(len(k)) 37 | d = 5381 + seed + l 38 | i = uint32(0) 39 | ) 40 | // Why is all this 5x faster than a for loop? 41 | if l >= 4 { 42 | for i < l-4 { 43 | d = (d * 33) ^ uint32(k[i]) 44 | d = (d * 33) ^ uint32(k[i+1]) 45 | d = (d * 33) ^ uint32(k[i+2]) 46 | d = (d * 33) ^ uint32(k[i+3]) 47 | i += 4 48 | } 49 | } 50 | switch l - i { 51 | case 1: 52 | case 2: 53 | d = (d * 33) ^ uint32(k[i]) 54 | case 3: 55 | d = (d * 33) ^ uint32(k[i]) 56 | d = (d * 33) ^ uint32(k[i+1]) 57 | case 4: 58 | d = (d * 33) ^ uint32(k[i]) 59 | d = (d * 33) ^ uint32(k[i+1]) 60 | d = (d * 33) ^ uint32(k[i+2]) 61 | } 62 | return d ^ (d >> 16) 63 | } 64 | 65 | func (sc *shardedCache) bucket(k string) *cache { 66 | return sc.cs[djb33(sc.seed, k)%sc.m] 67 | } 68 | 69 | func (sc *shardedCache) Set(k string, x any, d time.Duration) { 70 | sc.bucket(k).Set(k, x, d) 71 | } 72 | 73 | func (sc *shardedCache) Add(k string, x any, d time.Duration) error { 74 | return sc.bucket(k).Add(k, x, d) 75 | } 76 | 77 | func (sc *shardedCache) Replace(k string, x any, d time.Duration) error { 78 | return sc.bucket(k).Replace(k, x, d) 79 | } 80 | 81 | func (sc *shardedCache) Get(k string) (any, bool) { 82 | return sc.bucket(k).Get(k) 83 | } 84 | 85 | func (sc *shardedCache) Increment(k string, n int64) error { 86 | return sc.bucket(k).Increment(k, n) 87 | } 88 | 89 | func (sc *shardedCache) IncrementFloat(k string, n float64) error { 90 | return sc.bucket(k).IncrementFloat(k, n) 91 | } 92 | 93 | func (sc *shardedCache) Decrement(k string, n int64) error { 94 | return sc.bucket(k).Decrement(k, n) 95 | } 96 | 97 | func (sc *shardedCache) Delete(k string) { 98 | sc.bucket(k).Delete(k) 99 | } 100 | 101 | func (sc *shardedCache) DeleteExpired() { 102 | for _, v := range sc.cs { 103 | v.DeleteExpired() 104 | } 105 | } 106 | 107 | // Returns the items in the cache. This may include items that have expired, 108 | // but have not yet been cleaned up. If this is significant, the Expiration 109 | // fields of the items should be checked. Note that explicit synchronization 110 | // is needed to use a cache and its corresponding Items() return values at 111 | // the same time, as the maps are shared. 112 | func (sc *shardedCache) Items() []map[string]Item { 113 | res := make([]map[string]Item, len(sc.cs)) 114 | for i, v := range sc.cs { 115 | res[i] = v.Items() 116 | } 117 | return res 118 | } 119 | 120 | func (sc *shardedCache) Flush() { 121 | for _, v := range sc.cs { 122 | v.Flush() 123 | } 124 | } 125 | 126 | type shardedJanitor struct { 127 | Interval time.Duration 128 | stop chan bool 129 | } 130 | 131 | func (j *shardedJanitor) Run(sc *shardedCache) { 132 | j.stop = make(chan bool) 133 | tick := time.Tick(j.Interval) 134 | for { 135 | select { 136 | case <-tick: 137 | sc.DeleteExpired() 138 | case <-j.stop: 139 | return 140 | } 141 | } 142 | } 143 | 144 | func stopShardedJanitor(sc *unexportedShardedCache) { 145 | sc.janitor.stop <- true 146 | } 147 | 148 | func runShardedJanitor(sc *shardedCache, ci time.Duration) { 149 | j := &shardedJanitor{ 150 | Interval: ci, 151 | } 152 | sc.janitor = j 153 | go j.Run(sc) 154 | } 155 | 156 | func newShardedCache(n int, de time.Duration) *shardedCache { 157 | max := big.NewInt(0).SetUint64(uint64(math.MaxUint32)) 158 | rnd, err := rand.Int(rand.Reader, max) 159 | var seed uint32 160 | if err != nil { 161 | os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n")) 162 | seed = insecurerand.Uint32() 163 | } else { 164 | seed = uint32(rnd.Uint64()) 165 | } 166 | sc := &shardedCache{ 167 | seed: seed, 168 | m: uint32(n), 169 | cs: make([]*cache, n), 170 | } 171 | for i := 0; i < n; i++ { 172 | c := &cache{ 173 | defaultExpiration: de, 174 | items: map[string]Item{}, 175 | } 176 | sc.cs[i] = c 177 | } 178 | return sc 179 | } 180 | 181 | func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int) *unexportedShardedCache { 182 | if defaultExpiration == 0 { 183 | defaultExpiration = -1 184 | } 185 | sc := newShardedCache(shards, defaultExpiration) 186 | SC := &unexportedShardedCache{sc} 187 | if cleanupInterval > 0 { 188 | runShardedJanitor(sc, cleanupInterval) 189 | runtime.SetFinalizer(SC, stopShardedJanitor) 190 | } 191 | return SC 192 | } 193 | -------------------------------------------------------------------------------- /pkg/cache/sharded_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "strconv" 5 | "sync" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | // func TestDjb33(t *testing.T) { 11 | // } 12 | 13 | var shardedKeys = []string{ 14 | "f", 15 | "fo", 16 | "foo", 17 | "barf", 18 | "barfo", 19 | "foobar", 20 | "bazbarf", 21 | "bazbarfo", 22 | "bazbarfoo", 23 | "foobarbazq", 24 | "foobarbazqu", 25 | "foobarbazquu", 26 | "foobarbazquux", 27 | } 28 | 29 | func TestShardedCache(t *testing.T) { 30 | tc := unexportedNewSharded(DefaultExpiration, 0, 13) 31 | for _, v := range shardedKeys { 32 | tc.Set(v, "value", DefaultExpiration) 33 | } 34 | } 35 | 36 | func BenchmarkShardedCacheGetExpiring(b *testing.B) { 37 | benchmarkShardedCacheGet(b, 5*time.Minute) 38 | } 39 | 40 | func BenchmarkShardedCacheGetNotExpiring(b *testing.B) { 41 | benchmarkShardedCacheGet(b, NoExpiration) 42 | } 43 | 44 | func benchmarkShardedCacheGet(b *testing.B, exp time.Duration) { 45 | b.StopTimer() 46 | tc := unexportedNewSharded(exp, 0, 10) 47 | tc.Set("foobarba", "zquux", DefaultExpiration) 48 | b.StartTimer() 49 | for i := 0; i < b.N; i++ { 50 | tc.Get("foobarba") 51 | } 52 | } 53 | 54 | func BenchmarkShardedCacheGetManyConcurrentExpiring(b *testing.B) { 55 | benchmarkShardedCacheGetManyConcurrent(b, 5*time.Minute) 56 | } 57 | 58 | func BenchmarkShardedCacheGetManyConcurrentNotExpiring(b *testing.B) { 59 | benchmarkShardedCacheGetManyConcurrent(b, NoExpiration) 60 | } 61 | 62 | func benchmarkShardedCacheGetManyConcurrent(b *testing.B, exp time.Duration) { 63 | b.StopTimer() 64 | n := 10000 65 | tsc := unexportedNewSharded(exp, 0, 20) 66 | keys := make([]string, n) 67 | for i := 0; i < n; i++ { 68 | k := "foo" + strconv.Itoa(i) 69 | keys[i] = k 70 | tsc.Set(k, "bar", DefaultExpiration) 71 | } 72 | each := b.N / n 73 | wg := new(sync.WaitGroup) 74 | wg.Add(n) 75 | for _, v := range keys { 76 | go func(k string) { 77 | for j := 0; j < each; j++ { 78 | tsc.Get(k) 79 | } 80 | wg.Done() 81 | }(v) 82 | } 83 | b.StartTimer() 84 | wg.Wait() 85 | } 86 | -------------------------------------------------------------------------------- /pkg/incus/config.go: -------------------------------------------------------------------------------- 1 | package incus 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "github.com/lxc/incus/v6/shared/cliconfig" 9 | ) 10 | 11 | func RemoteConnectParams(remote string) (*ConnectParams, error) { 12 | clicfg, err := cliconfig.LoadConfig("") 13 | if err != nil || clicfg == nil { 14 | return nil, fmt.Errorf("failed to load Incus CLI config: %w", err) 15 | } 16 | 17 | if remote == "" { 18 | remote = clicfg.DefaultRemote 19 | } 20 | 21 | remoteConfig, ok := clicfg.Remotes[remote] 22 | if !ok { 23 | return nil, fmt.Errorf("remote '%s' not found in incus configuration", remote) 24 | } 25 | 26 | url := remoteConfig.Addr 27 | 28 | var serverCertFile, certFile, keyFile string 29 | 30 | // For HTTPS connections, determine client certificate paths 31 | if strings.HasPrefix(url, "https://") { 32 | // Check if custom paths are provided in our config 33 | serverCertFile = clicfg.ConfigPath("servercerts", remote+".crt") 34 | // Use default Incus client cert/key which are stored in the same directory as config.yml 35 | certFile = clicfg.ConfigPath("client.crt") 36 | keyFile = clicfg.ConfigPath("client.key") 37 | 38 | // Ensure certificate files exist 39 | if _, err := os.Stat(certFile); err != nil { 40 | return nil, fmt.Errorf("client certificate not found at %s: %w", certFile, err) 41 | } 42 | if _, err := os.Stat(keyFile); err != nil { 43 | return nil, fmt.Errorf("client key not found at %s: %w", keyFile, err) 44 | } 45 | } else if strings.HasPrefix(url, "unix://") { 46 | url = strings.TrimPrefix(url, "unix://") 47 | } 48 | 49 | return &ConnectParams{ 50 | Remote: remote, 51 | Url: url, 52 | CertFile: certFile, 53 | KeyFile: keyFile, 54 | ServerCertFile: serverCertFile, 55 | }, nil 56 | } 57 | -------------------------------------------------------------------------------- /pkg/incus/exec.go: -------------------------------------------------------------------------------- 1 | package incus 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "strconv" 9 | "sync" 10 | 11 | "ssh2incus/pkg/shlex" 12 | 13 | "github.com/gorilla/websocket" 14 | incus "github.com/lxc/incus/v6/client" 15 | "github.com/lxc/incus/v6/shared/api" 16 | ) 17 | 18 | type Window struct { 19 | Width int 20 | Height int 21 | } 22 | 23 | type WindowChannel chan Window 24 | 25 | func (c *Client) NewInstanceExec(e InstanceExec) *InstanceExec { 26 | return &InstanceExec{ 27 | client: c, 28 | Instance: e.Instance, 29 | Cmd: e.Cmd, 30 | Env: e.Env, 31 | IsPty: e.IsPty, 32 | Window: e.Window, 33 | WinCh: e.WinCh, 34 | User: e.User, 35 | Group: e.Group, 36 | Cwd: e.Cwd, 37 | Stdin: e.Stdin, 38 | Stdout: e.Stdout, 39 | Stderr: e.Stderr, 40 | } 41 | } 42 | 43 | type InstanceExec struct { 44 | client *Client 45 | Instance string 46 | Cmd string 47 | Env map[string]string 48 | IsPty bool 49 | Window 50 | WinCh WindowChannel 51 | User int 52 | Group int 53 | Cwd string 54 | 55 | Stdin io.ReadCloser 56 | Stdout io.WriteCloser 57 | Stderr io.WriteCloser 58 | 59 | execPost api.InstanceExecPost 60 | execArgs *incus.InstanceExecArgs 61 | 62 | ctx context.Context 63 | cancel context.CancelFunc 64 | } 65 | 66 | // BuildExecRequest prepares the execution parameters 67 | func (e *InstanceExec) BuildExecRequest() { 68 | args, _ := shlex.Split(e.Cmd, true) 69 | 70 | e.execPost = api.InstanceExecPost{ 71 | Command: args, 72 | WaitForWS: true, 73 | Interactive: e.IsPty, 74 | Environment: e.Env, 75 | Width: e.Window.Width, 76 | Height: e.Window.Height, 77 | User: uint32(e.User), 78 | Group: uint32(e.Group), 79 | Cwd: e.Cwd, 80 | } 81 | 82 | // Setup context with cancellation if not already done 83 | if e.ctx == nil { 84 | e.ctx, e.cancel = context.WithCancel(context.Background()) 85 | } 86 | } 87 | 88 | func (e *InstanceExec) Exec() (int, error) { 89 | client := e.client 90 | 91 | e.BuildExecRequest() 92 | 93 | // Setup error capturing 94 | errWriter, errBuf := e.setupErrorCapture() 95 | 96 | // Setup websocket control handler 97 | control, wg := e.setupControlHandler() 98 | 99 | // Setup execution args 100 | dataDone := make(chan bool) 101 | e.execArgs = &incus.InstanceExecArgs{ 102 | Stdin: e.Stdin, 103 | Stdout: e.Stdout, 104 | Stderr: errWriter, 105 | Control: control, 106 | DataDone: dataDone, 107 | } 108 | 109 | // Execute the command 110 | op, err := client.srv.ExecInstance(e.Instance, e.execPost, e.execArgs) 111 | if err != nil { 112 | return -1, fmt.Errorf("exec instance: %w", err) 113 | } 114 | 115 | // Wait for operation to complete 116 | if err = op.Wait(); err != nil { 117 | return -1, fmt.Errorf("operation wait: %w", err) 118 | } 119 | 120 | // Wait for data transfer to complete 121 | <-dataDone 122 | 123 | // Wait for control handler to finish 124 | wg.Wait() 125 | 126 | // Get execution result 127 | opAPI := op.Get() 128 | ret := int(opAPI.Metadata["return"].(float64)) 129 | 130 | errs := errBuf.String() 131 | if errs != "" { 132 | return ret, fmt.Errorf("stderr: %s", errs) 133 | } 134 | 135 | return ret, nil 136 | } 137 | 138 | // setupControlHandler prepares the websocket control handler 139 | func (e *InstanceExec) setupControlHandler() (func(*websocket.Conn), *sync.WaitGroup) { 140 | var ws *websocket.Conn 141 | var wg sync.WaitGroup 142 | 143 | control := func(conn *websocket.Conn) { 144 | ws = conn 145 | wg.Add(1) 146 | defer wg.Done() 147 | 148 | // Start window resize listener if channel is provided 149 | if e.WinCh != nil { 150 | go windowResizeListener(e.WinCh, ws) 151 | } 152 | 153 | // Read messages until connection is closed or context is canceled 154 | done := make(chan struct{}) 155 | go func() { 156 | defer close(done) 157 | defer func() { 158 | if r := recover(); r != nil { 159 | // gorilla websocket may panic sometimes 160 | } 161 | }() 162 | 163 | for { 164 | _, _, err := ws.ReadMessage() 165 | if err != nil { 166 | return 167 | } 168 | } 169 | }() 170 | 171 | select { 172 | case <-done: 173 | return 174 | case <-e.ctx.Done(): 175 | ws.Close() 176 | return 177 | } 178 | } 179 | 180 | return control, &wg 181 | } 182 | 183 | // setupErrorCapture configures error capturing and returns a MultiWriter 184 | func (e *InstanceExec) setupErrorCapture() (io.Writer, *bytes.Buffer) { 185 | var errBuf bytes.Buffer 186 | errWriter := io.MultiWriter(e.Stderr, &errBuf) 187 | return errWriter, &errBuf 188 | } 189 | 190 | func windowResizeListener(c WindowChannel, ws *websocket.Conn) { 191 | for win := range c { 192 | resizeWindow(ws, win.Width, win.Height) 193 | } 194 | } 195 | 196 | func resizeWindow(ws *websocket.Conn, width int, height int) { 197 | msg := api.InstanceExecControl{} 198 | msg.Command = "window-resize" 199 | msg.Args = make(map[string]string) 200 | msg.Args["width"] = strconv.Itoa(width) 201 | msg.Args["height"] = strconv.Itoa(height) 202 | 203 | ws.WriteJSON(msg) 204 | } 205 | -------------------------------------------------------------------------------- /pkg/incus/file.go: -------------------------------------------------------------------------------- 1 | package incus 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "sync" 10 | "time" 11 | 12 | "ssh2incus/pkg/cache" 13 | "ssh2incus/pkg/queue" 14 | "ssh2incus/pkg/util/buffer" 15 | uio "ssh2incus/pkg/util/io" 16 | 17 | incus "github.com/lxc/incus/v6/client" 18 | ) 19 | 20 | var ( 21 | fileExistsCache *cache.Cache 22 | fileExistsQueue *queue.Queueable[bool] 23 | fileExistsOnce sync.Once 24 | ) 25 | 26 | func init() { 27 | fileExistsOnce.Do(func() { 28 | fileExistsCache = cache.New(20*time.Minute, 30*time.Minute) 29 | fileExistsQueue = queue.New[bool](10000) 30 | }) 31 | } 32 | 33 | func (c *Client) UploadFile(project, instance string, src string, dest string) error { 34 | info, err := os.Stat(src) 35 | if err != nil { 36 | //log.Debugf("couldn't stat file %s", src) 37 | return err 38 | } 39 | 40 | mode, uid, gid := uio.GetOwnerMode(info) 41 | 42 | f, err := os.OpenFile(src, os.O_RDONLY, 0) 43 | if err != nil { 44 | //log.Debugf("couldn't open file %s for reading", src) 45 | return err 46 | } 47 | defer f.Close() 48 | 49 | err = c.UploadBytes(project, instance, dest, f, int64(uid), int64(gid), int(mode.Perm())) 50 | 51 | return err 52 | } 53 | 54 | func (c *Client) UploadBytes(project, instance, dest string, b io.ReadSeeker, uid, gid int64, mode int) error { 55 | args := incus.InstanceFileArgs{ 56 | Content: b, 57 | UID: uid, 58 | GID: gid, 59 | Mode: mode, 60 | Type: "file", 61 | WriteMode: "overwrite", 62 | } 63 | 64 | err := c.UseProject(project) 65 | if err != nil { 66 | return err 67 | } 68 | 69 | err = c.srv.CreateInstanceFile(instance, dest, args) 70 | 71 | return err 72 | } 73 | 74 | type FileExistsParams struct { 75 | Project string 76 | Instance string 77 | Path string 78 | Md5sum string 79 | ShouldCache bool 80 | } 81 | 82 | func (c *Client) FileExists(params *FileExistsParams) bool { 83 | return queue.EnqueueFnWithParam(fileExistsQueue, func(p *FileExistsParams) bool { 84 | var fileHash string 85 | if p.ShouldCache { 86 | fileHash = FileHash(p.Project, p.Instance, p.Path, p.Md5sum) 87 | if exists, ok := fileExistsCache.Get(fileHash); ok { 88 | //log.Debugf("file cache hit for %s", fileHash) 89 | return exists.(bool) 90 | } 91 | //log.Debugf("file cache miss for %s", fileHash) 92 | } 93 | 94 | stdout := buffer.NewOutputBuffer() 95 | stderr := buffer.NewOutputBuffer() 96 | cmd := fmt.Sprintf("test -f %s", p.Path) 97 | ie := c.NewInstanceExec(InstanceExec{ 98 | Instance: p.Instance, 99 | Cmd: cmd, 100 | Stdout: stdout, 101 | Stderr: stderr, 102 | }) 103 | ret, _ := ie.Exec() 104 | 105 | if ret != 0 { 106 | return false 107 | } 108 | 109 | exists := true 110 | 111 | if p.Md5sum != "" { 112 | ie.Cmd = fmt.Sprintf("md5sum %s", p.Path) 113 | ret, _ := ie.Exec() 114 | if ret != 0 { 115 | //log.Debug(stderr.Lines()) 116 | return false 117 | } 118 | out := stdout.Lines() 119 | if len(out) == 0 { 120 | return false 121 | } 122 | m := strings.Split(out[0], " ") 123 | if len(m) < 2 { 124 | return false 125 | } 126 | //log.Debugf("comparing md5 for %s: %s <=> %s", p.Path, p.Md5sum, m[0]) 127 | exists = p.Md5sum == m[0] 128 | } 129 | 130 | if p.ShouldCache && exists { 131 | fileExistsCache.SetDefault(fileHash, exists) 132 | } 133 | 134 | return exists 135 | }, params) 136 | } 137 | 138 | type InstanceFile struct { 139 | Project string 140 | Instance string 141 | Name string 142 | Path string 143 | Size int64 144 | Mode int 145 | Uid int 146 | Gid int 147 | Type string 148 | Content *buffer.BytesBuffer 149 | } 150 | 151 | func (c *Client) DownloadFile(project, instance string, path string) (*InstanceFile, error) { 152 | content, resp, err := c.srv.GetInstanceFile(instance, path) 153 | if err != nil { 154 | return nil, err 155 | } 156 | 157 | if resp.Type != "file" { 158 | return nil, fmt.Errorf("not a file: %s", path) 159 | } 160 | 161 | //sftpConn, err := c.srv.GetInstanceFileSFTP(instance) 162 | //if err != nil { 163 | // return nil, err 164 | //} 165 | //defer sftpConn.Close() 166 | // 167 | //src, err := sftpConn.Open(path) 168 | //if err != nil { 169 | // return nil, err 170 | //} 171 | 172 | buf := buffer.NewBytesBuffer() 173 | defer buf.Close() 174 | 175 | for { 176 | _, err = io.CopyN(buf, content, 1024*1024) 177 | if err != nil { 178 | if err == io.EOF { 179 | break 180 | } 181 | return nil, err 182 | } 183 | } 184 | content.Close() 185 | 186 | //contentBytes, err := io.ReadAll(content) 187 | //if err != nil { 188 | // return nil, err 189 | //} 190 | 191 | //srcInfo, err := sftpConn.Lstat(path) 192 | //if err != nil { 193 | // return nil, err 194 | //} 195 | 196 | //targetIsLink := false 197 | //if srcInfo.Mode()&os.ModeSymlink == os.ModeSymlink { 198 | // targetIsLink = true 199 | //} 200 | 201 | //var linkName string 202 | //if targetIsLink { 203 | // linkName, err = sftpConn.ReadLink(path) 204 | // if err != nil { 205 | // return nil, err 206 | // } 207 | //} 208 | 209 | //log.Debugf("read %d bytes from %s", buf.Size(), path) 210 | //log.Debugf("GetInstanceFile resp %#v", resp) 211 | 212 | return &InstanceFile{ 213 | Project: project, 214 | Instance: instance, 215 | Name: filepath.Base(path), 216 | Path: path, 217 | Size: buf.Size(), 218 | Mode: resp.Mode, 219 | Uid: int(resp.UID), 220 | Gid: int(resp.GID), 221 | Type: resp.Type, 222 | Content: buf, 223 | }, nil 224 | } 225 | 226 | func FileHash(project, instance, path, md5sum string) string { 227 | return fmt.Sprintf("%s/%s/%s:%s", project, instance, path, md5sum) 228 | } 229 | -------------------------------------------------------------------------------- /pkg/incus/incus.go: -------------------------------------------------------------------------------- 1 | package incus 2 | 3 | import ( 4 | "context" 5 | "crypto/ecdsa" 6 | "crypto/rsa" 7 | "crypto/tls" 8 | "crypto/x509" 9 | "encoding/pem" 10 | "fmt" 11 | "os" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "ssh2incus/pkg/cache" 17 | "ssh2incus/pkg/queue" 18 | "ssh2incus/pkg/util/structs" 19 | 20 | "github.com/lxc/incus/v6/client" 21 | "github.com/lxc/incus/v6/shared/api" 22 | ) 23 | 24 | var ( 25 | instanceStateCache *cache.Cache 26 | instanceStateQueue *queue.Queueable[*api.InstanceState] 27 | instanceStateOnce sync.Once 28 | ) 29 | 30 | func init() { 31 | instanceStateOnce.Do(func() { 32 | instanceStateCache = cache.New(1*time.Minute, 2*time.Minute) 33 | instanceStateQueue = queue.New[*api.InstanceState](100) 34 | }) 35 | } 36 | 37 | type ConnectParams struct { 38 | Remote string 39 | Url string 40 | CertFile string 41 | KeyFile string 42 | ServerCertFile string 43 | CaCertFile string 44 | } 45 | 46 | type Client struct { 47 | srv incus.InstanceServer 48 | params *ConnectParams 49 | project string 50 | } 51 | 52 | func NewClientWithParams(p *ConnectParams) *Client { 53 | c := new(Client) 54 | c.params = p 55 | return c 56 | } 57 | 58 | func (c *Client) Connect(ctx context.Context) error { 59 | var err error 60 | params := *c.params 61 | // Check if the URL is an HTTPS URL 62 | if strings.HasPrefix(params.Url, "https://") { 63 | // HTTPS connection requires client certificates 64 | if params.CertFile == "" || params.KeyFile == "" { 65 | return fmt.Errorf("client certificate and key files are required for HTTPS connections") 66 | } 67 | 68 | // Load client certificate and key 69 | keyPair, err := tls.LoadX509KeyPair(params.CertFile, params.KeyFile) 70 | if err != nil { 71 | return fmt.Errorf("failed to load client certificate and key: %w", err) 72 | } 73 | 74 | certPEM := pem.EncodeToMemory(&pem.Block{ 75 | Type: "CERTIFICATE", 76 | Bytes: keyPair.Certificate[0], 77 | }) 78 | 79 | // Convert the private key to PEM format 80 | // We need to determine the type of private key and encode accordingly 81 | var keyPEM []byte 82 | switch key := keyPair.PrivateKey.(type) { 83 | case *rsa.PrivateKey: 84 | keyPEM = pem.EncodeToMemory(&pem.Block{ 85 | Type: "RSA PRIVATE KEY", 86 | Bytes: x509.MarshalPKCS1PrivateKey(key), 87 | }) 88 | case *ecdsa.PrivateKey: 89 | keyBytes, err := x509.MarshalECPrivateKey(key) 90 | if err != nil { 91 | return fmt.Errorf("failed to marshal EC private key: %w", err) 92 | } 93 | keyPEM = pem.EncodeToMemory(&pem.Block{ 94 | Type: "EC PRIVATE KEY", 95 | Bytes: keyBytes, 96 | }) 97 | default: 98 | // For other types like ed25519, we'd need specific handling 99 | return fmt.Errorf("unsupported private key type: %T", keyPair.PrivateKey) 100 | } 101 | 102 | var serverCertPEM []byte 103 | if params.ServerCertFile != "" { 104 | serverCertPEM, err = os.ReadFile(params.ServerCertFile) 105 | if err != nil { 106 | return fmt.Errorf("failed to read CA cert file: %w", err) 107 | } 108 | } 109 | 110 | // Connect using HTTPS 111 | args := &incus.ConnectionArgs{ 112 | TLSClientCert: string(certPEM), 113 | TLSClientKey: string(keyPEM), 114 | TLSServerCert: string(serverCertPEM), 115 | SkipGetServer: true, 116 | } 117 | c.srv, err = incus.ConnectIncusWithContext(ctx, params.Url, args) 118 | return err 119 | } else { 120 | // If not HTTPS, treat as Unix socket path 121 | c.srv, err = incus.ConnectIncusUnixWithContext(ctx, params.Url, nil) 122 | return err 123 | } 124 | } 125 | 126 | func (c *Client) UseProject(project string) error { 127 | if project == "" { 128 | project = "default" 129 | } 130 | if project == c.project { 131 | return nil 132 | } 133 | p, _, err := c.srv.GetProject(project) 134 | if err != nil { 135 | return err 136 | } 137 | project = p.Name 138 | c.srv = c.srv.UseProject(project) 139 | c.project = project 140 | return nil 141 | } 142 | 143 | func (c *Client) GetConnectionInfo() map[string]interface{} { 144 | info, _ := c.srv.GetConnectionInfo() 145 | return structs.Map(info) 146 | } 147 | 148 | func (c *Client) Disconnect() { 149 | c.srv.Disconnect() 150 | } 151 | 152 | func IsDefaultProject(project string) bool { 153 | if project == "" || project == "default" { 154 | return true 155 | } 156 | return false 157 | } 158 | -------------------------------------------------------------------------------- /pkg/incus/instance.go: -------------------------------------------------------------------------------- 1 | package incus 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "ssh2incus/pkg/cache" 12 | "ssh2incus/pkg/queue" 13 | "ssh2incus/pkg/util/buffer" 14 | 15 | incus "github.com/lxc/incus/v6/client" 16 | "github.com/lxc/incus/v6/shared/api" 17 | ) 18 | 19 | var ( 20 | instanceCache *cache.Cache 21 | instanceQueue *queue.Queueable[*api.InstanceFull] 22 | instanceInitOnce sync.Once 23 | ) 24 | 25 | func init() { 26 | instanceInitOnce.Do(func() { 27 | instanceCache = cache.New(1*time.Minute, 2*time.Minute) 28 | instanceQueue = queue.New[*api.InstanceFull](100) 29 | }) 30 | } 31 | 32 | func (c *Client) GetInstance(project, name string) (*api.Instance, string, error) { 33 | err := c.UseProject(project) 34 | if err != nil { 35 | return nil, "", err 36 | } 37 | return c.srv.GetInstance(name) 38 | } 39 | 40 | func (c *Client) GetCachedInstance(project, instance string) (*api.InstanceFull, error) { 41 | cacheName := fmt.Sprintf("%s/%s", project, instance) 42 | if in, ok := instanceCache.Get(cacheName); ok { 43 | return in.(*api.InstanceFull), nil 44 | } 45 | in, err := queue.EnqueueWithParam(instanceQueue, func(i string) (*api.InstanceFull, error) { 46 | full, _, err := c.srv.GetInstanceFull(instance) 47 | return full, err 48 | }, instance) 49 | if err == nil { 50 | instanceCache.SetDefault(cacheName, in) 51 | } 52 | return in, err 53 | } 54 | 55 | func (c *Client) GetInstanceMetadata(instance string) (*api.ImageMetadata, string, error) { 56 | meta, etag, err := c.srv.GetInstanceMetadata(instance) 57 | return meta, etag, err 58 | } 59 | 60 | func (c *Client) GetCachedInstanceState(project, instance string) (*api.InstanceState, error) { 61 | cacheName := fmt.Sprintf("%s/%s", project, instance) 62 | if state, ok := instanceStateCache.Get(cacheName); ok { 63 | return state.(*api.InstanceState), nil 64 | } 65 | err := c.UseProject(project) 66 | if err != nil { 67 | return nil, err 68 | } 69 | state, err := queue.EnqueueWithParam(instanceStateQueue, func(i string) (*api.InstanceState, error) { 70 | s, _, err := c.srv.GetInstanceState(instance) 71 | return s, err 72 | }, instance) 73 | if err == nil { 74 | instanceStateCache.SetDefault(cacheName, state) 75 | } 76 | return state, err 77 | } 78 | 79 | func (c *Client) UpdateInstance(name string, instance api.InstancePut, ETag string) (incus.Operation, error) { 80 | return c.srv.UpdateInstance(name, instance, ETag) 81 | } 82 | 83 | func (c *Client) GetInstancesAllProjects(t api.InstanceType) (instances []api.Instance, err error) { 84 | return c.srv.GetInstancesAllProjects(t) 85 | } 86 | 87 | func (c *Client) GetInstanceNetworks(project, instance string) (map[string]api.InstanceStateNetwork, error) { 88 | state, err := c.GetCachedInstanceState(project, instance) 89 | if err != nil { 90 | return nil, err 91 | } 92 | return state.Network, nil 93 | } 94 | 95 | func (c *Client) DeleteInstanceDevice(i *api.Instance, name string) error { 96 | if !strings.HasPrefix(name, ProxyDevicePrefix) { 97 | return nil 98 | } 99 | 100 | // Need new ETag for each operation 101 | in, etag, err := c.srv.GetInstance(i.Name) 102 | if err != nil { 103 | return fmt.Errorf("failed to get instance %s.%s: %v", i.Name, i.Project, err) 104 | } 105 | 106 | device, ok := in.Devices[name] 107 | if !ok { 108 | return fmt.Errorf("device %s does not exist for %s.%s", device, in.Name, in.Project) 109 | } 110 | delete(in.Devices, name) 111 | 112 | op, err := c.UpdateInstance(in.Name, in.Writable(), etag) 113 | if err != nil { 114 | return err 115 | } 116 | 117 | err = op.Wait() 118 | if err != nil { 119 | return err 120 | } 121 | 122 | // Cleanup socket files 123 | if strings.HasPrefix(device["connect"], "unix:") { 124 | source := strings.TrimPrefix(device["connect"], "unix:") 125 | os.RemoveAll(path.Dir(source)) 126 | } 127 | 128 | if strings.HasPrefix(device["listen"], "unix:") { 129 | target := strings.TrimPrefix(device["listen"], "unix:") 130 | cmd := fmt.Sprintf("rm -f %s", target) 131 | stdout := buffer.NewOutputBuffer() 132 | stderr := buffer.NewOutputBuffer() 133 | defer stdout.Close() 134 | defer stderr.Close() 135 | ie := c.NewInstanceExec(InstanceExec{ 136 | Instance: in.Name, 137 | Cmd: cmd, 138 | Stdout: stdout, 139 | Stderr: stderr, 140 | }) 141 | ret, err := ie.Exec() 142 | 143 | if ret != 0 { 144 | return err 145 | } 146 | } 147 | 148 | return nil 149 | } 150 | -------------------------------------------------------------------------------- /pkg/incus/user.go: -------------------------------------------------------------------------------- 1 | package incus 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "ssh2incus/pkg/cache" 12 | "ssh2incus/pkg/queue" 13 | "ssh2incus/pkg/util/buffer" 14 | ) 15 | 16 | var ( 17 | instanceUserCache *cache.Cache 18 | instanceUserQueue *queue.Queueable[*InstanceUser] 19 | instanceUserOnce sync.Once 20 | ) 21 | 22 | func init() { 23 | instanceUserOnce.Do(func() { 24 | instanceUserCache = cache.New(20*time.Minute, 30*time.Minute) 25 | instanceUserQueue = queue.New[*InstanceUser](100) 26 | }) 27 | } 28 | 29 | type InstanceUser struct { 30 | Project string 31 | Instance string 32 | User string 33 | Uid int 34 | Gid int 35 | Dir string 36 | Shell string 37 | Ent string 38 | } 39 | 40 | func (i *InstanceUser) Welcome() string { 41 | return fmt.Sprintf("Welcome %q to incus shell on %s", i.User, i.FullInstance()) 42 | } 43 | 44 | func (i *InstanceUser) FullInstance() string { 45 | return fmt.Sprintf("%s.%s", i.Instance, i.Project) 46 | } 47 | 48 | func (c *Client) GetInstanceUser(project, instance, user string) (*InstanceUser, error) { 49 | iu, err := queue.EnqueueWithParam(instanceUserQueue, func(i string) (*InstanceUser, error) { 50 | stdout := buffer.NewOutputBuffer() 51 | stderr := buffer.NewOutputBuffer() 52 | 53 | err := c.UseProject(project) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | cmd := fmt.Sprintf("getent passwd %s", user) 59 | 60 | ie := c.NewInstanceExec(InstanceExec{ 61 | Instance: instance, 62 | Cmd: cmd, 63 | Stdout: stdout, 64 | Stderr: stderr, 65 | }) 66 | 67 | ret, err := ie.Exec() 68 | if err != nil { 69 | return nil, err 70 | } 71 | if ret != 0 { 72 | return nil, errors.New("user not found") 73 | } 74 | 75 | out := stdout.Lines() 76 | 77 | if len(out) < 1 { 78 | return nil, errors.New("user not found") 79 | } 80 | ent := strings.Split(out[0], ":") 81 | user = ent[0] 82 | uid, _ := strconv.Atoi(ent[2]) 83 | gid, _ := strconv.Atoi(ent[3]) 84 | dir := ent[5] 85 | shell := ent[6] 86 | iu := &InstanceUser{ 87 | Instance: instance, 88 | Project: project, 89 | User: user, 90 | Uid: uid, 91 | Gid: gid, 92 | Dir: dir, 93 | Shell: shell, 94 | Ent: out[0], 95 | } 96 | return iu, nil 97 | }, instance) 98 | 99 | return iu, err 100 | } 101 | 102 | func (c *Client) GetCachedInstanceUser(project, instance, user string) (*InstanceUser, error) { 103 | cacheKey := instanceUserKey(project, instance, user) 104 | if iu, ok := instanceUserCache.Get(cacheKey); ok { 105 | return iu.(*InstanceUser), nil 106 | } 107 | 108 | iu, err := c.GetInstanceUser(project, instance, user) 109 | 110 | if err == nil { 111 | instanceUserCache.SetDefault(cacheKey, iu) 112 | } 113 | 114 | return iu, err 115 | } 116 | 117 | func instanceUserKey(project, instance, user string) string { 118 | return fmt.Sprintf("%s/%s/%s", project, instance, user) 119 | } 120 | -------------------------------------------------------------------------------- /pkg/isatty/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) Yasuhiro MATSUMOTO 2 | 3 | MIT License (Expat) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /pkg/isatty/README.md: -------------------------------------------------------------------------------- 1 | # go-isatty 2 | 3 | [![Godoc Reference](https://godoc.org/github.com/mattn/go-isatty?status.svg)](http://godoc.org/github.com/mattn/go-isatty) 4 | [![Codecov](https://codecov.io/gh/mattn/go-isatty/branch/master/graph/badge.svg)](https://codecov.io/gh/mattn/go-isatty) 5 | [![Coverage Status](https://coveralls.io/repos/github/mattn/go-isatty/badge.svg?branch=master)](https://coveralls.io/github/mattn/go-isatty?branch=master) 6 | [![Go Report Card](https://goreportcard.com/badge/mattn/go-isatty)](https://goreportcard.com/report/mattn/go-isatty) 7 | 8 | isatty for golang 9 | 10 | ## Usage 11 | 12 | ```go 13 | package main 14 | 15 | import ( 16 | "fmt" 17 | "github.com/mattn/go-isatty" 18 | "os" 19 | ) 20 | 21 | func main() { 22 | if isatty.IsTerminal(os.Stdout.Fd()) { 23 | fmt.Println("Is Terminal") 24 | } else if isatty.IsCygwinTerminal(os.Stdout.Fd()) { 25 | fmt.Println("Is Cygwin/MSYS2 Terminal") 26 | } else { 27 | fmt.Println("Is Not Terminal") 28 | } 29 | } 30 | ``` 31 | 32 | ## Installation 33 | 34 | ``` 35 | $ go get github.com/mattn/go-isatty 36 | ``` 37 | 38 | ## License 39 | 40 | MIT 41 | 42 | ## Author 43 | 44 | Yasuhiro Matsumoto (a.k.a mattn) 45 | 46 | ## Thanks 47 | 48 | * k-takata: base idea for IsCygwinTerminal 49 | 50 | https://github.com/k-takata/go-iscygpty 51 | -------------------------------------------------------------------------------- /pkg/isatty/doc.go: -------------------------------------------------------------------------------- 1 | // Package isatty implements interface to isatty 2 | package isatty 3 | -------------------------------------------------------------------------------- /pkg/isatty/example_test.go: -------------------------------------------------------------------------------- 1 | package isatty_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "ssh2incus/pkg/isatty" 8 | ) 9 | 10 | func Example() { 11 | if isatty.IsTerminal(os.Stdout.Fd()) { 12 | fmt.Println("Is Terminal") 13 | } else if isatty.IsCygwinTerminal(os.Stdout.Fd()) { 14 | fmt.Println("Is Cygwin/MSYS2 Terminal") 15 | } else { 16 | fmt.Println("Is Not Terminal") 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_bsd.go: -------------------------------------------------------------------------------- 1 | //go:build (darwin || freebsd || openbsd || netbsd || dragonfly || hurd) && !appengine && !tinygo 2 | // +build darwin freebsd openbsd netbsd dragonfly hurd 3 | // +build !appengine 4 | // +build !tinygo 5 | 6 | package isatty 7 | 8 | import "golang.org/x/sys/unix" 9 | 10 | // IsTerminal return true if the file descriptor is terminal. 11 | func IsTerminal(fd uintptr) bool { 12 | _, err := unix.IoctlGetTermios(int(fd), unix.TIOCGETA) 13 | return err == nil 14 | } 15 | 16 | // IsCygwinTerminal return true if the file descriptor is a cygwin or msys2 17 | // terminal. This is also always false on this environment. 18 | func IsCygwinTerminal(fd uintptr) bool { 19 | return false 20 | } 21 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_others.go: -------------------------------------------------------------------------------- 1 | //go:build (appengine || js || nacl || tinygo || wasm) && !windows 2 | // +build appengine js nacl tinygo wasm 3 | // +build !windows 4 | 5 | package isatty 6 | 7 | // IsTerminal returns true if the file descriptor is terminal which 8 | // is always false on js and appengine classic which is a sandboxed PaaS. 9 | func IsTerminal(fd uintptr) bool { 10 | return false 11 | } 12 | 13 | // IsCygwinTerminal() return true if the file descriptor is a cygwin or msys2 14 | // terminal. This is also always false on this environment. 15 | func IsCygwinTerminal(fd uintptr) bool { 16 | return false 17 | } 18 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_others_test.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package isatty 5 | 6 | import ( 7 | "os" 8 | "testing" 9 | ) 10 | 11 | func TestTerminal(t *testing.T) { 12 | // test for non-panic 13 | t.Log("os.Stdout:", IsTerminal(os.Stdout.Fd())) 14 | } 15 | 16 | func TestCygwinPipeName(t *testing.T) { 17 | if IsCygwinTerminal(os.Stdout.Fd()) { 18 | t.Fatal("should be false always") 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_plan9.go: -------------------------------------------------------------------------------- 1 | //go:build plan9 2 | // +build plan9 3 | 4 | package isatty 5 | 6 | import ( 7 | "syscall" 8 | ) 9 | 10 | // IsTerminal returns true if the given file descriptor is a terminal. 11 | func IsTerminal(fd uintptr) bool { 12 | path, err := syscall.Fd2path(int(fd)) 13 | if err != nil { 14 | return false 15 | } 16 | return path == "/dev/cons" || path == "/mnt/term/dev/cons" 17 | } 18 | 19 | // IsCygwinTerminal return true if the file descriptor is a cygwin or msys2 20 | // terminal. This is also always false on this environment. 21 | func IsCygwinTerminal(fd uintptr) bool { 22 | return false 23 | } 24 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_solaris.go: -------------------------------------------------------------------------------- 1 | //go:build solaris && !appengine 2 | // +build solaris,!appengine 3 | 4 | package isatty 5 | 6 | import ( 7 | "golang.org/x/sys/unix" 8 | ) 9 | 10 | // IsTerminal returns true if the given file descriptor is a terminal. 11 | // see: https://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libc/port/gen/isatty.c 12 | func IsTerminal(fd uintptr) bool { 13 | _, err := unix.IoctlGetTermio(int(fd), unix.TCGETA) 14 | return err == nil 15 | } 16 | 17 | // IsCygwinTerminal return true if the file descriptor is a cygwin or msys2 18 | // terminal. This is also always false on this environment. 19 | func IsCygwinTerminal(fd uintptr) bool { 20 | return false 21 | } 22 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_tcgets.go: -------------------------------------------------------------------------------- 1 | //go:build (linux || aix || zos) && !appengine && !tinygo 2 | // +build linux aix zos 3 | // +build !appengine 4 | // +build !tinygo 5 | 6 | package isatty 7 | 8 | import "golang.org/x/sys/unix" 9 | 10 | // IsTerminal return true if the file descriptor is terminal. 11 | func IsTerminal(fd uintptr) bool { 12 | _, err := unix.IoctlGetTermios(int(fd), unix.TCGETS) 13 | return err == nil 14 | } 15 | 16 | // IsCygwinTerminal return true if the file descriptor is a cygwin or msys2 17 | // terminal. This is also always false on this environment. 18 | func IsCygwinTerminal(fd uintptr) bool { 19 | return false 20 | } 21 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows && !appengine 2 | // +build windows,!appengine 3 | 4 | package isatty 5 | 6 | import ( 7 | "errors" 8 | "strings" 9 | "syscall" 10 | "unicode/utf16" 11 | "unsafe" 12 | ) 13 | 14 | const ( 15 | objectNameInfo uintptr = 1 16 | fileNameInfo = 2 17 | fileTypePipe = 3 18 | ) 19 | 20 | var ( 21 | kernel32 = syscall.NewLazyDLL("kernel32.dll") 22 | ntdll = syscall.NewLazyDLL("ntdll.dll") 23 | procGetConsoleMode = kernel32.NewProc("GetConsoleMode") 24 | procGetFileInformationByHandleEx = kernel32.NewProc("GetFileInformationByHandleEx") 25 | procGetFileType = kernel32.NewProc("GetFileType") 26 | procNtQueryObject = ntdll.NewProc("NtQueryObject") 27 | ) 28 | 29 | func init() { 30 | // Check if GetFileInformationByHandleEx is available. 31 | if procGetFileInformationByHandleEx.Find() != nil { 32 | procGetFileInformationByHandleEx = nil 33 | } 34 | } 35 | 36 | // IsTerminal return true if the file descriptor is terminal. 37 | func IsTerminal(fd uintptr) bool { 38 | var st uint32 39 | r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, fd, uintptr(unsafe.Pointer(&st)), 0) 40 | return r != 0 && e == 0 41 | } 42 | 43 | // Check pipe name is used for cygwin/msys2 pty. 44 | // Cygwin/MSYS2 PTY has a name like: 45 | // 46 | // \{cygwin,msys}-XXXXXXXXXXXXXXXX-ptyN-{from,to}-master 47 | func isCygwinPipeName(name string) bool { 48 | token := strings.Split(name, "-") 49 | if len(token) < 5 { 50 | return false 51 | } 52 | 53 | if token[0] != `\msys` && 54 | token[0] != `\cygwin` && 55 | token[0] != `\Device\NamedPipe\msys` && 56 | token[0] != `\Device\NamedPipe\cygwin` { 57 | return false 58 | } 59 | 60 | if token[1] == "" { 61 | return false 62 | } 63 | 64 | if !strings.HasPrefix(token[2], "pty") { 65 | return false 66 | } 67 | 68 | if token[3] != `from` && token[3] != `to` { 69 | return false 70 | } 71 | 72 | if token[4] != "master" { 73 | return false 74 | } 75 | 76 | return true 77 | } 78 | 79 | // getFileNameByHandle use the undocomented ntdll NtQueryObject to get file full name from file handler 80 | // since GetFileInformationByHandleEx is not available under windows Vista and still some old fashion 81 | // guys are using Windows XP, this is a workaround for those guys, it will also work on system from 82 | // Windows vista to 10 83 | // see https://stackoverflow.com/a/18792477 for details 84 | func getFileNameByHandle(fd uintptr) (string, error) { 85 | if procNtQueryObject == nil { 86 | return "", errors.New("ntdll.dll: NtQueryObject not supported") 87 | } 88 | 89 | var buf [4 + syscall.MAX_PATH]uint16 90 | var result int 91 | r, _, e := syscall.Syscall6(procNtQueryObject.Addr(), 5, 92 | fd, objectNameInfo, uintptr(unsafe.Pointer(&buf)), uintptr(2*len(buf)), uintptr(unsafe.Pointer(&result)), 0) 93 | if r != 0 { 94 | return "", e 95 | } 96 | return string(utf16.Decode(buf[4 : 4+buf[0]/2])), nil 97 | } 98 | 99 | // IsCygwinTerminal() return true if the file descriptor is a cygwin or msys2 100 | // terminal. 101 | func IsCygwinTerminal(fd uintptr) bool { 102 | if procGetFileInformationByHandleEx == nil { 103 | name, err := getFileNameByHandle(fd) 104 | if err != nil { 105 | return false 106 | } 107 | return isCygwinPipeName(name) 108 | } 109 | 110 | // Cygwin/msys's pty is a pipe. 111 | ft, _, e := syscall.Syscall(procGetFileType.Addr(), 1, fd, 0, 0) 112 | if ft != fileTypePipe || e != 0 { 113 | return false 114 | } 115 | 116 | var buf [2 + syscall.MAX_PATH]uint16 117 | r, _, e := syscall.Syscall6(procGetFileInformationByHandleEx.Addr(), 118 | 4, fd, fileNameInfo, uintptr(unsafe.Pointer(&buf)), 119 | uintptr(len(buf)*2), 0, 0) 120 | if r == 0 || e != 0 { 121 | return false 122 | } 123 | 124 | l := *(*uint32)(unsafe.Pointer(&buf)) 125 | return isCygwinPipeName(string(utf16.Decode(buf[2 : 2+l/2]))) 126 | } 127 | -------------------------------------------------------------------------------- /pkg/isatty/isatty_windows_test.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package isatty 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestCygwinPipeName(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | result bool 14 | }{ 15 | {``, false}, 16 | {`\msys-`, false}, 17 | {`\cygwin-----`, false}, 18 | {`\msys-x-PTY5-pty1-from-master`, false}, 19 | {`\cygwin-x-PTY5-from-master`, false}, 20 | {`\cygwin-x-pty2-from-toaster`, false}, 21 | {`\cygwin--pty2-from-master`, false}, 22 | {`\\cygwin-x-pty2-from-master`, false}, 23 | {`\cygwin-x-pty2-from-master-`, true}, // for the feature 24 | {`\cygwin-e022582115c10879-pty4-from-master`, true}, 25 | {`\msys-e022582115c10879-pty4-to-master`, true}, 26 | {`\cygwin-e022582115c10879-pty4-to-master`, true}, 27 | {`\Device\NamedPipe\cygwin-e022582115c10879-pty4-from-master`, true}, 28 | {`\Device\NamedPipe\msys-e022582115c10879-pty4-to-master`, true}, 29 | {`Device\NamedPipe\cygwin-e022582115c10879-pty4-to-master`, false}, 30 | } 31 | 32 | for _, test := range tests { 33 | want := test.result 34 | got := isCygwinPipeName(test.name) 35 | if want != got { 36 | t.Fatalf("isatty(%q): got %v, want %v:", test.name, got, want) 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /pkg/queue/queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "sync" 4 | 5 | // Queueable handles sequential processing of operations with typed results 6 | type Queueable[T any] struct { 7 | mu sync.Mutex 8 | queue chan func() (T, error) 9 | results chan queueResult[T] 10 | done chan struct{} 11 | } 12 | 13 | type queueResult[T any] struct { 14 | value T 15 | err error 16 | } 17 | 18 | // New creates a new queue with the specified buffer size 19 | func New[T any](bufferSize int) *Queueable[T] { 20 | q := &Queueable[T]{ 21 | queue: make(chan func() (T, error), bufferSize), 22 | results: make(chan queueResult[T]), 23 | done: make(chan struct{}), 24 | } 25 | go q.process() 26 | return q 27 | } 28 | 29 | // process continuously processes operations from the queue in sequence 30 | func (q *Queueable[T]) process() { 31 | for { 32 | select { 33 | case op := <-q.queue: 34 | value, err := op() 35 | q.results <- queueResult[T]{value, err} 36 | case <-q.done: 37 | return 38 | } 39 | } 40 | } 41 | 42 | // Enqueue adds an operation to the queue and waits for its completion 43 | func (q *Queueable[T]) Enqueue(op func() (T, error)) (T, error) { 44 | q.mu.Lock() 45 | q.queue <- op 46 | q.mu.Unlock() 47 | 48 | // Wait for the result 49 | r := <-q.results 50 | return r.value, r.err 51 | } 52 | 53 | // EnqueueError adds an error-only operation to the queue 54 | func (q *Queueable[T]) EnqueueError(op func() error) error { 55 | var zero T 56 | wrappedOp := func() (T, error) { 57 | return zero, op() 58 | } 59 | 60 | _, err := q.Enqueue(wrappedOp) 61 | return err 62 | } 63 | 64 | // EnqueueFn adds a function that returns a value without error 65 | func (q *Queueable[T]) EnqueueFn(op func() T) T { 66 | wrappedOp := func() (T, error) { 67 | return op(), nil 68 | } 69 | 70 | result, _ := q.Enqueue(wrappedOp) 71 | return result 72 | } 73 | 74 | // Shutdown terminates the queue processor 75 | func (q *Queueable[T]) Shutdown() { 76 | close(q.done) 77 | } 78 | 79 | // EnqueueWithParam wraps a function that takes a parameter to be queued 80 | func EnqueueWithParam[T any, P any](q *Queueable[T], fn func(P) (T, error), param P) (T, error) { 81 | return q.Enqueue(func() (T, error) { 82 | return fn(param) 83 | }) 84 | } 85 | 86 | // EnqueueErrorWithParam wraps an error-returning function with a parameter 87 | func EnqueueErrorWithParam[T any, P any](q *Queueable[T], fn func(P) error, param P) error { 88 | wrappedOp := func() error { 89 | return fn(param) 90 | } 91 | return q.EnqueueError(wrappedOp) 92 | } 93 | 94 | // EnqueueFnWithParam wraps a value-returning function with a parameter 95 | func EnqueueFnWithParam[T any, P any](q *Queueable[T], fn func(P) T, param P) T { 96 | wrappedOp := func() T { 97 | return fn(param) 98 | } 99 | return q.EnqueueFn(wrappedOp) 100 | } 101 | 102 | // EnqueueBoolFn specifically for boolean-returning functions 103 | func EnqueueBoolFn[P any](q *Queueable[bool], fn func(P) bool, param P) bool { 104 | return EnqueueFnWithParam(q, fn, param) 105 | } 106 | -------------------------------------------------------------------------------- /pkg/shlex/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) anmitsu 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /pkg/shlex/README.md: -------------------------------------------------------------------------------- 1 | # go-shlex 2 | 3 | go-shlex is a library to make a lexical analyzer like Unix shell for 4 | Go. 5 | 6 | ## Install 7 | 8 | go get -u "github.com/anmitsu/go-shlex" 9 | 10 | ## Usage 11 | 12 | ```go 13 | package main 14 | 15 | import ( 16 | "fmt" 17 | "log" 18 | 19 | "github.com/anmitsu/go-shlex" 20 | ) 21 | 22 | func main() { 23 | cmd := `cp -Rdp "file name" 'file name2' dir\ name` 24 | words, err := shlex.Split(cmd, true) 25 | if err != nil { 26 | log.Fatal(err) 27 | } 28 | 29 | for _, w := range words { 30 | fmt.Println(w) 31 | } 32 | } 33 | ``` 34 | output 35 | 36 | cp 37 | -Rdp 38 | file name 39 | file name2 40 | dir name 41 | 42 | ## Documentation 43 | 44 | http://godoc.org/github.com/anmitsu/go-shlex 45 | 46 | -------------------------------------------------------------------------------- /pkg/shlex/shlex.go: -------------------------------------------------------------------------------- 1 | // Package shlex provides a simple lexical analysis like Unix shell. 2 | package shlex 3 | 4 | import ( 5 | "bufio" 6 | "errors" 7 | "io" 8 | "strings" 9 | "unicode" 10 | ) 11 | 12 | var ( 13 | ErrNoClosing = errors.New("No closing quotation") 14 | ErrNoEscaped = errors.New("No escaped character") 15 | ) 16 | 17 | // Tokenizer is the interface that classifies a token according to 18 | // words, whitespaces, quotations, escapes and escaped quotations. 19 | type Tokenizer interface { 20 | IsWord(rune) bool 21 | IsWhitespace(rune) bool 22 | IsQuote(rune) bool 23 | IsEscape(rune) bool 24 | IsEscapedQuote(rune) bool 25 | } 26 | 27 | // DefaultTokenizer implements a simple tokenizer like Unix shell. 28 | type DefaultTokenizer struct{} 29 | 30 | func (t *DefaultTokenizer) IsWord(r rune) bool { 31 | return r == '_' || unicode.IsLetter(r) || unicode.IsNumber(r) 32 | } 33 | func (t *DefaultTokenizer) IsQuote(r rune) bool { 34 | switch r { 35 | case '\'', '"': 36 | return true 37 | default: 38 | return false 39 | } 40 | } 41 | func (t *DefaultTokenizer) IsWhitespace(r rune) bool { 42 | return unicode.IsSpace(r) 43 | } 44 | func (t *DefaultTokenizer) IsEscape(r rune) bool { 45 | return r == '\\' 46 | } 47 | func (t *DefaultTokenizer) IsEscapedQuote(r rune) bool { 48 | return r == '"' 49 | } 50 | 51 | // Lexer represents a lexical analyzer. 52 | type Lexer struct { 53 | reader *bufio.Reader 54 | tokenizer Tokenizer 55 | posix bool 56 | whitespacesplit bool 57 | } 58 | 59 | // NewLexer creates a new Lexer reading from io.Reader. This Lexer 60 | // has a DefaultTokenizer according to posix and whitespacesplit 61 | // rules. 62 | func NewLexer(r io.Reader, posix, whitespacesplit bool) *Lexer { 63 | return &Lexer{ 64 | reader: bufio.NewReader(r), 65 | tokenizer: &DefaultTokenizer{}, 66 | posix: posix, 67 | whitespacesplit: whitespacesplit, 68 | } 69 | } 70 | 71 | // NewLexerString creates a new Lexer reading from a string. This 72 | // Lexer has a DefaultTokenizer according to posix and whitespacesplit 73 | // rules. 74 | func NewLexerString(s string, posix, whitespacesplit bool) *Lexer { 75 | return NewLexer(strings.NewReader(s), posix, whitespacesplit) 76 | } 77 | 78 | // Split splits a string according to posix or non-posix rules. 79 | func Split(s string, posix bool) ([]string, error) { 80 | return NewLexerString(s, posix, true).Split() 81 | } 82 | 83 | // SetTokenizer sets a Tokenizer. 84 | func (l *Lexer) SetTokenizer(t Tokenizer) { 85 | l.tokenizer = t 86 | } 87 | 88 | func (l *Lexer) Split() ([]string, error) { 89 | result := make([]string, 0) 90 | for { 91 | token, err := l.readToken() 92 | if token != "" { 93 | result = append(result, token) 94 | } 95 | 96 | if err == io.EOF { 97 | break 98 | } else if err != nil { 99 | return result, err 100 | } 101 | } 102 | return result, nil 103 | } 104 | 105 | func (l *Lexer) readToken() (string, error) { 106 | t := l.tokenizer 107 | token := "" 108 | quoted := false 109 | state := ' ' 110 | escapedstate := ' ' 111 | scanning: 112 | for { 113 | next, _, err := l.reader.ReadRune() 114 | if err != nil { 115 | if t.IsQuote(state) { 116 | return token, ErrNoClosing 117 | } else if t.IsEscape(state) { 118 | return token, ErrNoEscaped 119 | } 120 | return token, err 121 | } 122 | 123 | switch { 124 | case t.IsWhitespace(state): 125 | switch { 126 | case t.IsWhitespace(next): 127 | break scanning 128 | case l.posix && t.IsEscape(next): 129 | escapedstate = 'a' 130 | state = next 131 | case t.IsWord(next): 132 | token += string(next) 133 | state = 'a' 134 | case t.IsQuote(next): 135 | if !l.posix { 136 | token += string(next) 137 | } 138 | state = next 139 | default: 140 | token = string(next) 141 | if l.whitespacesplit { 142 | state = 'a' 143 | } else if token != "" || (l.posix && quoted) { 144 | break scanning 145 | } 146 | } 147 | case t.IsQuote(state): 148 | quoted = true 149 | switch { 150 | case next == state: 151 | if !l.posix { 152 | token += string(next) 153 | break scanning 154 | } else { 155 | state = 'a' 156 | } 157 | case l.posix && t.IsEscape(next) && t.IsEscapedQuote(state): 158 | escapedstate = state 159 | state = next 160 | default: 161 | token += string(next) 162 | } 163 | case t.IsEscape(state): 164 | if t.IsQuote(escapedstate) && next != state && next != escapedstate { 165 | token += string(state) 166 | } 167 | token += string(next) 168 | state = escapedstate 169 | case t.IsWord(state): 170 | switch { 171 | case t.IsWhitespace(next): 172 | if token != "" || (l.posix && quoted) { 173 | break scanning 174 | } 175 | case l.posix && t.IsQuote(next): 176 | state = next 177 | case l.posix && t.IsEscape(next): 178 | escapedstate = 'a' 179 | state = next 180 | case t.IsWord(next) || t.IsQuote(next): 181 | token += string(next) 182 | default: 183 | if l.whitespacesplit { 184 | token += string(next) 185 | } else if token != "" { 186 | l.reader.UnreadRune() 187 | break scanning 188 | } 189 | } 190 | } 191 | } 192 | return token, nil 193 | } 194 | -------------------------------------------------------------------------------- /pkg/ssh/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Glider Labs. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Glider Labs nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /pkg/ssh/README.md: -------------------------------------------------------------------------------- 1 | # gliderlabs/ssh 2 | 3 | [![GoDoc](https://godoc.org/github.com/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) 4 | [![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) 6 | [![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) 7 | [![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) 8 | [![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) 9 | 10 | > The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member 11 | 12 | This Go package wraps the [crypto/ssh 13 | package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for 14 | building SSH servers. The goal of the API was to make it as simple as using 15 | [net/http](https://golang.org/pkg/net/http/), so the API is very similar: 16 | 17 | ```go 18 | package main 19 | 20 | import ( 21 | "github.com/gliderlabs/ssh" 22 | "io" 23 | "log" 24 | ) 25 | 26 | func main() { 27 | ssh.Handle(func(s ssh.Session) { 28 | io.WriteString(s, "Hello world\n") 29 | }) 30 | 31 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 32 | } 33 | 34 | ``` 35 | This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). 36 | 37 | ## Examples 38 | 39 | A bunch of great examples are in the `_examples` directory. 40 | 41 | ## Usage 42 | 43 | [See GoDoc reference.](https://godoc.org/github.com/gliderlabs/ssh) 44 | 45 | ## Contributing 46 | 47 | Pull requests are welcome! However, since this project is very much about API 48 | design, please submit API changes as issues to discuss before submitting PRs. 49 | 50 | Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. 51 | 52 | ## Roadmap 53 | 54 | * Non-session channel handlers 55 | * Cleanup callback API 56 | * 1.0 release 57 | * High-level client? 58 | 59 | ## Sponsors 60 | 61 | Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | ## License 95 | 96 | [BSD](LICENSE) 97 | -------------------------------------------------------------------------------- /pkg/ssh/agent.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "os" 7 | "path" 8 | "sync" 9 | 10 | gossh "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | const ( 14 | agentRequestType = "auth-agent-req@openssh.com" 15 | agentChannelType = "auth-agent@openssh.com" 16 | 17 | agentTempDir = "auth-agent" 18 | agentListenFile = "listener.sock" 19 | ) 20 | 21 | // contextKeyAgentRequest is an internal context key for storing if the 22 | // client requested agent forwarding 23 | var contextKeyAgentRequest = &contextKey{"auth-agent-req"} 24 | 25 | // SetAgentRequested sets up the session context so that AgentRequested 26 | // returns true. 27 | func SetAgentRequested(ctx Context) { 28 | ctx.SetValue(contextKeyAgentRequest, true) 29 | } 30 | 31 | // AgentRequested returns true if the client requested agent forwarding. 32 | func AgentRequested(sess Session) bool { 33 | return sess.Context().Value(contextKeyAgentRequest) == true 34 | } 35 | 36 | // NewAgentListener sets up a temporary Unix socket that can be communicated 37 | // to the session environment and used for forwarding connections. 38 | func NewAgentListener() (net.Listener, error) { 39 | dir, err := os.MkdirTemp("", agentTempDir) 40 | if err != nil { 41 | return nil, err 42 | } 43 | l, err := net.Listen("unix", path.Join(dir, agentListenFile)) 44 | if err != nil { 45 | return nil, err 46 | } 47 | return l, nil 48 | } 49 | 50 | // ForwardAgentConnections takes connections from a listener to proxy into the 51 | // session on the OpenSSH channel for agent connections. It blocks and services 52 | // connections until the listener stop accepting. 53 | func ForwardAgentConnections(l net.Listener, s Session) { 54 | sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) 55 | for { 56 | conn, err := l.Accept() 57 | if err != nil { 58 | return 59 | } 60 | go func(conn net.Conn) { 61 | defer conn.Close() 62 | channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) 63 | if err != nil { 64 | return 65 | } 66 | defer channel.Close() 67 | go gossh.DiscardRequests(reqs) 68 | var wg sync.WaitGroup 69 | wg.Add(2) 70 | go func() { 71 | io.Copy(conn, channel) 72 | conn.(*net.UnixConn).CloseWrite() 73 | wg.Done() 74 | }() 75 | go func() { 76 | io.Copy(channel, conn) 77 | channel.CloseWrite() 78 | wg.Done() 79 | }() 80 | wg.Wait() 81 | }(conn) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /pkg/ssh/conn.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "time" 7 | ) 8 | 9 | type serverConn struct { 10 | net.Conn 11 | 12 | idleTimeout time.Duration 13 | handshakeDeadline time.Time 14 | maxDeadline time.Time 15 | closeCanceler context.CancelFunc 16 | } 17 | 18 | func (c *serverConn) Write(p []byte) (n int, err error) { 19 | if c.idleTimeout > 0 { 20 | c.updateDeadline() 21 | } 22 | n, err = c.Conn.Write(p) 23 | if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { 24 | c.closeCanceler() 25 | } 26 | return 27 | } 28 | 29 | func (c *serverConn) Read(b []byte) (n int, err error) { 30 | if c.idleTimeout > 0 { 31 | c.updateDeadline() 32 | } 33 | n, err = c.Conn.Read(b) 34 | if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { 35 | c.closeCanceler() 36 | } 37 | return 38 | } 39 | 40 | func (c *serverConn) Close() (err error) { 41 | err = c.Conn.Close() 42 | if c.closeCanceler != nil { 43 | c.closeCanceler() 44 | } 45 | return 46 | } 47 | 48 | func (c *serverConn) updateDeadline() { 49 | deadline := c.maxDeadline 50 | 51 | if !c.handshakeDeadline.IsZero() && (deadline.IsZero() || c.handshakeDeadline.Before(deadline)) { 52 | deadline = c.handshakeDeadline 53 | } 54 | 55 | if c.idleTimeout > 0 { 56 | idleDeadline := time.Now().Add(c.idleTimeout) 57 | if deadline.IsZero() || idleDeadline.Before(deadline) { 58 | deadline = idleDeadline 59 | } 60 | } 61 | 62 | c.Conn.SetDeadline(deadline) 63 | } 64 | -------------------------------------------------------------------------------- /pkg/ssh/context.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "net" 7 | "sync" 8 | 9 | gossh "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | // contextKey is a value for use with context.WithValue. It's used as 13 | // a pointer so it fits in an interface{} without allocation. 14 | type contextKey struct { 15 | name string 16 | } 17 | 18 | var ( 19 | // ContextKeyUser is a context key for use with Contexts in this package. 20 | // The associated value will be of type string. 21 | ContextKeyUser = &contextKey{"user"} 22 | 23 | // ContextKeySessionID is a context key for use with Contexts in this package. 24 | // The associated value will be of type string. 25 | ContextKeySessionID = &contextKey{"session-id"} 26 | 27 | // ContextKeyPermissions is a context key for use with Contexts in this package. 28 | // The associated value will be of type *Permissions. 29 | ContextKeyPermissions = &contextKey{"permissions"} 30 | 31 | // ContextKeyClientVersion is a context key for use with Contexts in this package. 32 | // The associated value will be of type string. 33 | ContextKeyClientVersion = &contextKey{"client-version"} 34 | 35 | // ContextKeyServerVersion is a context key for use with Contexts in this package. 36 | // The associated value will be of type string. 37 | ContextKeyServerVersion = &contextKey{"server-version"} 38 | 39 | // ContextKeyLocalAddr is a context key for use with Contexts in this package. 40 | // The associated value will be of type net.Addr. 41 | ContextKeyLocalAddr = &contextKey{"local-addr"} 42 | 43 | // ContextKeyRemoteAddr is a context key for use with Contexts in this package. 44 | // The associated value will be of type net.Addr. 45 | ContextKeyRemoteAddr = &contextKey{"remote-addr"} 46 | 47 | // ContextKeyServer is a context key for use with Contexts in this package. 48 | // The associated value will be of type *Server. 49 | ContextKeyServer = &contextKey{"ssh-server"} 50 | 51 | // ContextKeyConn is a context key for use with Contexts in this package. 52 | // The associated value will be of type gossh.ServerConn. 53 | ContextKeyConn = &contextKey{"ssh-conn"} 54 | 55 | // ContextKeyPublicKey is a context key for use with Contexts in this package. 56 | // The associated value will be of type PublicKey. 57 | ContextKeyPublicKey = &contextKey{"public-key"} 58 | 59 | ContextKeyCancelFunc = &contextKey{"cancel-func"} 60 | ) 61 | 62 | // Context is a package specific context interface. It exposes connection 63 | // metadata and allows new values to be easily written to it. It's used in 64 | // authentication handlers and callbacks, and its underlying context.Context is 65 | // exposed on Session in the session Handler. A connection-scoped lock is also 66 | // embedded in the context to make it easier to limit operations per-connection. 67 | type Context interface { 68 | context.Context 69 | sync.Locker 70 | 71 | // User returns the username used when establishing the SSH connection. 72 | User() string 73 | 74 | // SessionID returns the session hash. 75 | SessionID() string 76 | 77 | // ShortSessionID returns first 8 of session hash. 78 | ShortSessionID() string 79 | 80 | // ClientVersion returns the version reported by the client. 81 | ClientVersion() string 82 | 83 | // ServerVersion returns the version reported by the server. 84 | ServerVersion() string 85 | 86 | // RemoteAddr returns the remote address for this connection. 87 | RemoteAddr() net.Addr 88 | 89 | // LocalAddr returns the local address for this connection. 90 | LocalAddr() net.Addr 91 | 92 | // Permissions returns the Permissions object used for this connection. 93 | Permissions() *Permissions 94 | 95 | // SetValue allows you to easily write new values into the underlying context. 96 | SetValue(key, value interface{}) 97 | } 98 | 99 | type SshContext struct { 100 | context.Context 101 | *sync.Mutex 102 | 103 | values map[interface{}]interface{} 104 | valuesMu sync.Mutex 105 | } 106 | 107 | func NewContext(srv *Server) (*SshContext, context.CancelFunc) { 108 | innerCtx, cancel := context.WithCancel(context.Background()) 109 | ctx := &SshContext{Context: innerCtx, Mutex: &sync.Mutex{}, values: make(map[interface{}]interface{})} 110 | ctx.SetValue(ContextKeyServer, srv) 111 | perms := &Permissions{&gossh.Permissions{}} 112 | ctx.SetValue(ContextKeyPermissions, perms) 113 | return ctx, cancel 114 | } 115 | 116 | // this is separate from newContext because we will get ConnMetadata 117 | // at different points so it needs to be applied separately 118 | func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { 119 | if ctx.Value(ContextKeySessionID) != nil { 120 | return 121 | } 122 | ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) 123 | ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) 124 | ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) 125 | ctx.SetValue(ContextKeyUser, conn.User()) 126 | ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) 127 | ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) 128 | } 129 | 130 | func (ctx *SshContext) Value(key interface{}) interface{} { 131 | ctx.valuesMu.Lock() 132 | defer ctx.valuesMu.Unlock() 133 | if v, ok := ctx.values[key]; ok { 134 | return v 135 | } 136 | return ctx.Context.Value(key) 137 | } 138 | 139 | func (ctx *SshContext) SetValue(key, value interface{}) { 140 | ctx.valuesMu.Lock() 141 | defer ctx.valuesMu.Unlock() 142 | ctx.values[key] = value 143 | } 144 | 145 | func (ctx *SshContext) User() string { 146 | return ctx.Value(ContextKeyUser).(string) 147 | } 148 | 149 | func (ctx *SshContext) SessionID() string { 150 | return ctx.Value(ContextKeySessionID).(string) 151 | } 152 | 153 | func (ctx *SshContext) ShortSessionID() string { 154 | if ses, ok := ctx.Value(ContextKeySessionID).(string); ok { 155 | return ses[:8] 156 | } 157 | return "unknown" 158 | } 159 | 160 | func (ctx *SshContext) ClientVersion() string { 161 | return ctx.Value(ContextKeyClientVersion).(string) 162 | } 163 | 164 | func (ctx *SshContext) ServerVersion() string { 165 | return ctx.Value(ContextKeyServerVersion).(string) 166 | } 167 | 168 | func (ctx *SshContext) RemoteAddr() net.Addr { 169 | if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { 170 | return addr 171 | } 172 | return nil 173 | } 174 | 175 | func (ctx *SshContext) LocalAddr() net.Addr { 176 | return ctx.Value(ContextKeyLocalAddr).(net.Addr) 177 | } 178 | 179 | func (ctx *SshContext) Permissions() *Permissions { 180 | return ctx.Value(ContextKeyPermissions).(*Permissions) 181 | } 182 | -------------------------------------------------------------------------------- /pkg/ssh/context_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestSetPermissions(t *testing.T) { 9 | t.Parallel() 10 | permsExt := map[string]string{ 11 | "foo": "bar", 12 | } 13 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 14 | Handler: func(s Session) { 15 | if _, ok := s.Permissions().Extensions["foo"]; !ok { 16 | t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) 17 | } 18 | }, 19 | }, nil, PasswordAuth(func(ctx Context, password string) bool { 20 | ctx.Permissions().Extensions = permsExt 21 | return true 22 | })) 23 | defer cleanup() 24 | if err := session.Run(""); err != nil { 25 | t.Fatal(err) 26 | } 27 | } 28 | 29 | func TestSetValue(t *testing.T) { 30 | t.Parallel() 31 | value := map[string]string{ 32 | "foo": "bar", 33 | } 34 | key := "testValue" 35 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 36 | Handler: func(s Session) { 37 | v := s.Context().Value(key).(map[string]string) 38 | if v["foo"] != value["foo"] { 39 | t.Fatalf("got %#v; want %#v", v, value) 40 | } 41 | }, 42 | }, nil, PasswordAuth(func(ctx Context, password string) bool { 43 | ctx.SetValue(key, value) 44 | return true 45 | })) 46 | defer cleanup() 47 | if err := session.Run(""); err != nil { 48 | t.Fatal(err) 49 | } 50 | } 51 | 52 | func TestSetValueConcurrency(t *testing.T) { 53 | ctx, cancel := newContext(nil) 54 | defer cancel() 55 | 56 | go func() { 57 | for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue 58 | _, _ = ctx.Deadline() 59 | _ = ctx.Err() 60 | _ = ctx.Value("foo") 61 | select { 62 | case <-ctx.Done(): 63 | break 64 | default: 65 | } 66 | } 67 | }() 68 | ctx.SetValue("bar", -1) // a context value which never changes 69 | now := time.Now() 70 | var cnt int64 71 | go func() { 72 | for time.Since(now) < 100*time.Millisecond { 73 | cnt++ 74 | ctx.SetValue("foo", cnt) // a context value which changes a lot 75 | } 76 | cancel() 77 | }() 78 | <-ctx.Done() 79 | if ctx.Value("foo") != cnt { 80 | t.Fatal("context.Value(foo) doesn't match latest SetValue") 81 | } 82 | if ctx.Value("bar") != -1 { 83 | t.Fatal("context.Value(bar) doesn't match latest SetValue") 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /pkg/ssh/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package ssh wraps the crypto/ssh package with a higher-level API for building 3 | SSH servers. The goal of the API was to make it as simple as using net/http, so 4 | the API is very similar. 5 | 6 | You should be able to build any SSH server using only this package, which wraps 7 | relevant types and some functions from crypto/ssh. However, you still need to 8 | use crypto/ssh for building SSH clients. 9 | 10 | ListenAndServe starts an SSH server with a given address, handler, and options. The 11 | handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: 12 | 13 | ssh.Handle(func(s ssh.Session) { 14 | io.WriteString(s, "Hello world\n") 15 | }) 16 | 17 | log.Fatal(ssh.ListenAndServe(":2222", nil)) 18 | 19 | If you don't specify a host key, it will generate one every time. This is convenient 20 | except you'll have to deal with clients being confused that the host key is different. 21 | It's a better idea to generate or point to an existing key on your system: 22 | 23 | log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) 24 | 25 | Although all options have functional option helpers, another way to control the 26 | server's behavior is by creating a custom Server: 27 | 28 | s := &ssh.Server{ 29 | Addr: ":2222", 30 | Handler: sessionHandler, 31 | PublicKeyHandler: authHandler, 32 | } 33 | s.AddHostKey(hostKeySigner) 34 | 35 | log.Fatal(s.ListenAndServe()) 36 | 37 | This package automatically handles basic SSH requests like setting environment 38 | variables, requesting PTY, and changing window size. These requests are 39 | processed, responded to, and any relevant state is updated. This state is then 40 | exposed to you via the Session interface. 41 | 42 | The one big feature missing from the Session abstraction is signals. This was 43 | started, but not completed. Pull Requests welcome! 44 | */ 45 | package ssh 46 | -------------------------------------------------------------------------------- /pkg/ssh/example_test.go: -------------------------------------------------------------------------------- 1 | package ssh_test 2 | 3 | import ( 4 | "io" 5 | "os" 6 | 7 | "ssh2incus/pkg/ssh" 8 | ) 9 | 10 | func ExampleListenAndServe() { 11 | ssh.ListenAndServe(":2222", func(s ssh.Session) { 12 | io.WriteString(s, "Hello world\n") 13 | }) 14 | } 15 | 16 | func ExamplePasswordAuth() { 17 | ssh.ListenAndServe(":2222", nil, 18 | ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { 19 | return pass == "secret" 20 | }), 21 | ) 22 | } 23 | 24 | func ExampleNoPty() { 25 | ssh.ListenAndServe(":2222", nil, ssh.NoPty()) 26 | } 27 | 28 | func ExamplePublicKeyAuth() { 29 | ssh.ListenAndServe(":2222", nil, 30 | ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { 31 | data, _ := os.ReadFile("/path/to/allowed/key.pub") 32 | allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data) 33 | return ssh.KeysEqual(key, allowed) 34 | }), 35 | ) 36 | } 37 | 38 | func ExampleHostKeyFile() { 39 | ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) 40 | } 41 | -------------------------------------------------------------------------------- /pkg/ssh/options.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "os" 5 | 6 | gossh "golang.org/x/crypto/ssh" 7 | ) 8 | 9 | // PasswordAuth returns a functional option that sets PasswordHandler on the server. 10 | func PasswordAuth(fn PasswordHandler) Option { 11 | return func(srv *Server) error { 12 | srv.PasswordHandler = fn 13 | return nil 14 | } 15 | } 16 | 17 | // PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. 18 | func PublicKeyAuth(fn PublicKeyHandler) Option { 19 | return func(srv *Server) error { 20 | srv.PublicKeyHandler = fn 21 | return nil 22 | } 23 | } 24 | 25 | // HostKeyFile returns a functional option that adds HostSigners to the server 26 | // from a PEM file at filepath. 27 | func HostKeyFile(filepath string) Option { 28 | return func(srv *Server) error { 29 | pemBytes, err := os.ReadFile(filepath) 30 | if err != nil { 31 | return err 32 | } 33 | 34 | signer, err := gossh.ParsePrivateKey(pemBytes) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | srv.AddHostKey(signer) 40 | 41 | return nil 42 | } 43 | } 44 | 45 | func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { 46 | return func(srv *Server) error { 47 | srv.KeyboardInteractiveHandler = fn 48 | return nil 49 | } 50 | } 51 | 52 | // HostKeyPEM returns a functional option that adds HostSigners to the server 53 | // from a PEM file as bytes. 54 | func HostKeyPEM(bytes []byte) Option { 55 | return func(srv *Server) error { 56 | signer, err := gossh.ParsePrivateKey(bytes) 57 | if err != nil { 58 | return err 59 | } 60 | 61 | srv.AddHostKey(signer) 62 | 63 | return nil 64 | } 65 | } 66 | 67 | // NoPty returns a functional option that sets PtyCallback to return false, 68 | // denying PTY requests. 69 | func NoPty() Option { 70 | return func(srv *Server) error { 71 | srv.PtyCallback = func(ctx Context, pty Pty) bool { 72 | return false 73 | } 74 | return nil 75 | } 76 | } 77 | 78 | // WrapConn returns a functional option that sets ConnCallback on the server. 79 | func WrapConn(fn ConnCallback) Option { 80 | return func(srv *Server) error { 81 | srv.ConnCallback = fn 82 | return nil 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /pkg/ssh/options_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "net" 5 | "strings" 6 | "sync/atomic" 7 | "testing" 8 | 9 | gossh "golang.org/x/crypto/ssh" 10 | ) 11 | 12 | func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { 13 | for _, option := range options { 14 | if err := srv.SetOption(option); err != nil { 15 | t.Fatal(err) 16 | } 17 | } 18 | return newTestSession(t, srv, cfg) 19 | } 20 | 21 | func TestPasswordAuth(t *testing.T) { 22 | t.Parallel() 23 | testUser := "testuser" 24 | testPass := "testpass" 25 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 26 | Handler: func(s Session) { 27 | // noop 28 | }, 29 | }, &gossh.ClientConfig{ 30 | User: testUser, 31 | Auth: []gossh.AuthMethod{ 32 | gossh.Password(testPass), 33 | }, 34 | HostKeyCallback: gossh.InsecureIgnoreHostKey(), 35 | }, PasswordAuth(func(ctx Context, password string) bool { 36 | if ctx.User() != testUser { 37 | t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) 38 | } 39 | if password != testPass { 40 | t.Fatalf("user = %#v; want %#v", password, testPass) 41 | } 42 | return true 43 | })) 44 | defer cleanup() 45 | if err := session.Run(""); err != nil { 46 | t.Fatal(err) 47 | } 48 | } 49 | 50 | func TestPasswordAuthBadPass(t *testing.T) { 51 | t.Parallel() 52 | l := newLocalListener() 53 | srv := &Server{Handler: func(s Session) {}} 54 | srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { 55 | return false 56 | })) 57 | go srv.serveOnce(l) 58 | _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ 59 | User: "testuser", 60 | Auth: []gossh.AuthMethod{ 61 | gossh.Password("testpass"), 62 | }, 63 | HostKeyCallback: gossh.InsecureIgnoreHostKey(), 64 | }) 65 | if err != nil { 66 | if !strings.Contains(err.Error(), "unable to authenticate") { 67 | t.Fatal(err) 68 | } 69 | } 70 | } 71 | 72 | type wrappedConn struct { 73 | net.Conn 74 | written int32 75 | } 76 | 77 | func (c *wrappedConn) Write(p []byte) (n int, err error) { 78 | n, err = c.Conn.Write(p) 79 | atomic.AddInt32(&(c.written), int32(n)) 80 | return 81 | } 82 | 83 | func TestConnWrapping(t *testing.T) { 84 | t.Parallel() 85 | var wrapped *wrappedConn 86 | session, _, cleanup := newTestSessionWithOptions(t, &Server{ 87 | Handler: func(s Session) { 88 | // nothing 89 | }, 90 | }, &gossh.ClientConfig{ 91 | User: "testuser", 92 | Auth: []gossh.AuthMethod{ 93 | gossh.Password("testpass"), 94 | }, 95 | HostKeyCallback: gossh.InsecureIgnoreHostKey(), 96 | }, PasswordAuth(func(ctx Context, password string) bool { 97 | return true 98 | }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { 99 | wrapped = &wrappedConn{conn, 0} 100 | return wrapped 101 | })) 102 | defer cleanup() 103 | if err := session.Shell(); err != nil { 104 | t.Fatal(err) 105 | } 106 | if atomic.LoadInt32(&(wrapped.written)) == 0 { 107 | t.Fatal("wrapped conn not written to") 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /pkg/ssh/server_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | "net" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestAddHostKey(t *testing.T) { 13 | s := Server{} 14 | signer, err := generateSigner() 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | s.AddHostKey(signer) 19 | if len(s.HostSigners) != 1 { 20 | t.Fatal("Key was not properly added") 21 | } 22 | signer, err = generateSigner() 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | s.AddHostKey(signer) 27 | if len(s.HostSigners) != 1 { 28 | t.Fatal("Key was not properly replaced") 29 | } 30 | } 31 | 32 | func TestServerShutdown(t *testing.T) { 33 | l := newLocalListener() 34 | testBytes := []byte("Hello world\n") 35 | s := &Server{ 36 | Handler: func(s Session) { 37 | s.Write(testBytes) 38 | time.Sleep(50 * time.Millisecond) 39 | }, 40 | } 41 | go func() { 42 | err := s.Serve(l) 43 | if err != nil && err != ErrServerClosed { 44 | t.Fatal(err) 45 | } 46 | }() 47 | sessDone := make(chan struct{}) 48 | sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) 49 | go func() { 50 | defer cleanup() 51 | defer close(sessDone) 52 | var stdout bytes.Buffer 53 | sess.Stdout = &stdout 54 | if err := sess.Run(""); err != nil { 55 | t.Fatal(err) 56 | } 57 | if !bytes.Equal(stdout.Bytes(), testBytes) { 58 | t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) 59 | } 60 | }() 61 | 62 | srvDone := make(chan struct{}) 63 | go func() { 64 | defer close(srvDone) 65 | err := s.Shutdown(context.Background()) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | }() 70 | 71 | timeout := time.After(2 * time.Second) 72 | select { 73 | case <-timeout: 74 | t.Fatal("timeout") 75 | return 76 | case <-srvDone: 77 | // TODO: add timeout for sessDone 78 | <-sessDone 79 | return 80 | } 81 | } 82 | 83 | func TestServerClose(t *testing.T) { 84 | l := newLocalListener() 85 | s := &Server{ 86 | Handler: func(s Session) { 87 | time.Sleep(5 * time.Second) 88 | }, 89 | } 90 | go func() { 91 | err := s.Serve(l) 92 | if err != nil && err != ErrServerClosed { 93 | t.Fatal(err) 94 | } 95 | }() 96 | 97 | clientDoneChan := make(chan struct{}) 98 | closeDoneChan := make(chan struct{}) 99 | 100 | sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) 101 | go func() { 102 | defer cleanup() 103 | defer close(clientDoneChan) 104 | <-closeDoneChan 105 | if err := sess.Run(""); err != nil && err != io.EOF { 106 | t.Fatal(err) 107 | } 108 | }() 109 | 110 | go func() { 111 | err := s.Close() 112 | if err != nil { 113 | t.Fatal(err) 114 | } 115 | close(closeDoneChan) 116 | }() 117 | 118 | timeout := time.After(100 * time.Millisecond) 119 | select { 120 | case <-timeout: 121 | t.Error("timeout") 122 | return 123 | case <-s.getDoneChan(): 124 | <-clientDoneChan 125 | return 126 | } 127 | } 128 | 129 | func TestServerHandshakeTimeout(t *testing.T) { 130 | l := newLocalListener() 131 | 132 | s := &Server{ 133 | HandshakeTimeout: time.Millisecond, 134 | } 135 | go func() { 136 | if err := s.Serve(l); err != nil { 137 | t.Error(err) 138 | } 139 | }() 140 | 141 | conn, err := net.Dial("tcp", l.Addr().String()) 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | defer conn.Close() 146 | 147 | ch := make(chan struct{}) 148 | go func() { 149 | defer close(ch) 150 | io.Copy(io.Discard, conn) 151 | }() 152 | 153 | select { 154 | case <-ch: 155 | return 156 | case <-time.After(time.Second): 157 | t.Fatal("client connection was not force-closed") 158 | return 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /pkg/ssh/ssh.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "crypto/subtle" 5 | "net" 6 | 7 | gossh "golang.org/x/crypto/ssh" 8 | ) 9 | 10 | type Signal string 11 | 12 | // POSIX signals as listed in RFC 4254 Section 6.10. 13 | const ( 14 | SIGABRT Signal = "ABRT" 15 | SIGALRM Signal = "ALRM" 16 | SIGFPE Signal = "FPE" 17 | SIGHUP Signal = "HUP" 18 | SIGILL Signal = "ILL" 19 | SIGINT Signal = "INT" 20 | SIGKILL Signal = "KILL" 21 | SIGPIPE Signal = "PIPE" 22 | SIGQUIT Signal = "QUIT" 23 | SIGSEGV Signal = "SEGV" 24 | SIGTERM Signal = "TERM" 25 | SIGUSR1 Signal = "USR1" 26 | SIGUSR2 Signal = "USR2" 27 | ) 28 | 29 | // DefaultHandler is the default Handler used by Serve. 30 | var DefaultHandler Handler 31 | 32 | // Option is a functional option handler for Server. 33 | type Option func(*Server) error 34 | 35 | // Handler is a callback for handling established SSH sessions. 36 | type Handler func(Session) 37 | 38 | // BannerHandler is a callback for displaying the server banner. 39 | type BannerHandler func(ctx Context) string 40 | 41 | // PublicKeyHandler is a callback for performing public key authentication. 42 | type PublicKeyHandler func(ctx Context, key PublicKey) bool 43 | 44 | // PasswordHandler is a callback for performing password authentication. 45 | type PasswordHandler func(ctx Context, password string) bool 46 | 47 | // KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. 48 | type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool 49 | 50 | // PtyCallback is a hook for allowing PTY sessions. 51 | type PtyCallback func(ctx Context, pty Pty) bool 52 | 53 | // SessionRequestCallback is a callback for allowing or denying SSH sessions. 54 | type SessionRequestCallback func(sess Session, requestType string) bool 55 | 56 | // ConnCallback is a hook for new connections before handling. 57 | // It allows wrapping for timeouts and limiting by returning 58 | // the net.Conn that will be used as the underlying connection. 59 | type ConnCallback func(ctx Context, conn net.Conn) net.Conn 60 | 61 | // LocalPortForwardingCallback is a hook for allowing port forwarding 62 | type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool 63 | 64 | // ReversePortForwardingCallback is a hook for allowing reverse port forwarding 65 | type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool 66 | 67 | // ServerConfigCallback is a hook for creating custom default server configs 68 | type ServerConfigCallback func(ctx Context) *gossh.ServerConfig 69 | 70 | // ConnectionFailedCallback is a hook for reporting failed connections 71 | // Please note: the net.Conn is likely to be closed at this point 72 | type ConnectionFailedCallback func(conn net.Conn, err error) 73 | 74 | // Window represents the size of a PTY window. 75 | type Window struct { 76 | Width int 77 | Height int 78 | } 79 | 80 | // Pty represents a PTY request and configuration. 81 | type Pty struct { 82 | Term string 83 | Window Window 84 | // HELP WANTED: terminal modes! 85 | } 86 | 87 | // Serve accepts incoming SSH connections on the listener l, creating a new 88 | // connection goroutine for each. The connection goroutines read requests and 89 | // then calls handler to handle sessions. Handler is typically nil, in which 90 | // case the DefaultHandler is used. 91 | func Serve(l net.Listener, handler Handler, options ...Option) error { 92 | srv := &Server{Handler: handler} 93 | for _, option := range options { 94 | if err := srv.SetOption(option); err != nil { 95 | return err 96 | } 97 | } 98 | return srv.Serve(l) 99 | } 100 | 101 | // ListenAndServe listens on the TCP network address addr and then calls Serve 102 | // with handler to handle sessions on incoming connections. Handler is typically 103 | // nil, in which case the DefaultHandler is used. 104 | func ListenAndServe(addr string, handler Handler, options ...Option) error { 105 | srv := &Server{Addr: addr, Handler: handler} 106 | for _, option := range options { 107 | if err := srv.SetOption(option); err != nil { 108 | return err 109 | } 110 | } 111 | return srv.ListenAndServe() 112 | } 113 | 114 | // Handle registers the handler as the DefaultHandler. 115 | func Handle(handler Handler) { 116 | DefaultHandler = handler 117 | } 118 | 119 | // KeysEqual is constant time compare of the keys to avoid timing attacks. 120 | func KeysEqual(ak, bk PublicKey) bool { 121 | // avoid panic if one of the keys is nil, return false instead 122 | if ak == nil || bk == nil { 123 | return false 124 | } 125 | 126 | a := ak.Marshal() 127 | b := bk.Marshal() 128 | return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) 129 | } 130 | -------------------------------------------------------------------------------- /pkg/ssh/ssh_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestKeysEqual(t *testing.T) { 8 | defer func() { 9 | if r := recover(); r != nil { 10 | t.Errorf("The code did panic") 11 | } 12 | }() 13 | 14 | if KeysEqual(nil, nil) { 15 | t.Error("two nil keys should not return true") 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /pkg/ssh/tcpip.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "io" 5 | "log" 6 | "net" 7 | "strconv" 8 | "sync" 9 | 10 | gossh "golang.org/x/crypto/ssh" 11 | ) 12 | 13 | const ( 14 | forwardedTCPChannelType = "forwarded-tcpip" 15 | ) 16 | 17 | // direct-tcpip data struct as specified in RFC4254, Section 7.2 18 | type localForwardChannelData struct { 19 | DestAddr string 20 | DestPort uint32 21 | 22 | OriginAddr string 23 | OriginPort uint32 24 | } 25 | 26 | // DirectTCPIPHandler can be enabled by adding it to the server's 27 | // ChannelHandlers under direct-tcpip. 28 | func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { 29 | d := localForwardChannelData{} 30 | if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { 31 | newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) 32 | return 33 | } 34 | 35 | if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { 36 | newChan.Reject(gossh.Prohibited, "port forwarding is disabled") 37 | return 38 | } 39 | 40 | dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) 41 | 42 | var dialer net.Dialer 43 | dconn, err := dialer.DialContext(ctx, "tcp", dest) 44 | if err != nil { 45 | newChan.Reject(gossh.ConnectionFailed, err.Error()) 46 | return 47 | } 48 | 49 | ch, reqs, err := newChan.Accept() 50 | if err != nil { 51 | dconn.Close() 52 | return 53 | } 54 | go gossh.DiscardRequests(reqs) 55 | 56 | go func() { 57 | defer ch.Close() 58 | defer dconn.Close() 59 | io.Copy(ch, dconn) 60 | }() 61 | go func() { 62 | defer ch.Close() 63 | defer dconn.Close() 64 | io.Copy(dconn, ch) 65 | }() 66 | } 67 | 68 | type remoteForwardRequest struct { 69 | BindAddr string 70 | BindPort uint32 71 | } 72 | 73 | type remoteForwardSuccess struct { 74 | BindPort uint32 75 | } 76 | 77 | type remoteForwardCancelRequest struct { 78 | BindAddr string 79 | BindPort uint32 80 | } 81 | 82 | type remoteForwardChannelData struct { 83 | DestAddr string 84 | DestPort uint32 85 | OriginAddr string 86 | OriginPort uint32 87 | } 88 | 89 | // ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and 90 | // adding the HandleSSHRequest callback to the server's RequestHandlers under 91 | // tcpip-forward and cancel-tcpip-forward. 92 | type ForwardedTCPHandler struct { 93 | forwards map[string]net.Listener 94 | sync.Mutex 95 | } 96 | 97 | func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { 98 | h.Lock() 99 | if h.forwards == nil { 100 | h.forwards = make(map[string]net.Listener) 101 | } 102 | h.Unlock() 103 | conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) 104 | switch req.Type { 105 | case "tcpip-forward": 106 | var reqPayload remoteForwardRequest 107 | if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { 108 | // TODO: log parse failure 109 | return false, []byte{} 110 | } 111 | if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { 112 | return false, []byte("port forwarding is disabled") 113 | } 114 | addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) 115 | ln, err := net.Listen("tcp", addr) 116 | if err != nil { 117 | // TODO: log listen failure 118 | return false, []byte{} 119 | } 120 | _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) 121 | destPort, _ := strconv.Atoi(destPortStr) 122 | h.Lock() 123 | h.forwards[addr] = ln 124 | h.Unlock() 125 | go func() { 126 | <-ctx.Done() 127 | h.Lock() 128 | ln, ok := h.forwards[addr] 129 | h.Unlock() 130 | if ok { 131 | ln.Close() 132 | } 133 | }() 134 | go func() { 135 | for { 136 | c, err := ln.Accept() 137 | if err != nil { 138 | // TODO: log accept failure 139 | break 140 | } 141 | originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) 142 | originPort, _ := strconv.Atoi(orignPortStr) 143 | payload := gossh.Marshal(&remoteForwardChannelData{ 144 | DestAddr: reqPayload.BindAddr, 145 | DestPort: uint32(destPort), 146 | OriginAddr: originAddr, 147 | OriginPort: uint32(originPort), 148 | }) 149 | go func() { 150 | ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) 151 | if err != nil { 152 | // TODO: log failure to open channel 153 | log.Println(err) 154 | c.Close() 155 | return 156 | } 157 | go gossh.DiscardRequests(reqs) 158 | go func() { 159 | defer ch.Close() 160 | defer c.Close() 161 | io.Copy(ch, c) 162 | }() 163 | go func() { 164 | defer ch.Close() 165 | defer c.Close() 166 | io.Copy(c, ch) 167 | }() 168 | }() 169 | } 170 | h.Lock() 171 | delete(h.forwards, addr) 172 | h.Unlock() 173 | }() 174 | return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) 175 | 176 | case "cancel-tcpip-forward": 177 | var reqPayload remoteForwardCancelRequest 178 | if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { 179 | // TODO: log parse failure 180 | return false, []byte{} 181 | } 182 | addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) 183 | h.Lock() 184 | ln, ok := h.forwards[addr] 185 | h.Unlock() 186 | if ok { 187 | ln.Close() 188 | } 189 | return true, nil 190 | default: 191 | return false, nil 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /pkg/ssh/tcpip_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net" 7 | "strconv" 8 | "strings" 9 | "testing" 10 | 11 | gossh "golang.org/x/crypto/ssh" 12 | ) 13 | 14 | var sampleServerResponse = []byte("Hello world") 15 | 16 | func sampleSocketServer() net.Listener { 17 | l := newLocalListener() 18 | 19 | go func() { 20 | conn, err := l.Accept() 21 | if err != nil { 22 | return 23 | } 24 | conn.Write(sampleServerResponse) 25 | conn.Close() 26 | }() 27 | 28 | return l 29 | } 30 | 31 | func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { 32 | l := sampleSocketServer() 33 | 34 | _, client, cleanup := newTestSession(t, &Server{ 35 | Handler: func(s Session) {}, 36 | LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { 37 | addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) 38 | if addr != l.Addr().String() { 39 | panic("unexpected destinationHost: " + addr) 40 | } 41 | return forwardingEnabled 42 | }, 43 | }, nil) 44 | 45 | return l, client, func() { 46 | cleanup() 47 | l.Close() 48 | } 49 | } 50 | 51 | func TestLocalPortForwardingWorks(t *testing.T) { 52 | t.Parallel() 53 | 54 | l, client, cleanup := newTestSessionWithForwarding(t, true) 55 | defer cleanup() 56 | 57 | conn, err := client.Dial("tcp", l.Addr().String()) 58 | if err != nil { 59 | t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) 60 | } 61 | result, err := io.ReadAll(conn) 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | if !bytes.Equal(result, sampleServerResponse) { 66 | t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) 67 | } 68 | } 69 | 70 | func TestLocalPortForwardingRespectsCallback(t *testing.T) { 71 | t.Parallel() 72 | 73 | l, client, cleanup := newTestSessionWithForwarding(t, false) 74 | defer cleanup() 75 | 76 | _, err := client.Dial("tcp", l.Addr().String()) 77 | if err == nil { 78 | t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) 79 | } 80 | if !strings.Contains(err.Error(), "port forwarding is disabled") { 81 | t.Fatalf("Expected permission error but got %#v", err) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /pkg/ssh/util.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "encoding/binary" 7 | 8 | "golang.org/x/crypto/ssh" 9 | ) 10 | 11 | func generateSigner() (ssh.Signer, error) { 12 | key, err := rsa.GenerateKey(rand.Reader, 2048) 13 | if err != nil { 14 | return nil, err 15 | } 16 | return ssh.NewSignerFromKey(key) 17 | } 18 | 19 | func parsePtyRequest(s []byte) (pty Pty, ok bool) { 20 | term, s, ok := parseString(s) 21 | if !ok { 22 | return 23 | } 24 | width32, s, ok := parseUint32(s) 25 | if !ok { 26 | return 27 | } 28 | height32, _, ok := parseUint32(s) 29 | if !ok { 30 | return 31 | } 32 | pty = Pty{ 33 | Term: term, 34 | Window: Window{ 35 | Width: int(width32), 36 | Height: int(height32), 37 | }, 38 | } 39 | return 40 | } 41 | 42 | func parseWinchRequest(s []byte) (win Window, ok bool) { 43 | width32, s, ok := parseUint32(s) 44 | if width32 < 1 { 45 | ok = false 46 | } 47 | if !ok { 48 | return 49 | } 50 | height32, _, ok := parseUint32(s) 51 | if height32 < 1 { 52 | ok = false 53 | } 54 | if !ok { 55 | return 56 | } 57 | win = Window{ 58 | Width: int(width32), 59 | Height: int(height32), 60 | } 61 | return 62 | } 63 | 64 | func parseString(in []byte) (out string, rest []byte, ok bool) { 65 | if len(in) < 4 { 66 | return 67 | } 68 | length := binary.BigEndian.Uint32(in) 69 | if uint32(len(in)) < 4+length { 70 | return 71 | } 72 | out = string(in[4 : 4+length]) 73 | rest = in[4+length:] 74 | ok = true 75 | return 76 | } 77 | 78 | func parseUint32(in []byte) (uint32, []byte, bool) { 79 | if len(in) < 4 { 80 | return 0, nil, false 81 | } 82 | return binary.BigEndian.Uint32(in), in[4:], true 83 | } 84 | -------------------------------------------------------------------------------- /pkg/ssh/wrap.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import gossh "golang.org/x/crypto/ssh" 4 | 5 | // PublicKey is an abstraction of different types of public keys. 6 | type PublicKey interface { 7 | gossh.PublicKey 8 | } 9 | 10 | // The Permissions type holds fine-grained permissions that are specific to a 11 | // user or a specific authentication method for a user. Permissions, except for 12 | // "source-address", must be enforced in the server application layer, after 13 | // successful authentication. 14 | type Permissions struct { 15 | *gossh.Permissions 16 | } 17 | 18 | // A Signer can create signatures that verify against a public key. 19 | type Signer interface { 20 | gossh.Signer 21 | } 22 | 23 | // ParseAuthorizedKey parses a public key from an authorized_keys file used in 24 | // OpenSSH according to the sshd(8) manual page. 25 | func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { 26 | return gossh.ParseAuthorizedKey(in) 27 | } 28 | 29 | // ParsePublicKey parses an SSH public key formatted for use in 30 | // the SSH wire protocol according to RFC 4253, section 6.6. 31 | func ParsePublicKey(in []byte) (out PublicKey, err error) { 32 | return gossh.ParsePublicKey(in) 33 | } 34 | -------------------------------------------------------------------------------- /pkg/user/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | Copyright (c) 2016 Tommy Allen 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of 5 | this software and associated documentation files (the "Software"), to deal in 6 | the Software without restriction, including without limitation the rights to 7 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 8 | of the Software, and to permit persons to whom the Software is furnished to do 9 | so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in all 12 | copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | SOFTWARE. 21 | -------------------------------------------------------------------------------- /pkg/user/README.md: -------------------------------------------------------------------------------- 1 | # luser 2 | 3 | [![GoDoc](https://godoc.org/github.com/tweekmonster/luser?status.svg)](https://godoc.org/github.com/tweekmonster/luser) 4 | [![Build Status](https://travis-ci.org/tweekmonster/luser.svg?branch=master)](https://travis-ci.org/tweekmonster/luser) 5 | 6 | `luser` is a drop-in replacement for `os/user` which allows you to lookup users 7 | and groups in cross-compiled builds without `cgo`. 8 | 9 | 10 | ## Overview 11 | 12 | `os/user` requires `cgo` to lookup users using the target OS's API. This is 13 | the most reliable way to look up user and group information. However, 14 | cross-compiling means that `os/user` will only work for the OS you're using. 15 | `user.Current()` is usable when building without `cgo`, but doesn't always 16 | work. The `$USER` and `$HOME` variables could be different from what you 17 | expect or not even exist. 18 | 19 | If you want to cross-compile a relatively simple program that needs to write a 20 | config file somewhere in the user's directory, the last thing you want to do is 21 | figure out some elaborate build scheme involving virtual machines. 22 | 23 | 24 | ## Usage 25 | 26 | `luser` has the same API as `os/user`. You should be able to just replace 27 | `user.` with `luser.` in your files and let `goimports` do the rest. The 28 | returned `*User` and `*Group` types will have an `IsLuser` field indicating 29 | whether or not a fallback method was used. 30 | 31 | 32 | ## Install 33 | 34 | Install the package with: 35 | 36 | ```shell 37 | $ go get github.com/tweekmonster/luser 38 | ``` 39 | 40 | A sample program called `luser` can be installed if you want to see the 41 | fallback results on your platform: 42 | 43 | ```shell 44 | $ CGO_ENABLED=0 go install github.com/tweekmonster/luser/cmd/luser 45 | $ luser -c 46 | $ luser username 47 | $ luser 1000 48 | ``` 49 | 50 | 51 | ## Fallback lookup methods 52 | 53 | `os/user` functions are used when built with `cgo`. Otherwise, it falls back 54 | to one of the following lookup methods: 55 | 56 | | Method | Used for | 57 | |---------------|----------------------------------------------------------------| 58 | | `/etc/passwd` | Parsed to lookup user information. (Unix, Linux) | 59 | | `/etc/group` | Parsed to lookup group information. (Unix, Linux) | 60 | | `getent` | Optional. Find user/group information. (Unix, Linux) | 61 | | `dscacheutil` | Lookup user/group information via Directory Services. (Darwin) | 62 | | `id` | Finding a user's groups when using `GroupIds()`. | 63 | 64 | **Note:** Windows should always work regardless of the build platform since it 65 | uses `syscall` instead of `cgo`. 66 | 67 | 68 | ## Caveats 69 | 70 | - `luser.User` and `luser.Group` are new types. The underlying `user.*` types 71 | are embedded, however (e.g. `u.User`). 72 | - The lookup methods use `exec.Command()` and will be noticeably slower if 73 | you're looking up users and groups a lot. 74 | - Group-releated functions will only be available when compiling with Go 1.7+. 75 | -------------------------------------------------------------------------------- /pkg/user/current_cgo.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && cgo 2 | // +build !windows,cgo 3 | 4 | package user 5 | 6 | import "os/user" 7 | 8 | func currentUser() (*User, error) { 9 | u, err := user.Current() 10 | if err != nil { 11 | return nil, err 12 | } 13 | return &User{User: u}, nil 14 | } 15 | -------------------------------------------------------------------------------- /pkg/user/current_default.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !cgo 2 | // +build !windows,!cgo 3 | 4 | package user 5 | 6 | import ( 7 | "os" 8 | "os/user" 9 | "strconv" 10 | "sync" 11 | ) 12 | 13 | var currentUid = -1 14 | var current *User 15 | var currentMu sync.Mutex 16 | 17 | func init() { 18 | fallbackEnabled = true 19 | } 20 | 21 | // The default stub uses $USER and $HOME to fill in the user. A more reliable 22 | // method will be tried before falling back to the stub. 23 | func currentUser() (*User, error) { 24 | uid := os.Getuid() 25 | if uid >= 0 { 26 | if uid == currentUid && current != nil { 27 | return current, nil 28 | } 29 | 30 | currentMu.Lock() 31 | defer currentMu.Unlock() 32 | 33 | currentUid = uid 34 | u, err := lookupId(strconv.Itoa(uid)) 35 | current = u 36 | if err == nil { 37 | return u, nil 38 | } 39 | } 40 | 41 | if u, err := user.Current(); err == nil { 42 | return luser(u), nil 43 | } 44 | 45 | return nil, ErrCurrentUser 46 | } 47 | -------------------------------------------------------------------------------- /pkg/user/ds.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "os/exec" 7 | "os/user" 8 | "strings" 9 | ) 10 | 11 | var dscacheutilExe = "" // path to 'dscacheutil' program. 12 | 13 | func init() { 14 | if path, err := exec.LookPath("dscacheutil"); err == nil { 15 | dscacheutilExe = path 16 | } 17 | } 18 | 19 | func dsParseUser(data []byte) *user.User { 20 | u := &user.User{} 21 | 22 | scanner := bufio.NewScanner(bytes.NewReader(data)) 23 | for scanner.Scan() { 24 | line := scanner.Text() 25 | if i := strings.Index(line, ": "); i >= 0 { 26 | value := line[i+2:] 27 | 28 | switch line[:i] { 29 | case "name": 30 | u.Username = value 31 | case "uid": 32 | u.Uid = value 33 | case "gid": 34 | u.Gid = value 35 | case "dir": 36 | u.HomeDir = value 37 | case "gecos": 38 | if i := strings.Index(value, ","); i >= 0 { 39 | value = value[:i] 40 | } 41 | u.Name = value 42 | } 43 | } 44 | } 45 | 46 | return u 47 | } 48 | 49 | func dsUser(username string) (*user.User, error) { 50 | data, err := command(dscacheutilExe, "-q", "user", "-a", "name", username) 51 | if err != nil { 52 | return nil, err 53 | } 54 | 55 | u := dsParseUser(data) 56 | 57 | if u.Uid != "" && u.Username != "" && u.HomeDir != "" { 58 | return u, nil 59 | } 60 | 61 | return nil, errNotFound 62 | } 63 | 64 | func dsUserId(uid string) (*user.User, error) { 65 | data, err := command(dscacheutilExe, "-q", "user", "-a", "uid", uid) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | u := dsParseUser(data) 71 | 72 | if u.Uid != "" && u.Username != "" && u.HomeDir != "" { 73 | return u, nil 74 | } 75 | 76 | return nil, errNotFound 77 | } 78 | -------------------------------------------------------------------------------- /pkg/user/ds_group.go: -------------------------------------------------------------------------------- 1 | //go:build go1.7 2 | // +build go1.7 3 | 4 | package user 5 | 6 | import ( 7 | "bufio" 8 | "bytes" 9 | "os/user" 10 | "strings" 11 | ) 12 | 13 | func dsParseGroup(data []byte) *user.Group { 14 | g := &user.Group{} 15 | 16 | scanner := bufio.NewScanner(bytes.NewReader(data)) 17 | for scanner.Scan() { 18 | line := scanner.Text() 19 | if i := strings.Index(line, ": "); i >= 0 { 20 | value := line[i+2:] 21 | 22 | switch line[:i] { 23 | case "name": 24 | g.Name = value 25 | case "gid": 26 | g.Gid = value 27 | } 28 | } 29 | } 30 | 31 | return g 32 | } 33 | 34 | func dsGroup(group string) (*user.Group, error) { 35 | data, err := command(dscacheutilExe, "-q", "group", "-a", "name", group) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | g := dsParseGroup(data) 41 | if g.Name != "" && g.Gid != "" { 42 | return g, nil 43 | } 44 | 45 | return nil, errNotFound 46 | } 47 | 48 | func dsGroupId(group string) (*user.Group, error) { 49 | data, err := command(dscacheutilExe, "-q", "group", "-a", "gid", group) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | g := dsParseGroup(data) 55 | if g.Name != "" && g.Gid != "" { 56 | return g, nil 57 | } 58 | 59 | return nil, errNotFound 60 | } 61 | -------------------------------------------------------------------------------- /pkg/user/errors.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | import "strconv" 4 | 5 | // UnknownUserIdError is returned by LookupId when 6 | // a user cannot be found. 7 | type UnknownUserIdError int 8 | 9 | func (e UnknownUserIdError) Error() string { 10 | return "user: unknown userid " + strconv.Itoa(int(e)) 11 | } 12 | 13 | // UnknownUserError is returned by Lookup when 14 | // a user cannot be found. 15 | type UnknownUserError string 16 | 17 | func (e UnknownUserError) Error() string { 18 | return "user: unknown user " + string(e) 19 | } 20 | 21 | // UnknownGroupIdError is returned by LookupGroupId when 22 | // a group cannot be found. 23 | type UnknownGroupIdError string 24 | 25 | func (e UnknownGroupIdError) Error() string { 26 | return "group: unknown groupid " + string(e) 27 | } 28 | 29 | // UnknownGroupError is returned by LookupGroup when 30 | // a group cannot be found. 31 | type UnknownGroupError string 32 | 33 | func (e UnknownGroupError) Error() string { 34 | return "group: unknown group " + string(e) 35 | } 36 | -------------------------------------------------------------------------------- /pkg/user/group.go: -------------------------------------------------------------------------------- 1 | //go:build go1.7 2 | // +build go1.7 3 | 4 | package user 5 | 6 | import "os/user" 7 | 8 | // Group represents a grouping of users. Embedded *user.Group reference: 9 | // https://golang.org/pkg/os/user/#Group 10 | type Group struct { 11 | *user.Group 12 | 13 | // IsLuser is a flag indicating if the user was found without cgo. 14 | IsLuser bool 15 | } 16 | 17 | // LookupGroup looks up a group by name. If the group cannot be found, the 18 | // returned error is of type UnknownGroupError. 19 | func LookupGroup(name string) (*Group, error) { 20 | if fallbackEnabled { 21 | return lookupGroup(name) 22 | } 23 | 24 | g, err := user.LookupGroup(name) 25 | if err == nil { 26 | return &Group{Group: g}, err 27 | } 28 | 29 | return nil, err 30 | } 31 | 32 | // LookupGroupId looks up a group by groupid. If the group cannot be found, the 33 | // returned error is of type UnknownGroupIdError. 34 | func LookupGroupId(gid string) (*Group, error) { 35 | if fallbackEnabled { 36 | return lookupGroupId(gid) 37 | } 38 | 39 | g, err := user.LookupGroupId(gid) 40 | if err == nil { 41 | return &Group{Group: g}, err 42 | } 43 | 44 | return nil, err 45 | } 46 | 47 | // GroupIds returns the list of group IDs that the user is a member of. 48 | func (u *User) GroupIds() ([]string, error) { 49 | if u.IsLuser { 50 | return u.lookupUserGroupIds() 51 | } 52 | return u.User.GroupIds() 53 | } 54 | 55 | // GroupNames returns the list of group names that the user is a member of. 56 | func (u *User) GroupNames() ([]string, error) { 57 | if u.IsLuser { 58 | return u.lookupUserGroupNames() 59 | } 60 | groupIds, err := u.User.GroupIds() 61 | if err != nil { 62 | return nil, err 63 | } 64 | var names []string 65 | for _, gid := range groupIds { 66 | group, err := user.LookupGroupId(gid) 67 | if err != nil { 68 | continue 69 | } 70 | names = append(names, group.Name) 71 | } 72 | 73 | return names, nil 74 | } 75 | -------------------------------------------------------------------------------- /pkg/user/id.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | import ( 4 | "bytes" 5 | "log" 6 | "os/exec" 7 | "strings" 8 | ) 9 | 10 | var idExe = "" // path to the 'id' program. 11 | 12 | func init() { 13 | if path, err := exec.LookPath("/usr/bin/id"); err == nil { 14 | idExe = path 15 | } else { 16 | log.Fatal(err) 17 | } 18 | } 19 | 20 | func idGroupList(u *User) ([]string, error) { 21 | data, err := command(idExe, "-G", u.Username) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | data = bytes.TrimSpace(data) 27 | if len(data) > 0 { 28 | return strings.Fields(string(data)), nil 29 | } 30 | 31 | return nil, errNotFound 32 | } 33 | 34 | func idGroupNameList(u *User) ([]string, error) { 35 | data, err := command(idExe, "-Gn", u.Username) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | data = bytes.TrimSpace(data) 41 | if len(data) > 0 { 42 | return strings.Fields(string(data)), nil 43 | } 44 | 45 | return nil, errNotFound 46 | } 47 | -------------------------------------------------------------------------------- /pkg/user/lookup.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package user 5 | 6 | import "strconv" 7 | 8 | // Lookup a user by username. 9 | func lookupUser(username string) (*User, error) { 10 | if _, err := strconv.Atoi(username); err == nil { 11 | return nil, UnknownUserError(username) 12 | } 13 | 14 | if dscacheutilExe != "" { 15 | u, err := dsUser(username) 16 | if err == nil { 17 | return luser(u), nil 18 | } 19 | } 20 | 21 | u, err := getentUser(username) 22 | if err == nil { 23 | return luser(u), nil 24 | } 25 | 26 | return nil, UnknownUserError(username) 27 | } 28 | 29 | // Lookup user by UID. 30 | func lookupId(uid string) (*User, error) { 31 | id, err := strconv.Atoi(uid) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | if dscacheutilExe != "" { 37 | u, err := dsUserId(uid) 38 | if err == nil { 39 | return luser(u), nil 40 | } 41 | } 42 | 43 | u, err := getentUser(uid) 44 | if err == nil { 45 | return luser(u), nil 46 | } 47 | 48 | return nil, UnknownUserIdError(id) 49 | } 50 | 51 | func (u *User) lookupUserGroupIds() ([]string, error) { 52 | if idExe != "" { 53 | return idGroupList(u) 54 | } 55 | 56 | return nil, ErrListGroups 57 | } 58 | 59 | func (u *User) lookupUserGroupNames() ([]string, error) { 60 | if idExe != "" { 61 | return idGroupNameList(u) 62 | } 63 | 64 | return nil, ErrListGroups 65 | } 66 | -------------------------------------------------------------------------------- /pkg/user/lookup_group.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && go1.7 2 | // +build !windows,go1.7 3 | 4 | package user 5 | 6 | import "strconv" 7 | 8 | func lookupGroup(name string) (*Group, error) { 9 | if _, err := strconv.Atoi(name); err == nil { 10 | return nil, UnknownGroupError(name) 11 | } 12 | 13 | if dscacheutilExe != "" { 14 | g, err := dsGroup(name) 15 | if err == nil { 16 | return lgroup(g), nil 17 | } 18 | } 19 | 20 | g, err := getentGroup(name) 21 | if err == nil { 22 | return lgroup(g), nil 23 | } 24 | 25 | return nil, UnknownGroupError(name) 26 | } 27 | 28 | func lookupGroupId(gid string) (*Group, error) { 29 | _, err := strconv.Atoi(gid) 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | if dscacheutilExe != "" { 35 | g, err := dsGroupId(gid) 36 | if err == nil { 37 | return lgroup(g), nil 38 | } 39 | } 40 | 41 | g, err := getentGroup(gid) 42 | if err == nil { 43 | return lgroup(g), nil 44 | } 45 | 46 | return nil, UnknownGroupIdError(gid) 47 | } 48 | -------------------------------------------------------------------------------- /pkg/user/lookup_group_windows.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | func lookupGroup(name string) (*Group, error) { 4 | return nil, errWin 5 | } 6 | 7 | func lookupGroupId(gid string) (*Group, error) { 8 | return nil, errWin 9 | } 10 | 11 | func (u *User) lookupUserGroupIds() ([]string, error) { 12 | return nil, ErrListGroups 13 | } 14 | -------------------------------------------------------------------------------- /pkg/user/lookup_windows.go: -------------------------------------------------------------------------------- 1 | // Windows stub since syscall is used instead of cgo. os/user will work fine 2 | // when cross-compiling Windows. Aside from currentUser(), the other functions 3 | // should never excute since the original errors would need to include 4 | // "requires cgo" in the message before calling these. 5 | 6 | package user 7 | 8 | import ( 9 | "errors" 10 | "os/user" 11 | ) 12 | 13 | var errWin = errors.New("user: you should not get this error") 14 | 15 | func currentUser() (*User, error) { 16 | u, err := user.Current() 17 | if err != nil { 18 | return nil, err 19 | } 20 | return luser(u), nil 21 | } 22 | 23 | func lookupUser(username string) (*User, error) { 24 | return nil, errWin 25 | } 26 | 27 | func lookupId(uid string) (*User, error) { 28 | return nil, errWin 29 | } 30 | -------------------------------------------------------------------------------- /pkg/user/luser.go: -------------------------------------------------------------------------------- 1 | // Package luser is a drop-in replacement for 'os/user' which allows you to 2 | // lookup users and groups in cross-compiled builds without 'cgo'. 3 | // 4 | // 'os/user' requires 'cgo' to lookup users using the target OS's API. This is 5 | // the most reliable way to look up user and group information. However, 6 | // cross-compiling means that 'os/user' will only work for the OS you're using. 7 | // 'user.Current()' is usable when building without 'cgo', but doesn't always 8 | // work. The '$USER' and '$HOME' variables could be different from what you 9 | // expect or not even exist. 10 | // 11 | // If you want to cross-compile a relatively simple program that needs to write 12 | // a config file somewhere in the user's directory, the last thing you want to 13 | // do is figure out some elaborate build scheme involving virtual machines. 14 | // 15 | // When cgo is not available for a build, one of the following methods will be 16 | // used to lookup user and group information: 17 | // 18 | // | Method | Used for | 19 | // |---------------|----------------------------------------------------------------| 20 | // | `/etc/passwd` | Parsed to lookup user information. (Unix, Linux) | 21 | // | `/etc/group` | Parsed to lookup group information. (Unix, Linux) | 22 | // | `getent` | Optional. Find user/group information. (Unix, Linux) | 23 | // | `dscacheutil` | Lookup user/group information via Directory Services. (Darwin) | 24 | // | `id` | Finding a user's groups when using `GroupIds()`. | 25 | // 26 | // You should be able to simply replace 'user.' with 'luser.' (in most cases). 27 | package user 28 | 29 | import "os/user" 30 | 31 | // Switched to true in current_default.go. Windows does not use a fallback. 32 | var fallbackEnabled = false 33 | 34 | // User represents a user account. Embedded *user.User reference: 35 | // https://golang.org/pkg/os/user/#User 36 | type User struct { 37 | *user.User 38 | 39 | IsLuser bool // flag indicating if the user was found without cgo. 40 | } 41 | 42 | // Current returns the current user. On builds where cgo is available, this 43 | // returns the result from user.Current(). Otherwise, alternate lookup methods 44 | // are used before falling back to the built-in stub. 45 | func Current() (*User, error) { 46 | return currentUser() 47 | } 48 | 49 | // Lookup looks up a user by username. If the user cannot be found, the 50 | // returned error is of type UnknownUserError. 51 | func Lookup(username string) (*User, error) { 52 | if fallbackEnabled { 53 | return lookupUser(username) 54 | } 55 | 56 | u, err := user.Lookup(username) 57 | if err == nil { 58 | return &User{User: u}, nil 59 | } 60 | 61 | return nil, err 62 | } 63 | 64 | // LookupId looks up a user by userid. If the user cannot be found, the 65 | // returned error is of type UnknownUserIdError. 66 | func LookupId(uid string) (*User, error) { 67 | if fallbackEnabled { 68 | return lookupId(uid) 69 | } 70 | 71 | u, err := user.LookupId(uid) 72 | if err == nil { 73 | return &User{User: u}, nil 74 | } 75 | 76 | return nil, err 77 | } 78 | -------------------------------------------------------------------------------- /pkg/user/misc.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | import ( 4 | "errors" 5 | "os/exec" 6 | "os/user" 7 | ) 8 | 9 | // ErrListGroups returned when LookupGroupId() has no fallback. 10 | var ErrListGroups = errors.New("user: unable to list groups") 11 | 12 | // ErrCurrentUser returned when Current() fails to get the user. 13 | var ErrCurrentUser = errors.New("user: unable to get current user") 14 | 15 | // Generic internal error indicating that a search function did its job but 16 | // still couldn't find a user/group. 17 | var errNotFound = errors.New("not found") 18 | 19 | // Wrap user.User with luser.User 20 | func luser(u *user.User) *User { 21 | return &User{User: u, IsLuser: true} 22 | } 23 | 24 | // Convenience function for running a program and returning the output. 25 | func command(program string, args ...string) ([]byte, error) { 26 | cmd := exec.Command(program, args...) 27 | out, err := cmd.Output() 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | return out, nil 33 | } 34 | -------------------------------------------------------------------------------- /pkg/user/misc_group.go: -------------------------------------------------------------------------------- 1 | //go:build go1.7 2 | // +build go1.7 3 | 4 | package user 5 | 6 | import "os/user" 7 | 8 | // Wrap user.Group with luser.Group 9 | func lgroup(g *user.Group) *Group { 10 | return &Group{Group: g, IsLuser: true} 11 | } 12 | -------------------------------------------------------------------------------- /pkg/user/nss.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "os" 7 | "os/exec" 8 | "os/user" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | // GetentParseFiles tells the getent fallback to prefer parsing the 14 | // '/etc/passwd' and '/etc/group' files instead of executing the 'getent' 15 | // program. This is on by default since the speed of parsing the files 16 | // directly is comparible to the C API and is many times faster than executing 17 | // the 'getent' program. If false, the files will still be parsed if 'getent' 18 | // fails or isn't found. 19 | var GetentParseFiles = true 20 | 21 | var ( 22 | getentExe = "" // path to the 'getent' program. 23 | passwdFilePath = "" // path to the 'passwd' file. 24 | groupFilePath = "" // path to the 'group' file. 25 | ) 26 | 27 | func init() { 28 | if path, err := exec.LookPath("getent"); err == nil { 29 | getentExe = path 30 | } 31 | 32 | if _, err := os.Stat("/etc/passwd"); err == nil { 33 | passwdFilePath = "/etc/passwd" 34 | } 35 | 36 | if _, err := os.Stat("/etc/group"); err == nil { 37 | groupFilePath = "/etc/group" 38 | } 39 | } 40 | 41 | // Lookup user by parsing an database file first. If that fails, try with the 42 | // 'getent' program. 43 | func getent(database, key string) (string, error) { 44 | if !GetentParseFiles && getentExe != "" { 45 | data, err := command(getentExe, database, key) 46 | if err == nil { 47 | return string(bytes.TrimSpace(data)), nil 48 | } 49 | } 50 | 51 | dbfile := "" 52 | switch database { 53 | case "passwd": 54 | dbfile = passwdFilePath 55 | case "group": 56 | dbfile = groupFilePath 57 | } 58 | 59 | if dbfile != "" { 60 | if line, err := searchEntityDatabase(dbfile, key); err == nil { 61 | return line, nil 62 | } 63 | } 64 | 65 | return "", errNotFound 66 | } 67 | 68 | func getentUser(key string) (*user.User, error) { 69 | entity, err := getent("passwd", key) 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | parts := strings.Split(entity, ":") 75 | if len(parts) < 6 { 76 | return nil, errNotFound 77 | } 78 | 79 | // Get name from GECOS. 80 | name := parts[4] 81 | if i := strings.Index(name, ","); i >= 0 { 82 | name = name[:i] 83 | } 84 | 85 | return &user.User{ 86 | Username: parts[0], 87 | Uid: parts[2], 88 | Gid: parts[3], 89 | Name: name, 90 | HomeDir: parts[5], 91 | }, nil 92 | } 93 | 94 | func entityIndex(line string, index int) int { 95 | f := func(c rune) bool { 96 | if c == ':' { 97 | index-- 98 | return index == 0 99 | } 100 | return false 101 | } 102 | 103 | return strings.IndexFunc(line, f) 104 | } 105 | 106 | // Searches an entity database for a matching name or id. 107 | func searchEntityDatabase(filename, match string) (string, error) { 108 | file, err := os.Open(filename) 109 | if err != nil { 110 | return "", err 111 | } 112 | defer file.Close() 113 | 114 | id := false 115 | if _, err := strconv.Atoi(match); err == nil { 116 | id = true 117 | } 118 | 119 | prefix := match + ":" 120 | scanner := bufio.NewScanner(file) 121 | for scanner.Scan() { 122 | line := scanner.Text() 123 | if id { 124 | i := entityIndex(line, 2) 125 | if strings.HasPrefix(line[i+1:], prefix) { 126 | return line, nil 127 | } 128 | } else if strings.HasPrefix(line, prefix) { 129 | return line, nil 130 | } 131 | } 132 | 133 | return "", errNotFound 134 | } 135 | -------------------------------------------------------------------------------- /pkg/user/nss_group.go: -------------------------------------------------------------------------------- 1 | //go:build go1.7 2 | // +build go1.7 3 | 4 | package user 5 | 6 | import ( 7 | "os/user" 8 | "strings" 9 | ) 10 | 11 | func getentGroup(key string) (*user.Group, error) { 12 | entity, err := getent("group", key) 13 | if err != nil { 14 | return nil, err 15 | } 16 | 17 | parts := strings.Split(entity, ":") 18 | if len(parts) < 3 { 19 | return nil, errNotFound 20 | } 21 | 22 | return &user.Group{ 23 | Name: parts[0], 24 | Gid: parts[2], 25 | }, nil 26 | } 27 | -------------------------------------------------------------------------------- /pkg/util/buffer/bytes.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "sync" 7 | ) 8 | 9 | type BytesBuffer struct { 10 | buf *bytes.Buffer 11 | lines [][]byte 12 | *sync.Mutex 13 | } 14 | 15 | func NewBytesBuffer() *BytesBuffer { 16 | out := &BytesBuffer{ 17 | buf: &bytes.Buffer{}, 18 | lines: [][]byte{}, 19 | Mutex: &sync.Mutex{}, 20 | } 21 | return out 22 | } 23 | 24 | func (b *BytesBuffer) Write(p []byte) (n int, err error) { 25 | b.Lock() 26 | n, err = b.buf.Write(p) // and bytes.Buffer implements io.Writer 27 | b.Unlock() 28 | return // implicit 29 | } 30 | 31 | func (b *BytesBuffer) Close() error { 32 | b.byteLines() 33 | return nil 34 | } 35 | 36 | func (b *BytesBuffer) Lines() [][]byte { 37 | if b.lines != nil { 38 | return b.lines 39 | } 40 | b.Lock() 41 | b.byteLines() 42 | b.Unlock() 43 | return b.lines 44 | } 45 | 46 | func (b *BytesBuffer) byteLines() { 47 | s := bufio.NewScanner(b.buf) 48 | for s.Scan() { 49 | b.lines = append(b.lines, s.Bytes()) 50 | } 51 | } 52 | 53 | func (b *BytesBuffer) Bytes() []byte { 54 | return b.buf.Bytes() 55 | } 56 | 57 | func (b *BytesBuffer) Size() int64 { 58 | return int64(b.buf.Len()) 59 | } 60 | -------------------------------------------------------------------------------- /pkg/util/buffer/output.go: -------------------------------------------------------------------------------- 1 | package buffer 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "sync" 7 | ) 8 | 9 | type OutputBuffer struct { 10 | buf *bytes.Buffer 11 | lines []string 12 | *sync.Mutex 13 | } 14 | 15 | func NewOutputBuffer() *OutputBuffer { 16 | out := &OutputBuffer{ 17 | buf: &bytes.Buffer{}, 18 | lines: []string{}, 19 | Mutex: &sync.Mutex{}, 20 | } 21 | return out 22 | } 23 | 24 | func (b *OutputBuffer) Write(p []byte) (n int, err error) { 25 | b.Lock() 26 | n, err = b.buf.Write(p) // and bytes.Buffer implements io.Writer 27 | b.Unlock() 28 | return // implicit 29 | } 30 | 31 | func (b *OutputBuffer) Close() error { 32 | return nil 33 | } 34 | 35 | func (b *OutputBuffer) Lines() []string { 36 | b.Lock() 37 | s := bufio.NewScanner(b.buf) 38 | for s.Scan() { 39 | b.lines = append(b.lines, s.Text()) 40 | } 41 | b.Unlock() 42 | return b.lines 43 | } 44 | -------------------------------------------------------------------------------- /pkg/util/devicereg/devicereg.go: -------------------------------------------------------------------------------- 1 | package devicereg 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // Device represents any resource that needs cleanup 9 | type Device interface { 10 | ID() string 11 | Shutdown() error 12 | } 13 | 14 | // DeviceRegistry keeps track of all devices and handles graceful shutdown 15 | type DeviceRegistry struct { 16 | devices map[string]Device 17 | mu sync.RWMutex 18 | } 19 | 20 | // NewDeviceRegistry creates a new device registry 21 | func NewDeviceRegistry() *DeviceRegistry { 22 | return &DeviceRegistry{ 23 | devices: make(map[string]Device), 24 | } 25 | } 26 | 27 | // AddDevice adds a device to the registry 28 | func (r *DeviceRegistry) AddDevice(device Device) { 29 | r.mu.Lock() 30 | defer r.mu.Unlock() 31 | 32 | r.devices[device.ID()] = device 33 | } 34 | 35 | // RemoveDevice removes a device from the registry 36 | func (r *DeviceRegistry) RemoveDevice(device Device) bool { 37 | r.mu.Lock() 38 | defer r.mu.Unlock() 39 | 40 | if _, exists := r.devices[device.ID()]; exists { 41 | delete(r.devices, device.ID()) 42 | return true 43 | } 44 | return false 45 | } 46 | 47 | // ShutdownAllDevices gracefully shuts down all devices 48 | func (r *DeviceRegistry) ShutdownAllDevices(ctx context.Context) error { 49 | r.mu.RLock() 50 | // Create a copy of device IDs to avoid holding the lock during shutdown 51 | deviceIDs := make([]string, 0, len(r.devices)) 52 | for id := range r.devices { 53 | deviceIDs = append(deviceIDs, id) 54 | } 55 | r.mu.RUnlock() 56 | 57 | // Process each device one by one 58 | for _, id := range deviceIDs { 59 | // Check if context is canceled during shutdown 60 | select { 61 | case <-ctx.Done(): 62 | return ctx.Err() 63 | default: 64 | // Get the device (with read lock) 65 | r.mu.RLock() 66 | device, exists := r.devices[id] 67 | r.mu.RUnlock() 68 | 69 | if exists { 70 | if err := device.Shutdown(); err != nil { 71 | } 72 | } 73 | } 74 | } 75 | 76 | return nil 77 | } 78 | 79 | // Count returns the number of devices in the registry 80 | func (r *DeviceRegistry) Count() int { 81 | r.mu.RLock() 82 | defer r.mu.RUnlock() 83 | return len(r.devices) 84 | } 85 | -------------------------------------------------------------------------------- /pkg/util/dns.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "time" 8 | ) 9 | 10 | type DNSResolver struct { 11 | Timeout time.Duration 12 | Server string 13 | } 14 | 15 | func NewDNSResolver() *DNSResolver { 16 | return &DNSResolver{} 17 | } 18 | 19 | func (r *DNSResolver) LookupHost(hostname string) ([]net.IP, error) { 20 | if r.Timeout == 0 { 21 | r.Timeout = 2 * time.Second 22 | } 23 | 24 | ctx, cancel := context.WithTimeout(context.Background(), r.Timeout) 25 | defer cancel() 26 | 27 | dialFn := func(ctx context.Context, network, address string) (net.Conn, error) { 28 | d := net.Dialer{ 29 | Timeout: r.Timeout, 30 | } 31 | return d.DialContext(ctx, network, address) 32 | } 33 | 34 | if r.Server != "" { 35 | dialFn = func(ctx context.Context, network, address string) (net.Conn, error) { 36 | d := net.Dialer{ 37 | Timeout: r.Timeout, 38 | } 39 | return d.DialContext(ctx, network, r.Server) 40 | } 41 | } 42 | 43 | // Create a custom resolver 44 | resolver := &net.Resolver{ 45 | PreferGo: true, 46 | Dial: dialFn, 47 | } 48 | 49 | addrs, err := resolver.LookupIPAddr(ctx, hostname) 50 | if err != nil { 51 | return nil, fmt.Errorf("failed to resolve %s: %w", hostname, err) 52 | } 53 | 54 | ips := make([]net.IP, len(addrs)) 55 | for i, addr := range addrs { 56 | ips[i] = addr.IP 57 | } 58 | 59 | return ips, nil 60 | } 61 | -------------------------------------------------------------------------------- /pkg/util/goagain/legacy.go: -------------------------------------------------------------------------------- 1 | package goagain 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | // Block this goroutine awaiting signals. Signals are handled as they 11 | // are by Nginx and Unicorn: . 12 | func AwaitSignals(l net.Listener) (err error) { 13 | _, err = Wait(l) 14 | return 15 | } 16 | 17 | // Convert and validate the GOAGAIN_FD, GOAGAIN_NAME, and GOAGAIN_PPID 18 | // environment variables. If all three are present and in order, this 19 | // is a child process that may pick up where the parent left off. 20 | func GetEnvs() (l net.Listener, ppid int, err error) { 21 | if _, err = fmt.Sscan(os.Getenv("GOAGAIN_PPID"), &ppid); nil != err { 22 | return 23 | } 24 | l, err = Listener() 25 | return 26 | } 27 | 28 | // Send SIGQUIT to the given ppid in order to complete the handoff to the 29 | // child process. 30 | func KillParent(ppid int) error { 31 | return syscall.Kill(ppid, syscall.SIGQUIT) 32 | } 33 | -------------------------------------------------------------------------------- /pkg/util/gz.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "io" 7 | ) 8 | 9 | func Ungz(b []byte) ([]byte, error) { 10 | reader := bytes.NewReader(b) 11 | gzreader, err := gzip.NewReader(reader) 12 | if err != nil { 13 | return nil, err 14 | } 15 | defer gzreader.Close() 16 | 17 | b, err = io.ReadAll(gzreader) 18 | return b, err 19 | } 20 | -------------------------------------------------------------------------------- /pkg/util/io.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // PipeReader represents anything that can read (like gossh.Channel) 8 | type PipeReader interface { 9 | Read(p []byte) (n int, err error) 10 | } 11 | 12 | // PipeWriter represents anything that can write (like io.WriteCloser) 13 | type PipeWriter interface { 14 | Write(p []byte) (n int, err error) 15 | Close() error 16 | } 17 | 18 | // ErrorWriter represents anything that can handle error output 19 | // Similar to gossh.Channel's Stderr() method 20 | type ErrorWriter interface { 21 | Stderr() io.ReadWriter 22 | } 23 | 24 | // SetupPipes sets up bidirectional piping between a reader/writer and standard pipes 25 | // Returns pipes for stdin and stderr with proper types 26 | func SetupPipes[T interface { 27 | PipeReader 28 | PipeWriter 29 | ErrorWriter 30 | }](channel T) (stdin io.ReadCloser, stderr io.WriteCloser, cleanup func()) { 31 | // Create pipes with correct orientation 32 | stdinReader, stdinWriter := io.Pipe() // stdinWriter is for input to the process 33 | stderrReader, stderrWriter := io.Pipe() // stderrReader is for reading error output 34 | 35 | // Forward data from channel to stdinWriter 36 | go func(c T, w io.WriteCloser) { 37 | defer w.Close() 38 | io.Copy(w, c) 39 | }(channel, stdinWriter) 40 | 41 | // Forward data from stderrReader to channel's stderr 42 | go func(c T, e io.ReadCloser) { 43 | defer e.Close() 44 | io.Copy(c.Stderr(), e) 45 | }(channel, stderrReader) 46 | 47 | // Return cleanup function 48 | cleanup = func() { 49 | stdinReader.Close() 50 | stdinWriter.Close() 51 | stderrReader.Close() 52 | stderrWriter.Close() 53 | } 54 | 55 | return stdinReader, stderrWriter, cleanup 56 | } 57 | 58 | // Alternative version with more flexibility by using separate generics 59 | func SetupFlexiblePipes[ 60 | R PipeReader, 61 | E interface { 62 | Stderr() io.WriteCloser 63 | }, 64 | ](reader R, errorWriter E) (stdin *io.PipeWriter, stderr *io.PipeReader, cleanup func()) { 65 | stdinReader, stdinWriter := io.Pipe() 66 | stderrReader, stderrWriter := io.Pipe() 67 | 68 | // Forward data from reader to stdinWriter 69 | go func(r R, w *io.PipeWriter) { 70 | defer w.Close() 71 | io.Copy(w, r) 72 | }(reader, stdinWriter) 73 | 74 | // Forward data from stderrReader to errorWriter's stderr 75 | go func(e E, er *io.PipeReader) { 76 | defer er.Close() 77 | io.Copy(e.Stderr(), er) 78 | }(errorWriter, stderrReader) 79 | 80 | // Return cleanup function 81 | cleanup = func() { 82 | stdinReader.Close() 83 | stdinWriter.Close() 84 | stderrReader.Close() 85 | stderrWriter.Close() 86 | } 87 | 88 | return stdinWriter, stderrReader, cleanup 89 | } 90 | -------------------------------------------------------------------------------- /pkg/util/io/bytesreadcloser.go: -------------------------------------------------------------------------------- 1 | package io 2 | 3 | import ( 4 | "bytes" 5 | ) 6 | 7 | // BytesReadCloser is a basic in-memory reader with a closer interface. 8 | type BytesReadCloser struct { 9 | Buf *bytes.Buffer 10 | } 11 | 12 | // Read just returns the buffer. 13 | func (r BytesReadCloser) Read(b []byte) (n int, err error) { 14 | return r.Buf.Read(b) 15 | } 16 | 17 | // Close is a no-op. 18 | func (r BytesReadCloser) Close() error { 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /pkg/util/io/filesystem.go: -------------------------------------------------------------------------------- 1 | package io 2 | 3 | import ( 4 | "os" 5 | ) 6 | 7 | // GetPathMode returns a os.FileMode for the provided path. 8 | func GetPathMode(path string) (os.FileMode, error) { 9 | fi, err := os.Stat(path) 10 | if err != nil { 11 | return os.FileMode(0000), err 12 | } 13 | 14 | mode, _, _ := GetOwnerMode(fi) 15 | return mode, nil 16 | } 17 | -------------------------------------------------------------------------------- /pkg/util/io/filesystem_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package io 4 | 5 | import ( 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | func GetOwnerMode(fInfo os.FileInfo) (os.FileMode, int, int) { 11 | mode := fInfo.Mode() 12 | uid := int(fInfo.Sys().(*syscall.Stat_t).Uid) 13 | gid := int(fInfo.Sys().(*syscall.Stat_t).Gid) 14 | return mode, uid, gid 15 | } 16 | -------------------------------------------------------------------------------- /pkg/util/io/filesystem_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package io 4 | 5 | import ( 6 | "os" 7 | ) 8 | 9 | func GetOwnerMode(fInfo os.FileInfo) (os.FileMode, int, int) { 10 | return fInfo.Mode(), -1, -1 11 | } 12 | -------------------------------------------------------------------------------- /pkg/util/io/quotawriter.go: -------------------------------------------------------------------------------- 1 | package io 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | ) 7 | 8 | // QuotaWriter returns an error once a given write quota gets exceeded. 9 | type QuotaWriter struct { 10 | writer io.Writer 11 | quota int64 12 | n int64 13 | } 14 | 15 | // NewQuotaWriter returns a new QuotaWriter wrapping the given writer. 16 | // 17 | // If the given quota is negative, then no quota is applied. 18 | func NewQuotaWriter(writer io.Writer, quota int64) *QuotaWriter { 19 | return &QuotaWriter{ 20 | writer: writer, 21 | quota: quota, 22 | } 23 | } 24 | 25 | // Write implements the Writer interface. 26 | func (w *QuotaWriter) Write(p []byte) (n int, err error) { 27 | if w.quota >= 0 { 28 | w.n += int64(len(p)) 29 | if w.n > w.quota { 30 | return 0, fmt.Errorf("reached %d bytes, exceeding quota of %d", w.n, w.quota) 31 | } 32 | } 33 | 34 | return w.writer.Write(p) 35 | } 36 | -------------------------------------------------------------------------------- /pkg/util/io/readseeker.go: -------------------------------------------------------------------------------- 1 | package io 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | type readSeeker struct { 8 | io.Reader 9 | io.Seeker 10 | } 11 | 12 | // NewReadSeeker combines provided io.Reader and io.Seeker into a new io.ReadSeeker. 13 | func NewReadSeeker(reader io.Reader, seeker io.Seeker) io.ReadSeeker { 14 | return &readSeeker{Reader: reader, Seeker: seeker} 15 | } 16 | 17 | func (r *readSeeker) Read(p []byte) (n int, err error) { 18 | return r.Reader.Read(p) 19 | } 20 | 21 | func (r *readSeeker) Seek(offset int64, whence int) (int64, error) { 22 | return r.Seeker.Seek(offset, whence) 23 | } 24 | -------------------------------------------------------------------------------- /pkg/util/io/writer.go: -------------------------------------------------------------------------------- 1 | package io 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | ) 7 | 8 | func WriteAll(w io.Writer, data []byte) error { 9 | buf := bytes.NewBuffer(data) 10 | 11 | toWrite := int64(buf.Len()) 12 | for { 13 | n, err := io.Copy(w, buf) 14 | if err != nil { 15 | return err 16 | } 17 | 18 | toWrite -= n 19 | if toWrite <= 0 { 20 | return nil 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /pkg/util/ip.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "net" 4 | 5 | func IsIPAddress(input string) bool { 6 | ip := net.ParseIP(input) 7 | return ip != nil 8 | } 9 | 10 | func IsIPv4(input string) bool { 11 | ip := net.ParseIP(input) 12 | return ip != nil && ip.To4() != nil 13 | } 14 | -------------------------------------------------------------------------------- /pkg/util/md5.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "crypto/sha256" 7 | "fmt" 8 | "io" 9 | "os" 10 | ) 11 | 12 | func Sha256Bytes(b []byte) string { 13 | return sha256hash(bytes.NewReader(b)) 14 | } 15 | 16 | func sha256hash(b io.Reader) string { 17 | h := sha256.New() 18 | _, err := io.Copy(h, b) 19 | if err != nil { 20 | return "" 21 | } 22 | return fmt.Sprintf("%x", h.Sum(nil)) 23 | } 24 | 25 | func Md5Bytes(b []byte) string { 26 | return md5hash(bytes.NewReader(b)) 27 | } 28 | 29 | func Md5File(file string) string { 30 | f, err := os.Open(file) 31 | if err != nil { 32 | return "" 33 | } 34 | defer f.Close() 35 | 36 | return md5hash(f) 37 | } 38 | 39 | func md5hash(b io.Reader) string { 40 | h := md5.New() 41 | _, err := io.Copy(h, b) 42 | if err != nil { 43 | return "" 44 | } 45 | return fmt.Sprintf("%x", h.Sum(nil)) 46 | } 47 | -------------------------------------------------------------------------------- /pkg/util/port.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "net" 4 | 5 | func GetFreePort() (int, error) { 6 | addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 7 | if err != nil { 8 | return 0, err 9 | } 10 | 11 | l, err := net.ListenTCP("tcp", addr) 12 | if err != nil { 13 | return 0, err 14 | } 15 | defer l.Close() 16 | return l.Addr().(*net.TCPAddr).Port, nil 17 | } 18 | -------------------------------------------------------------------------------- /pkg/util/rand.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "math/rand" 5 | "strings" 6 | "time" 7 | ) 8 | 9 | func init() { 10 | rand.New(rand.NewSource(time.Now().UnixNano())) 11 | } 12 | 13 | func RandomString(n int) string { 14 | var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") 15 | 16 | b := make([]rune, n) 17 | for i := range b { 18 | b[i] = letterRunes[rand.Intn(len(letterRunes))] 19 | } 20 | return string(b) 21 | } 22 | 23 | func RandomStringLower(n int) string { 24 | return strings.ToLower(RandomString(n)) 25 | } 26 | -------------------------------------------------------------------------------- /pkg/util/shadow/shadow.go: -------------------------------------------------------------------------------- 1 | package shadow 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "strconv" 9 | "strings" 10 | "time" 11 | 12 | "github.com/GehirnInc/crypt" 13 | _ "github.com/GehirnInc/crypt/sha256_crypt" 14 | _ "github.com/GehirnInc/crypt/sha512_crypt" 15 | ) 16 | 17 | const secsInDay = 86400 //int64(24*time.Hour/time.Second) 18 | const ShadowFile = "/etc/shadow" 19 | 20 | var ( 21 | ErrNoSuchUser = errors.New("shadow: user entry is not present in database") 22 | ErrWrongPassword = errors.New("shadow: wrong password") 23 | ) 24 | 25 | type Shadow struct { 26 | shadowFile string 27 | entries []Entry 28 | } 29 | 30 | func New() *Shadow { 31 | return &Shadow{ 32 | shadowFile: ShadowFile, 33 | entries: []Entry{}, 34 | } 35 | } 36 | 37 | func (s *Shadow) Read() error { 38 | var err error 39 | s.entries, err = Read(s.shadowFile) 40 | return err 41 | } 42 | 43 | func (s *Shadow) ReadFile(shadowFile string) error { 44 | var err error 45 | s.shadowFile = shadowFile 46 | s.entries, err = Read(s.shadowFile) 47 | return err 48 | } 49 | 50 | func (s *Shadow) Lookup(name string) (*Entry, error) { 51 | for _, entry := range s.entries { 52 | if entry.Name == name { 53 | return &entry, nil 54 | } 55 | } 56 | 57 | return nil, ErrNoSuchUser 58 | } 59 | 60 | type Entry struct { 61 | // User login name. 62 | Name string 63 | 64 | // Hashed user password. 65 | Pass string 66 | 67 | // Days since Jan 1, 1970 password was last changed. 68 | LastChange int 69 | 70 | // The number of days the user will have to wait before she will be allowed to 71 | // change her password again. 72 | // 73 | // -1 if password aging is disabled. 74 | MinPassAge int 75 | 76 | // The number of days after which the user will have to change her password. 77 | // 78 | // -1 is password aging is disabled. 79 | MaxPassAge int 80 | 81 | // The number of days before a password is going to expire (see the maximum 82 | // password age above) during which the user should be warned. 83 | // 84 | // -1 is password aging is disabled. 85 | WarnPeriod int 86 | 87 | // The number of days after a password has expired (see the maximum 88 | // password age above) during which the password should still be accepted. 89 | // 90 | // -1 is password aging is disabled. 91 | InactivityPeriod int 92 | 93 | // The date of expiration of the account, expressed as the number of days 94 | // since Jan 1, 1970. 95 | // 96 | // -1 is account never expires. 97 | AcctExpiry int 98 | 99 | // Unused now. 100 | Flags int 101 | } 102 | 103 | func (e *Entry) IsAccountValid() bool { 104 | if e.AcctExpiry == -1 { 105 | return true 106 | } 107 | 108 | nowDays := int(time.Now().Unix() / secsInDay) 109 | return nowDays < e.AcctExpiry 110 | } 111 | 112 | func (e *Entry) IsPasswordValid() bool { 113 | if e.LastChange == -1 || e.MaxPassAge == -1 || e.InactivityPeriod == -1 { 114 | return true 115 | } 116 | 117 | nowDays := int(time.Now().Unix() / secsInDay) 118 | return nowDays < e.LastChange+e.MaxPassAge+e.InactivityPeriod 119 | } 120 | 121 | func (e *Entry) VerifyPassword(pass string) (err error) { 122 | // Do not permit null and locked passwords. 123 | if e.Pass == "" { 124 | return errors.New("verify: null password") 125 | } 126 | if e.Pass[0] == '!' { 127 | return errors.New("verify: locked password") 128 | } 129 | 130 | // crypt.NewFromHash may panic on unknown hash function. 131 | defer func() { 132 | if rcvr := recover(); rcvr != nil { 133 | err = fmt.Errorf("%v", rcvr) 134 | } 135 | }() 136 | 137 | if err := crypt.NewFromHash(e.Pass).Verify(e.Pass, []byte(pass)); err != nil { 138 | if errors.Is(err, crypt.ErrKeyMismatch) { 139 | return ErrWrongPassword 140 | } 141 | return err 142 | } 143 | return nil 144 | } 145 | 146 | // Read reads system shadow passwords database and returns all entires in it. 147 | func Read(shadowFile string) ([]Entry, error) { 148 | f, err := os.Open(shadowFile) 149 | if err != nil { 150 | return nil, err 151 | } 152 | s := bufio.NewScanner(f) 153 | 154 | var res []Entry 155 | for s.Scan() { 156 | ent, err := parseEntry(s.Text()) 157 | if err != nil { 158 | return res, err 159 | } 160 | 161 | res = append(res, *ent) 162 | } 163 | if err := s.Err(); err != nil { 164 | return res, err 165 | } 166 | return res, nil 167 | } 168 | 169 | func parseEntry(line string) (*Entry, error) { 170 | parts := strings.Split(line, ":") 171 | if len(parts) != 9 { 172 | return nil, errors.New("read: malformed entry") 173 | } 174 | 175 | res := &Entry{ 176 | Name: parts[0], 177 | Pass: parts[1], 178 | } 179 | 180 | for i, value := range [...]*int{ 181 | &res.LastChange, &res.MinPassAge, &res.MaxPassAge, 182 | &res.WarnPeriod, &res.InactivityPeriod, &res.AcctExpiry, &res.Flags, 183 | } { 184 | if parts[2+i] == "" { 185 | *value = -1 186 | } else { 187 | var err error 188 | *value, err = strconv.Atoi(parts[2+i]) 189 | if err != nil { 190 | return nil, fmt.Errorf("read: invalid value for field %d", 2+i) 191 | } 192 | } 193 | } 194 | 195 | return res, nil 196 | } 197 | 198 | func Lookup(name string) (*Entry, error) { 199 | return LookupFile(name, ShadowFile) 200 | } 201 | 202 | func LookupFile(name, shadowFile string) (*Entry, error) { 203 | entries, err := Read(shadowFile) 204 | if err != nil { 205 | return nil, err 206 | } 207 | 208 | for _, entry := range entries { 209 | if entry.Name == name { 210 | return &entry, nil 211 | } 212 | } 213 | 214 | return nil, ErrNoSuchUser 215 | } 216 | -------------------------------------------------------------------------------- /pkg/util/shadow/shadow_test.go: -------------------------------------------------------------------------------- 1 | package shadow_test 2 | 3 | import ( 4 | "fmt" 5 | "ssh2incus/pkg/util/shadow" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestShadow(t *testing.T) { 12 | t.Run("null", func(t *testing.T) { 13 | null, err := shadow.LookupFile("root", "/dev/null") 14 | assert.Error(t, err, fmt.Sprintf("%s", err)) 15 | assert.Nil(t, null) 16 | }) 17 | 18 | t.Run("root", func(t *testing.T) { 19 | root, err := shadow.LookupFile("root", "test/shadow.txt") 20 | assert.Nil(t, err, fmt.Sprintf("%s", err)) 21 | assert.True(t, root.IsAccountValid()) 22 | assert.True(t, root.IsPasswordValid()) 23 | }) 24 | t.Run("nobody", func(t *testing.T) { 25 | root, err := shadow.LookupFile("nobody", "test/shadow.txt") 26 | assert.Nil(t, err, fmt.Sprintf("%s", err)) 27 | assert.True(t, root.IsAccountValid()) 28 | assert.True(t, root.IsPasswordValid()) 29 | }) 30 | 31 | t.Run("ubuntu", func(t *testing.T) { 32 | root, err := shadow.LookupFile("ubuntu", "test/shadow.txt") 33 | assert.Nil(t, err, fmt.Sprintf("%s", err)) 34 | err = root.VerifyPassword("test") 35 | assert.Nil(t, err, fmt.Sprintf("%s", err)) 36 | assert.True(t, root.IsPasswordValid()) 37 | }) 38 | 39 | } 40 | -------------------------------------------------------------------------------- /pkg/util/shadow/test/shadow.txt: -------------------------------------------------------------------------------- 1 | root:*:19597:0:99999:7::: 2 | nobody:*:19597:0:99999:7::: 3 | ubuntu:$y$j9T$5ptKOfIPetXNfZ6lHZNL4/$KzAqgQ0uYqb2XSG1Rj9Ny0T7JoUAvJB.114H43j9tZ9:19678:0:99999:7::: -------------------------------------------------------------------------------- /pkg/util/string.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "slices" 5 | "strings" 6 | ) 7 | 8 | // MapToEnvString converts a map[string]string to KEY=value format 9 | func MapToEnvString(m map[string]string) string { 10 | if len(m) == 0 { 11 | return "" 12 | } 13 | 14 | pairs := make([]string, 0, len(m)) 15 | for key, value := range m { 16 | pairs = append(pairs, key+"="+value) 17 | } 18 | 19 | slices.Sort(pairs) 20 | 21 | return strings.Join(pairs, " ") 22 | } 23 | -------------------------------------------------------------------------------- /pkg/util/structs/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Fatih Arslan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /pkg/util/structs/README.md: -------------------------------------------------------------------------------- 1 | # Structs [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/fatih/structs) [![Build Status](http://img.shields.io/travis/fatih/structs.svg?style=flat-square)](https://travis-ci.org/fatih/structs) [![Coverage Status](http://img.shields.io/coveralls/fatih/structs.svg?style=flat-square)](https://coveralls.io/r/fatih/structs) 2 | 3 | Structs contains various utilities to work with Go (Golang) structs. It was 4 | initially used by me to convert a struct into a `map[string]interface{}`. With 5 | time I've added other utilities for structs. It's basically a high level 6 | package based on primitives from the reflect package. Feel free to add new 7 | functions or improve the existing code. 8 | 9 | ## Install 10 | 11 | ```bash 12 | go get github.com/fatih/structs 13 | ``` 14 | 15 | ## Usage and Examples 16 | 17 | Just like the standard lib `strings`, `bytes` and co packages, `structs` has 18 | many global functions to manipulate or organize your struct data. Lets define 19 | and declare a struct: 20 | 21 | ```go 22 | type Server struct { 23 | Name string `json:"name,omitempty"` 24 | ID int 25 | Enabled bool 26 | users []string // not exported 27 | http.Server // embedded 28 | } 29 | 30 | server := &Server{ 31 | Name: "gopher", 32 | ID: 123456, 33 | Enabled: true, 34 | } 35 | ``` 36 | 37 | ```go 38 | // Convert a struct to a map[string]interface{} 39 | // => {"Name":"gopher", "ID":123456, "Enabled":true} 40 | m := structs.Map(server) 41 | 42 | // Convert the values of a struct to a []interface{} 43 | // => ["gopher", 123456, true] 44 | v := structs.Values(server) 45 | 46 | // Convert the names of a struct to a []string 47 | // (see "Names methods" for more info about fields) 48 | n := structs.Names(server) 49 | 50 | // Convert the values of a struct to a []*Field 51 | // (see "Field methods" for more info about fields) 52 | f := structs.Fields(server) 53 | 54 | // Return the struct name => "Server" 55 | n := structs.Name(server) 56 | 57 | // Check if any field of a struct is initialized or not. 58 | h := structs.HasZero(server) 59 | 60 | // Check if all fields of a struct is initialized or not. 61 | z := structs.IsZero(server) 62 | 63 | // Check if server is a struct or a pointer to struct 64 | i := structs.IsStruct(server) 65 | ``` 66 | 67 | ### Struct methods 68 | 69 | The structs functions can be also used as independent methods by creating a new 70 | `*structs.Struct`. This is handy if you want to have more control over the 71 | structs (such as retrieving a single Field). 72 | 73 | ```go 74 | // Create a new struct type: 75 | s := structs.New(server) 76 | 77 | m := s.Map() // Get a map[string]interface{} 78 | v := s.Values() // Get a []interface{} 79 | f := s.Fields() // Get a []*Field 80 | n := s.Names() // Get a []string 81 | f := s.Field(name) // Get a *Field based on the given field name 82 | f, ok := s.FieldOk(name) // Get a *Field based on the given field name 83 | n := s.Name() // Get the struct name 84 | h := s.HasZero() // Check if any field is uninitialized 85 | z := s.IsZero() // Check if all fields are uninitialized 86 | ``` 87 | 88 | ### Field methods 89 | 90 | We can easily examine a single Field for more detail. Below you can see how we 91 | get and interact with various field methods: 92 | 93 | 94 | ```go 95 | s := structs.New(server) 96 | 97 | // Get the Field struct for the "Name" field 98 | name := s.Field("Name") 99 | 100 | // Get the underlying value, value => "gopher" 101 | value := name.Value().(string) 102 | 103 | // Set the field's value 104 | name.Set("another gopher") 105 | 106 | // Get the field's kind, kind => "string" 107 | name.Kind() 108 | 109 | // Check if the field is exported or not 110 | if name.IsExported() { 111 | fmt.Println("Name field is exported") 112 | } 113 | 114 | // Check if the value is a zero value, such as "" for string, 0 for int 115 | if !name.IsZero() { 116 | fmt.Println("Name is initialized") 117 | } 118 | 119 | // Check if the field is an anonymous (embedded) field 120 | if !name.IsEmbedded() { 121 | fmt.Println("Name is not an embedded field") 122 | } 123 | 124 | // Get the Field's tag value for tag name "json", tag value => "name,omitempty" 125 | tagValue := name.Tag("json") 126 | ``` 127 | 128 | Nested structs are supported too: 129 | 130 | ```go 131 | addrField := s.Field("Server").Field("Addr") 132 | 133 | // Get the value for addr 134 | a := addrField.Value().(string) 135 | 136 | // Or get all fields 137 | httpServer := s.Field("Server").Fields() 138 | ``` 139 | 140 | We can also get a slice of Fields from the Struct type to iterate over all 141 | fields. This is handy if you wish to examine all fields: 142 | 143 | ```go 144 | s := structs.New(server) 145 | 146 | for _, f := range s.Fields() { 147 | fmt.Printf("field name: %+v\n", f.Name()) 148 | 149 | if f.IsExported() { 150 | fmt.Printf("value : %+v\n", f.Value()) 151 | fmt.Printf("is zero : %+v\n", f.IsZero()) 152 | } 153 | } 154 | ``` 155 | 156 | ## Credits 157 | 158 | * [Fatih Arslan](https://github.com/fatih) 159 | * [Cihangir Savas](https://github.com/cihangir) 160 | 161 | ## License 162 | 163 | The MIT License (MIT) - see LICENSE.md for more details 164 | -------------------------------------------------------------------------------- /pkg/util/structs/field.go: -------------------------------------------------------------------------------- 1 | package structs 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | var ( 10 | errNotExported = errors.New("field is not exported") 11 | errNotSettable = errors.New("field is not settable") 12 | ) 13 | 14 | // Field represents a single struct field that encapsulates high level 15 | // functions around the field. 16 | type Field struct { 17 | value reflect.Value 18 | field reflect.StructField 19 | defaultTag string 20 | } 21 | 22 | // Tag returns the value associated with key in the tag string. If there is no 23 | // such key in the tag, Tag returns the empty string. 24 | func (f *Field) Tag(key string) string { 25 | return f.field.Tag.Get(key) 26 | } 27 | 28 | // Value returns the underlying value of the field. It panics if the field 29 | // is not exported. 30 | func (f *Field) Value() interface{} { 31 | return f.value.Interface() 32 | } 33 | 34 | // IsEmbedded returns true if the given field is an anonymous field (embedded) 35 | func (f *Field) IsEmbedded() bool { 36 | return f.field.Anonymous 37 | } 38 | 39 | // IsExported returns true if the given field is exported. 40 | func (f *Field) IsExported() bool { 41 | return f.field.PkgPath == "" 42 | } 43 | 44 | // IsZero returns true if the given field is not initialized (has a zero value). 45 | // It panics if the field is not exported. 46 | func (f *Field) IsZero() bool { 47 | zero := reflect.Zero(f.value.Type()).Interface() 48 | current := f.Value() 49 | 50 | return reflect.DeepEqual(current, zero) 51 | } 52 | 53 | // Name returns the name of the given field 54 | func (f *Field) Name() string { 55 | return f.field.Name 56 | } 57 | 58 | // Kind returns the fields kind, such as "string", "map", "bool", etc .. 59 | func (f *Field) Kind() reflect.Kind { 60 | return f.value.Kind() 61 | } 62 | 63 | // Set sets the field to given value v. It returns an error if the field is not 64 | // settable (not addressable or not exported) or if the given value's type 65 | // doesn't match the fields type. 66 | func (f *Field) Set(val interface{}) error { 67 | // we can't set unexported fields, so be sure this field is exported 68 | if !f.IsExported() { 69 | return errNotExported 70 | } 71 | 72 | // do we get here? not sure... 73 | if !f.value.CanSet() { 74 | return errNotSettable 75 | } 76 | 77 | given := reflect.ValueOf(val) 78 | 79 | if f.value.Kind() != given.Kind() { 80 | return fmt.Errorf("wrong kind. got: %s want: %s", given.Kind(), f.value.Kind()) 81 | } 82 | 83 | f.value.Set(given) 84 | return nil 85 | } 86 | 87 | // Zero sets the field to its zero value. It returns an error if the field is not 88 | // settable (not addressable or not exported). 89 | func (f *Field) Zero() error { 90 | zero := reflect.Zero(f.value.Type()).Interface() 91 | return f.Set(zero) 92 | } 93 | 94 | // Fields returns a slice of Fields. This is particular handy to get the fields 95 | // of a nested struct . A struct tag with the content of "-" ignores the 96 | // checking of that particular field. Example: 97 | // 98 | // // Field is ignored by this package. 99 | // Field *http.Request `structs:"-"` 100 | // 101 | // It panics if field is not exported or if field's kind is not struct 102 | func (f *Field) Fields() []*Field { 103 | return getFields(f.value, f.defaultTag) 104 | } 105 | 106 | // Field returns the field from a nested struct. It panics if the nested struct 107 | // is not exported or if the field was not found. 108 | func (f *Field) Field(name string) *Field { 109 | field, ok := f.FieldOk(name) 110 | if !ok { 111 | panic("field not found") 112 | } 113 | 114 | return field 115 | } 116 | 117 | // FieldOk returns the field from a nested struct. The boolean returns whether 118 | // the field was found (true) or not (false). 119 | func (f *Field) FieldOk(name string) (*Field, bool) { 120 | value := &f.value 121 | // value must be settable so we need to make sure it holds the address of the 122 | // variable and not a copy, so we can pass the pointer to strctVal instead of a 123 | // copy (which is not assigned to any variable, hence not settable). 124 | // see "https://blog.golang.org/laws-of-reflection#TOC_8." 125 | if f.value.Kind() != reflect.Ptr { 126 | a := f.value.Addr() 127 | value = &a 128 | } 129 | v := strctVal(value.Interface()) 130 | t := v.Type() 131 | 132 | field, ok := t.FieldByName(name) 133 | if !ok { 134 | return nil, false 135 | } 136 | 137 | return &Field{ 138 | field: field, 139 | value: v.FieldByName(name), 140 | }, true 141 | } 142 | -------------------------------------------------------------------------------- /pkg/util/structs/tags.go: -------------------------------------------------------------------------------- 1 | package structs 2 | 3 | import "strings" 4 | 5 | // tagOptions contains a slice of tag options 6 | type tagOptions []string 7 | 8 | // Has returns true if the given option is available in tagOptions 9 | func (t tagOptions) Has(opt string) bool { 10 | for _, tagOpt := range t { 11 | if tagOpt == opt { 12 | return true 13 | } 14 | } 15 | 16 | return false 17 | } 18 | 19 | // parseTag splits a struct field's tag into its name and a list of options 20 | // which comes after a name. A tag is in the form of: "name,option1,option2". 21 | // The name can be neglectected. 22 | func parseTag(tag string) (string, tagOptions) { 23 | // tag is one of followings: 24 | // "" 25 | // "name" 26 | // "name,opt" 27 | // "name,opt,opt2" 28 | // ",opt" 29 | 30 | res := strings.Split(tag, ",") 31 | return res[0], res[1:] 32 | } 33 | -------------------------------------------------------------------------------- /pkg/util/utmp/utmp_bsd.go: -------------------------------------------------------------------------------- 1 | //go:build freebsd 2 | // +build freebsd 3 | 4 | // Golang bindings for basic login/utmp accounting 5 | package utmp 6 | 7 | //#include 8 | //#include 9 | //#include 10 | //#include 11 | //#include 12 | //#include 13 | //#include 14 | //#include 15 | // 16 | //#include 17 | // 18 | //typedef char char_t; 19 | // 20 | //void pututmpx(struct utmpx* entry, char* uname, char* ptsname, char* host) { 21 | // entry->ut_type = USER_PROCESS; 22 | // entry->ut_pid = getpid(); 23 | // strcpy(entry->ut_line, ptsname + strlen("/dev/")); 24 | // 25 | // strcpy(entry->ut_id, ptsname + strlen("/dev/pts/")); 26 | // 27 | // //entry->ut_time = time(NULL); 28 | // strcpy(entry->ut_user, uname); 29 | // strcpy(entry->ut_host, host); 30 | // //entry->ut_addr = 0; 31 | // setutxent(); 32 | // pututxline(entry); 33 | //} 34 | // 35 | //void unpututmpx(struct utmpx* entry) { 36 | // entry->ut_type = DEAD_PROCESS; 37 | // entry->ut_line[0] = '\0'; 38 | // //entry->ut_time = 0; 39 | // entry->ut_user[0] = '\0'; 40 | // setutxent(); 41 | // pututxline(entry); 42 | // 43 | // endutxent(); 44 | //} 45 | // 46 | //#if 0 47 | //int putlastlogentry(int64_t t, int uid, char* line, char* host) { 48 | // int retval = 0; 49 | // FILE *f; 50 | // struct lastlog l; 51 | // 52 | // strncpy(l.ll_line, line, UT_LINESIZE); 53 | // l.ll_line[UT_LINESIZE-1] = '\0'; 54 | // strncpy(l.ll_host, host, UT_HOSTSIZE); 55 | // l.ll_host[UT_HOSTSIZE-1] = '\0'; 56 | // 57 | // l.ll_time = (time_t)t; 58 | // //printf("l: ll_line '%s', ll_host '%s', ll_time %d\n", l.ll_line, l.ll_host, l.ll_time); 59 | // 60 | // /* Write lastlog entry at fixed offset (uid * sizeof(struct lastlog) */ 61 | // if( NULL != (f = fopen("/var/log/lastlog", "rw+")) ) { 62 | // if( !fseek(f, (uid * sizeof(struct lastlog)), SEEK_SET) ) { 63 | // int fd = fileno(f); 64 | // if( write(fd, &l, sizeof(l)) == sizeof(l) ) { 65 | // retval = 1; 66 | // //int32_t stat = system("echo ---- lastlog ----; lastlog"); 67 | // } 68 | // } 69 | // fclose(f); 70 | // } 71 | // return retval; 72 | //} 73 | //#else 74 | //int putlastlogentry(int64_t t, int uid, char* line, char* host) { 75 | // return 0; 76 | //} 77 | //#endif 78 | import "C" 79 | 80 | import ( 81 | "fmt" 82 | "net" 83 | "os/user" 84 | "strings" 85 | "time" 86 | ) 87 | 88 | // UtmpEntry wraps the C struct utmp 89 | type UtmpEntry struct { 90 | entry C.struct_utmpx 91 | } 92 | 93 | // return remote client hostname or IP if host lookup fails 94 | // addr is expected to be of the format given by net.Addr.String() 95 | // eg., "127.0.0.1:80" or "[::1]:80" 96 | func GetHost(addr string) (h string) { 97 | if !strings.Contains(addr, "[") { 98 | h = strings.Split(addr, ":")[0] 99 | } else { 100 | h = strings.Split(strings.Split(addr, "[")[1], "]")[0] 101 | } 102 | hList, e := net.LookupAddr(h) 103 | //fmt.Printf("lookupAddr:%v\n", hList) 104 | if e == nil { 105 | h = hList[0] 106 | } 107 | return 108 | } 109 | 110 | // Put a username and the originating host/IP to utmp 111 | func Put(user, ptsName, host string) UtmpEntry { 112 | var entry UtmpEntry 113 | 114 | //log.Println("Put_utmp:host ", host, " user ", user) 115 | C.pututmpx(&entry.entry, C.CString(user), C.CString(ptsName), C.CString(host)) 116 | return entry 117 | } 118 | 119 | // Remove a username/host entry from utmp 120 | func Unput(entry UtmpEntry) { 121 | C.unpututmpx(&entry.entry) 122 | } 123 | 124 | // Put the login app, username and originating host/IP to lastlog 125 | func PutLastlogEntry(app, usr, ptsname, host string) { 126 | u, e := user.Lookup(usr) 127 | if e != nil { 128 | return 129 | } 130 | var uid uint32 131 | fmt.Sscanf(u.Uid, "%d", &uid) 132 | 133 | t := time.Now().Unix() 134 | _ = C.putlastlogentry(C.int64_t(t), C.int(uid), C.CString(app), C.CString(host)) 135 | //stat := C.putlastlogentry(C.int64_t(t), C.int(uid), C.CString(app), C.CString(host)) 136 | //fmt.Println("stat was:",stat) 137 | } 138 | -------------------------------------------------------------------------------- /pkg/util/utmp/utmp_darwin.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | // +build darwin 3 | 4 | package utmp 5 | 6 | type UtmpEntry struct { 7 | entry struct{} 8 | } 9 | 10 | func Put(user, ptsName, host string) UtmpEntry { 11 | var entry UtmpEntry 12 | return entry 13 | } 14 | 15 | // Remove a username/host entry from utmp 16 | func Unput(entry UtmpEntry) { 17 | } 18 | 19 | // Put the login app, username and originating host/IP to lastlog 20 | func PutLastlogEntry(app, usr, ptsname, host string) { 21 | } 22 | -------------------------------------------------------------------------------- /pkg/util/utmp/utmp_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | // Golang bindings for basic login/utmp accounting 5 | package utmp 6 | 7 | //#include 8 | //#include 9 | //#include 10 | //#include 11 | //#include 12 | //#include 13 | //#include 14 | //#include 15 | // 16 | //#include 17 | //#include 18 | // 19 | //typedef char char_t; 20 | // 21 | // 22 | //void pututmp(struct utmp* entry, char* uname, char* ptsname, char* host) { 23 | // entry->ut_type = USER_PROCESS; 24 | // entry->ut_pid = getpid(); 25 | // strcpy(entry->ut_line, ptsname + strlen("/dev/")); 26 | // 27 | // strcpy(entry->ut_id, ptsname + strlen("/dev/pts/")); 28 | // 29 | // entry->ut_time = time(NULL); 30 | // strcpy(entry->ut_user, uname); 31 | // strcpy(entry->ut_host, host); 32 | // entry->ut_addr = 0; 33 | // setutent(); 34 | // pututline(entry); 35 | //} 36 | // 37 | //void unpututmp(struct utmp* entry) { 38 | // entry->ut_type = DEAD_PROCESS; 39 | // memset(entry->ut_line, 0, UT_LINESIZE); 40 | // entry->ut_time = 0; 41 | // memset(entry->ut_user, 0, UT_NAMESIZE); 42 | // setutent(); 43 | // pututline(entry); 44 | // 45 | // endutent(); 46 | //} 47 | // 48 | //int putlastlogentry(int64_t t, int uid, char* line, char* host) { 49 | // int retval = 0; 50 | // FILE *f; 51 | // struct lastlog l; 52 | // 53 | // strncpy(l.ll_line, line, UT_LINESIZE); 54 | // l.ll_line[UT_LINESIZE-1] = '\0'; 55 | // strncpy(l.ll_host, host, UT_HOSTSIZE); 56 | // l.ll_host[UT_HOSTSIZE-1] = '\0'; 57 | // 58 | // l.ll_time = (time_t)t; 59 | // //printf("l: ll_line '%s', ll_host '%s', ll_time %d\n", l.ll_line, l.ll_host, l.ll_time); 60 | // 61 | // /* Write lastlog entry at fixed offset (uid * sizeof(struct lastlog) */ 62 | // if( NULL != (f = fopen("/var/log/lastlog", "rw+")) ) { 63 | // if( !fseek(f, (uid * sizeof(struct lastlog)), SEEK_SET) ) { 64 | // int fd = fileno(f); 65 | // if( write(fd, &l, sizeof(l)) == sizeof(l) ) { 66 | // retval = 1; 67 | // //int32_t stat = system("echo ---- lastlog ----; lastlog"); 68 | // } 69 | // } 70 | // fclose(f); 71 | // } 72 | // return retval; 73 | //} 74 | import "C" 75 | 76 | import ( 77 | "fmt" 78 | "net" 79 | "os/user" 80 | "strings" 81 | "time" 82 | ) 83 | 84 | // UtmpEntry wraps the C struct utmp 85 | type UtmpEntry struct { 86 | entry C.struct_utmp 87 | } 88 | 89 | // return remote client hostname or IP if host lookup fails 90 | // addr is expected to be of the format given by net.Addr.String() 91 | // eg., "127.0.0.1:80" or "[::1]:80" 92 | func GetHost(addr string) (h string) { 93 | if !strings.Contains(addr, "[") { 94 | h = strings.Split(addr, ":")[0] 95 | } else { 96 | h = strings.Split(strings.Split(addr, "[")[1], "]")[0] 97 | } 98 | hList, e := net.LookupAddr(h) 99 | //fmt.Printf("lookupAddr:%v\n", hList) 100 | if e == nil { 101 | h = hList[0] 102 | } 103 | return 104 | } 105 | 106 | // Put a username and the originating host/IP to utmp 107 | func Put(user, ptsName, host string) UtmpEntry { 108 | var entry UtmpEntry 109 | 110 | //log.Println("Put_utmp:host ", host, " user ", user) 111 | C.pututmp(&entry.entry, C.CString(user), C.CString(ptsName), C.CString(host)) 112 | return entry 113 | } 114 | 115 | // Remove a username/host entry from utmp 116 | func Unput(entry UtmpEntry) { 117 | C.unpututmp(&entry.entry) 118 | } 119 | 120 | // Put the login app, username and originating host/IP to lastlog 121 | func PutLastlogEntry(app, usr, ptsname, host string) { 122 | u, e := user.Lookup(usr) 123 | if e != nil { 124 | return 125 | } 126 | var uid uint32 127 | fmt.Sscanf(u.Uid, "%d", &uid) 128 | 129 | t := time.Now().Unix() 130 | _ = C.putlastlogentry(C.int64_t(t), C.int(uid), C.CString(app), C.CString(host)) 131 | //stat := C.putlastlogentry(C.int64_t(t), C.int(uid), C.CString(app), C.CString(host)) 132 | //fmt.Println("stat was:",stat) 133 | } 134 | -------------------------------------------------------------------------------- /server/auth.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "ssh2incus/pkg/ssh" 5 | "ssh2incus/pkg/user" 6 | 7 | log "github.com/sirupsen/logrus" 8 | gossh "golang.org/x/crypto/ssh" 9 | ) 10 | 11 | func hostAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool { 12 | log := log.WithField("session", ctx.ShortSessionID()) 13 | 14 | lu := LoginUserFromContext(ctx) 15 | 16 | log.Debugf("auth (host): attempting key auth for %s: %s %s", lu, key.Type(), gossh.FingerprintSHA256(key)) 17 | 18 | osUser, err := getOsUser(lu.User) 19 | if err != nil { 20 | log.Errorf("auth (host): %s", err) 21 | return false 22 | } 23 | 24 | if osUser.Uid != "0" && len(config.AllowedGroups) > 0 { 25 | userGroups, err := getUserGroups(osUser) 26 | if err != nil { 27 | log.Errorf("auth (host): %s", err) 28 | return false 29 | } 30 | 31 | if gid, match := groupMatch(config.AllowedGroups, userGroups); !match { 32 | log.Warnf("auth (host): no group match for %s %v in %v", lu.User, userGroups, config.AllowedGroups) 33 | return false 34 | } else { 35 | group, err := user.LookupGroupId(gid) 36 | if err != nil { 37 | log.Errorf("auth (host): %s", err) 38 | } 39 | log.Debugf("auth (host): host user %q matched %q group", lu.User, group.Name) 40 | } 41 | } 42 | 43 | keys, err := getUserAuthKeys(osUser) 44 | if err != nil { 45 | log.Errorf("auth (host): %s", err) 46 | return false 47 | } 48 | 49 | if len(keys) == 0 { 50 | log.Warnf("auth (host): no keys for %s", lu) 51 | return false 52 | } 53 | 54 | for _, k := range keys { 55 | equal, err := keysEqual(key, k) 56 | if err != nil { 57 | log.Errorf("auth (instance): failed to compare keys for %s: %s", lu, err) 58 | } 59 | if equal { 60 | log.Infof("auth (host): succeeded for %s: %s %s", lu, key.Type(), gossh.FingerprintSHA256(key)) 61 | if !lu.IsValid() { 62 | return false 63 | } 64 | lu.PublicKey = key 65 | return true 66 | } 67 | } 68 | 69 | log.Warnf("auth (host): failed for %s: %s %s", lu, key.Type(), gossh.FingerprintSHA256(key)) 70 | return false 71 | } 72 | 73 | // inAuthHandler performs host auth and instance auth 74 | func inAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool { 75 | log := log.WithField("session", ctx.ShortSessionID()) 76 | 77 | lu := LoginUserFromContext(ctx) 78 | 79 | // valid user on the host should be allowed 80 | valid := hostAuthHandler(ctx, key) 81 | if valid { 82 | return true 83 | } else { 84 | if !lu.IsValid() { 85 | return false 86 | } 87 | } 88 | 89 | // commands are allowed for host users only 90 | if lu.IsCommand() { 91 | return false 92 | } 93 | 94 | log.Debugf("auth (instance): attempting key auth for %s: %s %s", lu, key.Type(), gossh.FingerprintSHA256(key)) 95 | 96 | client, err := NewDefaultIncusClientWithContext(ctx) 97 | if err != nil { 98 | log.Error(err) 99 | return false 100 | } 101 | 102 | // User handling 103 | iu, err := client.GetCachedInstanceUser(lu.Project, lu.Instance, lu.InstanceUser) 104 | if err != nil { 105 | log.Errorf("auth (instance): failed to get instance user %s for %s: %s", lu.InstanceUser, lu, err) 106 | return false 107 | } 108 | 109 | if iu == nil { 110 | log.Errorf("auth (instance): not found instance user for %s", lu) 111 | return false 112 | } 113 | 114 | path := iu.Dir + "/.ssh/authorized_keys" 115 | file, err := client.DownloadFile(iu.Project, iu.Instance, path) 116 | if err != nil { 117 | log.Warnf("auth (instance): failed to download %s for %s: %s", path, lu, err) 118 | return false 119 | } 120 | 121 | keys := file.Content.Lines() 122 | 123 | if len(keys) == 0 { 124 | log.Warnf("auth (instance): no keys for %s", lu) 125 | return false 126 | } 127 | 128 | for _, k := range keys { 129 | equal, err := keysEqual(key, k) 130 | if err != nil { 131 | log.Errorf("auth (instance): failed to compare keys for %s: %s", lu, err) 132 | } 133 | if equal { 134 | log.Infof("auth (instance): succeeded for %s: %s %s", lu, key.Type(), gossh.FingerprintSHA256(key)) 135 | lu.PublicKey = key 136 | return true 137 | } 138 | } 139 | 140 | log.Warnf("auth (instance): failed for %s: %s %s", lu, key.Type(), gossh.FingerprintSHA256(key)) 141 | return false 142 | } 143 | 144 | func noAuthHandler(ctx ssh.Context, key ssh.PublicKey) bool { 145 | log := log.WithField("session", ctx.ShortSessionID()) 146 | 147 | lu := LoginUserFromContext(ctx) 148 | log.Infof("auth (noauth): noauth login key for %s: %s %s", lu, key.Type(), gossh.FingerprintSHA256(key)) 149 | if !lu.IsValid() { 150 | return false 151 | } 152 | lu.PublicKey = key 153 | return true 154 | } 155 | 156 | func keysEqual(key ssh.PublicKey, authKey []byte) (bool, error) { 157 | pkey, _, _, _, err := ssh.ParseAuthorizedKey(authKey) 158 | if err != nil { 159 | return false, err 160 | } 161 | return ssh.KeysEqual(pkey, key), nil 162 | } 163 | -------------------------------------------------------------------------------- /server/banner.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "ssh2incus/pkg/ssh" 8 | ) 9 | 10 | const banner = ` 11 | ┌──────────────────────────────────────────────┐ 12 | │ _ ____ _ │ 13 | │ ___ ___| |__ |___ \(_)_ __ ___ _ _ ___ │ 14 | │ / __/ __| '_ \ __) | | '_ \ / __| | | / __| │ 15 | │ \__ \__ \ | | |/ __/| | | | | (__| |_| \__ \ │ 16 | │ |___/___/_| |_|_____|_|_| |_|\___|\__,_|___/ │ 17 | └──────────────────────────────────────────────┘ 18 | ` 19 | 20 | func bannerHandler(ctx ssh.Context) string { 21 | lu := LoginUserFromContext(ctx) 22 | if !lu.IsValid() { 23 | return "" 24 | } 25 | if lu.IsCommand() { 26 | return banner 27 | } 28 | remote := lu.Remote 29 | if remote != "" { 30 | remote += " / " 31 | } 32 | hostname, _ := os.Hostname() 33 | if hostname != "" { 34 | hostname = fmt.Sprintf(" 💻 %s%s", remote, hostname) 35 | } 36 | b := banner + fmt.Sprintf( 37 | "👤 %s 📦 %s.%s%s\n────────────────────────────────────────────────\n", 38 | lu.InstanceUser, lu.Instance, lu.Project, hostname, 39 | ) 40 | return b + "\n" 41 | } 42 | -------------------------------------------------------------------------------- /server/config.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "os" 5 | "strconv" 6 | "strings" 7 | "time" 8 | 9 | "ssh2incus/pkg" 10 | ) 11 | 12 | var config *Config 13 | 14 | type Config struct { 15 | App *pkg.App 16 | Args []string 17 | 18 | Listen string 19 | Socket string 20 | Shell string 21 | Groups string 22 | HealthCheck string 23 | IncusSocket string 24 | Remote string 25 | URL string 26 | ClientCert string 27 | ClientKey string 28 | ServerCert string 29 | Master bool 30 | Debug bool 31 | Banner bool 32 | NoAuth bool 33 | InAuth bool 34 | Welcome bool 35 | AllowedGroups []string 36 | IdleTimeout time.Duration 37 | 38 | IncusInfo map[string]interface{} 39 | } 40 | 41 | func SetConfig(c *Config) { 42 | config = c 43 | } 44 | 45 | func (c *Config) SocketFdEnvName() string { 46 | return config.App.NAME() + "_SOCKET_FD" 47 | } 48 | 49 | func (c *Config) SocketFdEnvValue(f *os.File) string { 50 | return strconv.Itoa(int(f.Fd())) 51 | } 52 | 53 | func (c *Config) SocketFdEnvString(f *os.File) string { 54 | return config.SocketFdEnvName() + "=" + config.SocketFdEnvValue(f) 55 | } 56 | 57 | func (c *Config) ArgsEnvName() string { 58 | return config.App.NAME() + "_ARGS" 59 | } 60 | 61 | func (c *Config) ArgsEnvValue() string { 62 | return strings.Join(config.Args, " ") 63 | } 64 | 65 | func (c *Config) ArgsEnvString() string { 66 | return config.ArgsEnvName() + "=" + config.ArgsEnvValue() 67 | 68 | } 69 | -------------------------------------------------------------------------------- /server/device-registry.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "ssh2incus/pkg/incus" 8 | "ssh2incus/pkg/ssh" 9 | "ssh2incus/pkg/util/devicereg" 10 | 11 | "github.com/lxc/incus/v6/shared/api" 12 | log "github.com/sirupsen/logrus" 13 | ) 14 | 15 | var deviceRegistry *devicereg.DeviceRegistry 16 | 17 | func init() { 18 | deviceRegistry = devicereg.NewDeviceRegistry() 19 | } 20 | 21 | func cleanLeftoverProxyDevices() error { 22 | ctx, cancel := ssh.NewContext(nil) 23 | defer cancel() 24 | client, err := NewDefaultIncusClientWithContext(ctx) 25 | if err != nil { 26 | return err 27 | } 28 | defer client.Disconnect() 29 | 30 | allInstances, err := client.GetInstancesAllProjects(api.InstanceTypeAny) 31 | if err != nil { 32 | return fmt.Errorf("failed to get instances: %w", err) 33 | } 34 | for _, i := range allInstances { 35 | for device := range i.Devices { 36 | if !strings.HasPrefix(device, incus.ProxyDevicePrefix) { 37 | continue 38 | } 39 | err = client.DeleteInstanceDevice(&i, device) 40 | if err != nil { 41 | log.Errorf("delete instance %s.%s device %s: %v", i.Name, i.Project, device, err) 42 | continue 43 | } 44 | log.Infof("deleted leftover device %s on instance %s.%s", device, i.Name, i.Project) 45 | } 46 | } 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /server/incus.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "ssh2incus/pkg/incus" 9 | "ssh2incus/pkg/ssh" 10 | 11 | "github.com/lxc/incus/v6/shared/cliconfig" 12 | log "github.com/sirupsen/logrus" 13 | ) 14 | 15 | var ( 16 | ContextKeyIncusClient = &contextKey{"incusClient"} 17 | 18 | DefaultParams *incus.ConnectParams = nil 19 | 20 | incusConnectParams *incus.ConnectParams 21 | ) 22 | 23 | func NewIncusClient(params *incus.ConnectParams) (*incus.Client, error) { 24 | var err error 25 | if params == nil { 26 | params, err = getIncusConnectParams() 27 | if err != nil { 28 | return nil, err 29 | } 30 | } 31 | c := incus.NewClientWithParams(params) 32 | return c, nil 33 | } 34 | 35 | func NewDefaultIncusClientWithContext(ctx ssh.Context) (*incus.Client, error) { 36 | return NewIncusClientWithContext(ctx, DefaultParams) 37 | } 38 | 39 | func NewIncusClientWithContext(ctx ssh.Context, params *incus.ConnectParams) (*incus.Client, error) { 40 | if c, ok := ctx.Value(ContextKeyIncusClient).(*incus.Client); ok && c != nil { 41 | log.WithField("session", ctx.ShortSessionID()).Debug("reusing existing incus client") 42 | return c, nil 43 | } 44 | 45 | var err error 46 | if params == nil { 47 | params = DefaultParams 48 | } 49 | 50 | if srv, ok := ctx.Value(ssh.ContextKeyServer).(*ssh.Server); ok && srv != nil { 51 | lu := LoginUserFromContext(ctx) 52 | params, err = incus.RemoteConnectParams(lu.Remote) 53 | if err != nil { 54 | return nil, err 55 | } 56 | } 57 | 58 | c, err := NewIncusClient(params) 59 | if err != nil { 60 | return nil, fmt.Errorf("failed to initialize incus client: %v", err) 61 | } 62 | 63 | err = c.Connect(ctx) 64 | if err != nil { 65 | return nil, fmt.Errorf("failed to connect to incus: %v", err) 66 | } 67 | log.WithField("session", ctx.ShortSessionID()).Debug("new incus client created") 68 | ctx.SetValue(ContextKeyIncusClient, c) 69 | return c, nil 70 | } 71 | 72 | func (s *Server) checkIncus() error { 73 | ctx, cancel := ssh.NewContext(nil) 74 | defer cancel() 75 | client, err := NewDefaultIncusClientWithContext(ctx) 76 | if err != nil { 77 | return fmt.Errorf("failed to connect to incus: %w", err) 78 | } 79 | defer client.Disconnect() 80 | 81 | info := client.GetConnectionInfo() 82 | config.IncusInfo = info 83 | log.Debugln(info) 84 | 85 | return nil 86 | } 87 | 88 | func getIncusConnectParams() (*incus.ConnectParams, error) { 89 | if incusConnectParams != nil { 90 | return incusConnectParams, nil 91 | } 92 | 93 | clicfg, err := cliconfig.LoadConfig("") 94 | if err != nil { 95 | log.Debugf("Failed to load incus CLI config: %v", err) 96 | } 97 | 98 | var url, remote string 99 | var certFile, keyFile, serverCertFile string 100 | 101 | // First priority: Check if Remote is set 102 | if config.Remote != "" && clicfg != nil { 103 | remoteConfig, ok := clicfg.Remotes[config.Remote] 104 | if !ok { 105 | return nil, fmt.Errorf("remote '%s' not found in incus configuration", config.Remote) 106 | } 107 | remote = config.Remote 108 | url = remoteConfig.Addr 109 | 110 | // For HTTPS connections, determine client certificate paths 111 | if strings.HasPrefix(url, "https://") { 112 | // Check if custom paths are provided in our config 113 | if config.ServerCert != "" { 114 | serverCertFile = config.ServerCert 115 | } else { 116 | serverCertFile = clicfg.ConfigPath("servercerts", config.Remote+".crt") 117 | } 118 | if config.ClientCert != "" && config.ClientKey != "" { 119 | certFile = config.ClientCert 120 | keyFile = config.ClientKey 121 | } else { 122 | // Use default Incus client cert/key which are stored in the same directory as config.yml 123 | certFile = clicfg.ConfigPath("client.crt") 124 | keyFile = clicfg.ConfigPath("client.key") 125 | } 126 | 127 | // Ensure certificate files exist 128 | if _, err := os.Stat(certFile); err != nil { 129 | return nil, fmt.Errorf("client certificate not found at %s: %w", certFile, err) 130 | } 131 | if _, err := os.Stat(keyFile); err != nil { 132 | return nil, fmt.Errorf("client key not found at %s: %w", keyFile, err) 133 | } 134 | } else if strings.HasPrefix(url, "unix://") { 135 | url = strings.TrimPrefix(url, "unix://") 136 | } 137 | } else if config.URL != "" { 138 | // Second priority: Use URL if set 139 | url = config.URL 140 | 141 | // For HTTPS connections, we need to get cert/key from config or environment 142 | if strings.HasPrefix(url, "https://") { 143 | // First try config fields 144 | if config.ServerCert != "" { 145 | certFile = config.ServerCert 146 | } else { 147 | certFile = os.Getenv("INCUS_SERVER_CERT") 148 | } 149 | if config.ClientCert != "" && config.ClientKey != "" { 150 | certFile = config.ClientCert 151 | keyFile = config.ClientKey 152 | } else { 153 | // Otherwise try environment variables 154 | certFile = os.Getenv("INCUS_CLIENT_CERT") 155 | keyFile = os.Getenv("INCUS_CLIENT_KEY") 156 | } 157 | 158 | if certFile == "" || keyFile == "" { 159 | return nil, fmt.Errorf("HTTPS connection requires client certificate and key") 160 | } 161 | } 162 | } else if config.Socket != "" { 163 | // Third priority: Use Socket if set 164 | url = config.Socket 165 | } else { 166 | // Default: Let Incus client use default socket path 167 | url = "" 168 | } 169 | 170 | incusConnectParams = &incus.ConnectParams{ 171 | Remote: remote, 172 | Url: url, 173 | CertFile: certFile, 174 | KeyFile: keyFile, 175 | ServerCertFile: serverCertFile, 176 | } 177 | return incusConnectParams, nil 178 | } 179 | -------------------------------------------------------------------------------- /server/sftp-server-binary/sftp-server-binary.go: -------------------------------------------------------------------------------- 1 | package sftp_server_binary 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | //go:embed bin/ssh2incus-sftp-server-arm64.gz 10 | arm64Bytes []byte 11 | //go:embed bin/ssh2incus-sftp-server-amd64.gz 12 | amd64Bytes []byte 13 | 14 | binName = "/bin/ssh2incus-sftp-server" 15 | ) 16 | 17 | func init() { 18 | if len(arm64Bytes) == 0 { 19 | panic("arm64Bytes is empty") 20 | } 21 | if len(amd64Bytes) == 0 { 22 | panic("amd64Bytes is empty") 23 | } 24 | } 25 | 26 | func BinName() string { 27 | return binName 28 | } 29 | 30 | func BinBytes(arch string) ([]byte, error) { 31 | switch arch { 32 | case "arm64", "aarch64": 33 | return arm64Bytes, nil 34 | case "amd64", "x86_64", "x64", "x86-64", "x86": 35 | return amd64Bytes, nil 36 | default: 37 | return nil, fmt.Errorf("unsupported arch: %s", arch) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /server/sftp.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | 10 | "ssh2incus/pkg/incus" 11 | "ssh2incus/pkg/ssh" 12 | "ssh2incus/pkg/util" 13 | "ssh2incus/server/sftp-server-binary" 14 | 15 | log "github.com/sirupsen/logrus" 16 | ) 17 | 18 | const sftpSubsystem = "sftp" 19 | 20 | func sftpSubsystemHandler(s ssh.Session) { 21 | log := log.WithField("session", s.Context().ShortSessionID()) 22 | 23 | lu := LoginUserFromContext(s.Context()) 24 | if !lu.IsValid() { 25 | log.Errorf("invalid login for %s", lu) 26 | io.WriteString(s, fmt.Sprintf("Invalid login for %q (%s)\n", lu.OrigUser, lu)) 27 | s.Exit(ExitCodeInvalidLogin) 28 | return 29 | } 30 | log.Debugf("sftp: connecting %s", lu) 31 | 32 | client, err := NewDefaultIncusClientWithContext(s.Context()) 33 | if err != nil { 34 | log.Error(err) 35 | s.Exit(ExitCodeConnectionError) 36 | return 37 | } 38 | defer client.Disconnect() 39 | 40 | if !lu.IsDefaultProject() { 41 | err = client.UseProject(lu.Project) 42 | if err != nil { 43 | log.Errorf("sftp: using project %s error: %v", lu.Project, err) 44 | io.WriteString(s, fmt.Sprintf("unknown project %s\n", lu.Project)) 45 | s.Exit(ExitCodeInvalidProject) 46 | return 47 | } 48 | } 49 | 50 | instance, err := client.GetCachedInstance(lu.Project, lu.Instance) 51 | if err != nil { 52 | log.Errorf("sftp: cannot get instance for %s: %s", lu, err) 53 | io.WriteString(s, fmt.Sprintf("cannot get instance %s\n", lu.FullInstance())) 54 | s.Exit(ExitCodeMetaError) 55 | return 56 | } 57 | //log.Debugf("sftp: instance: %#v", instance) 58 | 59 | sftpServerBinBytes, err := sftp_server_binary.BinBytes(instance.Architecture) 60 | if err != nil { 61 | log.Errorf("sftp: failed to get sftp-server binary: %s", err) 62 | io.WriteString(s, fmt.Sprintf("failed to get sftp-server binary\n")) 63 | s.Exit(ExitCodeInternalError) 64 | return 65 | } 66 | sftpServerBinBytes, err = util.Ungz(sftpServerBinBytes) 67 | if err != nil { 68 | log.Errorf("sftp: failed to ungzip sftp-server: %s", err) 69 | io.WriteString(s, fmt.Sprintf("failed to prepare sftp-server\n")) 70 | s.Exit(ExitCodeInternalError) 71 | return 72 | } 73 | 74 | existsParams := &incus.FileExistsParams{ 75 | Project: lu.Project, 76 | Instance: lu.Instance, 77 | Path: sftp_server_binary.BinName(), 78 | Md5sum: util.Md5Bytes(sftpServerBinBytes), 79 | ShouldCache: true, 80 | } 81 | if !client.FileExists(existsParams) { 82 | err = client.UploadBytes(lu.Project, lu.Instance, sftp_server_binary.BinName(), bytes.NewReader(sftpServerBinBytes), 0, 0, 0755) 83 | if err != nil { 84 | log.Errorf("sftp: upload failed: %v", err) 85 | io.WriteString(s, fmt.Sprintf("sftp-server is not available on %s\n", lu.FullInstance())) 86 | s.Exit(ExitCodeConnectionError) 87 | return 88 | } 89 | log.Debugf("sftp: sftp-server uploaded %s to %s", sftp_server_binary.BinName(), lu.FullInstance()) 90 | } 91 | sftpServerBinBytes = nil 92 | 93 | iu, err := client.GetCachedInstanceUser(lu.Project, lu.Instance, lu.InstanceUser) 94 | if err != nil { 95 | log.Errorf("failed to get instance user %s for %s: %s", lu.InstanceUser, lu, err) 96 | io.WriteString(s, fmt.Sprintf("cannot get instance user %s\n", lu.InstanceUser)) 97 | s.Exit(ExitCodeMetaError) 98 | return 99 | } 100 | 101 | if iu == nil { 102 | io.WriteString(s, "not found user or instance\n") 103 | log.Errorf("sftp: not found instance user for %s", lu) 104 | s.Exit(ExitCodeInvalidLogin) 105 | return 106 | } 107 | 108 | //log.Debugf("sftp: found instance user %s [%d %d]", iu.User, iu.Uid, iu.Gid) 109 | 110 | stdin, stderr, cleanup := util.SetupPipes(s) 111 | defer cleanup() 112 | 113 | uid := 0 114 | gid := 0 115 | chroot := "/" 116 | home := iu.Dir 117 | if iu.Uid != 0 { 118 | chroot = iu.Dir 119 | home = "/" 120 | } 121 | 122 | cmd := fmt.Sprintf("%s -e -d %s", sftp_server_binary.BinName(), chroot) 123 | 124 | env := make(map[string]string) 125 | env["USER"] = iu.User 126 | env["UID"] = fmt.Sprintf("%d", iu.Uid) 127 | env["GID"] = fmt.Sprintf("%d", iu.Gid) 128 | env["HOME"] = home 129 | env["SSH_SESSION"] = s.Context().ShortSessionID() 130 | 131 | log.Debugf("sftp: CMD %s", cmd) 132 | log.Debugf("sftp: ENV %s", util.MapToEnvString(env)) 133 | 134 | ie := client.NewInstanceExec(incus.InstanceExec{ 135 | Instance: lu.Instance, 136 | Cmd: cmd, 137 | Env: env, 138 | Stdin: stdin, 139 | Stdout: s, 140 | Stderr: stderr, 141 | User: uid, 142 | Group: gid, 143 | }) 144 | 145 | ret, err := ie.Exec() 146 | if err != nil && err != io.EOF && !errors.Is(err, context.Canceled) { 147 | io.WriteString(s, "sftp connection failed\n") 148 | log.Errorf("sftp: exec failed: %s", err) 149 | } 150 | 151 | err = s.Exit(ret) 152 | if err != nil && err != io.EOF { 153 | log.Errorf("sftp: session exit failed: %v", err) 154 | } 155 | log.Debugf("sftp: exit %d", ret) 156 | } 157 | -------------------------------------------------------------------------------- /server/stdio-proxy-binary/stdio-proxy-binary.go: -------------------------------------------------------------------------------- 1 | package stdio_proxy_binary 2 | 3 | import ( 4 | _ "embed" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | //go:embed bin/ssh2incus-stdio-proxy-arm64.gz 10 | arm64Bytes []byte 11 | //go:embed bin/ssh2incus-stdio-proxy-amd64.gz 12 | amd64Bytes []byte 13 | 14 | binName = "/bin/ssh2incus-stdio-proxy" 15 | ) 16 | 17 | func init() { 18 | if len(arm64Bytes) == 0 { 19 | panic("arm64Bytes is empty") 20 | } 21 | if len(amd64Bytes) == 0 { 22 | panic("amd64Bytes is empty") 23 | } 24 | } 25 | 26 | func BinName() string { 27 | return binName 28 | } 29 | 30 | func BinBytes(arch string) ([]byte, error) { 31 | switch arch { 32 | case "arm64", "aarch64": 33 | return arm64Bytes, nil 34 | case "amd64", "x86_64", "x64", "x86-64", "x86": 35 | return amd64Bytes, nil 36 | default: 37 | return nil, fmt.Errorf("unsupported arch: %s", arch) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /server/subsystem.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | 6 | "ssh2incus/pkg/ssh" 7 | ) 8 | 9 | func defaultSubsystemHandler(s ssh.Session) { 10 | s.Write([]byte(fmt.Sprintf("%s subsytem not implemented\n", s.Subsystem()))) 11 | s.Exit(ExitCodeNotImplemented) 12 | } 13 | -------------------------------------------------------------------------------- /server/user.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "time" 10 | 11 | "ssh2incus/pkg/cache" 12 | "ssh2incus/pkg/incus" 13 | "ssh2incus/pkg/ssh" 14 | "ssh2incus/pkg/user" 15 | 16 | log "github.com/sirupsen/logrus" 17 | ) 18 | 19 | var ( 20 | ContextKeyLoginUser = &contextKey{"loginUser"} 21 | 22 | loginUserCache *cache.Cache 23 | loginUserFailedCache *cache.Cache 24 | ) 25 | 26 | func init() { 27 | loginUserCache = cache.New(15*time.Minute, 20*time.Minute) 28 | loginUserFailedCache = cache.New(1*time.Minute, 2*time.Minute) 29 | } 30 | 31 | type LoginUser struct { 32 | OrigUser string 33 | Remote string 34 | User string 35 | Instance string 36 | Project string 37 | InstanceUser string 38 | Command string 39 | PublicKey ssh.PublicKey 40 | 41 | ctx ssh.Context 42 | } 43 | 44 | func LoginUserFromContext(ctx ssh.Context) *LoginUser { 45 | if lu, ok := ctx.Value(ContextKeyLoginUser).(*LoginUser); ok { 46 | return lu 47 | } 48 | lu := parseLoginUser(ctx.User()) 49 | lu.ctx = ctx 50 | if lu.Remote == "" { 51 | lu.Remote = config.Remote 52 | } 53 | ctx.SetValue(ContextKeyLoginUser, lu) 54 | return lu 55 | } 56 | 57 | func (lu *LoginUser) String() string { 58 | if lu == nil { 59 | return "" 60 | } 61 | remote := "" 62 | if lu.Remote != "" { 63 | remote = lu.Remote + ":" 64 | } 65 | if lu.Command != "" { 66 | return fmt.Sprintf("%%%s@%s", lu.Command, lu.User) 67 | } 68 | return fmt.Sprintf("%s%s@%s.%s+%s", remote, lu.InstanceUser, lu.Instance, lu.Project, lu.User) 69 | } 70 | 71 | func (lu *LoginUser) FullInstance() string { 72 | if lu == nil { 73 | return "" 74 | } 75 | return fmt.Sprintf("%s.%s", lu.Instance, lu.Project) 76 | } 77 | 78 | func (lu *LoginUser) IsDefaultProject() bool { 79 | return incus.IsDefaultProject(lu.Project) 80 | } 81 | 82 | func (lu *LoginUser) IsValid() bool { 83 | log := log.WithField("session", lu.ctx.ShortSessionID()) 84 | 85 | if lu == nil { 86 | return false 87 | } 88 | 89 | if lu.IsCommand() { 90 | switch lu.Command { 91 | case "shell": 92 | return true 93 | default: 94 | return false 95 | } 96 | } 97 | 98 | if _, ok := loginUserFailedCache.Get(lu.Hash()); ok { 99 | return false 100 | } 101 | if _, ok := loginUserCache.Get(lu.Hash()); ok { 102 | return true 103 | } 104 | 105 | client, err := NewDefaultIncusClientWithContext(lu.ctx) 106 | if err != nil { 107 | log.Errorf("failed to initialize incus client for %s: %v", lu, err) 108 | return false 109 | } 110 | 111 | iu, err := client.GetCachedInstanceUser(lu.Project, lu.Instance, lu.InstanceUser) 112 | if err != nil || iu == nil { 113 | log.Errorf("instance user %s for %s error: %s", lu.InstanceUser, lu, err) 114 | loginUserFailedCache.SetDefault(lu.Hash(), time.Now()) 115 | return false 116 | } 117 | 118 | loginUserFailedCache.Delete(lu.Hash()) 119 | loginUserCache.SetDefault(lu.Hash(), time.Now()) 120 | return true 121 | } 122 | 123 | func (lu *LoginUser) IsCommand() bool { 124 | return lu.Command != "" 125 | } 126 | 127 | func (lu *LoginUser) Hash() string { 128 | if lu == nil { 129 | return "" 130 | } 131 | return fmt.Sprintf("%s/%s/%s/%s/%s", lu.Remote, lu.User, lu.Project, lu.Instance, lu.InstanceUser) 132 | } 133 | 134 | func (lu *LoginUser) InstanceHash() string { 135 | if lu == nil { 136 | return "" 137 | } 138 | return fmt.Sprintf("%s/%s/%s", lu.Remote, lu.Project, lu.Instance) 139 | } 140 | 141 | func getOsUser(username string) (*user.User, error) { 142 | u, err := user.Lookup(username) 143 | if err != nil { 144 | log.Errorf("user lookup: %v", err) 145 | return nil, err 146 | } 147 | return u, nil 148 | } 149 | 150 | func getUserAuthKeys(u *user.User) ([][]byte, error) { 151 | var keys [][]byte 152 | 153 | f, err := os.Open(filepath.Clean(u.HomeDir + "/.ssh/authorized_keys")) 154 | if err != nil { 155 | log.Errorf("error with authorized_keys: %v", err) 156 | return nil, err 157 | } 158 | defer f.Close() 159 | 160 | s := bufio.NewScanner(f) 161 | for s.Scan() { 162 | keys = append(keys, s.Bytes()) 163 | } 164 | return keys, nil 165 | } 166 | 167 | func getUserGroups(u *user.User) ([]string, error) { 168 | groups, err := u.GroupIds() 169 | if err != nil { 170 | log.Errorf("user groups: %v", err) 171 | return nil, err 172 | } 173 | return groups, nil 174 | } 175 | 176 | func parseLoginUser(user string) *LoginUser { 177 | lu := new(LoginUser) 178 | lu.OrigUser = user 179 | lu.InstanceUser = "root" 180 | lu.Project = "default" 181 | 182 | if r, u, ok := strings.Cut(user, ":"); ok { 183 | lu.Remote = r 184 | user = u 185 | } 186 | 187 | instance := user 188 | if i, u, ok := strings.Cut(user, "+"); ok { 189 | instance = i 190 | lu.User = u 191 | } else { 192 | lu.User = "root" 193 | } 194 | 195 | if u, i, ok := strings.Cut(instance, "@"); ok { 196 | instance = i 197 | lu.InstanceUser = u 198 | } 199 | 200 | if i, p, ok := strings.Cut(instance, "."); ok { 201 | lu.Instance = i 202 | lu.Project = p 203 | } else { 204 | lu.Instance = instance 205 | } 206 | 207 | if lu.Project == "" { 208 | lu.Project = "default" 209 | } 210 | 211 | if strings.HasPrefix(lu.Instance, "%") { 212 | lu.Command = strings.TrimPrefix(lu.Instance, "%") 213 | lu.Instance = "" 214 | lu.InstanceUser = "" 215 | } 216 | 217 | return lu 218 | } 219 | 220 | func getGroupIds(groups []string) []string { 221 | var ids []string 222 | for _, g := range groups { 223 | group, err := user.LookupGroup(g) 224 | if err != nil { 225 | log.Errorf("group lookup: %v", err) 226 | continue 227 | } 228 | ids = append(ids, group.Gid) 229 | } 230 | return ids 231 | } 232 | 233 | func groupMatch(a []string, b []string) (string, bool) { 234 | for _, i := range a { 235 | for _, j := range b { 236 | if i == j { 237 | return i, true 238 | } 239 | } 240 | } 241 | return "", false 242 | } 243 | -------------------------------------------------------------------------------- /server/user_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestParseUser(t *testing.T) { 9 | cases := map[string]LoginUser{ 10 | "instance": { 11 | User: "root", 12 | Instance: "instance", 13 | Project: "default", 14 | InstanceUser: "root", 15 | }, 16 | "instance+user": { 17 | User: "user", 18 | Instance: "instance", 19 | Project: "default", 20 | InstanceUser: "root", 21 | }, 22 | "instance.project+user": { 23 | User: "user", 24 | Instance: "instance", 25 | Project: "project", 26 | InstanceUser: "root", 27 | }, 28 | "iuser@instance.project+user": { 29 | User: "user", 30 | Instance: "instance", 31 | Project: "project", 32 | InstanceUser: "iuser", 33 | }, 34 | "iuser@instance.project": { 35 | User: "root", 36 | Instance: "instance", 37 | Project: "project", 38 | InstanceUser: "iuser", 39 | }, 40 | "iuser@instance": { 41 | User: "root", 42 | Instance: "instance", 43 | Project: "default", 44 | InstanceUser: "iuser", 45 | }, 46 | "remote:iuser@instance": { 47 | Remote: "remote", 48 | User: "root", 49 | Instance: "instance", 50 | Project: "default", 51 | InstanceUser: "iuser", 52 | }, 53 | "remote:iuser@instance.project+user": { 54 | Remote: "remote", 55 | User: "user", 56 | Instance: "instance", 57 | Project: "project", 58 | InstanceUser: "iuser", 59 | }, 60 | "%shell": { 61 | User: "root", 62 | Command: "shell", 63 | }, 64 | } 65 | 66 | for us, lu := range cases { 67 | t.Run(us, func(t *testing.T) { 68 | u := parseLoginUser(us) 69 | assert.Equal(t, lu.Instance, u.Instance) 70 | assert.Equal(t, u.InstanceUser, u.InstanceUser) 71 | assert.Equal(t, u.Project, u.Project) 72 | assert.Equal(t, u.User, u.User) 73 | }) 74 | } 75 | } 76 | --------------------------------------------------------------------------------