├── CODEOWNERS ├── example ├── help.png ├── echo │ └── main.go └── completetest │ └── main.go ├── .golangci.yaml ├── Makefile ├── .github └── workflows │ └── ci.yaml ├── completion ├── README.md ├── zsh.go ├── handlers.go ├── fish.go ├── bash.go ├── powershell.go └── all.go ├── env_test.go ├── completion.go ├── net.go ├── env.go ├── help.tpl ├── go.mod ├── serpent.go ├── values_test.go ├── README.md ├── yaml_test.go ├── LICENSE ├── option_test.go ├── yaml.go ├── help.go ├── completion_test.go ├── option.go ├── go.sum ├── values.go ├── command.go └── command_test.go /CODEOWNERS: -------------------------------------------------------------------------------- 1 | ./ @dannykopping @ethanndickson -------------------------------------------------------------------------------- /example/help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coder/serpent/HEAD/example/help.png -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linter-settings: 2 | linters: 3 | disable: 4 | - errcheck 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL = /bin/bash 2 | .ONESHELL: 3 | 4 | .PHONY: lint 5 | lint: 6 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.53.3 7 | ~/go/bin/golangci-lint run 8 | 9 | .PHONY: test 10 | test: 11 | go test -timeout=3m -race . 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: [push] 3 | jobs: 4 | make: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - uses: actions/setup-go@v4 9 | with: 10 | go-version: "^1.21" 11 | - name: test 12 | run: make test 13 | -------------------------------------------------------------------------------- /completion/README.md: -------------------------------------------------------------------------------- 1 | # completion 2 | 3 | The `completion` package extends `serpent` to allow applications to generate rich auto-completions. 4 | 5 | 6 | ## Protocol 7 | 8 | The completion scripts call out to the serpent command to generate 9 | completions. The convention is to pass the exact args and flags (or 10 | cmdline) of the in-progress command with a `COMPLETION_MODE=1` environment variable. That environment variable lets the command know to generate completions instead of running the command. 11 | By default, completions will be generated based on available flags and subcommands. Additional completions can be added by supplying a `CompletionHandlerFunc` on an Option or Command. -------------------------------------------------------------------------------- /example/echo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | 7 | "github.com/coder/serpent" 8 | ) 9 | 10 | func main() { 11 | var upper bool 12 | cmd := serpent.Command{ 13 | Use: "echo ", 14 | Short: "Prints the given text to the console.", 15 | Options: serpent.OptionSet{ 16 | { 17 | Name: "upper", 18 | Value: serpent.BoolOf(&upper), 19 | Flag: "upper", 20 | Description: "Prints the text in upper case.", 21 | }, 22 | }, 23 | Handler: func(inv *serpent.Invocation) error { 24 | if len(inv.Args) == 0 { 25 | inv.Stderr.Write([]byte("error: missing text\n")) 26 | os.Exit(1) 27 | } 28 | 29 | text := inv.Args[0] 30 | if upper { 31 | text = strings.ToUpper(text) 32 | } 33 | 34 | inv.Stdout.Write([]byte(text)) 35 | return nil 36 | }, 37 | } 38 | 39 | err := cmd.Invoke().WithOS().Run() 40 | if err != nil { 41 | panic(err) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /env_test.go: -------------------------------------------------------------------------------- 1 | package serpent_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | serpent "github.com/coder/serpent" 8 | ) 9 | 10 | func TestFilterNamePrefix(t *testing.T) { 11 | t.Parallel() 12 | type args struct { 13 | environ []string 14 | prefix string 15 | } 16 | tests := []struct { 17 | name string 18 | args args 19 | want serpent.Environ 20 | }{ 21 | {"empty", args{[]string{}, "SHIRE"}, nil}, 22 | { 23 | "ONE", 24 | args{ 25 | []string{ 26 | "SHIRE_BRANDYBUCK=hmm", 27 | }, 28 | "SHIRE_", 29 | }, 30 | []serpent.EnvVar{ 31 | {Name: "BRANDYBUCK", Value: "hmm"}, 32 | }, 33 | }, 34 | } 35 | for _, tt := range tests { 36 | tt := tt 37 | t.Run(tt.name, func(t *testing.T) { 38 | t.Parallel() 39 | if got := serpent.ParseEnviron(tt.args.environ, tt.args.prefix); !reflect.DeepEqual(got, tt.want) { 40 | t.Errorf("FilterNamePrefix() = %v, want %v", got, tt.want) 41 | } 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /completion/zsh.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | import ( 4 | "io" 5 | "path/filepath" 6 | 7 | home "github.com/mitchellh/go-homedir" 8 | ) 9 | 10 | type zsh struct { 11 | goos string 12 | programName string 13 | } 14 | 15 | var _ Shell = &zsh{} 16 | 17 | func Zsh(goos string, programName string) Shell { 18 | return &zsh{goos: goos, programName: programName} 19 | } 20 | 21 | func (z *zsh) Name() string { 22 | return "zsh" 23 | } 24 | 25 | func (z *zsh) InstallPath() (string, error) { 26 | homeDir, err := home.Dir() 27 | if err != nil { 28 | return "", err 29 | } 30 | return filepath.Join(homeDir, ".zshrc"), nil 31 | } 32 | 33 | func (z *zsh) WriteCompletion(w io.Writer) error { 34 | return writeConfig(w, zshCompletionTemplate, z.programName) 35 | } 36 | 37 | func (z *zsh) ProgramName() string { 38 | return z.programName 39 | } 40 | 41 | const zshCompletionTemplate = ` 42 | _{{.Name}}_completions() { 43 | local -a args completions 44 | args=("${words[@]:1:$#words}") 45 | completions=(${(f)"$(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")"}) 46 | compadd -a completions 47 | } 48 | compdef _{{.Name}}_completions {{.Name}} 49 | ` 50 | -------------------------------------------------------------------------------- /completion/handlers.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | 9 | "github.com/coder/serpent" 10 | ) 11 | 12 | // FileHandler returns a handler that completes file names, using the 13 | // given filter func, which may be nil. 14 | func FileHandler(filter func(info os.FileInfo) bool) serpent.CompletionHandlerFunc { 15 | return func(inv *serpent.Invocation) []string { 16 | var out []string 17 | _, word := inv.CurWords() 18 | 19 | dir, _ := filepath.Split(word) 20 | if dir == "" { 21 | dir = "." 22 | } 23 | f, err := os.Open(dir) 24 | if err != nil { 25 | return out 26 | } 27 | defer f.Close() 28 | if dir == "." { 29 | dir = "" 30 | } 31 | 32 | infos, err := f.Readdir(0) 33 | if err != nil { 34 | return out 35 | } 36 | 37 | for _, info := range infos { 38 | if filter != nil && !filter(info) { 39 | continue 40 | } 41 | 42 | var cur string 43 | if info.IsDir() { 44 | cur = fmt.Sprintf("%s%s%c", dir, info.Name(), os.PathSeparator) 45 | } else { 46 | cur = fmt.Sprintf("%s%s", dir, info.Name()) 47 | } 48 | 49 | if strings.HasPrefix(cur, word) { 50 | out = append(out, cur) 51 | } 52 | } 53 | return out 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /completion.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/spf13/pflag" 7 | ) 8 | 9 | // CompletionModeEnv is a special environment variable that is 10 | // set when the command is being run in completion mode. 11 | const CompletionModeEnv = "COMPLETION_MODE" 12 | 13 | // IsCompletionMode returns true if the command is being run in completion mode. 14 | func (inv *Invocation) IsCompletionMode() bool { 15 | _, ok := inv.Environ.Lookup(CompletionModeEnv) 16 | return ok 17 | } 18 | 19 | // DefaultCompletionHandler is a handler that prints all the subcommands, or 20 | // all the options that haven't been exhaustively set, if the current word 21 | // starts with a dash. 22 | func DefaultCompletionHandler(inv *Invocation) []string { 23 | _, cur := inv.CurWords() 24 | var allResps []string 25 | if strings.HasPrefix(cur, "-") { 26 | for _, opt := range inv.Command.Options { 27 | _, isSlice := opt.Value.(pflag.SliceValue) 28 | if opt.ValueSource == ValueSourceNone || 29 | opt.ValueSource == ValueSourceDefault || 30 | isSlice { 31 | allResps = append(allResps, "--"+opt.Flag) 32 | } 33 | } 34 | return allResps 35 | } 36 | for _, cmd := range inv.Command.Children { 37 | allResps = append(allResps, cmd.Name()) 38 | } 39 | return allResps 40 | } 41 | -------------------------------------------------------------------------------- /completion/fish.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | import ( 4 | "io" 5 | "path/filepath" 6 | 7 | home "github.com/mitchellh/go-homedir" 8 | ) 9 | 10 | type fish struct { 11 | goos string 12 | programName string 13 | } 14 | 15 | var _ Shell = &fish{} 16 | 17 | func Fish(goos string, programName string) Shell { 18 | return &fish{goos: goos, programName: programName} 19 | } 20 | 21 | func (f *fish) Name() string { 22 | return "fish" 23 | } 24 | 25 | func (f *fish) InstallPath() (string, error) { 26 | homeDir, err := home.Dir() 27 | if err != nil { 28 | return "", err 29 | } 30 | return filepath.Join(homeDir, ".config/fish/completions/", f.programName+".fish"), nil 31 | } 32 | 33 | func (f *fish) WriteCompletion(w io.Writer) error { 34 | return writeConfig(w, fishCompletionTemplate, f.programName) 35 | } 36 | 37 | func (f *fish) ProgramName() string { 38 | return f.programName 39 | } 40 | 41 | const fishCompletionTemplate = ` 42 | function _{{.Name}}_completions 43 | # Capture the full command line as an array 44 | set -l args (commandline -opc) 45 | set -l current (commandline -ct) 46 | COMPLETION_MODE=1 $args $current 47 | end 48 | 49 | # Setup Fish to use the function for completions for '{{.Name}}' 50 | complete -c {{.Name}} -f -a '(_{{.Name}}_completions)' 51 | ` 52 | -------------------------------------------------------------------------------- /completion/bash.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | import ( 4 | "io" 5 | "path/filepath" 6 | 7 | home "github.com/mitchellh/go-homedir" 8 | ) 9 | 10 | type bash struct { 11 | goos string 12 | programName string 13 | } 14 | 15 | var _ Shell = &bash{} 16 | 17 | func Bash(goos string, programName string) Shell { 18 | return &bash{goos: goos, programName: programName} 19 | } 20 | 21 | func (b *bash) Name() string { 22 | return "bash" 23 | } 24 | 25 | func (b *bash) InstallPath() (string, error) { 26 | homeDir, err := home.Dir() 27 | if err != nil { 28 | return "", err 29 | } 30 | if b.goos == "darwin" { 31 | return filepath.Join(homeDir, ".bash_profile"), nil 32 | } 33 | return filepath.Join(homeDir, ".bashrc"), nil 34 | } 35 | 36 | func (b *bash) WriteCompletion(w io.Writer) error { 37 | return writeConfig(w, bashCompletionTemplate, b.programName) 38 | } 39 | 40 | func (b *bash) ProgramName() string { 41 | return b.programName 42 | } 43 | 44 | const bashCompletionTemplate = ` 45 | _generate_{{.Name}}_completions() { 46 | local args=("${COMP_WORDS[@]:1:COMP_CWORD}") 47 | 48 | declare -a output 49 | mapfile -t output < <(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}") 50 | 51 | declare -a completions 52 | mapfile -t completions < <( compgen -W "$(printf '%q ' "${output[@]}")" -- "$2" ) 53 | 54 | local comp 55 | COMPREPLY=() 56 | for comp in "${completions[@]}"; do 57 | COMPREPLY+=("$(printf "%q" "$comp")") 58 | done 59 | } 60 | # Setup Bash to use the function for completions for '{{.Name}}' 61 | complete -F _generate_{{.Name}}_completions {{.Name}} 62 | ` 63 | -------------------------------------------------------------------------------- /net.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import ( 4 | "net" 5 | "strconv" 6 | 7 | "github.com/pion/udp" 8 | "golang.org/x/xerrors" 9 | ) 10 | 11 | // Net abstracts CLI commands interacting with the operating system networking. 12 | // 13 | // At present, it covers opening local listening sockets, since doing this 14 | // in testing is a challenge without flakes, since it's hard to pick a port we 15 | // know a priori will be free. 16 | type Net interface { 17 | // Listen has the same semantics as `net.Listen` but also supports `udp` 18 | Listen(network, address string) (net.Listener, error) 19 | } 20 | 21 | // osNet is an implementation that call the real OS for networking. 22 | type osNet struct{} 23 | 24 | func (osNet) Listen(network, address string) (net.Listener, error) { 25 | switch network { 26 | case "tcp", "tcp4", "tcp6", "unix", "unixpacket": 27 | return net.Listen(network, address) 28 | case "udp": 29 | host, port, err := net.SplitHostPort(address) 30 | if err != nil { 31 | return nil, xerrors.Errorf("split %q: %w", address, err) 32 | } 33 | 34 | var portInt int 35 | portInt, err = strconv.Atoi(port) 36 | if err != nil { 37 | return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, address, err) 38 | } 39 | 40 | // Use pion here so that we get a stream-style net.Conn listener, instead 41 | // of a packet-oriented connection that can read and write to multiple 42 | // addresses. 43 | return udp.Listen(network, &net.UDPAddr{ 44 | IP: net.ParseIP(host), 45 | Port: portInt, 46 | }) 47 | default: 48 | return nil, xerrors.Errorf("unknown listen network %q", network) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /env.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import "strings" 4 | 5 | // name returns the name of the environment variable. 6 | func envName(line string) string { 7 | return strings.ToUpper( 8 | strings.SplitN(line, "=", 2)[0], 9 | ) 10 | } 11 | 12 | // value returns the value of the environment variable. 13 | func envValue(line string) string { 14 | tokens := strings.SplitN(line, "=", 2) 15 | if len(tokens) < 2 { 16 | return "" 17 | } 18 | return tokens[1] 19 | } 20 | 21 | // Var represents a single environment variable of form 22 | // NAME=VALUE. 23 | type EnvVar struct { 24 | Name string 25 | Value string 26 | } 27 | 28 | type Environ []EnvVar 29 | 30 | func (e Environ) ToOS() []string { 31 | var env []string 32 | for _, v := range e { 33 | env = append(env, v.Name+"="+v.Value) 34 | } 35 | return env 36 | } 37 | 38 | func (e Environ) Lookup(name string) (string, bool) { 39 | for _, v := range e { 40 | if v.Name == name { 41 | return v.Value, true 42 | } 43 | } 44 | return "", false 45 | } 46 | 47 | func (e Environ) Get(name string) string { 48 | v, _ := e.Lookup(name) 49 | return v 50 | } 51 | 52 | func (e *Environ) Set(name, value string) { 53 | for i, v := range *e { 54 | if v.Name == name { 55 | (*e)[i].Value = value 56 | return 57 | } 58 | } 59 | *e = append(*e, EnvVar{Name: name, Value: value}) 60 | } 61 | 62 | // ParseEnviron returns all environment variables starting with 63 | // prefix without said prefix. 64 | func ParseEnviron(environ []string, prefix string) Environ { 65 | var filtered []EnvVar 66 | for _, line := range environ { 67 | name := envName(line) 68 | if strings.HasPrefix(name, prefix) { 69 | filtered = append(filtered, EnvVar{ 70 | Name: strings.TrimPrefix(name, prefix), 71 | Value: envValue(line), 72 | }) 73 | } 74 | } 75 | return filtered 76 | } 77 | -------------------------------------------------------------------------------- /help.tpl: -------------------------------------------------------------------------------- 1 | {{- /* Heavily inspired by the Go toolchain and fd */ -}} 2 | {{prettyHeader "Usage"}} 3 | {{indent .FullUsage 2}} 4 | 5 | 6 | {{ with .Short }} 7 | {{- indent . 2 | wrapTTY }} 8 | {{"\n"}} 9 | {{- end}} 10 | 11 | {{- with .Deprecated }} 12 | {{- indent (printf "DEPRECATED: %s" .) 2 | wrapTTY }} 13 | {{"\n"}} 14 | {{- end }} 15 | 16 | {{ with .Aliases }} 17 | {{" Aliases: "}} {{- joinStrings .}} 18 | {{- end }} 19 | 20 | {{- with .Long}} 21 | {{"\n"}} 22 | {{- indent . 2}} 23 | {{ "\n" }} 24 | {{- end }} 25 | {{ with visibleChildren . }} 26 | {{- range $index, $child := . }} 27 | {{- if eq $index 0 }} 28 | {{ prettyHeader "Subcommands"}} 29 | {{- end }} 30 | {{- "\n" }} 31 | {{- formatSubcommand . | trimNewline }} 32 | {{- end }} 33 | {{- "\n" }} 34 | {{- end }} 35 | {{- range $index, $group := optionGroups . }} 36 | {{ with $group.Name }} {{- print $group.Name " Options" | prettyHeader }} {{ else -}} {{ prettyHeader "Options"}}{{- end -}} 37 | {{- with $group.Description }} 38 | {{ formatGroupDescription . }} 39 | {{- else }} 40 | {{- end }} 41 | {{- range $index, $option := $group.Options }} 42 | {{- if not (eq $option.FlagShorthand "") }}{{- print "\n "}} {{ keyword "-"}}{{keyword $option.FlagShorthand }}{{", "}} 43 | {{- else }}{{- print "\n " -}} 44 | {{- end }} 45 | {{- with flagName $option }}{{keyword "--"}}{{ keyword . }}{{ end }} {{- with typeHelper $option }} {{ . }}{{ end }} 46 | {{- with envName $option }}, {{ print "$" . | keyword }}{{ end }} 47 | {{- with $option.Default }} (default: {{ . }}){{ end }} 48 | {{- with $option.Description }} 49 | {{- $desc := $option.Description }} 50 | {{ indent $desc 10 }} 51 | {{- if isDeprecated $option }}{{ indent (printf "DEPRECATED: Use %s instead." (useInstead $option)) 10 }}{{ end }} 52 | {{- end -}} 53 | {{- end }} 54 | {{- end }} 55 | {{- if .Parent }} 56 | ——— 57 | Run `{{ rootCommandName . }} --help` for a list of global options. 58 | {{- else }} 59 | {{- end }} 60 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coder/serpent 2 | 3 | go 1.21.4 4 | 5 | require ( 6 | cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 7 | github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 8 | github.com/hashicorp/go-multierror v1.1.1 9 | github.com/mitchellh/go-homedir v1.1.0 10 | github.com/mitchellh/go-wordwrap v1.0.1 11 | github.com/muesli/termenv v0.15.2 12 | github.com/natefinch/atomic v1.0.1 13 | github.com/pion/udp v0.1.4 14 | github.com/spf13/pflag v1.0.5 15 | github.com/stretchr/testify v1.8.4 16 | github.com/xhit/go-str2duration/v2 v2.1.0 17 | golang.org/x/crypto v0.19.0 18 | golang.org/x/exp v0.0.0-20240213143201-ec583247a57a 19 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 20 | gopkg.in/yaml.v3 v3.0.1 21 | ) 22 | 23 | require ( 24 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 25 | github.com/charmbracelet/lipgloss v0.8.0 // indirect 26 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 27 | github.com/go-logr/logr v1.4.1 // indirect 28 | github.com/google/go-cmp v0.6.0 // indirect 29 | github.com/hashicorp/errwrap v1.1.0 // indirect 30 | github.com/kr/pretty v0.3.1 // indirect 31 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 32 | github.com/mattn/go-isatty v0.0.20 // indirect 33 | github.com/mattn/go-runewidth v0.0.15 // indirect 34 | github.com/pion/transport/v2 v2.0.0 // indirect 35 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 36 | github.com/rivo/uniseg v0.4.4 // indirect 37 | github.com/rogpeppe/go-internal v1.10.0 // indirect 38 | go.opentelemetry.io/otel v1.19.0 // indirect 39 | go.opentelemetry.io/otel/sdk v1.19.0 // indirect 40 | go.opentelemetry.io/otel/trace v1.19.0 // indirect 41 | golang.org/x/net v0.21.0 // indirect 42 | golang.org/x/sys v0.17.0 // indirect 43 | golang.org/x/term v0.17.0 // indirect 44 | google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 // indirect 45 | google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 // indirect 46 | google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f // indirect 47 | google.golang.org/grpc v1.61.0 // indirect 48 | google.golang.org/protobuf v1.32.0 // indirect 49 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 50 | ) 51 | -------------------------------------------------------------------------------- /serpent.go: -------------------------------------------------------------------------------- 1 | // Package serpent offers an all-in-one solution for a highly configurable CLI 2 | // application. Within Coder, we use it for all of our subcommands, which 3 | // demands more functionality than cobra/viber offers. 4 | // 5 | // The Command interface is loosely based on the chi middleware pattern and 6 | // http.Handler/HandlerFunc. 7 | package serpent 8 | 9 | import ( 10 | "strings" 11 | 12 | "golang.org/x/exp/maps" 13 | ) 14 | 15 | // Group describes a hierarchy of groups that an option or command belongs to. 16 | type Group struct { 17 | Parent *Group `json:"parent,omitempty"` 18 | Name string `json:"name,omitempty"` 19 | YAML string `json:"yaml,omitempty"` 20 | Description string `json:"description,omitempty"` 21 | } 22 | 23 | // Ancestry returns the group and all of its parents, in order. 24 | func (g *Group) Ancestry() []Group { 25 | if g == nil { 26 | return nil 27 | } 28 | 29 | groups := []Group{*g} 30 | for p := g.Parent; p != nil; p = p.Parent { 31 | // Prepend to the slice so that the order is correct. 32 | groups = append([]Group{*p}, groups...) 33 | } 34 | return groups 35 | } 36 | 37 | func (g *Group) FullName() string { 38 | var names []string 39 | for _, g := range g.Ancestry() { 40 | names = append(names, g.Name) 41 | } 42 | return strings.Join(names, " / ") 43 | } 44 | 45 | // Annotations is an arbitrary key-mapping used to extend the Option and Command types. 46 | // Its methods won't panic if the map is nil. 47 | type Annotations map[string]string 48 | 49 | // Mark sets a value on the annotations map, creating one 50 | // if it doesn't exist. Mark does not mutate the original and 51 | // returns a copy. It is suitable for chaining. 52 | func (a Annotations) Mark(key string, value string) Annotations { 53 | var aa Annotations 54 | if a != nil { 55 | aa = maps.Clone(a) 56 | } else { 57 | aa = make(Annotations) 58 | } 59 | aa[key] = value 60 | return aa 61 | } 62 | 63 | // IsSet returns true if the key is set in the annotations map. 64 | func (a Annotations) IsSet(key string) bool { 65 | if a == nil { 66 | return false 67 | } 68 | _, ok := a[key] 69 | return ok 70 | } 71 | 72 | // Get retrieves a key from the map, returning false if the key is not found 73 | // or the map is nil. 74 | func (a Annotations) Get(key string) (string, bool) { 75 | if a == nil { 76 | return "", false 77 | } 78 | v, ok := a[key] 79 | return v, ok 80 | } 81 | -------------------------------------------------------------------------------- /completion/powershell.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | import ( 4 | "io" 5 | "os/exec" 6 | "strings" 7 | ) 8 | 9 | type powershell struct { 10 | goos string 11 | programName string 12 | } 13 | 14 | var _ Shell = &powershell{} 15 | 16 | func (p *powershell) Name() string { 17 | return "powershell" 18 | } 19 | 20 | func Powershell(goos string, programName string) Shell { 21 | return &powershell{goos: goos, programName: programName} 22 | } 23 | 24 | func (p *powershell) InstallPath() (string, error) { 25 | var ( 26 | path []byte 27 | err error 28 | ) 29 | cmd := "$PROFILE.CurrentUserAllHosts" 30 | if p.goos == "windows" { 31 | path, err = exec.Command("powershell", cmd).CombinedOutput() 32 | } else { 33 | path, err = exec.Command("pwsh", "-Command", cmd).CombinedOutput() 34 | } 35 | if err != nil { 36 | return "", err 37 | } 38 | return strings.TrimSpace(string(path)), nil 39 | } 40 | 41 | func (p *powershell) WriteCompletion(w io.Writer) error { 42 | return writeConfig(w, pshCompletionTemplate, p.programName) 43 | } 44 | 45 | func (p *powershell) ProgramName() string { 46 | return p.programName 47 | } 48 | 49 | const pshCompletionTemplate = ` 50 | # Escaping output sourced from: 51 | # https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47 52 | filter _{{.Name}}_escapeStringWithSpecialChars { 53 | ` + " $_ -replace '\\s|#|@|\\$|;|,|''|\\{|\\}|\\(|\\)|\"|`|\\||<|>|&','`$&'" + ` 54 | } 55 | 56 | $_{{.Name}}_completions = { 57 | param( 58 | $wordToComplete, 59 | $commandAst, 60 | $cursorPosition 61 | ) 62 | # Legacy space handling sourced from: 63 | # https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L107 64 | if ($PSVersionTable.PsVersion -lt [version]'7.2.0' -or 65 | ($PSVersionTable.PsVersion -lt [version]'7.3.0' -and -not [ExperimentalFeature]::IsEnabled("PSNativeCommandArgumentPassing")) -or 66 | (($PSVersionTable.PsVersion -ge [version]'7.3.0' -or [ExperimentalFeature]::IsEnabled("PSNativeCommandArgumentPassing")) -and 67 | $PSNativeCommandArgumentPassing -eq 'Legacy')) { 68 | $Space =` + "' `\"`\"'" + ` 69 | } else { 70 | $Space = ' ""' 71 | } 72 | $Command = $commandAst.ToString().Substring(0, $cursorPosition - 1) 73 | if ($wordToComplete -ne "" ) { 74 | $wordToComplete = $Command.Split(" ")[-1] 75 | } else { 76 | $Command = $Command + $Space 77 | } 78 | # Get completions by calling the command with the COMPLETION_MODE environment variable set to 1 79 | $env:COMPLETION_MODE = 1 80 | Invoke-Expression $Command | Where-Object { $_ -like "$wordToComplete*" } | ForEach-Object { 81 | "$_" | _{{.Name}}_escapeStringWithSpecialChars 82 | } 83 | $env:COMPLETION_MODE = '' 84 | } 85 | Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions 86 | ` 87 | -------------------------------------------------------------------------------- /values_test.go: -------------------------------------------------------------------------------- 1 | package serpent_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/require" 8 | 9 | serpent "github.com/coder/serpent" 10 | ) 11 | 12 | func TestDuration(t *testing.T) { 13 | t.Parallel() 14 | 15 | tests := []struct { 16 | name string 17 | input string 18 | expected time.Duration 19 | wantErr bool 20 | }{ 21 | // Standard time.Duration formats (should still work) 22 | { 23 | name: "Nanoseconds", 24 | input: "100ns", 25 | expected: 100 * time.Nanosecond, 26 | }, 27 | { 28 | name: "Microseconds", 29 | input: "100us", 30 | expected: 100 * time.Microsecond, 31 | }, 32 | { 33 | name: "Milliseconds", 34 | input: "100ms", 35 | expected: 100 * time.Millisecond, 36 | }, 37 | { 38 | name: "Seconds", 39 | input: "30s", 40 | expected: 30 * time.Second, 41 | }, 42 | { 43 | name: "Minutes", 44 | input: "5m", 45 | expected: 5 * time.Minute, 46 | }, 47 | { 48 | name: "Hours", 49 | input: "2h", 50 | expected: 2 * time.Hour, 51 | }, 52 | { 53 | name: "Combined", 54 | input: "1h30m", 55 | expected: 90 * time.Minute, 56 | }, 57 | // New formats with days and weeks support 58 | { 59 | name: "Days", 60 | input: "1d", 61 | expected: 24 * time.Hour, 62 | }, 63 | { 64 | name: "MultipleDays", 65 | input: "7d", 66 | expected: 7 * 24 * time.Hour, 67 | }, 68 | { 69 | name: "Weeks", 70 | input: "1w", 71 | expected: 7 * 24 * time.Hour, 72 | }, 73 | { 74 | name: "MultipleWeeks", 75 | input: "2w", 76 | expected: 14 * 24 * time.Hour, 77 | }, 78 | { 79 | name: "CombinedWithDays", 80 | input: "1d12h", 81 | expected: 36 * time.Hour, 82 | }, 83 | { 84 | name: "CombinedWithWeeks", 85 | input: "1w2d", 86 | expected: (7 + 2) * 24 * time.Hour, 87 | }, 88 | { 89 | name: "ComplexCombination", 90 | input: "2w3d4h5m6s", 91 | expected: (14 + 3) * 24 * time.Hour + 4*time.Hour + 5*time.Minute + 6*time.Second, 92 | }, 93 | // Error cases 94 | { 95 | name: "Invalid", 96 | input: "invalid", 97 | wantErr: true, 98 | }, 99 | { 100 | name: "Empty", 101 | input: "", 102 | wantErr: true, 103 | }, 104 | } 105 | 106 | for _, tt := range tests { 107 | t.Run(tt.name, func(t *testing.T) { 108 | t.Parallel() 109 | 110 | var d serpent.Duration 111 | err := d.Set(tt.input) 112 | 113 | if tt.wantErr { 114 | require.Error(t, err) 115 | return 116 | } 117 | 118 | require.NoError(t, err) 119 | require.Equal(t, tt.expected, d.Value()) 120 | 121 | // Verify String() returns a parseable value 122 | str := d.String() 123 | var d2 serpent.Duration 124 | err = d2.Set(str) 125 | require.NoError(t, err) 126 | require.Equal(t, d.Value(), d2.Value(), "String() should return a parseable value") 127 | }) 128 | } 129 | } 130 | 131 | func TestDurationOf(t *testing.T) { 132 | t.Parallel() 133 | 134 | td := 5 * time.Minute 135 | d := serpent.DurationOf(&td) 136 | require.NotNil(t, d) 137 | require.Equal(t, td, d.Value()) 138 | 139 | // Test modification through pointer 140 | newVal := 10 * time.Minute 141 | *d = serpent.Duration(newVal) 142 | require.Equal(t, newVal, time.Duration(td)) 143 | } 144 | -------------------------------------------------------------------------------- /example/completetest/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "github.com/coder/serpent" 9 | "github.com/coder/serpent/completion" 10 | ) 11 | 12 | // installCommand returns a serpent command that helps 13 | // a user configure their shell to use serpent's completion. 14 | func installCommand() *serpent.Command { 15 | var shell string 16 | return &serpent.Command{ 17 | Use: "completion [--shell ]", 18 | Short: "Generate completion scripts for the given shell.", 19 | Handler: func(inv *serpent.Invocation) error { 20 | defaultShell, err := completion.DetectUserShell(inv.Command.Parent.Name()) 21 | if err != nil { 22 | return fmt.Errorf("Could not detect user shell, please specify a shell using `--shell`") 23 | } 24 | return defaultShell.WriteCompletion(inv.Stdout) 25 | }, 26 | Options: serpent.OptionSet{ 27 | { 28 | Flag: "shell", 29 | FlagShorthand: "s", 30 | Description: "The shell to generate a completion script for.", 31 | Value: completion.ShellOptions(&shell), 32 | }, 33 | }, 34 | } 35 | } 36 | 37 | func main() { 38 | var ( 39 | print bool 40 | upper bool 41 | fileType string 42 | fileArr []string 43 | types []string 44 | ) 45 | cmd := serpent.Command{ 46 | Use: "completetest ", 47 | Short: "Prints the given text to the console.", 48 | Options: serpent.OptionSet{ 49 | { 50 | Name: "different", 51 | Value: serpent.BoolOf(&upper), 52 | Flag: "different", 53 | Description: "Do the command differently.", 54 | }, 55 | }, 56 | Handler: func(inv *serpent.Invocation) error { 57 | if len(inv.Args) == 0 { 58 | inv.Stderr.Write([]byte("error: missing text\n")) 59 | os.Exit(1) 60 | } 61 | 62 | text := inv.Args[0] 63 | if upper { 64 | text = strings.ToUpper(text) 65 | } 66 | 67 | inv.Stdout.Write([]byte(text)) 68 | return nil 69 | }, 70 | Children: []*serpent.Command{ 71 | { 72 | Use: "sub", 73 | Short: "A subcommand", 74 | Handler: func(inv *serpent.Invocation) error { 75 | inv.Stdout.Write([]byte("subcommand")) 76 | return nil 77 | }, 78 | Options: serpent.OptionSet{ 79 | { 80 | Name: "upper", 81 | Value: serpent.BoolOf(&upper), 82 | Flag: "upper", 83 | Description: "Prints the text in upper case.", 84 | }, 85 | }, 86 | }, 87 | { 88 | Use: "file ", 89 | Handler: func(inv *serpent.Invocation) error { 90 | return nil 91 | }, 92 | Options: serpent.OptionSet{ 93 | { 94 | Name: "print", 95 | Value: serpent.BoolOf(&print), 96 | Flag: "print", 97 | Description: "Print the file.", 98 | }, 99 | { 100 | Name: "type", 101 | Value: serpent.EnumOf(&fileType, "binary", "text"), 102 | Flag: "type", 103 | Description: "The type of file.", 104 | }, 105 | { 106 | Name: "extra", 107 | Flag: "extra", 108 | Description: "Extra files.", 109 | Value: serpent.StringArrayOf(&fileArr), 110 | }, 111 | { 112 | Name: "types", 113 | Flag: "types", 114 | Value: serpent.EnumArrayOf(&types, "binary", "text"), 115 | }, 116 | }, 117 | CompletionHandler: completion.FileHandler(nil), 118 | Middleware: serpent.RequireNArgs(1), 119 | }, 120 | installCommand(), 121 | }, 122 | } 123 | 124 | inv := cmd.Invoke().WithOS() 125 | 126 | err := inv.Run() 127 | if err != nil { 128 | panic(err) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # serpent 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/coder/serpent.svg)](https://pkg.go.dev/github.com/coder/serpent) 4 | 5 | `serpent` is a Go CLI configuration framework based on [cobra](https://github.com/spf13/cobra) and used by [coder/coder](https://github.com/coder/coder). 6 | It's designed for large-scale CLIs with dozens of commands and hundreds 7 | of options. If you're building a small, self-contained tool, go with 8 | cobra. 9 | 10 | ![help example](./example/help.png) 11 | 12 | When compared to cobra, serpent strives for: 13 | 14 | * Better default help output inspired by the Go toolchain 15 | * Greater flexibility in accepting options that span across multiple sources 16 | * Composition via middleware 17 | * Testability (e.g. OS Stdout and Stderr is only available to commands explicitly) 18 | 19 | ## Basic Usage 20 | 21 | See `example/echo`: 22 | 23 | ```go 24 | package main 25 | 26 | import ( 27 | "os" 28 | "strings" 29 | 30 | "github.com/coder/serpent" 31 | ) 32 | 33 | func main() { 34 | var upper bool 35 | cmd := serpent.Command{ 36 | Use: "echo ", 37 | Short: "Prints the given text to the console.", 38 | Options: serpent.OptionSet{ 39 | { 40 | Name: "upper", 41 | Value: serpent.BoolOf(&upper), 42 | Flag: "upper", 43 | Description: "Prints the text in upper case.", 44 | }, 45 | }, 46 | Handler: func(inv *serpent.Invocation) error { 47 | if len(inv.Args) == 0 { 48 | inv.Stderr.Write([]byte("error: missing text\n")) 49 | os.Exit(1) 50 | } 51 | 52 | text := inv.Args[0] 53 | if upper { 54 | text = strings.ToUpper(text) 55 | } 56 | 57 | inv.Stdout.Write([]byte(text)) 58 | return nil 59 | }, 60 | } 61 | 62 | err := cmd.Invoke().WithOS().Run() 63 | if err != nil { 64 | panic(err) 65 | } 66 | } 67 | ``` 68 | 69 | ## Design 70 | This Design section assumes you have a good understanding of how `cobra` works. 71 | 72 | ### Options 73 | 74 | Serpent is designed for high-configurability. To us, that means providing 75 | many ways to configure the same value (env, YAML, flags, etc.) and keeping 76 | the code clean and testable as you scale the number of options. 77 | 78 | Serpent's [Option](https://pkg.go.dev/github.com/coder/serpent#Option) type looks like: 79 | 80 | ```go 81 | type Option struct { 82 | Name string 83 | Flag string 84 | Env string 85 | Default string 86 | Value pflag.Value 87 | // ... 88 | } 89 | ``` 90 | 91 | And is used by each [Command](https://pkg.go.dev/github.com/coder/serpent#Command) when 92 | passed as an array to the `Options` field. 93 | 94 | ## Comparison with Cobra 95 | 96 | Here is a comparison of the `help` output between a simple `echo` command in Cobra and Serpent. 97 | 98 | ### Cobra 99 | 100 | ``` 101 | echo is for echoing anything back. Echo works a lot like print, except it has a child command. 102 | 103 | Usage: 104 | echo [string to echo] [flags] 105 | 106 | Flags: 107 | -h, --help help for echo 108 | -u, --upper make the output uppercase 109 | ``` 110 | 111 | ### Serpent 112 | 113 | ``` 114 | USAGE: 115 | echo 116 | 117 | Prints the given text to the console. 118 | 119 | OPTIONS: 120 | --upper bool 121 | Prints the text in upper case. 122 | ``` 123 | 124 | ## Migrating from Cobra 125 | 126 | Serpent is designed to be a replacement for Cobra and Viper. If you are familiar with Cobra, the transition to Serpent should be relatively straightforward. The main differences are: 127 | 128 | * **Command Structure:** Serpent uses a `serpent.Command` struct which is similar to `cobra.Command`. 129 | * **Options:** Serpent has a more flexible and powerful option system that allows you to define options from multiple sources (flags, environment variables, config files, etc.) in a single place. 130 | * **Middleware:** Serpent has a middleware system that allows you to compose functionality and apply it to your commands. 131 | * **Testability:** Serpent is designed to be more testable than Cobra. For example, OS stdout and stderr are only available to commands explicitly. 132 | 133 | ## Serpent vs. Cobra and Viper 134 | 135 | Serpent is intended to be a complete replacement for both Cobra and Viper. While Viper is often used with Cobra to provide environment and config file support, Serpent has this functionality built-in. This results in a more integrated and streamlined experience. 136 | 137 | ## Examples 138 | 139 | For a more comprehensive example of a large-scale CLI built with Serpent, please see the [coder/coder](https://github.com/coder/coder) repository. -------------------------------------------------------------------------------- /yaml_test.go: -------------------------------------------------------------------------------- 1 | package serpent_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/spf13/pflag" 7 | "github.com/stretchr/testify/require" 8 | "golang.org/x/exp/slices" 9 | "gopkg.in/yaml.v3" 10 | 11 | "github.com/coder/serpent" 12 | ) 13 | 14 | func TestOptionSet_YAML(t *testing.T) { 15 | t.Parallel() 16 | 17 | t.Run("RequireKey", func(t *testing.T) { 18 | t.Parallel() 19 | var workspaceName serpent.String 20 | os := serpent.OptionSet{ 21 | serpent.Option{ 22 | Name: "Workspace Name", 23 | Value: &workspaceName, 24 | Default: "billie", 25 | }, 26 | } 27 | 28 | node, err := os.MarshalYAML() 29 | require.NoError(t, err) 30 | require.Len(t, node.(*yaml.Node).Content, 0) 31 | }) 32 | 33 | t.Run("SimpleString", func(t *testing.T) { 34 | t.Parallel() 35 | 36 | var workspaceName serpent.String 37 | 38 | os := serpent.OptionSet{ 39 | serpent.Option{ 40 | Name: "Workspace Name", 41 | Value: &workspaceName, 42 | Default: "billie", 43 | Description: "The workspace's name.", 44 | Group: &serpent.Group{YAML: "names"}, 45 | YAML: "workspaceName", 46 | }, 47 | } 48 | 49 | err := os.SetDefaults() 50 | require.NoError(t, err) 51 | 52 | n, err := os.MarshalYAML() 53 | require.NoError(t, err) 54 | // Visually inspect for now. 55 | byt, err := yaml.Marshal(n) 56 | require.NoError(t, err) 57 | t.Logf("Raw YAML:\n%s", string(byt)) 58 | }) 59 | } 60 | 61 | func TestOptionSet_YAMLUnknownOptions(t *testing.T) { 62 | t.Parallel() 63 | os := serpent.OptionSet{ 64 | { 65 | Name: "Workspace Name", 66 | Default: "billie", 67 | Description: "The workspace's name.", 68 | YAML: "workspaceName", 69 | Value: new(serpent.String), 70 | }, 71 | } 72 | 73 | const yamlDoc = `something: else` 74 | err := yaml.Unmarshal([]byte(yamlDoc), &os) 75 | require.Error(t, err) 76 | require.Empty(t, os[0].Value.String()) 77 | 78 | os[0].YAML = "something" 79 | 80 | err = yaml.Unmarshal([]byte(yamlDoc), &os) 81 | require.NoError(t, err) 82 | 83 | require.Equal(t, "else", os[0].Value.String()) 84 | } 85 | 86 | // TestOptionSet_YAMLIsomorphism tests that the YAML representations of an 87 | // OptionSet converts to the same OptionSet when read back in. 88 | func TestOptionSet_YAMLIsomorphism(t *testing.T) { 89 | t.Parallel() 90 | // This is used to form a generic. 91 | //nolint:unused 92 | type kid struct { 93 | Name string `yaml:"name"` 94 | Age int `yaml:"age"` 95 | } 96 | 97 | for _, tc := range []struct { 98 | name string 99 | os serpent.OptionSet 100 | zeroValue func() pflag.Value 101 | }{ 102 | { 103 | name: "SimpleString", 104 | os: serpent.OptionSet{ 105 | { 106 | Name: "Workspace Name", 107 | Default: "billie", 108 | Description: "The workspace's name.", 109 | Group: &serpent.Group{YAML: "names"}, 110 | YAML: "workspaceName", 111 | }, 112 | }, 113 | zeroValue: func() pflag.Value { 114 | return serpent.StringOf(new(string)) 115 | }, 116 | }, 117 | { 118 | name: "Array", 119 | os: serpent.OptionSet{ 120 | { 121 | YAML: "names", 122 | Default: "jill,jack,joan", 123 | }, 124 | }, 125 | zeroValue: func() pflag.Value { 126 | return serpent.StringArrayOf(&[]string{}) 127 | }, 128 | }, 129 | { 130 | name: "ComplexObject", 131 | os: serpent.OptionSet{ 132 | { 133 | YAML: "kids", 134 | Default: `- name: jill 135 | age: 12 136 | - name: jack 137 | age: 13`, 138 | }, 139 | }, 140 | zeroValue: func() pflag.Value { 141 | return &serpent.Struct[[]kid]{} 142 | }, 143 | }, 144 | { 145 | name: "DeepGroup", 146 | os: serpent.OptionSet{ 147 | { 148 | YAML: "names", 149 | Default: "jill,jack,joan", 150 | Group: &serpent.Group{YAML: "kids", Parent: &serpent.Group{YAML: "family"}}, 151 | }, 152 | }, 153 | zeroValue: func() pflag.Value { 154 | return serpent.StringArrayOf(&[]string{}) 155 | }, 156 | }, 157 | } { 158 | tc := tc 159 | t.Run(tc.name, func(t *testing.T) { 160 | t.Parallel() 161 | 162 | // Set initial values. 163 | for i := range tc.os { 164 | tc.os[i].Value = tc.zeroValue() 165 | } 166 | err := tc.os.SetDefaults() 167 | require.NoError(t, err) 168 | 169 | y, err := tc.os.MarshalYAML() 170 | require.NoError(t, err) 171 | 172 | toByt, err := yaml.Marshal(y) 173 | require.NoError(t, err) 174 | 175 | t.Logf("Raw YAML:\n%s", string(toByt)) 176 | 177 | var y2 yaml.Node 178 | err = yaml.Unmarshal(toByt, &y2) 179 | require.NoError(t, err) 180 | 181 | os2 := slices.Clone(tc.os) 182 | for i := range os2 { 183 | os2[i].Value = tc.zeroValue() 184 | os2[i].ValueSource = serpent.ValueSourceNone 185 | } 186 | 187 | // os2 values should be zeroed whereas tc.os should be 188 | // set to defaults. 189 | // This check makes sure we aren't mixing pointers. 190 | require.NotEqual(t, tc.os, os2) 191 | err = os2.UnmarshalYAML(&y2) 192 | require.NoError(t, err) 193 | 194 | want := tc.os 195 | for i := range want { 196 | want[i].ValueSource = serpent.ValueSourceYAML 197 | } 198 | 199 | require.Equal(t, tc.os, os2) 200 | }) 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /completion/all.go: -------------------------------------------------------------------------------- 1 | package completion 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "io/fs" 9 | "os" 10 | "os/user" 11 | "path/filepath" 12 | "runtime" 13 | "strings" 14 | "text/template" 15 | 16 | "github.com/coder/serpent" 17 | 18 | "github.com/natefinch/atomic" 19 | ) 20 | 21 | const ( 22 | completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============` 23 | completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============` 24 | ) 25 | 26 | type Shell interface { 27 | Name() string 28 | InstallPath() (string, error) 29 | WriteCompletion(io.Writer) error 30 | ProgramName() string 31 | } 32 | 33 | const ( 34 | ShellBash string = "bash" 35 | ShellFish string = "fish" 36 | ShellZsh string = "zsh" 37 | ShellPowershell string = "powershell" 38 | ) 39 | 40 | func ShellByName(shell, programName string) (Shell, error) { 41 | switch shell { 42 | case ShellBash: 43 | return Bash(runtime.GOOS, programName), nil 44 | case ShellFish: 45 | return Fish(runtime.GOOS, programName), nil 46 | case ShellZsh: 47 | return Zsh(runtime.GOOS, programName), nil 48 | case ShellPowershell: 49 | return Powershell(runtime.GOOS, programName), nil 50 | default: 51 | return nil, fmt.Errorf("unsupported shell %q", shell) 52 | } 53 | } 54 | 55 | func ShellOptions(choice *string) *serpent.Enum { 56 | return serpent.EnumOf(choice, ShellBash, ShellFish, ShellZsh, ShellPowershell) 57 | } 58 | 59 | func DetectUserShell(programName string) (Shell, error) { 60 | // Attempt to get the SHELL environment variable first 61 | if shell := os.Getenv("SHELL"); shell != "" { 62 | return ShellByName(filepath.Base(shell), programName) 63 | } 64 | 65 | // Fallback: Look up the current user and parse /etc/passwd 66 | currentUser, err := user.Current() 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | // Open and parse /etc/passwd 72 | passwdFile, err := os.ReadFile("/etc/passwd") 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | lines := strings.Split(string(passwdFile), "\n") 78 | for _, line := range lines { 79 | if strings.HasPrefix(line, currentUser.Username+":") { 80 | parts := strings.Split(line, ":") 81 | if len(parts) > 6 { 82 | return ShellByName(filepath.Base(parts[6]), programName) // The shell is typically the 7th field 83 | } 84 | } 85 | } 86 | 87 | return nil, fmt.Errorf("default shell not found") 88 | } 89 | 90 | func writeConfig( 91 | w io.Writer, 92 | cfgTemplate string, 93 | programName string, 94 | ) error { 95 | tmpl, err := template.New("script").Parse(cfgTemplate) 96 | if err != nil { 97 | return fmt.Errorf("parse template: %w", err) 98 | } 99 | 100 | err = tmpl.Execute( 101 | w, 102 | map[string]string{ 103 | "Name": programName, 104 | }, 105 | ) 106 | if err != nil { 107 | return fmt.Errorf("execute template: %w", err) 108 | } 109 | 110 | return nil 111 | } 112 | 113 | func InstallShellCompletion(shell Shell) error { 114 | path, err := shell.InstallPath() 115 | if err != nil { 116 | return fmt.Errorf("get install path: %w", err) 117 | } 118 | var headerBuf bytes.Buffer 119 | err = writeConfig(&headerBuf, completionStartTemplate, shell.ProgramName()) 120 | if err != nil { 121 | return fmt.Errorf("generate header: %w", err) 122 | } 123 | 124 | var footerBytes bytes.Buffer 125 | err = writeConfig(&footerBytes, completionEndTemplate, shell.ProgramName()) 126 | if err != nil { 127 | return fmt.Errorf("generate footer: %w", err) 128 | } 129 | 130 | err = os.MkdirAll(filepath.Dir(path), 0o755) 131 | if err != nil { 132 | return fmt.Errorf("create directories: %w", err) 133 | } 134 | 135 | f, err := os.ReadFile(path) 136 | if err != nil && !errors.Is(err, fs.ErrNotExist) { 137 | return fmt.Errorf("read ssh config failed: %w", err) 138 | } 139 | 140 | before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f) 141 | if err != nil { 142 | return err 143 | } 144 | 145 | outBuf := bytes.Buffer{} 146 | _, _ = outBuf.Write(before) 147 | if len(before) > 0 { 148 | _, _ = outBuf.Write([]byte("\n")) 149 | } 150 | _, _ = outBuf.Write(headerBuf.Bytes()) 151 | err = shell.WriteCompletion(&outBuf) 152 | if err != nil { 153 | return fmt.Errorf("generate completion: %w", err) 154 | } 155 | _, _ = outBuf.Write(footerBytes.Bytes()) 156 | _, _ = outBuf.Write([]byte("\n")) 157 | _, _ = outBuf.Write(after) 158 | 159 | err = atomic.WriteFile(path, &outBuf) 160 | if err != nil { 161 | return fmt.Errorf("write completion: %w", err) 162 | } 163 | 164 | return nil 165 | } 166 | 167 | func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) { 168 | startCount := bytes.Count(data, header) 169 | endCount := bytes.Count(data, footer) 170 | if startCount > 1 || endCount > 1 { 171 | return nil, nil, fmt.Errorf("Malformed config file: multiple config sections") 172 | } 173 | 174 | startIndex := bytes.Index(data, header) 175 | endIndex := bytes.Index(data, footer) 176 | if startIndex == -1 && endIndex != -1 { 177 | return data, nil, fmt.Errorf("Malformed config file: missing completion header") 178 | } 179 | if startIndex != -1 && endIndex == -1 { 180 | return data, nil, fmt.Errorf("Malformed config file: missing completion footer") 181 | } 182 | if startIndex != -1 && endIndex != -1 { 183 | if startIndex > endIndex { 184 | return data, nil, fmt.Errorf("Malformed config file: completion header after footer") 185 | } 186 | // Include leading and trailing newline, if present 187 | start := startIndex 188 | if start > 0 { 189 | start-- 190 | } 191 | end := endIndex + len(footer) 192 | if end < len(data) { 193 | end++ 194 | } 195 | return data[:start], data[end:], nil 196 | } 197 | return data, nil, nil 198 | } 199 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. -------------------------------------------------------------------------------- /option_test.go: -------------------------------------------------------------------------------- 1 | package serpent_test 2 | 3 | import ( 4 | "encoding/json" 5 | "regexp" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | serpent "github.com/coder/serpent" 11 | ) 12 | 13 | func TestOptionSet_ParseFlags(t *testing.T) { 14 | t.Parallel() 15 | 16 | t.Run("SimpleString", func(t *testing.T) { 17 | t.Parallel() 18 | 19 | var workspaceName serpent.String 20 | 21 | os := serpent.OptionSet{ 22 | serpent.Option{ 23 | Name: "Workspace Name", 24 | Value: &workspaceName, 25 | Flag: "workspace-name", 26 | FlagShorthand: "n", 27 | }, 28 | } 29 | 30 | var err error 31 | err = os.FlagSet().Parse([]string{"--workspace-name", "foo"}) 32 | require.NoError(t, err) 33 | require.EqualValues(t, "foo", workspaceName) 34 | 35 | err = os.FlagSet().Parse([]string{"-n", "f"}) 36 | require.NoError(t, err) 37 | require.EqualValues(t, "f", workspaceName) 38 | }) 39 | 40 | t.Run("StringArray", func(t *testing.T) { 41 | t.Parallel() 42 | 43 | var names serpent.StringArray 44 | 45 | os := serpent.OptionSet{ 46 | serpent.Option{ 47 | Name: "name", 48 | Value: &names, 49 | Flag: "name", 50 | FlagShorthand: "n", 51 | }, 52 | } 53 | 54 | err := os.SetDefaults() 55 | require.NoError(t, err) 56 | 57 | err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"}) 58 | require.NoError(t, err) 59 | require.EqualValues(t, []string{"foo", "bar"}, names) 60 | }) 61 | 62 | t.Run("ExtraFlags", func(t *testing.T) { 63 | t.Parallel() 64 | 65 | var workspaceName serpent.String 66 | 67 | os := serpent.OptionSet{ 68 | serpent.Option{ 69 | Name: "Workspace Name", 70 | Value: &workspaceName, 71 | }, 72 | } 73 | 74 | err := os.FlagSet().Parse([]string{"--some-unknown", "foo"}) 75 | require.Error(t, err) 76 | }) 77 | 78 | t.Run("RegexValid", func(t *testing.T) { 79 | t.Parallel() 80 | 81 | var regexpString serpent.Regexp 82 | 83 | os := serpent.OptionSet{ 84 | serpent.Option{ 85 | Name: "RegexpString", 86 | Value: ®expString, 87 | Flag: "regexp-string", 88 | }, 89 | } 90 | 91 | err := os.FlagSet().Parse([]string{"--regexp-string", "$test^"}) 92 | require.NoError(t, err) 93 | }) 94 | 95 | t.Run("RegexInvalid", func(t *testing.T) { 96 | t.Parallel() 97 | 98 | var regexpString serpent.Regexp 99 | 100 | os := serpent.OptionSet{ 101 | serpent.Option{ 102 | Name: "RegexpString", 103 | Value: ®expString, 104 | Flag: "regexp-string", 105 | }, 106 | } 107 | 108 | err := os.FlagSet().Parse([]string{"--regexp-string", "(("}) 109 | require.Error(t, err) 110 | }) 111 | } 112 | 113 | func TestOptionSet_ParseEnv(t *testing.T) { 114 | t.Parallel() 115 | 116 | t.Run("SimpleString", func(t *testing.T) { 117 | t.Parallel() 118 | 119 | var workspaceName serpent.String 120 | 121 | os := serpent.OptionSet{ 122 | serpent.Option{ 123 | Name: "Workspace Name", 124 | Value: &workspaceName, 125 | Env: "WORKSPACE_NAME", 126 | }, 127 | } 128 | 129 | err := os.ParseEnv([]serpent.EnvVar{ 130 | {Name: "WORKSPACE_NAME", Value: "foo"}, 131 | }) 132 | require.NoError(t, err) 133 | require.EqualValues(t, "foo", workspaceName) 134 | }) 135 | 136 | t.Run("EmptyValue", func(t *testing.T) { 137 | t.Parallel() 138 | 139 | var workspaceName serpent.String 140 | 141 | os := serpent.OptionSet{ 142 | serpent.Option{ 143 | Name: "Workspace Name", 144 | Value: &workspaceName, 145 | Default: "defname", 146 | Env: "WORKSPACE_NAME", 147 | }, 148 | } 149 | 150 | err := os.SetDefaults() 151 | require.NoError(t, err) 152 | 153 | err = os.ParseEnv(serpent.ParseEnviron([]string{"CODER_WORKSPACE_NAME="}, "CODER_")) 154 | require.NoError(t, err) 155 | require.EqualValues(t, "defname", workspaceName) 156 | }) 157 | 158 | t.Run("StringSlice", func(t *testing.T) { 159 | t.Parallel() 160 | 161 | var actual serpent.StringArray 162 | expected := []string{"foo", "bar", "baz"} 163 | 164 | os := serpent.OptionSet{ 165 | serpent.Option{ 166 | Name: "name", 167 | Value: &actual, 168 | Env: "NAMES", 169 | }, 170 | } 171 | 172 | err := os.SetDefaults() 173 | require.NoError(t, err) 174 | 175 | err = os.ParseEnv([]serpent.EnvVar{ 176 | {Name: "NAMES", Value: "foo,bar,baz"}, 177 | }) 178 | require.NoError(t, err) 179 | require.EqualValues(t, expected, actual) 180 | }) 181 | 182 | t.Run("StructMapStringString", func(t *testing.T) { 183 | t.Parallel() 184 | 185 | var actual serpent.Struct[map[string]string] 186 | expected := map[string]string{"foo": "bar", "baz": "zap"} 187 | 188 | os := serpent.OptionSet{ 189 | serpent.Option{ 190 | Name: "labels", 191 | Value: &actual, 192 | Env: "LABELS", 193 | }, 194 | } 195 | 196 | err := os.SetDefaults() 197 | require.NoError(t, err) 198 | 199 | err = os.ParseEnv([]serpent.EnvVar{ 200 | {Name: "LABELS", Value: `{"foo":"bar","baz":"zap"}`}, 201 | }) 202 | require.NoError(t, err) 203 | require.EqualValues(t, expected, actual.Value) 204 | }) 205 | 206 | t.Run("Homebrew", func(t *testing.T) { 207 | t.Parallel() 208 | 209 | var agentToken serpent.String 210 | 211 | os := serpent.OptionSet{ 212 | serpent.Option{ 213 | Name: "Agent Token", 214 | Value: &agentToken, 215 | Env: "AGENT_TOKEN", 216 | }, 217 | } 218 | 219 | err := os.ParseEnv([]serpent.EnvVar{ 220 | {Name: "HOMEBREW_AGENT_TOKEN", Value: "foo"}, 221 | }) 222 | require.NoError(t, err) 223 | require.EqualValues(t, "foo", agentToken) 224 | }) 225 | } 226 | 227 | func TestOptionSet_JsonMarshal(t *testing.T) { 228 | t.Parallel() 229 | 230 | // This unit test ensures if the source optionset is missing the option 231 | // and cannot determine the type, it will not panic. The unmarshal will 232 | // succeed with a best effort. 233 | t.Run("MissingSrcOption", func(t *testing.T) { 234 | t.Parallel() 235 | 236 | var str serpent.String = "something" 237 | var arr serpent.StringArray = []string{"foo", "bar"} 238 | opts := serpent.OptionSet{ 239 | serpent.Option{ 240 | Name: "StringOpt", 241 | Value: &str, 242 | }, 243 | serpent.Option{ 244 | Name: "ArrayOpt", 245 | Value: &arr, 246 | }, 247 | } 248 | data, err := json.Marshal(opts) 249 | require.NoError(t, err, "marshal option set") 250 | 251 | tgt := serpent.OptionSet{} 252 | err = json.Unmarshal(data, &tgt) 253 | require.NoError(t, err, "unmarshal option set") 254 | for i := range opts { 255 | compareOptionsExceptValues(t, opts[i], tgt[i]) 256 | require.Empty(t, tgt[i].Value.String(), "unknown value types are empty") 257 | } 258 | }) 259 | 260 | t.Run("RegexCase", func(t *testing.T) { 261 | t.Parallel() 262 | 263 | val := serpent.Regexp(*regexp.MustCompile(".*")) 264 | opts := serpent.OptionSet{ 265 | serpent.Option{ 266 | Name: "Regex", 267 | Value: &val, 268 | Default: ".*", 269 | }, 270 | } 271 | data, err := json.Marshal(opts) 272 | require.NoError(t, err, "marshal option set") 273 | 274 | var foundVal serpent.Regexp 275 | newOpts := serpent.OptionSet{ 276 | serpent.Option{ 277 | Name: "Regex", 278 | Value: &foundVal, 279 | }, 280 | } 281 | err = json.Unmarshal(data, &newOpts) 282 | require.NoError(t, err, "unmarshal option set") 283 | 284 | require.EqualValues(t, opts[0].Value.String(), newOpts[0].Value.String()) 285 | }) 286 | } 287 | 288 | func compareOptionsExceptValues(t *testing.T, exp, found serpent.Option) { 289 | t.Helper() 290 | 291 | require.Equalf(t, exp.Name, found.Name, "option name %q", exp.Name) 292 | require.Equalf(t, exp.Description, found.Description, "option description %q", exp.Name) 293 | require.Equalf(t, exp.Required, found.Required, "option required %q", exp.Name) 294 | require.Equalf(t, exp.Flag, found.Flag, "option flag %q", exp.Name) 295 | require.Equalf(t, exp.FlagShorthand, found.FlagShorthand, "option flag shorthand %q", exp.Name) 296 | require.Equalf(t, exp.Env, found.Env, "option env %q", exp.Name) 297 | require.Equalf(t, exp.YAML, found.YAML, "option yaml %q", exp.Name) 298 | require.Equalf(t, exp.Default, found.Default, "option default %q", exp.Name) 299 | require.Equalf(t, exp.ValueSource, found.ValueSource, "option value source %q", exp.Name) 300 | require.Equalf(t, exp.Hidden, found.Hidden, "option hidden %q", exp.Name) 301 | require.Equalf(t, exp.Annotations, found.Annotations, "option annotations %q", exp.Name) 302 | require.Equalf(t, exp.Group, found.Group, "option group %q", exp.Name) 303 | // UseInstead is the same comparison problem, just check the length 304 | require.Equalf(t, len(exp.UseInstead), len(found.UseInstead), "option use instead %q", exp.Name) 305 | } 306 | -------------------------------------------------------------------------------- /yaml.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/mitchellh/go-wordwrap" 9 | "github.com/spf13/pflag" 10 | "golang.org/x/xerrors" 11 | "gopkg.in/yaml.v3" 12 | ) 13 | 14 | var ( 15 | _ yaml.Marshaler = new(OptionSet) 16 | _ yaml.Unmarshaler = new(OptionSet) 17 | ) 18 | 19 | // deepMapNode returns the mapping node at the given path, 20 | // creating it if it doesn't exist. 21 | func deepMapNode(n *yaml.Node, path []string, headComment string) *yaml.Node { 22 | if len(path) == 0 { 23 | return n 24 | } 25 | 26 | // Name is every two nodes. 27 | for i := 0; i < len(n.Content)-1; i += 2 { 28 | if n.Content[i].Value == path[0] { 29 | // Found matching name, recurse. 30 | return deepMapNode(n.Content[i+1], path[1:], headComment) 31 | } 32 | } 33 | 34 | // Not found, create it. 35 | nameNode := yaml.Node{ 36 | Kind: yaml.ScalarNode, 37 | Value: path[0], 38 | HeadComment: headComment, 39 | } 40 | valueNode := yaml.Node{ 41 | Kind: yaml.MappingNode, 42 | } 43 | n.Content = append(n.Content, &nameNode) 44 | n.Content = append(n.Content, &valueNode) 45 | return deepMapNode(&valueNode, path[1:], headComment) 46 | } 47 | 48 | // MarshalYAML converts the option set to a YAML node, that can be 49 | // converted into bytes via yaml.Marshal. 50 | // 51 | // The node is returned to enable post-processing higher up in 52 | // the stack. 53 | // 54 | // It is isomorphic with FromYAML. 55 | func (optSet *OptionSet) MarshalYAML() (any, error) { 56 | root := yaml.Node{ 57 | Kind: yaml.MappingNode, 58 | } 59 | 60 | for _, opt := range *optSet { 61 | if opt.YAML == "" { 62 | continue 63 | } 64 | 65 | defValue := opt.Default 66 | if defValue == "" { 67 | defValue = "" 68 | } 69 | comment := wordwrap.WrapString( 70 | fmt.Sprintf("%s\n(default: %s, type: %s)", opt.Description, defValue, opt.Value.Type()), 71 | 80, 72 | ) 73 | nameNode := yaml.Node{ 74 | Kind: yaml.ScalarNode, 75 | Value: opt.YAML, 76 | HeadComment: comment, 77 | } 78 | 79 | _, isValidator := opt.Value.(interface{ Underlying() pflag.Value }) 80 | var valueNode yaml.Node 81 | if opt.Value == nil { 82 | valueNode = yaml.Node{ 83 | Kind: yaml.ScalarNode, 84 | Value: "null", 85 | } 86 | } else if m, ok := opt.Value.(yaml.Marshaler); ok && !isValidator { 87 | // Validators do a wrap, and should be handled by the else statement. 88 | v, err := m.MarshalYAML() 89 | if err != nil { 90 | return nil, xerrors.Errorf( 91 | "marshal %q: %w", opt.Name, err, 92 | ) 93 | } 94 | valueNode, ok = v.(yaml.Node) 95 | if !ok { 96 | return nil, xerrors.Errorf( 97 | "marshal %q: unexpected underlying type %T", 98 | opt.Name, v, 99 | ) 100 | } 101 | } else { 102 | // The all-other types case. 103 | // 104 | // A bit of a hack, we marshal and then unmarshal to get 105 | // the underlying node. 106 | byt, err := yaml.Marshal(opt.Value) 107 | if err != nil { 108 | return nil, xerrors.Errorf( 109 | "marshal %q: %w", opt.Name, err, 110 | ) 111 | } 112 | 113 | var docNode yaml.Node 114 | err = yaml.Unmarshal(byt, &docNode) 115 | if err != nil { 116 | return nil, xerrors.Errorf( 117 | "unmarshal %q: %w", opt.Name, err, 118 | ) 119 | } 120 | if len(docNode.Content) != 1 { 121 | return nil, xerrors.Errorf( 122 | "unmarshal %q: expected one node, got %d", 123 | opt.Name, len(docNode.Content), 124 | ) 125 | } 126 | 127 | valueNode = *docNode.Content[0] 128 | } 129 | var group []string 130 | for _, g := range opt.Group.Ancestry() { 131 | if g.YAML == "" { 132 | return nil, xerrors.Errorf( 133 | "group yaml name is empty for %q, groups: %+v", 134 | opt.Name, 135 | opt.Group, 136 | ) 137 | } 138 | group = append(group, g.YAML) 139 | } 140 | var groupDesc string 141 | if opt.Group != nil { 142 | groupDesc = wordwrap.WrapString(opt.Group.Description, 80) 143 | } 144 | parentValueNode := deepMapNode( 145 | &root, group, 146 | groupDesc, 147 | ) 148 | parentValueNode.Content = append( 149 | parentValueNode.Content, 150 | &nameNode, 151 | &valueNode, 152 | ) 153 | } 154 | return &root, nil 155 | } 156 | 157 | // mapYAMLNodes converts parent into a map with keys of form "group.subgroup.option" 158 | // and values as the corresponding YAML nodes. 159 | func mapYAMLNodes(parent *yaml.Node) (map[string]*yaml.Node, error) { 160 | if parent.Kind != yaml.MappingNode { 161 | return nil, xerrors.Errorf("expected mapping node, got type %v", parent.Kind) 162 | } 163 | if len(parent.Content)%2 != 0 { 164 | return nil, xerrors.Errorf("expected an even number of k/v pairs, got %d", len(parent.Content)) 165 | } 166 | var ( 167 | key string 168 | m = make(map[string]*yaml.Node, len(parent.Content)/2) 169 | merr error 170 | ) 171 | for i, child := range parent.Content { 172 | if i%2 == 0 { 173 | if child.Kind != yaml.ScalarNode { 174 | // We immediately because the rest of the code is bound to fail 175 | // if we don't know to expect a key or a value. 176 | return nil, xerrors.Errorf("expected scalar node for key, got type %v", child.Kind) 177 | } 178 | key = child.Value 179 | continue 180 | } 181 | 182 | // We don't know if this is a grouped simple option or complex option, 183 | // so we store both "key" and "group.key". Since we're storing pointers, 184 | // the additional memory is of little concern. 185 | m[key] = child 186 | if child.Kind != yaml.MappingNode { 187 | continue 188 | } 189 | 190 | sub, err := mapYAMLNodes(child) 191 | if err != nil { 192 | merr = errors.Join(merr, xerrors.Errorf("mapping node %q: %w", key, err)) 193 | continue 194 | } 195 | for k, v := range sub { 196 | m[key+"."+k] = v 197 | } 198 | } 199 | 200 | return m, nil 201 | } 202 | 203 | func (o *Option) setFromYAMLNode(n *yaml.Node) error { 204 | o.ValueSource = ValueSourceYAML 205 | if um, ok := o.Value.(yaml.Unmarshaler); ok { 206 | return um.UnmarshalYAML(n) 207 | } 208 | 209 | switch n.Kind { 210 | case yaml.ScalarNode: 211 | return o.Value.Set(n.Value) 212 | case yaml.SequenceNode: 213 | // We treat empty values as nil for consistency with other option 214 | // mechanisms. 215 | if len(n.Content) == 0 { 216 | if o.Value == nil { 217 | return nil 218 | } 219 | return o.Value.Set("") 220 | } 221 | return n.Decode(o.Value) 222 | case yaml.MappingNode: 223 | return xerrors.Errorf("mapping nodes must implement yaml.Unmarshaler") 224 | default: 225 | return xerrors.Errorf("unexpected node kind %v", n.Kind) 226 | } 227 | } 228 | 229 | // UnmarshalYAML converts the given YAML node into the option set. 230 | // It is isomorphic with ToYAML. 231 | func (optSet *OptionSet) UnmarshalYAML(rootNode *yaml.Node) error { 232 | // The rootNode will be a DocumentNode if it's read from a file. We do 233 | // not support multiple documents in a single file. 234 | if rootNode.Kind == yaml.DocumentNode { 235 | if len(rootNode.Content) != 1 { 236 | return xerrors.Errorf("expected one node in document, got %d", len(rootNode.Content)) 237 | } 238 | rootNode = rootNode.Content[0] 239 | } 240 | 241 | yamlNodes, err := mapYAMLNodes(rootNode) 242 | if err != nil { 243 | return xerrors.Errorf("mapping nodes: %w", err) 244 | } 245 | 246 | matchedNodes := make(map[string]*yaml.Node, len(yamlNodes)) 247 | 248 | var merr error 249 | for i := range *optSet { 250 | opt := &(*optSet)[i] 251 | if opt.YAML == "" { 252 | continue 253 | } 254 | var group []string 255 | for _, g := range opt.Group.Ancestry() { 256 | if g.YAML == "" { 257 | return xerrors.Errorf( 258 | "group yaml name is empty for %q, groups: %+v", 259 | opt.Name, 260 | opt.Group, 261 | ) 262 | } 263 | group = append(group, g.YAML) 264 | delete(yamlNodes, strings.Join(group, ".")) 265 | } 266 | 267 | key := strings.Join(append(group, opt.YAML), ".") 268 | node, ok := yamlNodes[key] 269 | if !ok { 270 | continue 271 | } 272 | 273 | matchedNodes[key] = node 274 | if opt.ValueSource != ValueSourceNone { 275 | continue 276 | } 277 | if err := opt.setFromYAMLNode(node); err != nil { 278 | merr = errors.Join(merr, xerrors.Errorf("setting %q: %w", opt.YAML, err)) 279 | } 280 | } 281 | 282 | // Remove all matched nodes and their descendants from yamlNodes so we 283 | // can accurately report unknown options. 284 | for k := range yamlNodes { 285 | var key string 286 | for _, part := range strings.Split(k, ".") { 287 | if key != "" { 288 | key += "." 289 | } 290 | key += part 291 | if _, ok := matchedNodes[key]; ok { 292 | delete(yamlNodes, k) 293 | } 294 | } 295 | } 296 | for k := range yamlNodes { 297 | merr = errors.Join(merr, xerrors.Errorf("unknown option %q", k)) 298 | } 299 | 300 | return merr 301 | } 302 | -------------------------------------------------------------------------------- /help.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import ( 4 | "bufio" 5 | _ "embed" 6 | "flag" 7 | "fmt" 8 | "os" 9 | "regexp" 10 | "sort" 11 | "strings" 12 | "sync" 13 | "text/tabwriter" 14 | "text/template" 15 | 16 | "github.com/mitchellh/go-wordwrap" 17 | "github.com/muesli/termenv" 18 | "golang.org/x/crypto/ssh/terminal" 19 | "golang.org/x/xerrors" 20 | 21 | "github.com/coder/pretty" 22 | ) 23 | 24 | //go:embed help.tpl 25 | var helpTemplateRaw string 26 | 27 | type optionGroup struct { 28 | Name string 29 | Description string 30 | Options OptionSet 31 | } 32 | 33 | func ttyWidth() int { 34 | width, _, err := terminal.GetSize(0) 35 | if err != nil { 36 | return 80 37 | } 38 | return width 39 | } 40 | 41 | // wrapTTY wraps a string to the width of the terminal, or 80 no terminal 42 | // is detected. 43 | func wrapTTY(s string) string { 44 | return wordwrap.WrapString(s, uint(ttyWidth())) 45 | } 46 | 47 | var ( 48 | helpColorProfile termenv.Profile 49 | helpColorOnce sync.Once 50 | ) 51 | 52 | // Color returns a color for the given string. 53 | func helpColor(s string) termenv.Color { 54 | helpColorOnce.Do(func() { 55 | helpColorProfile = termenv.NewOutput(os.Stdout).ColorProfile() 56 | if flag.Lookup("test.v") != nil { 57 | // Use a consistent colorless profile in tests so that results 58 | // are deterministic. 59 | helpColorProfile = termenv.Ascii 60 | } 61 | }) 62 | return helpColorProfile.Color(s) 63 | } 64 | 65 | // prettyHeader formats a header string with consistent styling. 66 | // It uppercases the text, adds a colon, and applies the header color. 67 | func prettyHeader(s string) string { 68 | headerFg := pretty.FgColor(helpColor("#337CA0")) 69 | s = strings.ToUpper(s) 70 | txt := pretty.String(s, ":") 71 | headerFg.Format(txt) 72 | return txt.String() 73 | } 74 | 75 | var defaultHelpTemplate = func() *template.Template { 76 | var ( 77 | optionFg = pretty.FgColor( 78 | helpColor("#04A777"), 79 | ) 80 | ) 81 | return template.Must( 82 | template.New("usage").Funcs( 83 | template.FuncMap{ 84 | "wrapTTY": func(s string) string { 85 | return wrapTTY(s) 86 | }, 87 | "trimNewline": func(s string) string { 88 | return strings.TrimSuffix(s, "\n") 89 | }, 90 | "keyword": func(s string) string { 91 | txt := pretty.String(s) 92 | optionFg.Format(txt) 93 | return txt.String() 94 | }, 95 | "prettyHeader": prettyHeader, 96 | "typeHelper": func(opt *Option) string { 97 | switch v := opt.Value.(type) { 98 | case *Enum: 99 | return strings.Join(v.Choices, "|") 100 | case *EnumArray: 101 | return fmt.Sprintf("[%s]", strings.Join(v.Choices, "|")) 102 | default: 103 | return v.Type() 104 | } 105 | }, 106 | "joinStrings": func(s []string) string { 107 | return strings.Join(s, ", ") 108 | }, 109 | "indent": func(body string, spaces int) string { 110 | twidth := ttyWidth() 111 | 112 | spacing := strings.Repeat(" ", spaces) 113 | 114 | wrapLim := twidth - len(spacing) 115 | body = wordwrap.WrapString(body, uint(wrapLim)) 116 | 117 | sc := bufio.NewScanner(strings.NewReader(body)) 118 | 119 | var sb strings.Builder 120 | for sc.Scan() { 121 | // Remove existing indent, if any. 122 | // line = strings.TrimSpace(line) 123 | // Use spaces so we can easily calculate wrapping. 124 | _, _ = sb.WriteString(spacing) 125 | _, _ = sb.Write(sc.Bytes()) 126 | _, _ = sb.WriteString("\n") 127 | } 128 | return sb.String() 129 | }, 130 | "rootCommandName": func(cmd *Command) string { 131 | return strings.Split(cmd.FullName(), " ")[0] 132 | }, 133 | "formatSubcommand": func(cmd *Command) string { 134 | // Minimize padding by finding the longest neighboring name. 135 | maxNameLength := len(cmd.Name()) 136 | if parent := cmd.Parent; parent != nil { 137 | for _, c := range parent.Children { 138 | if len(c.Name()) > maxNameLength { 139 | maxNameLength = len(c.Name()) 140 | } 141 | } 142 | } 143 | 144 | var sb strings.Builder 145 | _, _ = fmt.Fprintf( 146 | &sb, "%s%s%s", 147 | strings.Repeat(" ", 4), cmd.Name(), strings.Repeat(" ", maxNameLength-len(cmd.Name())+4), 148 | ) 149 | 150 | // This is the point at which indentation begins if there's a 151 | // next line. 152 | descStart := sb.Len() 153 | 154 | twidth := ttyWidth() 155 | 156 | for i, line := range strings.Split( 157 | wordwrap.WrapString(cmd.Short, uint(twidth-descStart)), "\n", 158 | ) { 159 | if i > 0 { 160 | _, _ = sb.WriteString(strings.Repeat(" ", descStart)) 161 | } 162 | _, _ = sb.WriteString(line) 163 | _, _ = sb.WriteString("\n") 164 | } 165 | 166 | return sb.String() 167 | }, 168 | "envName": func(opt Option) string { 169 | if opt.Env == "" { 170 | return "" 171 | } 172 | return opt.Env 173 | }, 174 | "flagName": func(opt Option) string { 175 | return opt.Flag 176 | }, 177 | 178 | "isDeprecated": func(opt Option) bool { 179 | return len(opt.UseInstead) > 0 180 | }, 181 | "useInstead": func(opt Option) string { 182 | var sb strings.Builder 183 | for i, s := range opt.UseInstead { 184 | if i > 0 { 185 | if i == len(opt.UseInstead)-1 { 186 | _, _ = sb.WriteString(" and ") 187 | } else { 188 | _, _ = sb.WriteString(", ") 189 | } 190 | } 191 | if s.Flag != "" { 192 | _, _ = sb.WriteString("--") 193 | _, _ = sb.WriteString(s.Flag) 194 | } else if s.FlagShorthand != "" { 195 | _, _ = sb.WriteString("-") 196 | _, _ = sb.WriteString(s.FlagShorthand) 197 | } else if s.Env != "" { 198 | _, _ = sb.WriteString("$") 199 | _, _ = sb.WriteString(s.Env) 200 | } else { 201 | _, _ = sb.WriteString(s.Name) 202 | } 203 | } 204 | return sb.String() 205 | }, 206 | "formatGroupDescription": func(s string) string { 207 | s = strings.ReplaceAll(s, "\n", "") 208 | s = s + "\n" 209 | s = wrapTTY(s) 210 | return s 211 | }, 212 | "visibleChildren": func(cmd *Command) []*Command { 213 | return filterSlice(cmd.Children, func(c *Command) bool { 214 | return !c.Hidden 215 | }) 216 | }, 217 | "optionGroups": func(cmd *Command) []optionGroup { 218 | groups := []optionGroup{{ 219 | // Default group. 220 | Name: "", 221 | Description: "", 222 | }} 223 | 224 | // Sort options lexicographically. 225 | sort.Slice(cmd.Options, func(i, j int) bool { 226 | return cmd.Options[i].Name < cmd.Options[j].Name 227 | }) 228 | 229 | optionLoop: 230 | for _, opt := range cmd.Options { 231 | if opt.Hidden { 232 | continue 233 | } 234 | 235 | if len(opt.Group.Ancestry()) == 0 { 236 | // Just add option to default group. 237 | groups[0].Options = append(groups[0].Options, opt) 238 | continue 239 | } 240 | 241 | groupName := opt.Group.FullName() 242 | 243 | for i, foundGroup := range groups { 244 | if foundGroup.Name != groupName { 245 | continue 246 | } 247 | groups[i].Options = append(groups[i].Options, opt) 248 | continue optionLoop 249 | } 250 | 251 | groups = append(groups, optionGroup{ 252 | Name: groupName, 253 | Description: opt.Group.Description, 254 | Options: OptionSet{opt}, 255 | }) 256 | } 257 | sort.Slice(groups, func(i, j int) bool { 258 | // Sort groups lexicographically. 259 | return groups[i].Name < groups[j].Name 260 | }) 261 | 262 | return filterSlice(groups, func(g optionGroup) bool { 263 | return len(g.Options) > 0 264 | }) 265 | }, 266 | }, 267 | ).Parse(helpTemplateRaw), 268 | ) 269 | }() 270 | 271 | func filterSlice[T any](s []T, f func(T) bool) []T { 272 | var r []T 273 | for _, v := range s { 274 | if f(v) { 275 | r = append(r, v) 276 | } 277 | } 278 | return r 279 | } 280 | 281 | // newLineLimiter makes working with Go templates more bearable. Without this, 282 | // modifying the template is a slow toil of counting newlines and constantly 283 | // checking that a change to one command's help doesn't break another. 284 | type newlineLimiter struct { 285 | // w is not an interface since we call WriteRune byte-wise, 286 | // and the devirtualization overhead is significant. 287 | w *bufio.Writer 288 | limit int 289 | 290 | newLineCounter int 291 | } 292 | 293 | // isSpace is a based on unicode.IsSpace, but only checks ASCII characters. 294 | func isSpace(b byte) bool { 295 | switch b { 296 | case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0: 297 | return true 298 | } 299 | return false 300 | } 301 | 302 | func (lm *newlineLimiter) Write(p []byte) (int, error) { 303 | for _, b := range p { 304 | switch { 305 | case b == '\r': 306 | // Carriage returns can sneak into `help.tpl` when `git clone` 307 | // is configured to automatically convert line endings. 308 | continue 309 | case b == '\n': 310 | lm.newLineCounter++ 311 | if lm.newLineCounter > lm.limit { 312 | continue 313 | } 314 | case !isSpace(b): 315 | lm.newLineCounter = 0 316 | } 317 | err := lm.w.WriteByte(b) 318 | if err != nil { 319 | return 0, err 320 | } 321 | } 322 | return len(p), nil 323 | } 324 | 325 | var usageWantsArgRe = regexp.MustCompile(`<.*>`) 326 | 327 | type UnknownSubcommandError struct { 328 | Args []string 329 | } 330 | 331 | func (e *UnknownSubcommandError) Error() string { 332 | return fmt.Sprintf("unknown subcommand %q", strings.Join(e.Args, " ")) 333 | } 334 | 335 | // DefaultHelpFn returns a function that generates usage (help) 336 | // output for a given command. 337 | func DefaultHelpFn() HandlerFunc { 338 | return func(inv *Invocation) error { 339 | // We use stdout for help and not stderr since there's no straightforward 340 | // way to distinguish between a user error and a help request. 341 | // 342 | // We buffer writes to stdout because the newlineLimiter writes one 343 | // rune at a time. 344 | outBuf := bufio.NewWriter(inv.Stdout) 345 | out := newlineLimiter{w: outBuf, limit: 2} 346 | tabwriter := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) 347 | err := defaultHelpTemplate.Execute(tabwriter, inv.Command) 348 | if err != nil { 349 | return xerrors.Errorf("execute template: %w", err) 350 | } 351 | err = tabwriter.Flush() 352 | if err != nil { 353 | return err 354 | } 355 | err = outBuf.Flush() 356 | if err != nil { 357 | return err 358 | } 359 | if len(inv.Args) > 0 && !usageWantsArgRe.MatchString(inv.Command.Use) { 360 | _, _ = fmt.Fprintf(inv.Stderr, "---\nerror: unknown subcommand %q\n", inv.Args[0]) 361 | } 362 | if len(inv.Args) > 0 { 363 | // Return an error so that exit status is non-zero when 364 | // a subcommand is not found. 365 | return &UnknownSubcommandError{Args: inv.Args} 366 | } 367 | return nil 368 | } 369 | } 370 | -------------------------------------------------------------------------------- /completion_test.go: -------------------------------------------------------------------------------- 1 | package serpent_test 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "testing" 10 | 11 | serpent "github.com/coder/serpent" 12 | "github.com/coder/serpent/completion" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestCompletion(t *testing.T) { 17 | t.Parallel() 18 | 19 | cmd := func() *serpent.Command { return sampleCommand(t) } 20 | 21 | t.Run("SubcommandList", func(t *testing.T) { 22 | t.Parallel() 23 | i := cmd().Invoke("") 24 | i.Environ.Set(serpent.CompletionModeEnv, "1") 25 | io := fakeIO(i) 26 | err := i.Run() 27 | require.NoError(t, err) 28 | require.Equal(t, "altfile\nfile\nrequired-flag\ntoupper\n", io.Stdout.String()) 29 | }) 30 | 31 | t.Run("SubcommandNoPartial", func(t *testing.T) { 32 | t.Parallel() 33 | i := cmd().Invoke("f") 34 | i.Environ.Set(serpent.CompletionModeEnv, "1") 35 | io := fakeIO(i) 36 | err := i.Run() 37 | require.NoError(t, err) 38 | require.Equal(t, "altfile\nfile\nrequired-flag\ntoupper\n", io.Stdout.String()) 39 | }) 40 | 41 | t.Run("SubcommandComplete", func(t *testing.T) { 42 | t.Parallel() 43 | i := cmd().Invoke("required-flag") 44 | i.Environ.Set(serpent.CompletionModeEnv, "1") 45 | io := fakeIO(i) 46 | err := i.Run() 47 | require.NoError(t, err) 48 | require.Equal(t, "required-flag\n", io.Stdout.String()) 49 | }) 50 | 51 | t.Run("ListFlags", func(t *testing.T) { 52 | t.Parallel() 53 | i := cmd().Invoke("required-flag", "-") 54 | i.Environ.Set(serpent.CompletionModeEnv, "1") 55 | io := fakeIO(i) 56 | err := i.Run() 57 | require.NoError(t, err) 58 | require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-enum-array\n--req-string\n", io.Stdout.String()) 59 | }) 60 | 61 | t.Run("ListFlagsAfterArg", func(t *testing.T) { 62 | t.Parallel() 63 | i := cmd().Invoke("altfile", "-") 64 | i.Environ.Set(serpent.CompletionModeEnv, "1") 65 | io := fakeIO(i) 66 | err := i.Run() 67 | require.NoError(t, err) 68 | require.Equal(t, "doesntexist.go\n--extra\n", io.Stdout.String()) 69 | }) 70 | 71 | t.Run("FlagExhaustive", func(t *testing.T) { 72 | t.Parallel() 73 | i := cmd().Invoke("required-flag", "--req-bool", "--req-string", "foo bar", "--req-array", "asdf", "--req-array", "qwerty", "-") 74 | i.Environ.Set(serpent.CompletionModeEnv, "1") 75 | io := fakeIO(i) 76 | err := i.Run() 77 | require.NoError(t, err) 78 | require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) 79 | }) 80 | 81 | t.Run("FlagShorthand", func(t *testing.T) { 82 | t.Parallel() 83 | i := cmd().Invoke("required-flag", "-b", "-s", "foo bar", "-a", "asdf", "-") 84 | i.Environ.Set(serpent.CompletionModeEnv, "1") 85 | io := fakeIO(i) 86 | err := i.Run() 87 | require.NoError(t, err) 88 | require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) 89 | }) 90 | 91 | t.Run("NoOptDefValueFlag", func(t *testing.T) { 92 | t.Parallel() 93 | i := cmd().Invoke("--verbose", "-") 94 | i.Environ.Set(serpent.CompletionModeEnv, "1") 95 | io := fakeIO(i) 96 | err := i.Run() 97 | require.NoError(t, err) 98 | require.Equal(t, "--prefix\n", io.Stdout.String()) 99 | }) 100 | 101 | t.Run("EnumOK", func(t *testing.T) { 102 | t.Parallel() 103 | i := cmd().Invoke("required-flag", "--req-enum", "") 104 | i.Environ.Set(serpent.CompletionModeEnv, "1") 105 | io := fakeIO(i) 106 | err := i.Run() 107 | require.NoError(t, err) 108 | require.Equal(t, "foo\nbar\nqux\n", io.Stdout.String()) 109 | }) 110 | 111 | t.Run("EnumEqualsOK", func(t *testing.T) { 112 | t.Parallel() 113 | i := cmd().Invoke("required-flag", "--req-enum", "--req-enum=") 114 | i.Environ.Set(serpent.CompletionModeEnv, "1") 115 | io := fakeIO(i) 116 | err := i.Run() 117 | require.NoError(t, err) 118 | require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) 119 | }) 120 | 121 | t.Run("EnumEqualsBeginQuotesOK", func(t *testing.T) { 122 | t.Parallel() 123 | i := cmd().Invoke("required-flag", "--req-enum", "--req-enum=\"") 124 | i.Environ.Set(serpent.CompletionModeEnv, "1") 125 | io := fakeIO(i) 126 | err := i.Run() 127 | require.NoError(t, err) 128 | require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) 129 | }) 130 | 131 | t.Run("EnumArrayOK", func(t *testing.T) { 132 | t.Parallel() 133 | i := cmd().Invoke("required-flag", "--req-enum-array", "") 134 | i.Environ.Set(serpent.CompletionModeEnv, "1") 135 | io := fakeIO(i) 136 | err := i.Run() 137 | require.NoError(t, err) 138 | require.Equal(t, "foo\nbar\nqux\n", io.Stdout.String()) 139 | }) 140 | 141 | t.Run("EnumArrayEqualsOK", func(t *testing.T) { 142 | t.Parallel() 143 | i := cmd().Invoke("required-flag", "--req-enum-array=") 144 | i.Environ.Set(serpent.CompletionModeEnv, "1") 145 | io := fakeIO(i) 146 | err := i.Run() 147 | require.NoError(t, err) 148 | require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) 149 | }) 150 | 151 | t.Run("EnumArrayEqualsBeginQuotesOK", func(t *testing.T) { 152 | t.Parallel() 153 | i := cmd().Invoke("required-flag", "--req-enum-array=\"") 154 | i.Environ.Set(serpent.CompletionModeEnv, "1") 155 | io := fakeIO(i) 156 | err := i.Run() 157 | require.NoError(t, err) 158 | require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) 159 | }) 160 | 161 | } 162 | 163 | func TestFileCompletion(t *testing.T) { 164 | t.Parallel() 165 | 166 | cmd := func() *serpent.Command { return sampleCommand(t) } 167 | 168 | t.Run("DirOK", func(t *testing.T) { 169 | t.Parallel() 170 | tempDir := t.TempDir() 171 | i := cmd().Invoke("file", tempDir) 172 | i.Environ.Set(serpent.CompletionModeEnv, "1") 173 | io := fakeIO(i) 174 | err := i.Run() 175 | require.NoError(t, err) 176 | require.Equal(t, fmt.Sprintf("%s%c\n", tempDir, os.PathSeparator), io.Stdout.String()) 177 | }) 178 | 179 | t.Run("EmptyDirOK", func(t *testing.T) { 180 | t.Parallel() 181 | tempDir := t.TempDir() + string(os.PathSeparator) 182 | i := cmd().Invoke("file", tempDir) 183 | i.Environ.Set(serpent.CompletionModeEnv, "1") 184 | io := fakeIO(i) 185 | err := i.Run() 186 | require.NoError(t, err) 187 | require.Equal(t, "", io.Stdout.String()) 188 | }) 189 | 190 | cases := []struct { 191 | name string 192 | realPath string 193 | paths []string 194 | }{ 195 | { 196 | name: "CurDirOK", 197 | realPath: ".", 198 | paths: []string{"", "./", "././"}, 199 | }, 200 | { 201 | name: "PrevDirOK", 202 | realPath: "..", 203 | paths: []string{"../", ".././"}, 204 | }, 205 | { 206 | name: "RootOK", 207 | realPath: "/", 208 | paths: []string{"/", "/././"}, 209 | }, 210 | } 211 | for _, tc := range cases { 212 | tc := tc 213 | t.Run(tc.name, func(t *testing.T) { 214 | t.Parallel() 215 | for _, path := range tc.paths { 216 | i := cmd().Invoke("file", path) 217 | i.Environ.Set(serpent.CompletionModeEnv, "1") 218 | io := fakeIO(i) 219 | err := i.Run() 220 | require.NoError(t, err) 221 | output := strings.Split(io.Stdout.String(), "\n") 222 | output = output[:len(output)-1] 223 | for _, str := range output { 224 | if strings.HasSuffix(str, string(os.PathSeparator)) { 225 | require.DirExists(t, str) 226 | } else { 227 | require.FileExists(t, str) 228 | } 229 | } 230 | files, err := os.ReadDir(tc.realPath) 231 | require.NoError(t, err) 232 | require.Equal(t, len(files), len(output)) 233 | } 234 | }) 235 | } 236 | } 237 | 238 | func TestCompletionInstall(t *testing.T) { 239 | t.Parallel() 240 | 241 | t.Run("InstallingNew", func(t *testing.T) { 242 | dir := t.TempDir() 243 | path := filepath.Join(dir, "fake.sh") 244 | shell := &fakeShell{baseInstallDir: dir, programName: "fake"} 245 | 246 | err := completion.InstallShellCompletion(shell) 247 | require.NoError(t, err) 248 | contents, err := os.ReadFile(path) 249 | require.NoError(t, err) 250 | require.Equal(t, "# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n", string(contents)) 251 | }) 252 | 253 | cases := []struct { 254 | name string 255 | input []byte 256 | expected []byte 257 | errMsg string 258 | }{ 259 | { 260 | name: "InstallingAppend", 261 | input: []byte("FAKE_SCRIPT"), 262 | expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), 263 | }, 264 | { 265 | name: "InstallReplaceBeginning", 266 | input: []byte("# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), 267 | expected: []byte("# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), 268 | }, 269 | { 270 | name: "InstallReplaceMiddle", 271 | input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), 272 | expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), 273 | }, 274 | { 275 | name: "InstallReplaceEnd", 276 | input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), 277 | expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), 278 | }, 279 | { 280 | name: "InstallNoFooter", 281 | input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n"), 282 | errMsg: "missing completion footer", 283 | }, 284 | { 285 | name: "InstallNoHeader", 286 | input: []byte("OLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), 287 | errMsg: "missing completion header", 288 | }, 289 | { 290 | name: "InstallBadOrder", 291 | input: []byte("# ============ END fake COMPLETION ==============\nFAKE_COMPLETION\n# ============ BEGIN fake COMPLETION =============="), 292 | errMsg: "header after footer", 293 | }, 294 | } 295 | 296 | for _, tc := range cases { 297 | tc := tc 298 | t.Run(tc.name, func(t *testing.T) { 299 | dir := t.TempDir() 300 | path := filepath.Join(dir, "fake.sh") 301 | err := os.WriteFile(path, tc.input, 0o644) 302 | require.NoError(t, err) 303 | 304 | shell := &fakeShell{baseInstallDir: dir, programName: "fake"} 305 | err = completion.InstallShellCompletion(shell) 306 | if tc.errMsg != "" { 307 | require.ErrorContains(t, err, tc.errMsg) 308 | return 309 | } else { 310 | require.NoError(t, err) 311 | contents, err := os.ReadFile(path) 312 | require.NoError(t, err) 313 | require.Equal(t, tc.expected, contents) 314 | } 315 | }) 316 | } 317 | } 318 | 319 | type fakeShell struct { 320 | baseInstallDir string 321 | programName string 322 | } 323 | 324 | func (f *fakeShell) ProgramName() string { 325 | return f.programName 326 | } 327 | 328 | var _ completion.Shell = &fakeShell{} 329 | 330 | func (f *fakeShell) InstallPath() (string, error) { 331 | return filepath.Join(f.baseInstallDir, "fake.sh"), nil 332 | } 333 | 334 | func (f *fakeShell) Name() string { 335 | return "Fake" 336 | } 337 | 338 | func (f *fakeShell) WriteCompletion(w io.Writer) error { 339 | _, err := w.Write([]byte("\nFAKE_COMPLETION\n")) 340 | return err 341 | } 342 | -------------------------------------------------------------------------------- /option.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "os" 7 | "slices" 8 | "strings" 9 | 10 | "github.com/hashicorp/go-multierror" 11 | "github.com/spf13/pflag" 12 | "golang.org/x/xerrors" 13 | ) 14 | 15 | type ValueSource string 16 | 17 | const ( 18 | ValueSourceNone ValueSource = "" 19 | ValueSourceFlag ValueSource = "flag" 20 | ValueSourceEnv ValueSource = "env" 21 | ValueSourceYAML ValueSource = "yaml" 22 | ValueSourceDefault ValueSource = "default" 23 | ) 24 | 25 | var valueSourcePriority = []ValueSource{ 26 | ValueSourceFlag, 27 | ValueSourceEnv, 28 | ValueSourceYAML, 29 | ValueSourceDefault, 30 | ValueSourceNone, 31 | } 32 | 33 | // Option is a configuration option for a CLI application. 34 | type Option struct { 35 | Name string `json:"name,omitempty"` 36 | Description string `json:"description,omitempty"` 37 | // Required means this value must be set by some means. It requires 38 | // `ValueSource != ValueSourceNone` 39 | // If `Default` is set, then `Required` is ignored. 40 | Required bool `json:"required,omitempty"` 41 | 42 | // Flag is the long name of the flag used to configure this option. If unset, 43 | // flag configuring is disabled. 44 | Flag string `json:"flag,omitempty"` 45 | // FlagShorthand is the one-character shorthand for the flag. If unset, no 46 | // shorthand is used. 47 | FlagShorthand string `json:"flag_shorthand,omitempty"` 48 | 49 | // Env is the environment variable used to configure this option. If unset, 50 | // environment configuring is disabled. 51 | Env string `json:"env,omitempty"` 52 | 53 | // YAML is the YAML key used to configure this option. If unset, YAML 54 | // configuring is disabled. 55 | YAML string `json:"yaml,omitempty"` 56 | 57 | // Default is parsed into Value if set. 58 | Default string `json:"default,omitempty"` 59 | // Value includes the types listed in values.go. 60 | Value pflag.Value `json:"value,omitempty"` 61 | 62 | // Annotations enable extensions to serpent higher up in the stack. It's useful for 63 | // help formatting and documentation generation. 64 | Annotations Annotations `json:"annotations,omitempty"` 65 | 66 | // Group is a group hierarchy that helps organize this option in help, configs 67 | // and other documentation. 68 | Group *Group `json:"group,omitempty"` 69 | 70 | // UseInstead is a list of options that should be used instead of this one. 71 | // The field is used to generate a deprecation warning. 72 | UseInstead []Option `json:"use_instead,omitempty"` 73 | 74 | Hidden bool `json:"hidden,omitempty"` 75 | 76 | ValueSource ValueSource `json:"value_source,omitempty"` 77 | 78 | CompletionHandler CompletionHandlerFunc `json:"-"` 79 | } 80 | 81 | // optionNoMethods is just a wrapper around Option so we can defer to the 82 | // default json.Unmarshaler behavior. 83 | type optionNoMethods Option 84 | 85 | func (o *Option) UnmarshalJSON(data []byte) error { 86 | // If an option has no values, we have no idea how to unmarshal it. 87 | // So just discard the json data. 88 | if o.Value == nil { 89 | o.Value = &DiscardValue 90 | } 91 | 92 | return json.Unmarshal(data, (*optionNoMethods)(o)) 93 | } 94 | 95 | func (o Option) YAMLPath() string { 96 | if o.YAML == "" { 97 | return "" 98 | } 99 | var gs []string 100 | for _, g := range o.Group.Ancestry() { 101 | gs = append(gs, g.YAML) 102 | } 103 | return strings.Join(append(gs, o.YAML), ".") 104 | } 105 | 106 | // OptionSet is a group of options that can be applied to a command. 107 | type OptionSet []Option 108 | 109 | // UnmarshalJSON implements json.Unmarshaler for OptionSets. Options have an 110 | // interface Value type that cannot handle unmarshalling because the types cannot 111 | // be inferred. Since it is a slice, instantiating the Options first does not 112 | // help. 113 | // 114 | // However, we typically do instantiate the slice to have the correct types. 115 | // So this unmarshaller will attempt to find the named option in the existing 116 | // set, if it cannot, the value is discarded. If the option exists, the value 117 | // is unmarshalled into the existing option, and replaces the existing option. 118 | // 119 | // The value is discarded if it's type cannot be inferred. This behavior just 120 | // feels "safer", although it should never happen if the correct option set 121 | // is passed in. The situation where this could occur is if a client and server 122 | // are on different versions with different options. 123 | func (optSet *OptionSet) UnmarshalJSON(data []byte) error { 124 | dec := json.NewDecoder(bytes.NewBuffer(data)) 125 | // Should be a json array, so consume the starting open bracket. 126 | t, err := dec.Token() 127 | if err != nil { 128 | return xerrors.Errorf("read array open bracket: %w", err) 129 | } 130 | if t != json.Delim('[') { 131 | return xerrors.Errorf("expected array open bracket, got %q", t) 132 | } 133 | 134 | // As long as json elements exist, consume them. The counter is used for 135 | // better errors. 136 | var i int 137 | OptionSetDecodeLoop: 138 | for dec.More() { 139 | var opt Option 140 | // jValue is a placeholder value that allows us to capture the 141 | // raw json for the value to attempt to unmarshal later. 142 | var jValue jsonValue 143 | opt.Value = &jValue 144 | err := dec.Decode(&opt) 145 | if err != nil { 146 | return xerrors.Errorf("decode %d option: %w", i, err) 147 | } 148 | // This counter is used to contextualize errors to show which element of 149 | // the array we failed to decode. It is only used in the error above, as 150 | // if the above works, we can instead use the Option.Name which is more 151 | // descriptive and useful. So increment here for the next decode. 152 | i++ 153 | 154 | // Try to see if the option already exists in the option set. 155 | // If it does, just update the existing option. 156 | for optIndex, have := range *optSet { 157 | if have.Name == opt.Name { 158 | if jValue != nil { 159 | err := json.Unmarshal(jValue, &(*optSet)[optIndex].Value) 160 | if err != nil { 161 | return xerrors.Errorf("decode option %q value: %w", have.Name, err) 162 | } 163 | // Set the opt's value 164 | opt.Value = (*optSet)[optIndex].Value 165 | } else { 166 | // Hopefully the user passed empty values in the option set. There is no easy way 167 | // to tell, and if we do not do this, it breaks json.Marshal if we do it again on 168 | // this new option set. 169 | opt.Value = (*optSet)[optIndex].Value 170 | } 171 | // Override the existing. 172 | (*optSet)[optIndex] = opt 173 | // Go to the next option to decode. 174 | continue OptionSetDecodeLoop 175 | } 176 | } 177 | 178 | // If the option doesn't exist, the value will be discarded. 179 | // We do this because we cannot infer the type of the value. 180 | opt.Value = DiscardValue 181 | *optSet = append(*optSet, opt) 182 | } 183 | 184 | t, err = dec.Token() 185 | if err != nil { 186 | return xerrors.Errorf("read array close bracket: %w", err) 187 | } 188 | if t != json.Delim(']') { 189 | return xerrors.Errorf("expected array close bracket, got %q", t) 190 | } 191 | 192 | return nil 193 | } 194 | 195 | // Add adds the given Options to the OptionSet. 196 | func (optSet *OptionSet) Add(opts ...Option) { 197 | *optSet = append(*optSet, opts...) 198 | } 199 | 200 | // Filter will only return options that match the given filter. (return true) 201 | func (optSet OptionSet) Filter(filter func(opt Option) bool) OptionSet { 202 | cpy := make(OptionSet, 0) 203 | for _, opt := range optSet { 204 | if filter(opt) { 205 | cpy = append(cpy, opt) 206 | } 207 | } 208 | return cpy 209 | } 210 | 211 | // FlagSet returns a pflag.FlagSet for the OptionSet. 212 | func (optSet *OptionSet) FlagSet() *pflag.FlagSet { 213 | if optSet == nil { 214 | return &pflag.FlagSet{} 215 | } 216 | 217 | fs := pflag.NewFlagSet("", pflag.ContinueOnError) 218 | for _, opt := range *optSet { 219 | if opt.Flag == "" { 220 | continue 221 | } 222 | var noOptDefValue string 223 | { 224 | no, ok := opt.Value.(NoOptDefValuer) 225 | if ok { 226 | noOptDefValue = no.NoOptDefValue() 227 | } 228 | } 229 | 230 | val := opt.Value 231 | if val == nil { 232 | val = DiscardValue 233 | } 234 | 235 | fs.AddFlag(&pflag.Flag{ 236 | Name: opt.Flag, 237 | Shorthand: opt.FlagShorthand, 238 | Usage: opt.Description, 239 | Value: val, 240 | DefValue: "", 241 | Changed: false, 242 | Deprecated: "", 243 | NoOptDefVal: noOptDefValue, 244 | Hidden: opt.Hidden, 245 | }) 246 | } 247 | fs.Usage = func() { 248 | _, _ = os.Stderr.WriteString("Override (*FlagSet).Usage() to print help text.\n") 249 | } 250 | return fs 251 | } 252 | 253 | // ParseEnv parses the given environment variables into the OptionSet. 254 | // Use EnvsWithPrefix to filter out prefixes. 255 | func (optSet *OptionSet) ParseEnv(vs []EnvVar) error { 256 | if optSet == nil { 257 | return nil 258 | } 259 | 260 | var merr *multierror.Error 261 | 262 | // We parse environment variables first instead of using a nested loop to 263 | // avoid N*M complexity when there are a lot of options and environment 264 | // variables. 265 | envs := make(map[string]string) 266 | for _, v := range vs { 267 | envs[v.Name] = v.Value 268 | } 269 | 270 | for i, opt := range *optSet { 271 | if opt.Env == "" { 272 | continue 273 | } 274 | 275 | envVal, ok := envs[opt.Env] 276 | if !ok { 277 | // Homebrew strips all environment variables that do not start with `HOMEBREW_`. 278 | // This prevented using brew to invoke the Coder agent, because the environment 279 | // variables to not get passed down. 280 | // 281 | // A customer wanted to use their custom tap inside a workspace, which was failing 282 | // because the agent lacked the environment variables to authenticate with Git. 283 | envVal, ok = envs[`HOMEBREW_`+opt.Env] 284 | } 285 | // Currently, empty values are treated as if the environment variable is 286 | // unset. This behavior is technically not correct as there is now no 287 | // way for a user to change a Default value to an empty string from 288 | // the environment. Unfortunately, we have old configuration files 289 | // that rely on the faulty behavior. 290 | // 291 | // TODO: We should remove this hack in May 2023, when deployments 292 | // have had months to migrate to the new behavior. 293 | if !ok || envVal == "" { 294 | continue 295 | } 296 | 297 | (*optSet)[i].ValueSource = ValueSourceEnv 298 | if err := opt.Value.Set(envVal); err != nil { 299 | merr = multierror.Append( 300 | merr, xerrors.Errorf("parse %q: %w", opt.Name, err), 301 | ) 302 | } 303 | } 304 | 305 | return merr.ErrorOrNil() 306 | } 307 | 308 | // SetDefaults sets the default values for each Option, skipping values 309 | // that already have a value source. 310 | func (optSet *OptionSet) SetDefaults() error { 311 | if optSet == nil { 312 | return nil 313 | } 314 | 315 | var merr *multierror.Error 316 | 317 | // It's common to have multiple options with the same value to 318 | // handle deprecation. We group the options by value so that we 319 | // don't let other options overwrite user input. 320 | groupByValue := make(map[pflag.Value][]*Option) 321 | for i := range *optSet { 322 | opt := &(*optSet)[i] 323 | if opt.Value == nil { 324 | merr = multierror.Append( 325 | merr, 326 | xerrors.Errorf( 327 | "parse %q: no Value field set\nFull opt: %+v", 328 | opt.Name, opt, 329 | ), 330 | ) 331 | continue 332 | } 333 | groupByValue[opt.Value] = append(groupByValue[opt.Value], opt) 334 | } 335 | 336 | // Sorts by value source, then a default value being set. 337 | sortOptionByValueSourcePriorityOrDefault := func(a, b *Option) int { 338 | if a.ValueSource != b.ValueSource { 339 | return slices.Index(valueSourcePriority, a.ValueSource) - slices.Index(valueSourcePriority, b.ValueSource) 340 | } 341 | if a.Default != b.Default { 342 | if a.Default == "" { 343 | return 1 344 | } 345 | if b.Default == "" { 346 | return -1 347 | } 348 | } 349 | return 0 350 | } 351 | for _, opts := range groupByValue { 352 | // Sort the options by priority and whether or not a default is 353 | // set. This won't affect the value but represents correctness 354 | // from whence the value originated. 355 | slices.SortFunc(opts, sortOptionByValueSourcePriorityOrDefault) 356 | 357 | // If the first option has a value source, then we don't need to 358 | // set the default, but mark the source for all options. 359 | if opts[0].ValueSource != ValueSourceNone { 360 | for _, opt := range opts[1:] { 361 | opt.ValueSource = opts[0].ValueSource 362 | } 363 | continue 364 | } 365 | 366 | var optWithDefault *Option 367 | for _, opt := range opts { 368 | if opt.Default == "" { 369 | continue 370 | } 371 | if optWithDefault != nil && optWithDefault.Default != opt.Default { 372 | merr = multierror.Append( 373 | merr, 374 | xerrors.Errorf( 375 | "parse %q: multiple defaults set for the same value: %q and %q (%q)", 376 | opt.Name, opt.Default, optWithDefault.Default, optWithDefault.Name, 377 | ), 378 | ) 379 | continue 380 | } 381 | optWithDefault = opt 382 | } 383 | if optWithDefault == nil { 384 | continue 385 | } 386 | if err := optWithDefault.Value.Set(optWithDefault.Default); err != nil { 387 | merr = multierror.Append( 388 | merr, xerrors.Errorf("parse %q: %w", optWithDefault.Name, err), 389 | ) 390 | } 391 | for _, opt := range opts { 392 | opt.ValueSource = ValueSourceDefault 393 | } 394 | } 395 | 396 | return merr.ErrorOrNil() 397 | } 398 | 399 | // ByName returns the Option with the given name, or nil if no such option 400 | // exists. 401 | func (optSet OptionSet) ByName(name string) *Option { 402 | for i := range optSet { 403 | if optSet[i].Name == name { 404 | return &optSet[i] 405 | } 406 | } 407 | return nil 408 | } 409 | 410 | func (optSet OptionSet) ByFlag(flag string) *Option { 411 | if flag == "" { 412 | return nil 413 | } 414 | for i := range optSet { 415 | opt := &optSet[i] 416 | if opt.Flag == flag { 417 | return opt 418 | } 419 | } 420 | return nil 421 | } 422 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 h1:KHblWIE/KHOwQ6lEbMZt6YpcGve2FEZ1sDtrW1Am5UI= 2 | cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6/go.mod h1:NaoTA7KwopCrnaSb0JXTC0PTp/O/Y83Lndnq0OEV3ZQ= 3 | cloud.google.com/go v0.110.10 h1:LXy9GEO+timppncPIAZoOj3l58LIU9k+kn48AN7IO3Y= 4 | cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk= 5 | cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI= 6 | cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= 7 | cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= 8 | cloud.google.com/go/logging v1.8.1 h1:26skQWPeYhvIasWKm48+Eq7oUqdcdbwsCVwz5Ys0FvU= 9 | cloud.google.com/go/logging v1.8.1/go.mod h1:TJjR+SimHwuC8MZ9cjByQulAMgni+RkXeI3wwctHJEI= 10 | cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg= 11 | cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI= 12 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= 13 | github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= 14 | github.com/charmbracelet/lipgloss v0.8.0 h1:IS00fk4XAHcf8uZKc3eHeMUTCxUH6NkaTrdyCQk84RU= 15 | github.com/charmbracelet/lipgloss v0.8.0/go.mod h1:p4eYUZZJ/0oXTuCQKFF8mqyKCz0ja6y+7DniDDw5KKU= 16 | github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 h1:3A0ES21Ke+FxEM8CXx9n47SZOKOpgSE1bbJzlE4qPVs= 17 | github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0/go.mod h1:5UuS2Ts+nTToAMeOjNlnHFkPahrtDkmpydBen/3wgZc= 18 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 19 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 20 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 21 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= 22 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 23 | github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= 24 | github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 25 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 26 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 27 | github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= 28 | github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 29 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 30 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 31 | github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 32 | github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= 33 | github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= 34 | github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= 35 | github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= 36 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 37 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 38 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 39 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 40 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 41 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 42 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 43 | github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= 44 | github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= 45 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 46 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 47 | github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= 48 | github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= 49 | github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= 50 | github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= 51 | github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= 52 | github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= 53 | github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= 54 | github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= 55 | github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= 56 | github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= 57 | github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= 58 | github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= 59 | github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= 60 | github.com/pion/transport/v2 v2.0.0 h1:bsMYyqHCbkvHwj+eNCFBuxtlKndKfyGI2vaQmM3fIE4= 61 | github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= 62 | github.com/pion/udp v0.1.4 h1:OowsTmu1Od3sD6i3fQUJxJn2fEvJO6L1TidgadtbTI8= 63 | github.com/pion/udp v0.1.4/go.mod h1:G8LDo56HsFwC24LIcnT4YIDU5qcB6NepqqjP0keL2us= 64 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 65 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 66 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= 67 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 68 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 69 | github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= 70 | github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= 71 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 72 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= 73 | github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= 74 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 75 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 76 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 77 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 78 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 79 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 80 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 81 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 82 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 83 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 84 | github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc= 85 | github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU= 86 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 87 | go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= 88 | go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= 89 | go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE= 90 | go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319EUrDVLrt7jqt8= 91 | go.opentelemetry.io/otel/sdk v1.19.0 h1:6USY6zH+L8uMH8L3t1enZPR3WFEmSTADlqldyHtJi3o= 92 | go.opentelemetry.io/otel/sdk v1.19.0/go.mod h1:NedEbbS4w3C6zElbLdPJKOpJQOrGUJ+GfzpjUvI0v1A= 93 | go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= 94 | go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= 95 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 96 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 97 | golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= 98 | golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= 99 | golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE= 100 | golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= 101 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 102 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 103 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 104 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 105 | golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= 106 | golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= 107 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 108 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 109 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 110 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 111 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 112 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 113 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 114 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 115 | golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 116 | golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 117 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 118 | golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= 119 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 120 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 121 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 122 | golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 123 | golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= 124 | golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= 125 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 126 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 127 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 128 | golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 129 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 130 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 131 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 132 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 133 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 134 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 135 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= 136 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= 137 | google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= 138 | google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17/go.mod h1:J7XzRzVy1+IPwWHZUzoD0IccYZIrXILAQpc+Qy9CMhY= 139 | google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 h1:JpwMPBpFN3uKhdaekDpiNlImDdkUAyiJ6ez/uxGaUSo= 140 | google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17/go.mod h1:0xJLfVdJqpAPl8tDg1ujOCGzx6LFLttXT5NhllGOXY4= 141 | google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f h1:ultW7fxlIvee4HYrtnaRPon9HpEgFk5zYpmfMgtKB5I= 142 | google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f/go.mod h1:L9KNLi232K1/xB6f7AlSX692koaRnKaWSR0stBki0Yc= 143 | google.golang.org/grpc v1.61.0 h1:TOvOcuXn30kRao+gfcvsebNEa5iZIiLkisYEkf7R7o0= 144 | google.golang.org/grpc v1.61.0/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= 145 | google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= 146 | google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 147 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 148 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 149 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 150 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 151 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 152 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 153 | -------------------------------------------------------------------------------- /values.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import ( 4 | "encoding/csv" 5 | "encoding/json" 6 | "fmt" 7 | "net" 8 | "net/url" 9 | "reflect" 10 | "regexp" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "github.com/spf13/pflag" 16 | str2duration "github.com/xhit/go-str2duration/v2" 17 | "golang.org/x/xerrors" 18 | "gopkg.in/yaml.v3" 19 | ) 20 | 21 | // NoOptDefValuer describes behavior when no 22 | // option is passed into the flag. 23 | // 24 | // This is useful for boolean or otherwise binary flags. 25 | type NoOptDefValuer interface { 26 | NoOptDefValue() string 27 | } 28 | 29 | // Validator is a wrapper around a pflag.Value that allows for validation 30 | // of the value after or before it has been set. 31 | type Validator[T pflag.Value] struct { 32 | Value T 33 | // validate is called after the value is set. 34 | validate func(T) error 35 | } 36 | 37 | func Validate[T pflag.Value](opt T, validate func(value T) error) *Validator[T] { 38 | return &Validator[T]{Value: opt, validate: validate} 39 | } 40 | 41 | func (i *Validator[T]) String() string { 42 | return i.Value.String() 43 | } 44 | 45 | func (i *Validator[T]) Set(input string) error { 46 | err := i.Value.Set(input) 47 | if err != nil { 48 | return err 49 | } 50 | if i.validate != nil { 51 | err = i.validate(i.Value) 52 | if err != nil { 53 | return err 54 | } 55 | } 56 | return nil 57 | } 58 | 59 | func (i *Validator[T]) Type() string { 60 | return i.Value.Type() 61 | } 62 | 63 | func (i *Validator[T]) MarshalYAML() (interface{}, error) { 64 | m, ok := any(i.Value).(yaml.Marshaler) 65 | if !ok { 66 | return i.Value, nil 67 | } 68 | return m.MarshalYAML() 69 | } 70 | 71 | func (i *Validator[T]) UnmarshalYAML(n *yaml.Node) error { 72 | return n.Decode(i.Value) 73 | } 74 | 75 | func (i *Validator[T]) MarshalJSON() ([]byte, error) { 76 | return json.Marshal(i.Value) 77 | } 78 | 79 | func (i *Validator[T]) UnmarshalJSON(b []byte) error { 80 | return json.Unmarshal(b, i.Value) 81 | } 82 | 83 | func (i *Validator[T]) Underlying() pflag.Value { return i.Value } 84 | 85 | // values.go contains a standard set of value types that can be used as 86 | // Option Values. 87 | 88 | type Int64 int64 89 | 90 | func Int64Of(i *int64) *Int64 { 91 | return (*Int64)(i) 92 | } 93 | 94 | func (i *Int64) Set(s string) error { 95 | ii, err := strconv.ParseInt(s, 10, 64) 96 | *i = Int64(ii) 97 | return err 98 | } 99 | 100 | func (i Int64) Value() int64 { 101 | return int64(i) 102 | } 103 | 104 | func (i Int64) String() string { 105 | return strconv.Itoa(int(i)) 106 | } 107 | 108 | func (Int64) Type() string { 109 | return "int" 110 | } 111 | 112 | type Float64 float64 113 | 114 | func Float64Of(f *float64) *Float64 { 115 | return (*Float64)(f) 116 | } 117 | 118 | func (f *Float64) Set(s string) error { 119 | ff, err := strconv.ParseFloat(s, 64) 120 | *f = Float64(ff) 121 | return err 122 | } 123 | 124 | func (f Float64) Value() float64 { 125 | return float64(f) 126 | } 127 | 128 | func (f Float64) String() string { 129 | return strconv.FormatFloat(float64(f), 'f', -1, 64) 130 | } 131 | 132 | func (Float64) Type() string { 133 | return "float64" 134 | } 135 | 136 | type Bool bool 137 | 138 | func BoolOf(b *bool) *Bool { 139 | return (*Bool)(b) 140 | } 141 | 142 | func (b *Bool) Set(s string) error { 143 | if s == "" { 144 | *b = Bool(false) 145 | return nil 146 | } 147 | bb, err := strconv.ParseBool(s) 148 | *b = Bool(bb) 149 | return err 150 | } 151 | 152 | func (*Bool) NoOptDefValue() string { 153 | return "true" 154 | } 155 | 156 | func (b Bool) String() string { 157 | return strconv.FormatBool(bool(b)) 158 | } 159 | 160 | func (b Bool) Value() bool { 161 | return bool(b) 162 | } 163 | 164 | func (Bool) Type() string { 165 | return "bool" 166 | } 167 | 168 | type String string 169 | 170 | func StringOf(s *string) *String { 171 | return (*String)(s) 172 | } 173 | 174 | func (*String) NoOptDefValue() string { 175 | return "" 176 | } 177 | 178 | func (s *String) Set(v string) error { 179 | *s = String(v) 180 | return nil 181 | } 182 | 183 | func (s String) String() string { 184 | return string(s) 185 | } 186 | 187 | func (s String) Value() string { 188 | return string(s) 189 | } 190 | 191 | func (String) Type() string { 192 | return "string" 193 | } 194 | 195 | var ( 196 | _ pflag.SliceValue = &StringArray{} 197 | _ pflag.Value = &StringArray{} 198 | ) 199 | 200 | // StringArray is a slice of strings that implements pflag.Value and pflag.SliceValue. 201 | type StringArray []string 202 | 203 | func StringArrayOf(ss *[]string) *StringArray { 204 | return (*StringArray)(ss) 205 | } 206 | 207 | func (s *StringArray) Append(v string) error { 208 | *s = append(*s, v) 209 | return nil 210 | } 211 | 212 | func (s *StringArray) Replace(vals []string) error { 213 | *s = vals 214 | return nil 215 | } 216 | 217 | func (s *StringArray) GetSlice() []string { 218 | return *s 219 | } 220 | 221 | func readAsCSV(v string) ([]string, error) { 222 | return csv.NewReader(strings.NewReader(v)).Read() 223 | } 224 | 225 | func writeAsCSV(vals []string) string { 226 | var sb strings.Builder 227 | err := csv.NewWriter(&sb).Write(vals) 228 | if err != nil { 229 | return fmt.Sprintf("error: %s", err) 230 | } 231 | return sb.String() 232 | } 233 | 234 | func (s *StringArray) Set(v string) error { 235 | if v == "" { 236 | *s = nil 237 | return nil 238 | } 239 | ss, err := readAsCSV(v) 240 | if err != nil { 241 | return err 242 | } 243 | *s = append(*s, ss...) 244 | return nil 245 | } 246 | 247 | func (s StringArray) String() string { 248 | return writeAsCSV([]string(s)) 249 | } 250 | 251 | func (s StringArray) Value() []string { 252 | return []string(s) 253 | } 254 | 255 | func (StringArray) Type() string { 256 | return "string-array" 257 | } 258 | 259 | type Duration time.Duration 260 | 261 | func DurationOf(d *time.Duration) *Duration { 262 | return (*Duration)(d) 263 | } 264 | 265 | func (d *Duration) Set(v string) error { 266 | // Try [str2duration.ParseDuration] first, which supports days and weeks. 267 | // If it fails, fall back to [time.ParseDuration] for backward compatibility. 268 | dd, err := str2duration.ParseDuration(v) 269 | if err == nil { 270 | *d = Duration(dd) 271 | return nil 272 | } 273 | 274 | // Fallback to standard [time.ParseDuration]. 275 | dd, err = time.ParseDuration(v) 276 | *d = Duration(dd) 277 | return err 278 | } 279 | 280 | func (d *Duration) Value() time.Duration { 281 | return time.Duration(*d) 282 | } 283 | 284 | func (d *Duration) String() string { 285 | return time.Duration(*d).String() 286 | } 287 | 288 | func (Duration) Type() string { 289 | return "duration" 290 | } 291 | 292 | func (d *Duration) MarshalYAML() (interface{}, error) { 293 | return yaml.Node{ 294 | Kind: yaml.ScalarNode, 295 | Value: d.String(), 296 | }, nil 297 | } 298 | 299 | func (d *Duration) UnmarshalYAML(n *yaml.Node) error { 300 | return d.Set(n.Value) 301 | } 302 | 303 | type URL url.URL 304 | 305 | func URLOf(u *url.URL) *URL { 306 | return (*URL)(u) 307 | } 308 | 309 | func (u *URL) Set(v string) error { 310 | uu, err := url.Parse(v) 311 | if err != nil { 312 | return err 313 | } 314 | *u = URL(*uu) 315 | return nil 316 | } 317 | 318 | func (u *URL) String() string { 319 | uu := url.URL(*u) 320 | return uu.String() 321 | } 322 | 323 | func (u *URL) MarshalYAML() (interface{}, error) { 324 | return yaml.Node{ 325 | Kind: yaml.ScalarNode, 326 | Value: u.String(), 327 | }, nil 328 | } 329 | 330 | func (u *URL) UnmarshalYAML(n *yaml.Node) error { 331 | return u.Set(n.Value) 332 | } 333 | 334 | func (u *URL) MarshalJSON() ([]byte, error) { 335 | return json.Marshal(u.String()) 336 | } 337 | 338 | func (u *URL) UnmarshalJSON(b []byte) error { 339 | var s string 340 | err := json.Unmarshal(b, &s) 341 | if err != nil { 342 | return err 343 | } 344 | return u.Set(s) 345 | } 346 | 347 | func (*URL) Type() string { 348 | return "url" 349 | } 350 | 351 | func (u *URL) Value() *url.URL { 352 | return (*url.URL)(u) 353 | } 354 | 355 | // HostPort is a host:port pair. 356 | type HostPort struct { 357 | Host string 358 | Port string 359 | } 360 | 361 | func (hp *HostPort) Set(v string) error { 362 | if v == "" { 363 | return xerrors.Errorf("must not be empty") 364 | } 365 | var err error 366 | hp.Host, hp.Port, err = net.SplitHostPort(v) 367 | return err 368 | } 369 | 370 | func (hp *HostPort) String() string { 371 | if hp.Host == "" && hp.Port == "" { 372 | return "" 373 | } 374 | // Warning: net.JoinHostPort must be used over concatenation to support 375 | // IPv6 addresses. 376 | return net.JoinHostPort(hp.Host, hp.Port) 377 | } 378 | 379 | func (hp *HostPort) MarshalJSON() ([]byte, error) { 380 | return json.Marshal(hp.String()) 381 | } 382 | 383 | func (hp *HostPort) UnmarshalJSON(b []byte) error { 384 | var s string 385 | err := json.Unmarshal(b, &s) 386 | if err != nil { 387 | return err 388 | } 389 | if s == "" { 390 | hp.Host = "" 391 | hp.Port = "" 392 | return nil 393 | } 394 | return hp.Set(s) 395 | } 396 | 397 | func (hp *HostPort) MarshalYAML() (interface{}, error) { 398 | return yaml.Node{ 399 | Kind: yaml.ScalarNode, 400 | Value: hp.String(), 401 | }, nil 402 | } 403 | 404 | func (hp *HostPort) UnmarshalYAML(n *yaml.Node) error { 405 | return hp.Set(n.Value) 406 | } 407 | 408 | func (*HostPort) Type() string { 409 | return "host:port" 410 | } 411 | 412 | var ( 413 | _ yaml.Marshaler = new(Struct[struct{}]) 414 | _ yaml.Unmarshaler = new(Struct[struct{}]) 415 | ) 416 | 417 | // Struct is a special value type that encodes an arbitrary struct. 418 | // It implements the flag.Value interface, but in general these values should 419 | // only be accepted via config for ergonomics. 420 | // 421 | // The string encoding type is YAML. 422 | type Struct[T any] struct { 423 | Value T 424 | } 425 | 426 | //nolint:revive 427 | func (s *Struct[T]) Set(v string) error { 428 | return yaml.Unmarshal([]byte(v), &s.Value) 429 | } 430 | 431 | //nolint:revive 432 | func (s *Struct[T]) String() string { 433 | byt, err := yaml.Marshal(s.Value) 434 | if err != nil { 435 | return "decode failed: " + err.Error() 436 | } 437 | return string(byt) 438 | } 439 | 440 | // nolint:revive 441 | func (s *Struct[T]) MarshalYAML() (interface{}, error) { 442 | var n yaml.Node 443 | err := n.Encode(s.Value) 444 | if err != nil { 445 | return nil, err 446 | } 447 | return n, nil 448 | } 449 | 450 | // nolint:revive 451 | func (s *Struct[T]) UnmarshalYAML(n *yaml.Node) error { 452 | // HACK: for compatibility with flags, we use nil slices instead of empty 453 | // slices. In most cases, nil slices and empty slices are treated 454 | // the same, so this behavior may be removed at some point. 455 | if typ := reflect.TypeOf(s.Value); typ.Kind() == reflect.Slice && len(n.Content) == 0 { 456 | reflect.ValueOf(&s.Value).Elem().Set(reflect.Zero(typ)) 457 | return nil 458 | } 459 | return n.Decode(&s.Value) 460 | } 461 | 462 | //nolint:revive 463 | func (s *Struct[T]) Type() string { 464 | return fmt.Sprintf("struct[%T]", s.Value) 465 | } 466 | 467 | // nolint:revive 468 | func (s *Struct[T]) MarshalJSON() ([]byte, error) { 469 | return json.Marshal(s.Value) 470 | } 471 | 472 | // nolint:revive 473 | func (s *Struct[T]) UnmarshalJSON(b []byte) error { 474 | return json.Unmarshal(b, &s.Value) 475 | } 476 | 477 | // DiscardValue does nothing but implements the pflag.Value interface. 478 | // It's useful in cases where you want to accept an option, but access the 479 | // underlying value directly instead of through the Option methods. 480 | var DiscardValue discardValue 481 | 482 | type discardValue struct{} 483 | 484 | func (discardValue) Set(string) error { 485 | return nil 486 | } 487 | 488 | func (discardValue) String() string { 489 | return "" 490 | } 491 | 492 | func (discardValue) Type() string { 493 | return "discard" 494 | } 495 | 496 | func (discardValue) UnmarshalJSON([]byte) error { 497 | return nil 498 | } 499 | 500 | // jsonValue is intentionally not exported. It is just used to store the raw JSON 501 | // data for a value to defer it's unmarshal. It implements the pflag.Value to be 502 | // usable in an Option. 503 | type jsonValue json.RawMessage 504 | 505 | func (jsonValue) Set(string) error { 506 | return xerrors.Errorf("json value is read-only") 507 | } 508 | 509 | func (jsonValue) String() string { 510 | return "" 511 | } 512 | 513 | func (jsonValue) Type() string { 514 | return "json" 515 | } 516 | 517 | func (j *jsonValue) UnmarshalJSON(data []byte) error { 518 | if j == nil { 519 | return xerrors.New("json.RawMessage: UnmarshalJSON on nil pointer") 520 | } 521 | *j = append((*j)[0:0], data...) 522 | return nil 523 | } 524 | 525 | var _ pflag.Value = (*Enum)(nil) 526 | 527 | type Enum struct { 528 | Choices []string 529 | Value *string 530 | } 531 | 532 | func EnumOf(v *string, choices ...string) *Enum { 533 | // copy choices to avoid data race during unmarshaling 534 | choices = append([]string{}, choices...) 535 | return &Enum{ 536 | Choices: choices, 537 | Value: v, 538 | } 539 | } 540 | 541 | func (e *Enum) Set(v string) error { 542 | for _, c := range e.Choices { 543 | if strings.EqualFold(v, c) { 544 | *e.Value = v 545 | return nil 546 | } 547 | } 548 | return xerrors.Errorf("invalid choice: %s, should be one of %v", v, e.Choices) 549 | } 550 | 551 | func (e *Enum) Type() string { 552 | return fmt.Sprintf("enum[%v]", strings.Join(e.Choices, "\\|")) 553 | } 554 | 555 | func (e *Enum) String() string { 556 | return *e.Value 557 | } 558 | 559 | func (e *Enum) MarshalYAML() (interface{}, error) { 560 | return yaml.Node{ 561 | Kind: yaml.ScalarNode, 562 | Value: e.String(), 563 | }, nil 564 | } 565 | 566 | func (e *Enum) UnmarshalYAML(n *yaml.Node) error { 567 | return e.Set(n.Value) 568 | } 569 | 570 | type Regexp regexp.Regexp 571 | 572 | func (r *Regexp) MarshalJSON() ([]byte, error) { 573 | return json.Marshal(r.String()) 574 | } 575 | 576 | func (r *Regexp) UnmarshalJSON(data []byte) error { 577 | var source string 578 | err := json.Unmarshal(data, &source) 579 | if err != nil { 580 | return err 581 | } 582 | 583 | exp, err := regexp.Compile(source) 584 | if err != nil { 585 | return xerrors.Errorf("invalid regex expression: %w", err) 586 | } 587 | *r = Regexp(*exp) 588 | return nil 589 | } 590 | 591 | func (r *Regexp) MarshalYAML() (interface{}, error) { 592 | return yaml.Node{ 593 | Kind: yaml.ScalarNode, 594 | Value: r.String(), 595 | }, nil 596 | } 597 | 598 | func (r *Regexp) UnmarshalYAML(n *yaml.Node) error { 599 | return r.Set(n.Value) 600 | } 601 | 602 | func (r *Regexp) Set(v string) error { 603 | exp, err := regexp.Compile(v) 604 | if err != nil { 605 | return xerrors.Errorf("invalid regex expression: %w", err) 606 | } 607 | *r = Regexp(*exp) 608 | return nil 609 | } 610 | 611 | func (r Regexp) String() string { 612 | return r.Value().String() 613 | } 614 | 615 | func (r *Regexp) Value() *regexp.Regexp { 616 | if r == nil { 617 | return nil 618 | } 619 | return (*regexp.Regexp)(r) 620 | } 621 | 622 | func (Regexp) Type() string { 623 | return "regexp" 624 | } 625 | 626 | var _ pflag.Value = (*YAMLConfigPath)(nil) 627 | 628 | // YAMLConfigPath is a special value type that encodes a path to a YAML 629 | // configuration file where options are read from. 630 | type YAMLConfigPath string 631 | 632 | func (p *YAMLConfigPath) Set(v string) error { 633 | *p = YAMLConfigPath(v) 634 | return nil 635 | } 636 | 637 | func (p *YAMLConfigPath) String() string { 638 | return string(*p) 639 | } 640 | 641 | func (*YAMLConfigPath) Type() string { 642 | return "yaml-config-path" 643 | } 644 | 645 | var _ pflag.SliceValue = (*EnumArray)(nil) 646 | var _ pflag.Value = (*EnumArray)(nil) 647 | 648 | type EnumArray struct { 649 | Choices []string 650 | Value *[]string 651 | } 652 | 653 | func (e *EnumArray) Append(s string) error { 654 | for _, c := range e.Choices { 655 | if strings.EqualFold(s, c) { 656 | *e.Value = append(*e.Value, s) 657 | return nil 658 | } 659 | } 660 | return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) 661 | } 662 | 663 | func (e *EnumArray) GetSlice() []string { 664 | return *e.Value 665 | } 666 | 667 | func (e *EnumArray) Replace(ss []string) error { 668 | for _, s := range ss { 669 | found := false 670 | for _, c := range e.Choices { 671 | if strings.EqualFold(s, c) { 672 | found = true 673 | break 674 | } 675 | } 676 | if !found { 677 | return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) 678 | } 679 | } 680 | *e.Value = ss 681 | return nil 682 | } 683 | 684 | func (e *EnumArray) Set(v string) error { 685 | if v == "" { 686 | *e.Value = nil 687 | return nil 688 | } 689 | ss, err := readAsCSV(v) 690 | if err != nil { 691 | return err 692 | } 693 | for _, s := range ss { 694 | err := e.Append(s) 695 | if err != nil { 696 | return err 697 | } 698 | } 699 | return nil 700 | } 701 | 702 | func (e *EnumArray) String() string { 703 | return writeAsCSV(*e.Value) 704 | } 705 | 706 | func (e *EnumArray) Type() string { 707 | return fmt.Sprintf("enum-array[%v]", strings.Join(e.Choices, "\\|")) 708 | } 709 | 710 | func EnumArrayOf(v *[]string, choices ...string) *EnumArray { 711 | choices = append([]string{}, choices...) 712 | return &EnumArray{ 713 | Choices: choices, 714 | Value: v, 715 | } 716 | } 717 | -------------------------------------------------------------------------------- /command.go: -------------------------------------------------------------------------------- 1 | package serpent 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "flag" 7 | "fmt" 8 | "io" 9 | "os" 10 | "os/signal" 11 | "strings" 12 | "testing" 13 | "unicode" 14 | 15 | "cdr.dev/slog" 16 | 17 | "github.com/spf13/pflag" 18 | "golang.org/x/exp/constraints" 19 | "golang.org/x/exp/slices" 20 | "golang.org/x/xerrors" 21 | "gopkg.in/yaml.v3" 22 | ) 23 | 24 | // Command describes an executable command. 25 | type Command struct { 26 | // Parent is the direct parent of the command. 27 | // 28 | // It is set automatically when an invokation runs. 29 | Parent *Command 30 | 31 | // Children is a list of direct descendants. 32 | Children []*Command 33 | 34 | // Use is provided in form "command [flags] [args...]". 35 | Use string 36 | 37 | // Aliases is a list of alternative names for the command. 38 | Aliases []string 39 | 40 | // Short is a one-line description of the command. 41 | Short string 42 | 43 | // Hidden determines whether the command should be hidden from help. 44 | Hidden bool 45 | 46 | // Deprecated indicates whether this command is deprecated. 47 | // If empty, the command is not deprecated. 48 | // If set, the value is used as the deprecation message. 49 | Deprecated string `json:"deprecated,omitempty"` 50 | 51 | // RawArgs determines whether the command should receive unparsed arguments. 52 | // No flags are parsed when set, and the command is responsible for parsing 53 | // its own flags. 54 | RawArgs bool 55 | 56 | // Long is a detailed description of the command, 57 | // presented on its help page. It may contain examples. 58 | Long string 59 | Options OptionSet 60 | Annotations Annotations 61 | 62 | // Middleware is called before the Handler. 63 | // Use Chain() to combine multiple middlewares. 64 | Middleware MiddlewareFunc 65 | Handler HandlerFunc 66 | HelpHandler HandlerFunc 67 | // CompletionHandler is called when the command is run in completion 68 | // mode. If nil, only the default completion handler is used. 69 | // 70 | // Flag and option parsing is best-effort in this mode, so even if an Option 71 | // is "required" it may not be set. 72 | CompletionHandler CompletionHandlerFunc 73 | } 74 | 75 | // AddSubcommands adds the given subcommands, setting their 76 | // Parent field automatically. 77 | func (c *Command) AddSubcommands(cmds ...*Command) { 78 | for _, cmd := range cmds { 79 | cmd.Parent = c 80 | c.Children = append(c.Children, cmd) 81 | } 82 | } 83 | 84 | // Walk calls fn for the command and all its children. 85 | func (c *Command) Walk(fn func(*Command)) { 86 | fn(c) 87 | for _, child := range c.Children { 88 | child.Parent = c 89 | child.Walk(fn) 90 | } 91 | } 92 | 93 | func ascendingSortFn[T constraints.Ordered](a, b T) int { 94 | if a < b { 95 | return -1 96 | } else if a == b { 97 | return 0 98 | } 99 | return 1 100 | } 101 | 102 | // init performs initialization and linting on the command and all its children. 103 | func (c *Command) init() error { 104 | if c.Use == "" { 105 | c.Use = "unnamed" 106 | } 107 | var merr error 108 | 109 | for i := range c.Options { 110 | opt := &c.Options[i] 111 | if opt.Name == "" { 112 | switch { 113 | case opt.Flag != "": 114 | opt.Name = opt.Flag 115 | case opt.Env != "": 116 | opt.Name = opt.Env 117 | case opt.YAML != "": 118 | opt.Name = opt.YAML 119 | default: 120 | merr = errors.Join(merr, xerrors.Errorf("option must have a Name, Flag, Env or YAML field")) 121 | } 122 | } 123 | if opt.Description != "" { 124 | // Enforce that description uses sentence form. 125 | if unicode.IsLower(rune(opt.Description[0])) { 126 | merr = errors.Join(merr, xerrors.Errorf("option %q description should start with a capital letter", opt.Name)) 127 | } 128 | if !strings.HasSuffix(opt.Description, ".") { 129 | merr = errors.Join(merr, xerrors.Errorf("option %q description should end with a period", opt.Name)) 130 | } 131 | } 132 | } 133 | 134 | slices.SortFunc(c.Options, func(a, b Option) int { 135 | return ascendingSortFn(a.Name, b.Name) 136 | }) 137 | slices.SortFunc(c.Children, func(a, b *Command) int { 138 | return ascendingSortFn(a.Name(), b.Name()) 139 | }) 140 | for _, child := range c.Children { 141 | child.Parent = c 142 | err := child.init() 143 | if err != nil { 144 | merr = errors.Join(merr, xerrors.Errorf("command %v: %w", child.Name(), err)) 145 | } 146 | } 147 | return merr 148 | } 149 | 150 | // Name returns the first word in the Use string. 151 | func (c *Command) Name() string { 152 | return strings.Split(c.Use, " ")[0] 153 | } 154 | 155 | // FullName returns the full invocation name of the command, 156 | // as seen on the command line. 157 | func (c *Command) FullName() string { 158 | var names []string 159 | if c.Parent != nil { 160 | names = append(names, c.Parent.FullName()) 161 | } 162 | names = append(names, c.Name()) 163 | return strings.Join(names, " ") 164 | } 165 | 166 | // FullName returns usage of the command, preceded 167 | // by the usage of its parents. 168 | func (c *Command) FullUsage() string { 169 | var uses []string 170 | if c.Parent != nil { 171 | uses = append(uses, c.Parent.FullName()) 172 | } 173 | uses = append(uses, c.Use) 174 | return strings.Join(uses, " ") 175 | } 176 | 177 | // FullOptions returns the options of the command and its parents. 178 | func (c *Command) FullOptions() OptionSet { 179 | var opts OptionSet 180 | if c.Parent != nil { 181 | opts = append(opts, c.Parent.FullOptions()...) 182 | } 183 | opts = append(opts, c.Options...) 184 | return opts 185 | } 186 | 187 | // Invoke creates a new invocation of the command, with 188 | // stdio discarded. 189 | // 190 | // The returned invocation is not live until Run() is called. 191 | func (c *Command) Invoke(args ...string) *Invocation { 192 | return &Invocation{ 193 | Command: c, 194 | Args: args, 195 | Stdout: io.Discard, 196 | Stderr: io.Discard, 197 | Stdin: strings.NewReader(""), 198 | Logger: slog.Make(), 199 | } 200 | } 201 | 202 | // Invocation represents an instance of a command being executed. 203 | type Invocation struct { 204 | ctx context.Context 205 | Command *Command 206 | parsedFlags *pflag.FlagSet 207 | 208 | // Args is reduced into the remaining arguments after parsing flags 209 | // during Run. 210 | Args []string 211 | 212 | // Environ is a list of environment variables. Use EnvsWithPrefix to parse 213 | // os.Environ. 214 | Environ Environ 215 | Stdout io.Writer 216 | Stderr io.Writer 217 | Stdin io.Reader 218 | 219 | // Deprecated 220 | Logger slog.Logger 221 | // Deprecated 222 | Net Net 223 | 224 | // testing 225 | signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) 226 | } 227 | 228 | // WithOS returns the invocation as a main package, filling in the invocation's unset 229 | // fields with OS defaults. 230 | func (inv *Invocation) WithOS() *Invocation { 231 | return inv.with(func(i *Invocation) { 232 | i.Stdout = os.Stdout 233 | i.Stderr = os.Stderr 234 | i.Stdin = os.Stdin 235 | i.Args = os.Args[1:] 236 | i.Environ = ParseEnviron(os.Environ(), "") 237 | i.Net = osNet{} 238 | }) 239 | } 240 | 241 | // WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext. 242 | // This should only be used in testing. 243 | func (inv *Invocation) WithTestSignalNotifyContext( 244 | _ testing.TB, // ensure we only call this from tests 245 | f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc), 246 | ) *Invocation { 247 | return inv.with(func(i *Invocation) { 248 | i.signalNotifyContext = f 249 | }) 250 | } 251 | 252 | // SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in 253 | // tests. 254 | func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) { 255 | if inv.signalNotifyContext == nil { 256 | return signal.NotifyContext(parent, signals...) 257 | } 258 | return inv.signalNotifyContext(parent, signals...) 259 | } 260 | 261 | func (inv *Invocation) WithTestParsedFlags( 262 | _ testing.TB, // ensure we only call this from tests 263 | parsedFlags *pflag.FlagSet, 264 | ) *Invocation { 265 | return inv.with(func(i *Invocation) { 266 | i.parsedFlags = parsedFlags 267 | }) 268 | } 269 | 270 | func (inv *Invocation) Context() context.Context { 271 | if inv.ctx == nil { 272 | return context.Background() 273 | } 274 | return inv.ctx 275 | } 276 | 277 | func (inv *Invocation) ParsedFlags() *pflag.FlagSet { 278 | if inv.parsedFlags == nil { 279 | panic("flags not parsed, has Run() been called?") 280 | } 281 | return inv.parsedFlags 282 | } 283 | 284 | type runState struct { 285 | allArgs []string 286 | commandDepth int 287 | 288 | flagParseErr error 289 | } 290 | 291 | func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { 292 | fs2 := pflag.NewFlagSet("", pflag.ContinueOnError) 293 | fs2.Usage = func() {} 294 | fs.VisitAll(func(f *pflag.Flag) { 295 | if f.Name == without { 296 | return 297 | } 298 | fs2.AddFlag(f) 299 | }) 300 | return fs2 301 | } 302 | 303 | func (inv *Invocation) CurWords() (prev string, cur string) { 304 | switch len(inv.Args) { 305 | // All the shells we support will supply at least one argument (empty string), 306 | // but we don't want to panic. 307 | case 0: 308 | cur = "" 309 | prev = "" 310 | case 1: 311 | cur = inv.Args[0] 312 | prev = "" 313 | default: 314 | cur = inv.Args[len(inv.Args)-1] 315 | prev = inv.Args[len(inv.Args)-2] 316 | } 317 | return 318 | } 319 | 320 | // run recursively executes the command and its children. 321 | // allArgs is wired through the stack so that global flags can be accepted 322 | // anywhere in the command invocation. 323 | func (inv *Invocation) run(state *runState) error { 324 | if inv.Command.Deprecated != "" { 325 | fmt.Fprintf(inv.Stderr, "%s %q is deprecated!. %s\n", 326 | prettyHeader("warning"), 327 | inv.Command.FullName(), 328 | inv.Command.Deprecated, 329 | ) 330 | } 331 | err := inv.Command.Options.ParseEnv(inv.Environ) 332 | if err != nil { 333 | return xerrors.Errorf("parsing env: %w", err) 334 | } 335 | 336 | // Now the fun part, argument parsing! 337 | 338 | children := make(map[string]*Command) 339 | for _, child := range inv.Command.Children { 340 | child.Parent = inv.Command 341 | for _, name := range append(child.Aliases, child.Name()) { 342 | if _, ok := children[name]; ok { 343 | return xerrors.Errorf("duplicate command name: %s", name) 344 | } 345 | children[name] = child 346 | } 347 | } 348 | 349 | if inv.parsedFlags == nil { 350 | inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError) 351 | // We handle Usage ourselves. 352 | inv.parsedFlags.Usage = func() {} 353 | } 354 | 355 | // If we find a duplicate flag, we want the deeper command's flag to override 356 | // the shallow one. Unfortunately, pflag has no way to remove a flag, so we 357 | // have to create a copy of the flagset without a value. 358 | inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) { 359 | if inv.parsedFlags.Lookup(f.Name) != nil { 360 | inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name) 361 | } 362 | inv.parsedFlags.AddFlag(f) 363 | }) 364 | 365 | var parsedArgs []string 366 | 367 | if !inv.Command.RawArgs { 368 | // Flag parsing will fail on intermediate commands in the command tree, 369 | // so we check the error after looking for a child command. 370 | state.flagParseErr = inv.parsedFlags.Parse(state.allArgs) 371 | parsedArgs = inv.parsedFlags.Args() 372 | } 373 | 374 | // Set value sources for flags. 375 | for i, opt := range inv.Command.Options { 376 | if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed { 377 | inv.Command.Options[i].ValueSource = ValueSourceFlag 378 | } 379 | } 380 | 381 | // Read YAML configs, if any. 382 | for _, opt := range inv.Command.Options { 383 | path, ok := opt.Value.(*YAMLConfigPath) 384 | if !ok || path.String() == "" { 385 | continue 386 | } 387 | 388 | byt, err := os.ReadFile(path.String()) 389 | if err != nil { 390 | return xerrors.Errorf("reading yaml: %w", err) 391 | } 392 | 393 | var n yaml.Node 394 | err = yaml.Unmarshal(byt, &n) 395 | if err != nil { 396 | return xerrors.Errorf("decoding yaml: %w", err) 397 | } 398 | 399 | err = inv.Command.Options.UnmarshalYAML(&n) 400 | if err != nil { 401 | return xerrors.Errorf("applying yaml: %w", err) 402 | } 403 | } 404 | 405 | err = inv.Command.Options.SetDefaults() 406 | if err != nil { 407 | return xerrors.Errorf("setting defaults: %w", err) 408 | } 409 | 410 | // Run child command if found (next child only) 411 | // We must do subcommand detection after flag parsing so we don't mistake flag 412 | // values for subcommand names. 413 | if len(parsedArgs) > state.commandDepth { 414 | nextArg := parsedArgs[state.commandDepth] 415 | if child, ok := children[nextArg]; ok { 416 | child.Parent = inv.Command 417 | inv.Command = child 418 | state.commandDepth++ 419 | return inv.run(state) 420 | } 421 | } 422 | 423 | // Outputted completions are not filtered based on the word under the cursor, as every shell we support does this already. 424 | // We only look at the current word to figure out handler to run, or what directory to inspect. 425 | if inv.IsCompletionMode() { 426 | for _, e := range inv.complete() { 427 | fmt.Fprintln(inv.Stdout, e) 428 | } 429 | return nil 430 | } 431 | 432 | ignoreFlagParseErrors := inv.Command.RawArgs 433 | 434 | // Flag parse errors are irrelevant for raw args commands. 435 | if !ignoreFlagParseErrors && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { 436 | return xerrors.Errorf( 437 | "parsing flags (%v) for %q: %w", 438 | state.allArgs, 439 | inv.Command.FullName(), state.flagParseErr, 440 | ) 441 | } 442 | 443 | // All options should be set. Check all required options have sources, 444 | // meaning they were set by the user in some way (env, flag, etc). 445 | var missing []string 446 | for _, opt := range inv.Command.Options { 447 | if opt.Required && opt.ValueSource == ValueSourceNone { 448 | name := opt.Name 449 | // use flag as a fallback if name is empty 450 | if name == "" { 451 | name = opt.Flag 452 | } 453 | missing = append(missing, name) 454 | } 455 | } 456 | // Don't error for missing flags if `--help` was supplied. 457 | if len(missing) > 0 && !inv.IsCompletionMode() && !errors.Is(state.flagParseErr, pflag.ErrHelp) { 458 | return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", ")) 459 | } 460 | 461 | if inv.Command.RawArgs { 462 | // If we're at the root command, then the name is omitted 463 | // from the arguments, so we can just use the entire slice. 464 | if state.commandDepth == 0 { 465 | inv.Args = state.allArgs 466 | } else { 467 | argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags) 468 | if err != nil { 469 | panic(err) 470 | } 471 | inv.Args = state.allArgs[argPos+1:] 472 | } 473 | } else { 474 | // In non-raw-arg mode, we want to skip over flags. 475 | inv.Args = parsedArgs[state.commandDepth:] 476 | } 477 | 478 | mw := inv.Command.Middleware 479 | if mw == nil { 480 | mw = Chain() 481 | } 482 | 483 | ctx := inv.ctx 484 | if ctx == nil { 485 | ctx = context.Background() 486 | } 487 | 488 | ctx, cancel := context.WithCancel(ctx) 489 | defer cancel() 490 | inv = inv.WithContext(ctx) 491 | 492 | if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { 493 | if inv.Command.HelpHandler == nil { 494 | return DefaultHelpFn()(inv) 495 | } 496 | return inv.Command.HelpHandler(inv) 497 | } 498 | 499 | err = mw(inv.Command.Handler)(inv) 500 | if err != nil { 501 | return &RunCommandError{ 502 | Cmd: inv.Command, 503 | Err: err, 504 | } 505 | } 506 | return nil 507 | } 508 | 509 | type RunCommandError struct { 510 | Cmd *Command 511 | Err error 512 | } 513 | 514 | func (e *RunCommandError) Unwrap() error { 515 | return e.Err 516 | } 517 | 518 | func (e *RunCommandError) Error() string { 519 | return fmt.Sprintf("running command %q: %+v", e.Cmd.FullName(), e.Err) 520 | } 521 | 522 | // findArg returns the index of the first occurrence of arg in args, skipping 523 | // over all flags. 524 | func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) { 525 | for i := 0; i < len(args); i++ { 526 | arg := args[i] 527 | if !strings.HasPrefix(arg, "-") { 528 | if arg == want { 529 | return i, nil 530 | } 531 | continue 532 | } 533 | 534 | // This is a flag! 535 | if strings.Contains(arg, "=") { 536 | // The flag contains the value in the same arg, just skip. 537 | continue 538 | } 539 | 540 | // We need to check if NoOptValue is set, then we should not wait 541 | // for the next arg to be the value. 542 | f := fs.Lookup(strings.TrimLeft(arg, "-")) 543 | if f == nil { 544 | return -1, xerrors.Errorf("unknown flag: %s", arg) 545 | } 546 | if f.NoOptDefVal != "" { 547 | continue 548 | } 549 | 550 | if i == len(args)-1 { 551 | return -1, xerrors.Errorf("flag %s requires a value", arg) 552 | } 553 | 554 | // Skip the value. 555 | i++ 556 | } 557 | 558 | return -1, xerrors.Errorf("arg %s not found", want) 559 | } 560 | 561 | // Run executes the command. 562 | // If two command share a flag name, the first command wins. 563 | // 564 | //nolint:revive 565 | func (inv *Invocation) Run() (err error) { 566 | err = inv.Command.init() 567 | if err != nil { 568 | return xerrors.Errorf("initializing command: %w", err) 569 | } 570 | 571 | defer func() { 572 | // Pflag is panicky, so additional context is helpful in tests. 573 | if flag.Lookup("test.v") == nil { 574 | return 575 | } 576 | if r := recover(); r != nil { 577 | err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r) 578 | panic(err) 579 | } 580 | }() 581 | // We close Stdin to prevent deadlocks, e.g. when the command 582 | // has ended but an io.Copy is still reading from Stdin. 583 | defer func() { 584 | if inv.Stdin == nil { 585 | return 586 | } 587 | rc, ok := inv.Stdin.(io.ReadCloser) 588 | if !ok { 589 | return 590 | } 591 | e := rc.Close() 592 | err = errors.Join(err, e) 593 | }() 594 | err = inv.run(&runState{ 595 | allArgs: inv.Args, 596 | }) 597 | return err 598 | } 599 | 600 | // WithContext returns a copy of the Invocation with the given context. 601 | func (inv *Invocation) WithContext(ctx context.Context) *Invocation { 602 | return inv.with(func(i *Invocation) { 603 | i.ctx = ctx 604 | }) 605 | } 606 | 607 | // with returns a copy of the Invocation with the given function applied. 608 | func (inv *Invocation) with(fn func(*Invocation)) *Invocation { 609 | i2 := *inv 610 | fn(&i2) 611 | return &i2 612 | } 613 | 614 | func (inv *Invocation) complete() []string { 615 | prev, cur := inv.CurWords() 616 | 617 | // If the current word is a flag 618 | if strings.HasPrefix(cur, "--") { 619 | flagParts := strings.Split(cur, "=") 620 | flagName := flagParts[0][2:] 621 | // If it's an equals flag 622 | if len(flagParts) == 2 { 623 | if out := inv.completeFlag(flagName); out != nil { 624 | for i, o := range out { 625 | out[i] = fmt.Sprintf("--%s=%s", flagName, o) 626 | } 627 | return out 628 | } 629 | } else if out := inv.Command.Options.ByFlag(flagName); out != nil { 630 | // If the current word is a valid flag, auto-complete it so the 631 | // shell moves the cursor 632 | return []string{cur} 633 | } 634 | } 635 | // If the previous word is a flag, then we're writing it's value 636 | // and we should check it's handler 637 | if strings.HasPrefix(prev, "--") { 638 | word := prev[2:] 639 | if out := inv.completeFlag(word); out != nil { 640 | return out 641 | } 642 | } 643 | // If the current word is the command, move the shell cursor 644 | if inv.Command.Name() == cur { 645 | return []string{inv.Command.Name()} 646 | } 647 | var completions []string 648 | 649 | if inv.Command.CompletionHandler != nil { 650 | completions = append(completions, inv.Command.CompletionHandler(inv)...) 651 | } 652 | 653 | completions = append(completions, DefaultCompletionHandler(inv)...) 654 | 655 | return completions 656 | } 657 | 658 | func (inv *Invocation) completeFlag(word string) []string { 659 | opt := inv.Command.Options.ByFlag(word) 660 | if opt == nil { 661 | return nil 662 | } 663 | if opt.CompletionHandler != nil { 664 | return opt.CompletionHandler(inv) 665 | } 666 | enum, ok := opt.Value.(*Enum) 667 | if ok { 668 | return enum.Choices 669 | } 670 | enumArr, ok := opt.Value.(*EnumArray) 671 | if ok { 672 | return enumArr.Choices 673 | } 674 | return nil 675 | } 676 | 677 | // MiddlewareFunc returns the next handler in the chain, 678 | // or nil if there are no more. 679 | type MiddlewareFunc func(next HandlerFunc) HandlerFunc 680 | 681 | func chain(ms ...MiddlewareFunc) MiddlewareFunc { 682 | return MiddlewareFunc(func(next HandlerFunc) HandlerFunc { 683 | if len(ms) > 0 { 684 | return chain(ms[1:]...)(ms[0](next)) 685 | } 686 | return next 687 | }) 688 | } 689 | 690 | // Chain returns a Handler that first calls middleware in order. 691 | // 692 | //nolint:revive 693 | func Chain(ms ...MiddlewareFunc) MiddlewareFunc { 694 | // We need to reverse the array to provide top-to-bottom execution 695 | // order when defining a command. 696 | reversed := make([]MiddlewareFunc, len(ms)) 697 | for i := range ms { 698 | reversed[len(ms)-1-i] = ms[i] 699 | } 700 | return chain(reversed...) 701 | } 702 | 703 | func RequireNArgs(want int) MiddlewareFunc { 704 | return RequireRangeArgs(want, want) 705 | } 706 | 707 | // RequireRangeArgs returns a Middleware that requires the number of arguments 708 | // to be between start and end (inclusive). If end is -1, then the number of 709 | // arguments must be at least start. 710 | func RequireRangeArgs(start, end int) MiddlewareFunc { 711 | if start < 0 { 712 | panic("start must be >= 0") 713 | } 714 | return func(next HandlerFunc) HandlerFunc { 715 | return func(i *Invocation) error { 716 | got := len(i.Args) 717 | switch { 718 | case start == end && got != start: 719 | switch start { 720 | case 0: 721 | if len(i.Command.Children) > 0 { 722 | return xerrors.Errorf("unrecognized subcommand %q", i.Args[0]) 723 | } 724 | return xerrors.Errorf("wanted no args but got %v %v", got, i.Args) 725 | default: 726 | return xerrors.Errorf( 727 | "wanted %v args but got %v %v", 728 | start, 729 | got, 730 | i.Args, 731 | ) 732 | } 733 | case start > 0 && end == -1: 734 | switch { 735 | case got < start: 736 | return xerrors.Errorf( 737 | "wanted at least %v args but got %v", 738 | start, 739 | got, 740 | ) 741 | default: 742 | return next(i) 743 | } 744 | case start > end: 745 | panic("start must be <= end") 746 | case got < start || got > end: 747 | return xerrors.Errorf( 748 | "wanted between %v and %v args but got %v", 749 | start, end, 750 | got, 751 | ) 752 | default: 753 | return next(i) 754 | } 755 | } 756 | } 757 | } 758 | 759 | // HandlerFunc handles an Invocation of a command. 760 | type HandlerFunc func(i *Invocation) error 761 | 762 | type CompletionHandlerFunc func(i *Invocation) []string 763 | -------------------------------------------------------------------------------- /command_test.go: -------------------------------------------------------------------------------- 1 | package serpent_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "os" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/require" 12 | "golang.org/x/xerrors" 13 | 14 | serpent "github.com/coder/serpent" 15 | "github.com/coder/serpent/completion" 16 | ) 17 | 18 | // ioBufs is the standard input, output, and error for a command. 19 | type ioBufs struct { 20 | Stdin bytes.Buffer 21 | Stdout bytes.Buffer 22 | Stderr bytes.Buffer 23 | } 24 | 25 | // fakeIO sets Stdin, Stdout, and Stderr to buffers. 26 | func fakeIO(i *serpent.Invocation) *ioBufs { 27 | var b ioBufs 28 | i.Stdout = &b.Stdout 29 | i.Stderr = &b.Stderr 30 | i.Stdin = &b.Stdin 31 | return &b 32 | } 33 | 34 | func sampleCommand(t *testing.T) *serpent.Command { 35 | t.Helper() 36 | var ( 37 | verbose bool 38 | lower bool 39 | prefix string 40 | reqBool bool 41 | reqStr string 42 | reqArr []string 43 | reqEnumArr []string 44 | fileArr []string 45 | enumStr string 46 | ) 47 | enumChoices := []string{"foo", "bar", "qux"} 48 | return &serpent.Command{ 49 | Use: "root [subcommand]", 50 | Options: serpent.OptionSet{ 51 | serpent.Option{ 52 | Name: "verbose", 53 | Flag: "verbose", 54 | Default: "false", 55 | Value: serpent.BoolOf(&verbose), 56 | }, 57 | serpent.Option{ 58 | Name: "verbose-old", 59 | Flag: "verbode-old", 60 | Value: serpent.BoolOf(&verbose), 61 | }, 62 | serpent.Option{ 63 | Name: "prefix", 64 | Flag: "prefix", 65 | Value: serpent.StringOf(&prefix), 66 | }, 67 | }, 68 | Children: []*serpent.Command{ 69 | { 70 | Use: "required-flag --req-bool=true --req-string=foo", 71 | Short: "Example with required flags", 72 | Options: serpent.OptionSet{ 73 | serpent.Option{ 74 | Name: "req-bool", 75 | Flag: "req-bool", 76 | FlagShorthand: "b", 77 | Value: serpent.BoolOf(&reqBool), 78 | Required: true, 79 | }, 80 | serpent.Option{ 81 | Name: "req-string", 82 | Flag: "req-string", 83 | FlagShorthand: "s", 84 | Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error { 85 | ok := strings.Contains(value.String(), " ") 86 | if !ok { 87 | return xerrors.Errorf("string must contain a space") 88 | } 89 | return nil 90 | }), 91 | Required: true, 92 | }, 93 | serpent.Option{ 94 | Name: "req-enum", 95 | Flag: "req-enum", 96 | Value: serpent.EnumOf(&enumStr, enumChoices...), 97 | }, 98 | serpent.Option{ 99 | Name: "req-array", 100 | Flag: "req-array", 101 | FlagShorthand: "a", 102 | Value: serpent.StringArrayOf(&reqArr), 103 | }, 104 | serpent.Option{ 105 | Name: "req-enum-array", 106 | Flag: "req-enum-array", 107 | Value: serpent.EnumArrayOf(&reqEnumArr, enumChoices...), 108 | }, 109 | }, 110 | HelpHandler: func(i *serpent.Invocation) error { 111 | _, _ = i.Stdout.Write([]byte("help text.png")) 112 | return nil 113 | }, 114 | Handler: func(i *serpent.Invocation) error { 115 | _, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool))) 116 | return nil 117 | }, 118 | }, 119 | { 120 | Use: "toupper [word]", 121 | Short: "Converts a word to upper case", 122 | Middleware: serpent.Chain( 123 | serpent.RequireNArgs(1), 124 | ), 125 | Aliases: []string{"up"}, 126 | Options: serpent.OptionSet{ 127 | serpent.Option{ 128 | Name: "lower", 129 | Flag: "lower", 130 | Value: serpent.BoolOf(&lower), 131 | }, 132 | }, 133 | Handler: func(i *serpent.Invocation) error { 134 | _, _ = i.Stdout.Write([]byte(prefix)) 135 | w := i.Args[0] 136 | if lower { 137 | w = strings.ToLower(w) 138 | } else { 139 | w = strings.ToUpper(w) 140 | } 141 | _, _ = i.Stdout.Write( 142 | []byte( 143 | w, 144 | ), 145 | ) 146 | if verbose { 147 | _, _ = i.Stdout.Write([]byte("!!!")) 148 | } 149 | return nil 150 | }, 151 | }, 152 | { 153 | Use: "file ", 154 | Handler: func(inv *serpent.Invocation) error { 155 | return nil 156 | }, 157 | CompletionHandler: completion.FileHandler(func(info os.FileInfo) bool { 158 | return true 159 | }), 160 | Middleware: serpent.RequireNArgs(1), 161 | }, 162 | { 163 | Use: "altfile", 164 | Handler: func(inv *serpent.Invocation) error { 165 | return nil 166 | }, 167 | Options: serpent.OptionSet{ 168 | { 169 | Name: "extra", 170 | Flag: "extra", 171 | Description: "Extra files.", 172 | Value: serpent.StringArrayOf(&fileArr), 173 | }, 174 | }, 175 | CompletionHandler: func(i *serpent.Invocation) []string { 176 | return []string{"doesntexist.go"} 177 | }, 178 | }, 179 | }, 180 | } 181 | } 182 | 183 | func TestCommand(t *testing.T) { 184 | t.Parallel() 185 | 186 | cmd := func() *serpent.Command { return sampleCommand(t) } 187 | 188 | t.Run("SimpleOK", func(t *testing.T) { 189 | t.Parallel() 190 | i := cmd().Invoke("toupper", "hello") 191 | io := fakeIO(i) 192 | err := i.Run() 193 | require.NoError(t, err) 194 | require.Equal(t, "HELLO", io.Stdout.String()) 195 | }) 196 | 197 | t.Run("Alias", func(t *testing.T) { 198 | t.Parallel() 199 | i := cmd().Invoke( 200 | "up", "hello", 201 | ) 202 | io := fakeIO(i) 203 | err := i.Run() 204 | require.NoError(t, err) 205 | 206 | require.Equal(t, "HELLO", io.Stdout.String()) 207 | }) 208 | 209 | t.Run("BadArgs", func(t *testing.T) { 210 | t.Parallel() 211 | i := cmd().Invoke( 212 | "toupper", 213 | ) 214 | io := fakeIO(i) 215 | err := i.Run() 216 | require.Empty(t, io.Stdout.String()) 217 | require.Error(t, err) 218 | }) 219 | 220 | t.Run("NoSubcommand", func(t *testing.T) { 221 | t.Parallel() 222 | i := cmd().Invoke( 223 | "na", 224 | ) 225 | io := fakeIO(i) 226 | err := i.Run() 227 | require.Error(t, err) 228 | require.Contains(t, io.Stderr.String(), "unknown subcommand") 229 | }) 230 | 231 | t.Run("UnknownFlags", func(t *testing.T) { 232 | t.Parallel() 233 | i := cmd().Invoke( 234 | "toupper", "--unknown", 235 | ) 236 | io := fakeIO(i) 237 | err := i.Run() 238 | require.Empty(t, io.Stdout.String()) 239 | require.Error(t, err) 240 | }) 241 | 242 | t.Run("Verbose", func(t *testing.T) { 243 | t.Parallel() 244 | i := cmd().Invoke( 245 | "--verbose", "toupper", "hello", 246 | ) 247 | io := fakeIO(i) 248 | require.NoError(t, i.Run()) 249 | require.Equal(t, "HELLO!!!", io.Stdout.String()) 250 | }) 251 | 252 | t.Run("Verbose=", func(t *testing.T) { 253 | t.Parallel() 254 | i := cmd().Invoke( 255 | "--verbose=true", "toupper", "hello", 256 | ) 257 | io := fakeIO(i) 258 | require.NoError(t, i.Run()) 259 | require.Equal(t, "HELLO!!!", io.Stdout.String()) 260 | }) 261 | 262 | t.Run("PrefixSpace", func(t *testing.T) { 263 | t.Parallel() 264 | i := cmd().Invoke( 265 | "--prefix", "conv: ", "toupper", "hello", 266 | ) 267 | io := fakeIO(i) 268 | require.NoError(t, i.Run()) 269 | require.Equal(t, "conv: HELLO", io.Stdout.String()) 270 | }) 271 | 272 | t.Run("GlobalFlagsAnywhere", func(t *testing.T) { 273 | t.Parallel() 274 | i := cmd().Invoke( 275 | "toupper", "--prefix", "conv: ", "hello", "--verbose", 276 | ) 277 | io := fakeIO(i) 278 | require.NoError(t, i.Run()) 279 | require.Equal(t, "conv: HELLO!!!", io.Stdout.String()) 280 | }) 281 | 282 | t.Run("LowerVerbose", func(t *testing.T) { 283 | t.Parallel() 284 | i := cmd().Invoke( 285 | "toupper", "--verbose", "hello", "--lower", 286 | ) 287 | io := fakeIO(i) 288 | require.NoError(t, i.Run()) 289 | require.Equal(t, "hello!!!", io.Stdout.String()) 290 | }) 291 | 292 | t.Run("ParsedFlags", func(t *testing.T) { 293 | t.Parallel() 294 | i := cmd().Invoke( 295 | "toupper", "--verbose", "hello", "--lower", 296 | ) 297 | _ = fakeIO(i) 298 | require.NoError(t, i.Run()) 299 | require.Equal(t, 300 | "true", 301 | i.ParsedFlags().Lookup("verbose").Value.String(), 302 | ) 303 | }) 304 | 305 | t.Run("NoDeepChild", func(t *testing.T) { 306 | t.Parallel() 307 | i := cmd().Invoke( 308 | "root", "level", "level", "toupper", "--verbose", "hello", "--lower", 309 | ) 310 | fio := fakeIO(i) 311 | require.Error(t, i.Run(), fio.Stdout.String()) 312 | }) 313 | 314 | t.Run("RequiredFlagsMissing", func(t *testing.T) { 315 | t.Parallel() 316 | i := cmd().Invoke( 317 | "required-flag", 318 | ) 319 | fio := fakeIO(i) 320 | err := i.Run() 321 | require.Error(t, err, fio.Stdout.String()) 322 | require.ErrorContains(t, err, "Missing values") 323 | }) 324 | 325 | t.Run("RequiredFlagsMissingWithHelp", func(t *testing.T) { 326 | t.Parallel() 327 | i := cmd().Invoke( 328 | "required-flag", 329 | "--help", 330 | ) 331 | fio := fakeIO(i) 332 | err := i.Run() 333 | require.NoError(t, err) 334 | require.Contains(t, fio.Stdout.String(), "help text.png") 335 | }) 336 | 337 | t.Run("RequiredFlagsMissingBool", func(t *testing.T) { 338 | t.Parallel() 339 | i := cmd().Invoke( 340 | "required-flag", "--req-string", "foo bar", 341 | ) 342 | fio := fakeIO(i) 343 | err := i.Run() 344 | require.Error(t, err, fio.Stdout.String()) 345 | require.ErrorContains(t, err, "Missing values for the required flags: req-bool") 346 | }) 347 | 348 | t.Run("RequiredFlagsMissingString", func(t *testing.T) { 349 | t.Parallel() 350 | i := cmd().Invoke( 351 | "required-flag", "--req-bool", "true", 352 | ) 353 | fio := fakeIO(i) 354 | err := i.Run() 355 | require.Error(t, err, fio.Stdout.String()) 356 | require.ErrorContains(t, err, "Missing values for the required flags: req-string") 357 | }) 358 | 359 | t.Run("RequiredFlagsInvalid", func(t *testing.T) { 360 | t.Parallel() 361 | i := cmd().Invoke( 362 | "required-flag", "--req-string", "nospace", 363 | ) 364 | fio := fakeIO(i) 365 | err := i.Run() 366 | require.Error(t, err, fio.Stdout.String()) 367 | require.ErrorContains(t, err, "string must contain a space") 368 | }) 369 | 370 | t.Run("RequiredFlagsOK", func(t *testing.T) { 371 | t.Parallel() 372 | i := cmd().Invoke( 373 | "required-flag", "--req-bool", "true", "--req-string", "foo bar", 374 | ) 375 | fio := fakeIO(i) 376 | err := i.Run() 377 | require.NoError(t, err, fio.Stdout.String()) 378 | }) 379 | 380 | t.Run("DeprecatedCommand", func(t *testing.T) { 381 | t.Parallel() 382 | 383 | deprecatedCmd := &serpent.Command{ 384 | Use: "deprecated-cmd", 385 | Deprecated: "This command is deprecated and will be removed in the future.", 386 | Handler: func(i *serpent.Invocation) error { 387 | _, _ = i.Stdout.Write([]byte("Running deprecated command")) 388 | return nil 389 | }, 390 | } 391 | 392 | i := deprecatedCmd.Invoke() 393 | io := fakeIO(i) 394 | err := i.Run() 395 | require.NoError(t, err) 396 | expectedWarning := fmt.Sprintf("WARNING: %q is deprecated!. %s\n", deprecatedCmd.Use, deprecatedCmd.Deprecated) 397 | require.Equal(t, io.Stderr.String(), expectedWarning) 398 | require.Contains(t, io.Stdout.String(), "Running deprecated command") 399 | }) 400 | } 401 | 402 | func TestCommand_DeepNest(t *testing.T) { 403 | t.Parallel() 404 | cmd := &serpent.Command{ 405 | Use: "1", 406 | Children: []*serpent.Command{ 407 | { 408 | Use: "2", 409 | Children: []*serpent.Command{ 410 | { 411 | Use: "3", 412 | Handler: func(i *serpent.Invocation) error { 413 | _, _ = i.Stdout.Write([]byte("3")) 414 | return nil 415 | }, 416 | }, 417 | }, 418 | }, 419 | }, 420 | } 421 | inv := cmd.Invoke("2", "3") 422 | stdio := fakeIO(inv) 423 | err := inv.Run() 424 | require.NoError(t, err) 425 | require.Equal(t, "3", stdio.Stdout.String()) 426 | } 427 | 428 | func TestCommand_FlagOverride(t *testing.T) { 429 | t.Parallel() 430 | var flag string 431 | 432 | cmd := &serpent.Command{ 433 | Use: "1", 434 | Options: serpent.OptionSet{ 435 | { 436 | Name: "flag", 437 | Flag: "f", 438 | Value: serpent.DiscardValue, 439 | }, 440 | }, 441 | Children: []*serpent.Command{ 442 | { 443 | Use: "2", 444 | Options: serpent.OptionSet{ 445 | { 446 | Name: "flag", 447 | Flag: "f", 448 | Value: serpent.StringOf(&flag), 449 | }, 450 | }, 451 | Handler: func(i *serpent.Invocation) error { 452 | return nil 453 | }, 454 | }, 455 | }, 456 | } 457 | 458 | err := cmd.Invoke("2", "--f", "mhmm").Run() 459 | require.NoError(t, err) 460 | 461 | require.Equal(t, "mhmm", flag) 462 | } 463 | 464 | func TestCommand_MiddlewareOrder(t *testing.T) { 465 | t.Parallel() 466 | 467 | mw := func(letter string) serpent.MiddlewareFunc { 468 | return func(next serpent.HandlerFunc) serpent.HandlerFunc { 469 | return (func(i *serpent.Invocation) error { 470 | _, _ = i.Stdout.Write([]byte(letter)) 471 | return next(i) 472 | }) 473 | } 474 | } 475 | 476 | cmd := &serpent.Command{ 477 | Use: "toupper [word]", 478 | Short: "Converts a word to upper case", 479 | Middleware: serpent.Chain( 480 | mw("A"), 481 | mw("B"), 482 | mw("C"), 483 | ), 484 | Handler: (func(i *serpent.Invocation) error { 485 | return nil 486 | }), 487 | } 488 | 489 | i := cmd.Invoke( 490 | "hello", "world", 491 | ) 492 | io := fakeIO(i) 493 | require.NoError(t, i.Run()) 494 | require.Equal(t, "ABC", io.Stdout.String()) 495 | } 496 | 497 | func TestCommand_RawArgs(t *testing.T) { 498 | t.Parallel() 499 | 500 | cmd := func() *serpent.Command { 501 | return &serpent.Command{ 502 | Use: "root", 503 | Options: serpent.OptionSet{ 504 | { 505 | Name: "password", 506 | Flag: "password", 507 | Value: serpent.StringOf(new(string)), 508 | }, 509 | }, 510 | Children: []*serpent.Command{ 511 | { 512 | Use: "sushi ", 513 | Short: "Throws back raw output", 514 | RawArgs: true, 515 | Handler: (func(i *serpent.Invocation) error { 516 | if v := i.ParsedFlags().Lookup("password").Value.String(); v != "codershack" { 517 | return xerrors.Errorf("password %q is wrong!", v) 518 | } 519 | _, _ = i.Stdout.Write([]byte(strings.Join(i.Args, " "))) 520 | return nil 521 | }), 522 | }, 523 | }, 524 | } 525 | } 526 | 527 | t.Run("OK", func(t *testing.T) { 528 | // Flag parsed before the raw arg command should still work. 529 | t.Parallel() 530 | 531 | i := cmd().Invoke( 532 | "--password", "codershack", "sushi", "hello", "--verbose", "world", 533 | ) 534 | io := fakeIO(i) 535 | require.NoError(t, i.Run()) 536 | require.Equal(t, "hello --verbose world", io.Stdout.String()) 537 | }) 538 | 539 | t.Run("BadFlag", func(t *testing.T) { 540 | // Verbose before the raw arg command should fail. 541 | t.Parallel() 542 | 543 | i := cmd().Invoke( 544 | "--password", "codershack", "--verbose", "sushi", "hello", "world", 545 | ) 546 | io := fakeIO(i) 547 | require.Error(t, i.Run()) 548 | require.Empty(t, io.Stdout.String()) 549 | }) 550 | 551 | t.Run("NoPassword", func(t *testing.T) { 552 | // Flag parsed before the raw arg command should still work. 553 | t.Parallel() 554 | i := cmd().Invoke( 555 | "sushi", "hello", "--verbose", "world", 556 | ) 557 | _ = fakeIO(i) 558 | require.Error(t, i.Run()) 559 | }) 560 | } 561 | 562 | func TestCommand_RootRaw(t *testing.T) { 563 | t.Parallel() 564 | cmd := &serpent.Command{ 565 | RawArgs: true, 566 | Handler: func(i *serpent.Invocation) error { 567 | _, _ = i.Stdout.Write([]byte(strings.Join(i.Args, " "))) 568 | return nil 569 | }, 570 | } 571 | 572 | inv := cmd.Invoke("hello", "--verbose", "--friendly") 573 | stdio := fakeIO(inv) 574 | err := inv.Run() 575 | require.NoError(t, err) 576 | 577 | require.Equal(t, "hello --verbose --friendly", stdio.Stdout.String()) 578 | } 579 | 580 | func TestCommand_HyphenHyphen(t *testing.T) { 581 | t.Parallel() 582 | var verbose bool 583 | cmd := &serpent.Command{ 584 | Handler: (func(i *serpent.Invocation) error { 585 | _, _ = i.Stdout.Write([]byte(strings.Join(i.Args, " "))) 586 | if verbose { 587 | return xerrors.New("verbose should not be true because flag after --") 588 | } 589 | return nil 590 | }), 591 | Options: serpent.OptionSet{ 592 | { 593 | Name: "verbose", 594 | Flag: "verbose", 595 | Value: serpent.BoolOf(&verbose), 596 | }, 597 | }, 598 | } 599 | 600 | inv := cmd.Invoke("--", "--verbose", "--friendly") 601 | stdio := fakeIO(inv) 602 | err := inv.Run() 603 | require.NoError(t, err) 604 | 605 | require.Equal(t, "--verbose --friendly", stdio.Stdout.String()) 606 | } 607 | 608 | func TestCommand_ContextCancels(t *testing.T) { 609 | t.Parallel() 610 | 611 | var gotCtx context.Context 612 | 613 | cmd := &serpent.Command{ 614 | Handler: (func(i *serpent.Invocation) error { 615 | gotCtx = i.Context() 616 | if err := gotCtx.Err(); err != nil { 617 | return xerrors.Errorf("unexpected context error: %w", i.Context().Err()) 618 | } 619 | return nil 620 | }), 621 | } 622 | 623 | err := cmd.Invoke().Run() 624 | require.NoError(t, err) 625 | 626 | require.Error(t, gotCtx.Err()) 627 | } 628 | 629 | func TestCommand_Help(t *testing.T) { 630 | t.Parallel() 631 | 632 | cmd := func() *serpent.Command { 633 | return &serpent.Command{ 634 | Use: "root", 635 | HelpHandler: (func(i *serpent.Invocation) error { 636 | _, _ = i.Stdout.Write([]byte("abdracadabra")) 637 | return nil 638 | }), 639 | Handler: (func(i *serpent.Invocation) error { 640 | return xerrors.New("should not be called") 641 | }), 642 | } 643 | } 644 | 645 | t.Run("DefaultHandler", func(t *testing.T) { 646 | t.Parallel() 647 | 648 | c := cmd() 649 | c.HelpHandler = nil 650 | err := c.Invoke("--help").Run() 651 | require.NoError(t, err) 652 | }) 653 | 654 | t.Run("Long", func(t *testing.T) { 655 | t.Parallel() 656 | 657 | inv := cmd().Invoke("--help") 658 | stdio := fakeIO(inv) 659 | err := inv.Run() 660 | require.NoError(t, err) 661 | 662 | require.Contains(t, stdio.Stdout.String(), "abdracadabra") 663 | }) 664 | 665 | t.Run("Short", func(t *testing.T) { 666 | t.Parallel() 667 | 668 | inv := cmd().Invoke("-h") 669 | stdio := fakeIO(inv) 670 | err := inv.Run() 671 | require.NoError(t, err) 672 | 673 | require.Contains(t, stdio.Stdout.String(), "abdracadabra") 674 | }) 675 | } 676 | 677 | func TestCommand_SliceFlags(t *testing.T) { 678 | t.Parallel() 679 | 680 | cmd := func(want ...string) *serpent.Command { 681 | var got []string 682 | return &serpent.Command{ 683 | Use: "root", 684 | Options: serpent.OptionSet{ 685 | { 686 | Name: "arr", 687 | Flag: "arr", 688 | Default: "bad,bad,bad", 689 | Value: serpent.StringArrayOf(&got), 690 | }, 691 | }, 692 | Handler: (func(i *serpent.Invocation) error { 693 | require.Equal(t, want, got) 694 | return nil 695 | }), 696 | } 697 | } 698 | 699 | err := cmd("good", "good", "good").Invoke("--arr", "good", "--arr", "good", "--arr", "good").Run() 700 | require.NoError(t, err) 701 | 702 | err = cmd("bad", "bad", "bad").Invoke().Run() 703 | require.NoError(t, err) 704 | } 705 | 706 | func TestCommand_EmptySlice(t *testing.T) { 707 | t.Parallel() 708 | 709 | cmd := func(want ...string) *serpent.Command { 710 | var got []string 711 | return &serpent.Command{ 712 | Use: "root", 713 | Options: serpent.OptionSet{ 714 | { 715 | Name: "arr", 716 | Flag: "arr", 717 | Default: "def,def,def", 718 | Env: "ARR", 719 | Value: serpent.StringArrayOf(&got), 720 | }, 721 | }, 722 | Handler: (func(i *serpent.Invocation) error { 723 | require.Equal(t, want, got) 724 | return nil 725 | }), 726 | } 727 | } 728 | 729 | // Base-case, uses default. 730 | err := cmd("def", "def", "def").Invoke().Run() 731 | require.NoError(t, err) 732 | 733 | // Empty-env uses default, too. 734 | inv := cmd("def", "def", "def").Invoke() 735 | inv.Environ.Set("ARR", "") 736 | require.NoError(t, err) 737 | 738 | // Reset to nothing at all via flag. 739 | inv = cmd().Invoke("--arr", "") 740 | inv.Environ.Set("ARR", "cant see") 741 | err = inv.Run() 742 | require.NoError(t, err) 743 | 744 | // Reset to a specific value with flag. 745 | inv = cmd("great").Invoke("--arr", "great") 746 | inv.Environ.Set("ARR", "") 747 | err = inv.Run() 748 | require.NoError(t, err) 749 | } 750 | 751 | func TestCommand_DefaultsOverride(t *testing.T) { 752 | t.Parallel() 753 | 754 | test := func(name string, want string, fn func(t *testing.T, inv *serpent.Invocation)) { 755 | t.Run(name, func(t *testing.T) { 756 | t.Parallel() 757 | 758 | var ( 759 | got string 760 | config serpent.YAMLConfigPath 761 | ) 762 | cmd := &serpent.Command{ 763 | Options: serpent.OptionSet{ 764 | { 765 | Name: "url", 766 | Flag: "url", 767 | Default: "def.com", 768 | Env: "URL", 769 | Value: serpent.StringOf(&got), 770 | YAML: "url", 771 | }, 772 | { 773 | Name: "url-deprecated", 774 | Flag: "url-deprecated", 775 | Env: "URL_DEPRECATED", 776 | Value: serpent.StringOf(&got), 777 | }, 778 | { 779 | Name: "config", 780 | Flag: "config", 781 | Default: "", 782 | Value: &config, 783 | }, 784 | }, 785 | Handler: (func(i *serpent.Invocation) error { 786 | _, _ = fmt.Fprintf(i.Stdout, "%s", got) 787 | return nil 788 | }), 789 | } 790 | 791 | inv := cmd.Invoke() 792 | stdio := fakeIO(inv) 793 | fn(t, inv) 794 | err := inv.Run() 795 | require.NoError(t, err) 796 | require.Equal(t, want, stdio.Stdout.String()) 797 | }) 798 | } 799 | 800 | test("DefaultOverNothing", "def.com", func(t *testing.T, inv *serpent.Invocation) {}) 801 | 802 | test("FlagOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) { 803 | inv.Args = []string{"--url", "good.com"} 804 | }) 805 | 806 | test("EnvOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) { 807 | inv.Environ.Set("URL", "good.com") 808 | }) 809 | 810 | test("FlagOverEnv", "good.com", func(t *testing.T, inv *serpent.Invocation) { 811 | inv.Environ.Set("URL", "bad.com") 812 | inv.Args = []string{"--url", "good.com"} 813 | }) 814 | 815 | test("FlagOverYAML", "good.com", func(t *testing.T, inv *serpent.Invocation) { 816 | fi, err := os.CreateTemp(t.TempDir(), "config.yaml") 817 | require.NoError(t, err) 818 | defer fi.Close() 819 | 820 | _, err = fi.WriteString("url: bad.com") 821 | require.NoError(t, err) 822 | 823 | inv.Args = []string{"--config", fi.Name(), "--url", "good.com"} 824 | }) 825 | 826 | test("EnvOverYAML", "good.com", func(t *testing.T, inv *serpent.Invocation) { 827 | fi, err := os.CreateTemp(t.TempDir(), "config.yaml") 828 | require.NoError(t, err) 829 | defer fi.Close() 830 | 831 | _, err = fi.WriteString("url: bad.com") 832 | require.NoError(t, err) 833 | 834 | inv.Environ.Set("URL", "good.com") 835 | }) 836 | 837 | test("YAMLOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) { 838 | fi, err := os.CreateTemp(t.TempDir(), "config.yaml") 839 | require.NoError(t, err) 840 | defer fi.Close() 841 | 842 | _, err = fi.WriteString("url: good.com") 843 | require.NoError(t, err) 844 | 845 | inv.Args = []string{"--config", fi.Name()} 846 | }) 847 | 848 | test("AltFlagOverDefault", "good.com", func(t *testing.T, inv *serpent.Invocation) { 849 | inv.Args = []string{"--url-deprecated", "good.com"} 850 | }) 851 | } 852 | 853 | func TestCommand_OptionsWithSharedValue(t *testing.T) { 854 | t.Parallel() 855 | 856 | var got string 857 | makeCmd := func(def, altDef string) *serpent.Command { 858 | got = "" 859 | return &serpent.Command{ 860 | Options: serpent.OptionSet{ 861 | { 862 | Name: "url", 863 | Flag: "url", 864 | Env: "URL", 865 | Default: def, 866 | Value: serpent.StringOf(&got), 867 | }, 868 | { 869 | Name: "alt-url", 870 | Flag: "alt-url", 871 | Env: "ALT_URL", 872 | Default: altDef, 873 | Value: serpent.StringOf(&got), 874 | }, 875 | }, 876 | Handler: (func(i *serpent.Invocation) error { 877 | return nil 878 | }), 879 | } 880 | } 881 | 882 | // Check proper value propagation. 883 | err := makeCmd("def.com", "def.com").Invoke().Run() 884 | require.NoError(t, err, "default values are same") 885 | require.Equal(t, "def.com", got) 886 | 887 | err = makeCmd("def.com", "").Invoke().Run() 888 | require.NoError(t, err, "other default value is empty") 889 | require.Equal(t, "def.com", got) 890 | 891 | err = makeCmd("def.com", "").Invoke("--url", "sup").Run() 892 | require.NoError(t, err) 893 | require.Equal(t, "sup", got) 894 | 895 | err = makeCmd("def.com", "").Invoke("--alt-url", "hup").Run() 896 | require.NoError(t, err) 897 | require.Equal(t, "hup", got) 898 | 899 | // Both flags are given, last wins. 900 | err = makeCmd("def.com", "").Invoke("--url", "sup", "--alt-url", "hup").Run() 901 | require.NoError(t, err) 902 | require.Equal(t, "hup", got) 903 | 904 | // Both flags are given, last wins #2. 905 | err = makeCmd("", "def.com").Invoke("--alt-url", "hup", "--url", "sup").Run() 906 | require.NoError(t, err) 907 | require.Equal(t, "sup", got) 908 | 909 | // Both flags are given, option type priority wins. 910 | inv := makeCmd("def.com", "").Invoke("--alt-url", "hup") 911 | inv.Environ.Set("URL", "sup") 912 | err = inv.Run() 913 | require.NoError(t, err) 914 | require.Equal(t, "hup", got) 915 | 916 | // Both flags are given, option type priority wins #2. 917 | inv = makeCmd("", "def.com").Invoke("--url", "sup") 918 | inv.Environ.Set("ALT_URL", "hup") 919 | err = inv.Run() 920 | require.NoError(t, err) 921 | require.Equal(t, "sup", got) 922 | 923 | // Catch invalid configuration. 924 | err = makeCmd("def.com", "alt-def.com").Invoke().Run() 925 | require.Error(t, err, "default values are different") 926 | } 927 | --------------------------------------------------------------------------------