├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── LICENSE ├── README.md ├── command.go ├── doc.go ├── examples └── cmd │ ├── echo │ └── main.go │ └── task │ ├── main.go │ └── tasks.go ├── go.mod ├── go.sum ├── parse.go ├── parse_test.go ├── pkg ├── suggest │ ├── suggest.go │ └── suggest_test.go └── textutil │ ├── textutil.go │ └── textutil_test.go ├── run.go ├── run_test.go ├── state.go ├── state_test.go └── usage.go /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | 10 | jobs: 11 | build: 12 | name: Build and test 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | # go-version: ['oldstable', 'stable', '1.23.0-rc.2'] 17 | go-version: ["oldstable", "stable"] 18 | env: 19 | VERBOSE: 1 20 | 21 | steps: 22 | - name: Checkout code 23 | uses: actions/checkout@v4 24 | - name: Set up Go 25 | uses: actions/setup-go@v5 26 | with: 27 | go-version: ${{ matrix.go-version }} 28 | - name: Install tparse 29 | run: go install github.com/mfridman/tparse@main 30 | - name: Build 31 | run: go build -v . 32 | - name: Run tests 33 | shell: bash 34 | run: | 35 | go test $(go list ./... | grep -v 'examples') -count=1 -v -json -cover \ 36 | | tparse -all -follow -sort=elapsed -trimpath=auto 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | internal/ 2 | tmp/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Michael Fridman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cli 2 | 3 | [![GoDoc](https://godoc.org/github.com/mfridman/cli?status.svg)](https://pkg.go.dev/github.com/mfridman/cli#pkg-index) 4 | [![CI](https://github.com/mfridman/cli/actions/workflows/ci.yaml/badge.svg)](https://github.com/mfridman/cli/actions/workflows/ci.yaml) 5 | 6 | A Go package for building CLI applications. Extends the standard library's `flag` package to support 7 | [flags anywhere](https://mfridman.com/blog/2024/allowing-flags-anywhere-on-the-cli/) in command 8 | arguments. 9 | 10 | ## Features 11 | 12 | The **bare minimum** to build a CLI application while leveraging the standard library's `flag` 13 | package. 14 | 15 | - Nested subcommands for organizing complex CLIs 16 | - Flexible flag parsing, allowing flags anywhere 17 | - Subcommands inherit flags from parent commands 18 | - Type-safe flag access 19 | - Automatic generation of help text and usage information 20 | - Suggestions for misspelled or incomplete commands 21 | 22 | ### But why? 23 | 24 | This package is intentionally minimal. It aims to be a building block for CLI applications that want 25 | to leverage the standard library's `flag` package while providing a bit more structure and 26 | flexibility. 27 | 28 | - Build maintainable command-line tools quickly 29 | - Focus on application logic rather than framework complexity 30 | - Extend functionality **only when needed** 31 | 32 | Sometimes less is more. While other frameworks offer extensive features, this package focuses on 33 | core functionality. 34 | 35 | ## Installation 36 | 37 | ```bash 38 | go get github.com/mfridman/cli@latest 39 | ``` 40 | 41 | Required go version: 1.21 or higher 42 | 43 | ## Quick Start 44 | 45 | Here's a simple example of a CLI application that echoes back the input with a required `-c` flag to 46 | capitalize the output: 47 | 48 | ```go 49 | root := &cli.Command{ 50 | Name: "echo", 51 | Usage: "echo [flags] ...", 52 | ShortHelp: "echo is a simple command that prints the provided text", 53 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 54 | // Add a flag to capitalize the input 55 | f.Bool("c", false, "capitalize the input") 56 | }), 57 | FlagsMetadata: []cli.FlagMetadata{ 58 | {Name: "c", Required: true}, 59 | }, 60 | Exec: func(ctx context.Context, s *cli.State) error { 61 | if len(s.Args) == 0 { 62 | return errors.New("must provide text to echo, see --help") 63 | } 64 | output := strings.Join(s.Args, " ") 65 | // If -c flag is set, capitalize the output 66 | if cli.GetFlag[bool](s, "c") { 67 | output = strings.ToUpper(output) 68 | } 69 | fmt.Fprintln(s.Stdout, output) 70 | return nil 71 | }, 72 | } 73 | if err := cli.Parse(root, os.Args[1:]); err != nil { 74 | if errors.Is(err, flag.ErrHelp) { 75 | fmt.Fprintf(os.Stdout, "%s\n", cli.DefaultUsage(root)) 76 | return 77 | } 78 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 79 | os.Exit(1) 80 | } 81 | if err := cli.Run(context.Background(), root, nil); err != nil { 82 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 83 | os.Exit(1) 84 | } 85 | ``` 86 | 87 | ## Command Structure 88 | 89 | Each command is represented by a `Command` struct: 90 | 91 | ```go 92 | type Command struct { 93 | Name string // Required 94 | Usage string 95 | ShortHelp string 96 | UsageFunc func(*Command) string 97 | Flags *flag.FlagSet 98 | FlagsMetadata []FlagMetadata 99 | SubCommands []*Command 100 | Exec func(ctx context.Context, s *State) error 101 | } 102 | ``` 103 | 104 | The `Name` field is the command's name and is **required**. 105 | 106 | The `Usage` and `ShortHelp` fields are used to generate help text. Nice-to-have but not required. 107 | 108 | The `Flags` field is a `*flag.FlagSet` that defines the command's flags. 109 | 110 | > [!TIP] 111 | > 112 | > There's a convenience function `FlagsFunc` that allows you to define flags inline: 113 | 114 | ```go 115 | root := &cli.Command{ 116 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 117 | fs.Bool("verbose", false, "enable verbose output") 118 | fs.String("output", "", "output file") 119 | fs.Int("count", 0, "number of items") 120 | }), 121 | FlagsMetadata: []cli.FlagMetadata{ 122 | {Name: "c", Required: true}, 123 | }, 124 | } 125 | ``` 126 | 127 | The optional `FlagsMetadata` field is a way to extend defined flags. The `flag` package alone is a 128 | bit limiting, so we add this to provide the most common features, such as handling of required 129 | flags. 130 | 131 | The `SubCommands` field is a list of `*Command` structs that represent subcommands. This allows you 132 | to organize CLI applications into a hierarchy of commands. Each subcommand can have its own flags 133 | and business logic. 134 | 135 | The `Exec` field is a function that is called when the command is executed. This is where you put 136 | business logic. 137 | 138 | ## Flag Access 139 | 140 | Flags can be accessed using the type-safe `GetFlag` function, called inside the `Exec` function: 141 | 142 | ```go 143 | // Access boolean flag 144 | verbose := cli.GetFlag[bool](state, "verbose") 145 | // Access string flag 146 | output := cli.GetFlag[string](state, "output") 147 | // Access integer flag 148 | count := cli.GetFlag[int](state, "count") 149 | ``` 150 | 151 | ### State Inheritance 152 | 153 | Child commands automatically inherit their parent command's flags: 154 | 155 | ```go 156 | // Parent command with a verbose flag 157 | root := cli.Command{ 158 | Name: "root", 159 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 160 | f.Bool("verbose", false, "enable verbose mode") 161 | }), 162 | } 163 | 164 | // Child command that can access parent's verbose flag 165 | sub := cli.Command{ 166 | Name: "sub", 167 | Exec: func(ctx context.Context, s *cli.State) error { 168 | verbose := cli.GetFlag[bool](s, "verbose") 169 | if verbose { 170 | fmt.Println("Verbose mode enabled") 171 | } 172 | return nil 173 | }, 174 | } 175 | ``` 176 | 177 | ## Help System 178 | 179 | Help text is automatically generated, but you can customize it by setting the `UsageFunc` field. 180 | 181 | There is a `DefaultUsage` function that generates a default help text for a command, which is useful 182 | to display when `flag.ErrHelp` is returned from `Parse`: 183 | 184 | ```go 185 | if err := cli.Parse(root, os.Args[1:]); err != nil { 186 | if errors.Is(err, flag.ErrHelp) { 187 | fmt.Fprintf(os.Stdout, "%s\n", cli.DefaultUsage(root)) // Display help text and exit 188 | return 189 | } 190 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 191 | os.Exit(1) 192 | } 193 | ``` 194 | 195 | ## Usage Syntax Conventions 196 | 197 | When reading command usage strings, the following syntax is used: 198 | 199 | | Syntax | Description | 200 | | ------------- | -------------------------- | 201 | | `` | Required argument | 202 | | `[optional]` | Optional argument | 203 | | `...` | One or more arguments | 204 | | `[arg]...` | Zero or more arguments | 205 | | `(a\|b)` | Must choose one of a or b | 206 | | `[-f ]` | Flag with value (optional) | 207 | | `-f ` | Flag with value (required) | 208 | 209 | Examples: 210 | 211 | ```bash 212 | # Multiple source files, one destination 213 | mv ... 214 | 215 | # Required flag with value, optional config 216 | build -t [config]... 217 | 218 | # Subcommands with own flags 219 | docker (run|build) [--file ] 220 | 221 | # Multiple flag values 222 | find [--exclude ]... 223 | 224 | # Choice between options, required path 225 | chmod (u+x|a+r) ... 226 | 227 | # Flag groups with value 228 | kubectl [-n ] (get|delete) (pod|service) 229 | ``` 230 | 231 | ## Status 232 | 233 | This project is in active development and undergoing changes as the API gets refined. Please open an 234 | issue if you encounter any problems or have suggestions for improvement. 235 | 236 | ## Acknowledgements 237 | 238 | There are many great CLI libraries out there, but I always felt [they were too heavy for my 239 | needs](https://mfridman.com/blog/2021/a-simpler-building-block-for-go-clis/). 240 | 241 | I was inspired by Peter Bourgon's [ff](https://github.com/peterbourgon/ff) library, specifically the 242 | `v3` branch, which was soooo close to what I wanted. But the `v4` branch took a different direction 243 | and I wanted to keep the simplicity of `v3`. This library aims to pick up where `v3` left off. 244 | 245 | ## License 246 | 247 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 248 | -------------------------------------------------------------------------------- /command.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "strings" 8 | 9 | "github.com/mfridman/cli/pkg/suggest" 10 | ) 11 | 12 | // Command represents a CLI command or subcommand within the application's command hierarchy. 13 | type Command struct { 14 | // Name is always a single word representing the command's name. It is used to identify the 15 | // command in the command hierarchy and in help text. 16 | Name string 17 | 18 | // Usage provides the command's full usage pattern. 19 | // 20 | // Example: "cli todo list [flags]" 21 | Usage string 22 | 23 | // ShortHelp is a brief description of the command's purpose. It is displayed in the help text 24 | // when the command is shown. 25 | ShortHelp string 26 | 27 | // UsageFunc is an optional function that can be used to generate a custom usage string for the 28 | // command. It receives the current command and should return a string with the full usage 29 | // pattern. 30 | UsageFunc func(*Command) string 31 | 32 | // Flags holds the command-specific flag definitions. Each command maintains its own flag set 33 | // for parsing arguments. 34 | Flags *flag.FlagSet 35 | // FlagsMetadata is an optional list of flag information to extend the FlagSet with additional 36 | // metadata. This is useful for tracking required flags. 37 | FlagsMetadata []FlagMetadata 38 | 39 | // SubCommands is a list of nested commands that exist under this command. 40 | SubCommands []*Command 41 | 42 | // Exec defines the command's execution logic. It receives the current application [State] and 43 | // returns an error if execution fails. This function is called when [Run] is invoked on the 44 | // command. 45 | Exec func(ctx context.Context, s *State) error 46 | 47 | state *State 48 | } 49 | 50 | // Path returns the command chain from root to current command. It can only be called after the root 51 | // command has been parsed and the command hierarchy has been established. 52 | func (c *Command) Path() []*Command { 53 | if c.state == nil { 54 | return nil 55 | } 56 | return c.state.path 57 | } 58 | 59 | func (c *Command) terminal() *Command { 60 | if c.state == nil || len(c.state.path) == 0 { 61 | return c 62 | } 63 | // Get the last command in the path - this is our terminal command 64 | return c.state.path[len(c.state.path)-1] 65 | } 66 | 67 | // FlagMetadata holds additional metadata for a flag, such as whether it is required. 68 | type FlagMetadata struct { 69 | // Name is the flag's name. Must match the flag name in the flag set. 70 | Name string 71 | 72 | // Required indicates whether the flag is required. 73 | Required bool 74 | } 75 | 76 | // FlagsFunc is a helper function that creates a new [flag.FlagSet] and applies the given function 77 | // to it. Intended for use in command definitions to simplify flag setup. Example usage: 78 | // 79 | // cmd.Flags = cli.FlagsFunc(func(f *flag.FlagSet) { 80 | // f.Bool("verbose", false, "enable verbose output") 81 | // f.String("output", "", "output file") 82 | // f.Int("count", 0, "number of items") 83 | // }) 84 | func FlagsFunc(fn func(f *flag.FlagSet)) *flag.FlagSet { 85 | fset := flag.NewFlagSet("", flag.ContinueOnError) 86 | fn(fset) 87 | return fset 88 | } 89 | 90 | // findSubCommand searches for a subcommand by name and returns it if found. Returns nil if no 91 | // subcommand with the given name exists. 92 | func (c *Command) findSubCommand(name string) *Command { 93 | for _, sub := range c.SubCommands { 94 | if strings.EqualFold(sub.Name, name) { 95 | return sub 96 | } 97 | } 98 | return nil 99 | } 100 | 101 | func (c *Command) formatUnknownCommandError(unknownCmd string) error { 102 | var known []string 103 | for _, sub := range c.SubCommands { 104 | known = append(known, sub.Name) 105 | } 106 | suggestions := suggest.FindSimilar(unknownCmd, known, 3) 107 | if len(suggestions) > 0 { 108 | return fmt.Errorf("unknown command %q. Did you mean one of these?\n\t%s", 109 | unknownCmd, 110 | strings.Join(suggestions, "\n\t")) 111 | } 112 | return fmt.Errorf("unknown command %q", unknownCmd) 113 | } 114 | 115 | func formatFlagName(name string) string { 116 | return "-" + name 117 | } 118 | 119 | func getCommandPath(commands []*Command) string { 120 | var commandPath []string 121 | for _, c := range commands { 122 | commandPath = append(commandPath, c.Name) 123 | } 124 | return strings.Join(commandPath, " ") 125 | } 126 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package cli provides a lightweight library for building command-line applications using Go's 2 | // standard library flag package. It extends flag functionality to support flags anywhere in command 3 | // arguments. 4 | // 5 | // Key features: 6 | // - Nested subcommands for organizing complex CLIs 7 | // - Flexible flag parsing, allowing flags anywhere in arguments 8 | // - Parent-to-child flag inheritance 9 | // - Type-safe flag access 10 | // - Automatic help text generation 11 | // - Command suggestions for misspelled inputs 12 | // 13 | // Quick example: 14 | // 15 | // root := &cli.Command{ 16 | // Name: "echo", 17 | // Usage: "echo [flags] ...", 18 | // ShortHelp: "prints the provided text", 19 | // Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 20 | // f.Bool("c", false, "capitalize the input") 21 | // }), 22 | // Exec: func(ctx context.Context, s *cli.State) error { 23 | // output := strings.Join(s.Args, " ") 24 | // if cli.GetFlag[bool](s, "c") { 25 | // output = strings.ToUpper(output) 26 | // } 27 | // fmt.Fprintln(s.Stdout, output) 28 | // return nil 29 | // }, 30 | // } 31 | // 32 | // The package intentionally maintains a minimal API surface to serve as a building block for CLI 33 | // applications while leveraging the standard library's flag package. This approach enables 34 | // developers to build maintainable command-line tools quickly while focusing on application logic 35 | // rather than framework complexity. 36 | package cli 37 | -------------------------------------------------------------------------------- /examples/cmd/echo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "flag" 7 | "fmt" 8 | "os" 9 | "strings" 10 | 11 | "github.com/mfridman/cli" 12 | ) 13 | 14 | func main() { 15 | root := &cli.Command{ 16 | Name: "echo", 17 | Usage: "echo [flags] ...", 18 | ShortHelp: "echo is a simple command that prints the provided text", 19 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 20 | // Add a flag to capitalize the input 21 | f.Bool("c", false, "capitalize the input") 22 | }), 23 | FlagsMetadata: []cli.FlagMetadata{ 24 | {Name: "c", Required: true}, 25 | }, 26 | Exec: func(ctx context.Context, s *cli.State) error { 27 | if len(s.Args) == 0 { 28 | return errors.New("must provide text to echo, see --help") 29 | } 30 | output := strings.Join(s.Args, " ") 31 | // If -c flag is set, capitalize the output 32 | if cli.GetFlag[bool](s, "c") { 33 | output = strings.ToUpper(output) 34 | } 35 | fmt.Fprintln(s.Stdout, output) 36 | return nil 37 | }, 38 | } 39 | if err := cli.Parse(root, os.Args[1:]); err != nil { 40 | if errors.Is(err, flag.ErrHelp) { 41 | fmt.Fprintf(os.Stdout, "%s\n", cli.DefaultUsage(root)) 42 | return 43 | } 44 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 45 | os.Exit(1) 46 | } 47 | if err := cli.Run(context.Background(), root, nil); err != nil { 48 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 49 | os.Exit(1) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /examples/cmd/task/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "errors" 7 | "flag" 8 | "fmt" 9 | "os" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/mfridman/cli" 15 | ) 16 | 17 | // todo 18 | // ├── (-file, required) 19 | // ├── list 20 | // │ ├── today 21 | // │ └── overdue 22 | // │ └── (-tags) 23 | // │ 24 | // └── task 25 | // ├── add 26 | // │ └── (-tags) 27 | // ├── done 28 | // └── remove (-force, -all) 29 | 30 | func main() { 31 | root := &cli.Command{ 32 | Name: "todo", 33 | Usage: "todo [flags]", 34 | ShortHelp: "A simple CLI for managing your tasks", 35 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 36 | f.Bool("verbose", false, "enable verbose output") 37 | f.Bool("version", false, "print the version") 38 | }), 39 | Exec: func(ctx context.Context, s *cli.State) error { 40 | if cli.GetFlag[bool](s, "version") { 41 | fmt.Fprintf(s.Stdout, "todo v1.0.0\n") 42 | return nil 43 | } 44 | fmt.Fprintf(s.Stderr, "todo: subcommand required, use --help for more information\n") 45 | return nil 46 | }, 47 | SubCommands: []*cli.Command{ 48 | list(), 49 | task(), 50 | }, 51 | } 52 | 53 | if err := cli.Parse(root, os.Args[1:]); err != nil { 54 | if errors.Is(err, flag.ErrHelp) { 55 | fmt.Fprintf(os.Stdout, "%s\n", cli.DefaultUsage(root)) 56 | return 57 | } 58 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 59 | os.Exit(1) 60 | } 61 | if err := cli.Run(context.Background(), root, nil); err != nil { 62 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 63 | os.Exit(1) 64 | } 65 | } 66 | 67 | func list() *cli.Command { 68 | return &cli.Command{ 69 | Name: "list", 70 | Usage: "todo list [flags]", 71 | ShortHelp: "List tasks", 72 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 73 | f.String("file", "", "path to the tasks file") 74 | f.String("tags", "", "filter tasks by tags") 75 | }), 76 | FlagsMetadata: []cli.FlagMetadata{ 77 | {Name: "file", Required: true}, 78 | }, 79 | Exec: func(ctx context.Context, s *cli.State) error { 80 | fmt.Fprintf(s.Stderr, "todo list: subcommand required, use --help for more information\n") 81 | return nil 82 | }, 83 | SubCommands: []*cli.Command{ 84 | listToday(), 85 | listOverdue(), 86 | }, 87 | } 88 | } 89 | 90 | func getTasksFromFile(s *cli.State) (*TaskList, error) { 91 | file := cli.GetFlag[string](s, "file") 92 | return Load(file) 93 | } 94 | 95 | func listToday() *cli.Command { 96 | return &cli.Command{ 97 | Name: "today", 98 | Usage: "todo list today [flags]", 99 | ShortHelp: "List tasks due today", 100 | Exec: func(ctx context.Context, s *cli.State) error { 101 | tasks, err := getTasksFromFile(s) 102 | if err != nil { 103 | return err 104 | } 105 | today := tasks.ListToday() 106 | if len(today) == 0 { 107 | fmt.Fprintf(s.Stdout, "No tasks due today, enjoy your day!\n") 108 | return nil 109 | } 110 | fmt.Fprintf(s.Stdout, "Tasks due today:\n") 111 | for _, task := range today { 112 | fmt.Fprintf(s.Stdout, " %s\n", task.String()) 113 | } 114 | return nil 115 | }, 116 | } 117 | } 118 | 119 | func listOverdue() *cli.Command { 120 | return &cli.Command{ 121 | Name: "overdue", 122 | Usage: "todo list overdue [flags]", 123 | ShortHelp: "List overdue tasks", 124 | Exec: func(ctx context.Context, s *cli.State) error { 125 | tasks, err := getTasksFromFile(s) 126 | if err != nil { 127 | return err 128 | } 129 | overdue := tasks.ListOverdue() 130 | if len(overdue) == 0 { 131 | fmt.Fprintf(s.Stdout, "No overdue tasks, enjoy your day!\n") 132 | return nil 133 | } 134 | fmt.Fprintf(s.Stdout, "Overdue tasks:\n") 135 | for _, task := range overdue { 136 | fmt.Fprintf(s.Stdout, " %s\n", task.String()) 137 | } 138 | return nil 139 | }, 140 | } 141 | } 142 | 143 | func task() *cli.Command { 144 | return &cli.Command{ 145 | Name: "task", 146 | Usage: "todo task [flags]", 147 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 148 | f.String("file", "", "path to the tasks file") 149 | }), 150 | FlagsMetadata: []cli.FlagMetadata{ 151 | {Name: "file", Required: true}, 152 | }, 153 | ShortHelp: "Manage tasks", 154 | SubCommands: []*cli.Command{ 155 | taskAdd(), 156 | taskDone(), 157 | taskRemove(), 158 | }, 159 | } 160 | } 161 | 162 | func taskAdd() *cli.Command { 163 | return &cli.Command{ 164 | Name: "add", 165 | Usage: "todo task add [flags]", 166 | ShortHelp: "Add a new task", 167 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 168 | f.String("tags", "", "comma-separated list of tags") 169 | }), 170 | Exec: func(ctx context.Context, s *cli.State) error { 171 | var ( 172 | tagsText = cli.GetFlag[string](s, "tags") 173 | file = cli.GetFlag[string](s, "file") 174 | ) 175 | var tags []string 176 | if tagsText != "" { 177 | tags = strings.Split(tagsText, ",") 178 | } 179 | tasks, err := getTasksFromFile(s) 180 | if err != nil { 181 | return err 182 | } 183 | id := tasks.LatestID() + 1 184 | tasks.Add(Task{ 185 | ID: id, 186 | Text: strings.Join(s.Args, " "), 187 | Tags: tags, 188 | Created: time.Now(), 189 | Status: Pending, 190 | }) 191 | if err := Save(file, tasks); err != nil { 192 | 193 | return err 194 | } 195 | fmt.Fprintf(s.Stdout, "Task added with ID %d\n", id) 196 | return nil 197 | }, 198 | } 199 | } 200 | 201 | func taskDone() *cli.Command { 202 | return &cli.Command{ 203 | Name: "done", 204 | Usage: "todo task done [flags]", 205 | ShortHelp: "Mark a task as done", 206 | Exec: func(ctx context.Context, s *cli.State) error { 207 | if len(s.Args) == 0 { 208 | return errors.New("task ID required") 209 | } 210 | tasks, err := getTasksFromFile(s) 211 | if err != nil { 212 | return err 213 | } 214 | id := s.Args[0] 215 | parsedID, err := strconv.Atoi(id) 216 | if err != nil { 217 | return fmt.Errorf("invalid task ID: %w", err) 218 | } 219 | return tasks.Done(parsedID) 220 | }, 221 | } 222 | } 223 | 224 | func taskRemove() *cli.Command { 225 | return &cli.Command{ 226 | Name: "remove", 227 | Usage: "todo task remove [flags]", 228 | ShortHelp: "Remove a task", 229 | Flags: cli.FlagsFunc(func(f *flag.FlagSet) { 230 | f.Bool("force", false, "force removal without confirmation") 231 | f.Bool("all", false, "remove all tasks") 232 | }), 233 | Exec: func(ctx context.Context, s *cli.State) error { 234 | var ( 235 | force = cli.GetFlag[bool](s, "force") 236 | all = cli.GetFlag[bool](s, "all") 237 | file = cli.GetFlag[string](s, "file") 238 | ) 239 | if len(s.Args) == 0 && !all { 240 | return errors.New("task ID required, or use --all to remove all tasks") 241 | } 242 | if all { 243 | if !force { 244 | 245 | reader := bufio.NewReader(os.Stdin) 246 | fmt.Print("Are you sure you want to clear all tasks? (y/N): ") 247 | response, err := reader.ReadString('\n') 248 | if err != nil { 249 | return fmt.Errorf("failed to read input: %w", err) 250 | } 251 | response = strings.TrimSpace(strings.ToLower(response)) 252 | if response != "y" { 253 | fmt.Fprintf(s.Stdout, "Operation cancelled\n") 254 | return nil 255 | } 256 | } 257 | // add a confirmation prompt 258 | return Save(file, &TaskList{}) 259 | } 260 | return nil 261 | }, 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /examples/cmd/task/tasks.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | type Task struct { 13 | ID int `json:"id,omitempty"` 14 | Text string `json:"text,omitempty"` 15 | Tags []string `json:"tags,omitempty"` 16 | Created time.Time `json:"created,omitempty"` 17 | Status Status `json:"status,omitempty"` 18 | } 19 | 20 | func (t *Task) String() string { 21 | return fmt.Sprintf("%d: %s (%s, %s) [%s]", t.ID, t.Text, t.Created.Format("2006-01-02"), t.Status, strings.Join(t.Tags, ",")) 22 | } 23 | 24 | type TaskList struct { 25 | Tasks []Task `json:"tasks,omitempty"` 26 | } 27 | 28 | func (l *TaskList) LatestID() int { 29 | var id int 30 | for _, t := range l.Tasks { 31 | if t.ID > id { 32 | id = t.ID 33 | } 34 | } 35 | return id 36 | } 37 | 38 | type Status string 39 | 40 | const ( 41 | Pending Status = "pending" 42 | Done Status = "done" 43 | ) 44 | 45 | func (l *TaskList) Add(t Task) { 46 | l.Tasks = append(l.Tasks, t) 47 | } 48 | 49 | func (l *TaskList) Remove(id int) error { 50 | for i, t := range l.Tasks { 51 | if t.ID == id { 52 | l.Tasks = append(l.Tasks[:i], l.Tasks[i+1:]...) 53 | return nil 54 | } 55 | } 56 | return fmt.Errorf("task with ID %d not found", id) 57 | } 58 | 59 | func (l *TaskList) List() []Task { 60 | return l.Tasks 61 | } 62 | 63 | func (l *TaskList) ListToday() []Task { 64 | var tasks []Task 65 | for _, t := range l.Tasks { 66 | if t.Created.Day() == time.Now().Day() { 67 | tasks = append(tasks, t) 68 | } 69 | } 70 | return tasks 71 | } 72 | 73 | func (l *TaskList) ListOverdue() []Task { 74 | var tasks []Task 75 | for _, t := range l.Tasks { 76 | if t.Created.Before(time.Now()) && t.Status == Pending { 77 | tasks = append(tasks, t) 78 | } 79 | } 80 | return tasks 81 | } 82 | 83 | func (l *TaskList) Done(id int) error { 84 | for i, t := range l.Tasks { 85 | if t.ID == id { 86 | t.Status = Done 87 | l.Tasks[i] = t 88 | return nil 89 | } 90 | } 91 | return fmt.Errorf("task with ID %d not found", id) 92 | } 93 | 94 | func (l *TaskList) Find(id int) (Task, bool) { 95 | for _, t := range l.Tasks { 96 | if t.ID == id { 97 | return t, true 98 | } 99 | } 100 | return Task{}, false 101 | } 102 | 103 | func (l *TaskList) FindByTag(tag string) []Task { 104 | var tasks []Task 105 | for _, task := range l.Tasks { 106 | for _, t := range task.Tags { 107 | if t == tag { 108 | tasks = append(tasks, task) 109 | } 110 | } 111 | } 112 | return tasks 113 | } 114 | 115 | func Save(file string, l *TaskList) error { 116 | data, err := json.MarshalIndent(l, "", " ") 117 | if err != nil { 118 | return fmt.Errorf("failed to marshal task list: %w", err) 119 | } 120 | dir := filepath.Dir(file) 121 | if err := os.MkdirAll(dir, 0755); err != nil { 122 | return fmt.Errorf("failed to create directory %s: %w", dir, err) 123 | } 124 | return os.WriteFile(file, data, 0644) 125 | } 126 | 127 | func Load(file string) (*TaskList, error) { 128 | data, err := os.ReadFile(file) 129 | if err != nil { 130 | if os.IsNotExist(err) { 131 | l := &TaskList{Tasks: []Task{}} 132 | if err := Save(file, l); err != nil { 133 | return nil, fmt.Errorf("failed to save file %s: %w", file, err) 134 | } 135 | return l, nil 136 | } 137 | return nil, fmt.Errorf("failed to read file %s: %w", file, err) 138 | } 139 | var l *TaskList 140 | if err := json.Unmarshal(data, &l); err != nil { 141 | return nil, fmt.Errorf("failed to unmarshal task list: %w", err) 142 | } 143 | return l, nil 144 | } 145 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mfridman/cli 2 | 3 | go 1.21.0 4 | 5 | toolchain go1.23.2 6 | 7 | require ( 8 | github.com/mfridman/xflag v0.1.0 9 | github.com/stretchr/testify v1.10.0 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/pmezard/go-difflib v1.0.0 // indirect 15 | gopkg.in/yaml.v3 v3.0.1 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/mfridman/xflag v0.1.0 h1:TWZrZwG1QklFX5S4j1vxfF1sZbZeZSGofMwPMLAF29M= 4 | github.com/mfridman/xflag v0.1.0/go.mod h1:/483ywM5ZO5SuMVjrIGquYNE5CzLrj5Ux/LxWWnjRaE= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 8 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 9 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 11 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 12 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 13 | -------------------------------------------------------------------------------- /parse.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "errors" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "regexp" 9 | "slices" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/mfridman/xflag" 14 | ) 15 | 16 | // Parse traverses the command hierarchy and parses arguments. It returns an error if parsing fails 17 | // at any point. 18 | // 19 | // This function is the main entry point for parsing command-line arguments and should be called 20 | // with the root command and the arguments to parse, typically os.Args[1:]. Once parsing is 21 | // complete, the root command is ready to be executed with the [Run] function. 22 | func Parse(root *Command, args []string) error { 23 | if root == nil { 24 | return fmt.Errorf("failed to parse: root command is nil") 25 | } 26 | if err := validateCommands(root, nil); err != nil { 27 | return fmt.Errorf("failed to parse: %w", err) 28 | } 29 | 30 | // Initialize or update root state 31 | if root.state == nil { 32 | root.state = &State{ 33 | path: []*Command{root}, 34 | } 35 | } else { 36 | // Reset command path but preserve other state 37 | root.state.path = []*Command{root} 38 | } 39 | // First split args at the -- delimiter if present 40 | var argsToParse []string 41 | var remainingArgs []string 42 | for i, arg := range args { 43 | if arg == "--" { 44 | argsToParse = args[:i] 45 | remainingArgs = args[i+1:] 46 | break 47 | } 48 | } 49 | if argsToParse == nil { 50 | argsToParse = args 51 | } 52 | 53 | current := root 54 | if current.Flags == nil { 55 | current.Flags = flag.NewFlagSet(root.Name, flag.ContinueOnError) 56 | } 57 | var commandChain []*Command 58 | commandChain = append(commandChain, root) 59 | 60 | // Create combined flags with all parent flags 61 | combinedFlags := flag.NewFlagSet(root.Name, flag.ContinueOnError) 62 | combinedFlags.SetOutput(io.Discard) 63 | 64 | // First pass: process commands and build the flag set 65 | i := 0 66 | for i < len(argsToParse) { 67 | arg := argsToParse[i] 68 | 69 | // Skip flags and their values 70 | if strings.HasPrefix(arg, "-") { 71 | // For formats like -flag=x or --flag=x 72 | if strings.Contains(arg, "=") { 73 | i++ 74 | continue 75 | } 76 | 77 | // Check if this flag expects a value 78 | name := strings.TrimLeft(arg, "-") 79 | if f := current.Flags.Lookup(name); f != nil { 80 | if _, isBool := f.Value.(interface{ IsBoolFlag() bool }); !isBool { 81 | // Skip both flag and its value 82 | i += 2 83 | continue 84 | } 85 | } 86 | i++ 87 | continue 88 | } 89 | 90 | // Try to traverse to subcommand 91 | if len(current.SubCommands) > 0 { 92 | if sub := current.findSubCommand(arg); sub != nil { 93 | root.state.path = append(slices.Clone(root.state.path), sub) 94 | if sub.Flags == nil { 95 | sub.Flags = flag.NewFlagSet(sub.Name, flag.ContinueOnError) 96 | } 97 | current = sub 98 | commandChain = append(commandChain, sub) 99 | i++ 100 | continue 101 | } 102 | return current.formatUnknownCommandError(arg) 103 | } 104 | break 105 | } 106 | current.Flags.Usage = func() { /* suppress default usage */ } 107 | 108 | // Add the help check here, after we've found the correct command 109 | hasHelp := false 110 | for _, arg := range argsToParse { 111 | if arg == "-h" || arg == "--h" || arg == "-help" || arg == "--help" { 112 | hasHelp = true 113 | break 114 | } 115 | } 116 | 117 | // Add flags in reverse order for proper precedence 118 | for i := len(commandChain) - 1; i >= 0; i-- { 119 | cmd := commandChain[i] 120 | if cmd.Flags != nil { 121 | cmd.Flags.VisitAll(func(f *flag.Flag) { 122 | if combinedFlags.Lookup(f.Name) == nil { 123 | combinedFlags.Var(f.Value, f.Name, f.Usage) 124 | } 125 | }) 126 | } 127 | } 128 | // Make sure to return help only after combining all flags, this way we get the full list of 129 | // flags in the help message! 130 | if hasHelp { 131 | return flag.ErrHelp 132 | } 133 | 134 | // Let ParseToEnd handle the flag parsing 135 | if err := xflag.ParseToEnd(combinedFlags, argsToParse); err != nil { 136 | return fmt.Errorf("command %q: %w", getCommandPath(root.state.path), err) 137 | } 138 | 139 | // Check required flags 140 | var missingFlags []string 141 | for _, cmd := range commandChain { 142 | if len(cmd.FlagsMetadata) > 0 { 143 | for _, flagMetadata := range cmd.FlagsMetadata { 144 | if !flagMetadata.Required { 145 | continue 146 | } 147 | flag := combinedFlags.Lookup(flagMetadata.Name) 148 | if flag == nil { 149 | return fmt.Errorf("command %q: internal error: required flag %s not found in flag set", getCommandPath(root.state.path), formatFlagName(flagMetadata.Name)) 150 | } 151 | if _, isBool := flag.Value.(interface{ IsBoolFlag() bool }); isBool { 152 | isSet := false 153 | for _, arg := range argsToParse { 154 | if strings.HasPrefix(arg, "-"+flagMetadata.Name) || strings.HasPrefix(arg, "--"+flagMetadata.Name) { 155 | isSet = true 156 | break 157 | } 158 | } 159 | if !isSet { 160 | missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name)) 161 | } 162 | } else if flag.Value.String() == flag.DefValue { 163 | missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name)) 164 | } 165 | } 166 | } 167 | } 168 | if len(missingFlags) > 0 { 169 | msg := "required flag" 170 | if len(missingFlags) > 1 { 171 | msg += "s" 172 | } 173 | return fmt.Errorf("command %q: %s %q not set", getCommandPath(root.state.path), msg, strings.Join(missingFlags, ", ")) 174 | } 175 | 176 | // Skip past command names in remaining args 177 | parsed := combinedFlags.Args() 178 | startIdx := 0 179 | for _, arg := range parsed { 180 | isCommand := false 181 | for _, cmd := range commandChain { 182 | if arg == cmd.Name { 183 | startIdx++ 184 | isCommand = true 185 | break 186 | } 187 | } 188 | if !isCommand { 189 | break 190 | } 191 | } 192 | 193 | // Combine remaining parsed args and everything after delimiter 194 | var finalArgs []string 195 | if startIdx < len(parsed) { 196 | finalArgs = append(finalArgs, parsed[startIdx:]...) 197 | } 198 | if len(remainingArgs) > 0 { 199 | finalArgs = append(finalArgs, remainingArgs...) 200 | } 201 | root.state.Args = finalArgs 202 | 203 | if current.Exec == nil { 204 | return fmt.Errorf("command %q: no exec function defined", getCommandPath(root.state.path)) 205 | } 206 | return nil 207 | } 208 | 209 | var validNameRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]*$`) 210 | 211 | func validateName(root *Command) error { 212 | if !validNameRegex.MatchString(root.Name) { 213 | return fmt.Errorf("name must start with a letter and contain only letters, numbers, dashes (-) or underscores (_)") 214 | } 215 | return nil 216 | } 217 | 218 | func validateCommands(root *Command, path []string) error { 219 | if root.Name == "" { 220 | if len(path) == 0 { 221 | return errors.New("root command has no name") 222 | } 223 | return fmt.Errorf("subcommand in path [%s] has no name", strings.Join(path, ", ")) 224 | } 225 | 226 | currentPath := append(path, root.Name) 227 | if err := validateName(root); err != nil { 228 | quoted := make([]string, len(currentPath)) 229 | for i, p := range currentPath { 230 | quoted[i] = strconv.Quote(p) 231 | } 232 | return fmt.Errorf("command [%s]: %w", strings.Join(quoted, ", "), err) 233 | } 234 | 235 | for _, sub := range root.SubCommands { 236 | if err := validateCommands(sub, currentPath); err != nil { 237 | return err 238 | } 239 | } 240 | return nil 241 | } 242 | -------------------------------------------------------------------------------- /parse_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "flag" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // testState is a helper struct to hold the commands for testing 15 | // 16 | // root --verbose --version 17 | // ├── add --dry-run 18 | // └── nested --force 19 | // └── sub --echo 20 | // └── hello --mandatory-flag=false --another-mandatory-flag some-value 21 | type testState struct { 22 | add *Command 23 | nested, sub, hello *Command 24 | root *Command 25 | } 26 | 27 | func newTestState() testState { 28 | exec := func(ctx context.Context, s *State) error { return errors.New("not implemented") } 29 | add := &Command{ 30 | Name: "add", 31 | Flags: FlagsFunc(func(fset *flag.FlagSet) { 32 | fset.Bool("dry-run", false, "enable dry-run mode") 33 | }), 34 | Exec: exec, 35 | } 36 | sub := &Command{ 37 | Name: "sub", 38 | Flags: FlagsFunc(func(fset *flag.FlagSet) { 39 | fset.String("echo", "", "echo the message") 40 | }), 41 | FlagsMetadata: []FlagMetadata{ 42 | {Name: "echo", Required: false}, // not required 43 | }, 44 | Exec: exec, 45 | } 46 | hello := &Command{ 47 | Name: "hello", 48 | Flags: FlagsFunc(func(fset *flag.FlagSet) { 49 | fset.Bool("mandatory-flag", false, "mandatory flag") 50 | fset.String("another-mandatory-flag", "", "another mandatory flag") 51 | }), 52 | FlagsMetadata: []FlagMetadata{ 53 | {Name: "mandatory-flag", Required: true}, 54 | {Name: "another-mandatory-flag", Required: true}, 55 | }, 56 | Exec: exec, 57 | } 58 | nested := &Command{ 59 | Name: "nested", 60 | Flags: FlagsFunc(func(fset *flag.FlagSet) { 61 | fset.Bool("force", false, "force the operation") 62 | }), 63 | SubCommands: []*Command{sub, hello}, 64 | Exec: exec, 65 | } 66 | root := &Command{ 67 | Name: "todo", 68 | Flags: FlagsFunc(func(fset *flag.FlagSet) { 69 | fset.Bool("verbose", false, "enable verbose mode") 70 | fset.Bool("version", false, "show version") 71 | }), 72 | SubCommands: []*Command{add, nested}, 73 | Exec: exec, 74 | } 75 | return testState{ 76 | add: add, 77 | nested: nested, 78 | sub: sub, 79 | root: root, 80 | hello: hello, 81 | } 82 | } 83 | 84 | func TestParse(t *testing.T) { 85 | t.Parallel() 86 | 87 | t.Run("error on parse with no exec", func(t *testing.T) { 88 | t.Parallel() 89 | cmd := &Command{ 90 | Name: "foo", 91 | Exec: func(ctx context.Context, s *State) error { return nil }, 92 | SubCommands: []*Command{ 93 | { 94 | Name: "bar", 95 | Exec: func(ctx context.Context, s *State) error { return nil }, 96 | SubCommands: []*Command{ 97 | { 98 | Name: "baz", 99 | }, 100 | }, 101 | }, 102 | }, 103 | } 104 | err := Parse(cmd, []string{"bar", "baz"}) 105 | require.Error(t, err) 106 | assert.ErrorContains(t, err, `command "foo bar baz": no exec function defined`) 107 | }) 108 | t.Run("parsing errors", func(t *testing.T) { 109 | t.Parallel() 110 | 111 | err := Parse(nil, nil) 112 | require.Error(t, err) 113 | require.Contains(t, err.Error(), "command is nil") 114 | 115 | err = Parse(&Command{}, nil) 116 | require.Error(t, err) 117 | require.Contains(t, err.Error(), "root command has no name") 118 | }) 119 | t.Run("subcommand nil flags", func(t *testing.T) { 120 | t.Parallel() 121 | 122 | err := Parse(&Command{ 123 | Name: "root", 124 | SubCommands: []*Command{{ 125 | Name: "sub", 126 | Exec: func(ctx context.Context, s *State) error { return nil }, 127 | }}, 128 | Exec: func(ctx context.Context, s *State) error { return nil }, 129 | }, []string{"sub"}) 130 | require.NoError(t, err) 131 | }) 132 | t.Run("default flag usage", func(t *testing.T) { 133 | t.Parallel() 134 | 135 | by := bytes.NewBuffer(nil) 136 | root := &Command{ 137 | Name: "root", 138 | Usage: "root [flags]", 139 | Flags: FlagsFunc(func(fset *flag.FlagSet) { 140 | fset.SetOutput(by) 141 | }), 142 | } 143 | err := Parse(root, []string{"--help"}) 144 | require.Error(t, err) 145 | require.ErrorIs(t, err, flag.ErrHelp) 146 | require.Empty(t, by.String()) 147 | }) 148 | t.Run("no flags", func(t *testing.T) { 149 | t.Parallel() 150 | s := newTestState() 151 | 152 | err := Parse(s.root, []string{"add", "item1"}) 153 | require.NoError(t, err) 154 | cmd := getCommand(t, s.root) 155 | 156 | require.Equal(t, s.add, cmd) 157 | require.False(t, GetFlag[bool](s.root.state, "dry-run")) 158 | }) 159 | t.Run("unknown flag", func(t *testing.T) { 160 | t.Parallel() 161 | s := newTestState() 162 | 163 | err := Parse(s.root, []string{"add", "--unknown", "item1"}) 164 | require.Error(t, err) 165 | require.Contains(t, err.Error(), `command "todo add": flag provided but not defined: -unknown`) 166 | }) 167 | t.Run("with subcommand flags", func(t *testing.T) { 168 | t.Parallel() 169 | s := newTestState() 170 | 171 | err := Parse(s.root, []string{"add", "--dry-run", "item1"}) 172 | require.NoError(t, err) 173 | cmd := getCommand(t, s.root) 174 | 175 | assert.Equal(t, s.add, cmd) 176 | assert.True(t, GetFlag[bool](s.root.state, "dry-run")) 177 | }) 178 | t.Run("help flag", func(t *testing.T) { 179 | t.Parallel() 180 | s := newTestState() 181 | 182 | err := Parse(s.root, []string{"--help"}) 183 | require.Error(t, err) 184 | require.ErrorIs(t, err, flag.ErrHelp) 185 | }) 186 | t.Run("help flag with subcommand", func(t *testing.T) { 187 | t.Parallel() 188 | s := newTestState() 189 | 190 | err := Parse(s.root, []string{"add", "--help"}) 191 | require.Error(t, err) 192 | require.ErrorIs(t, err, flag.ErrHelp) 193 | }) 194 | t.Run("help flag with subcommand at s.root", func(t *testing.T) { 195 | t.Parallel() 196 | s := newTestState() 197 | 198 | err := Parse(s.root, []string{"--help", "add"}) 199 | require.Error(t, err) 200 | require.ErrorIs(t, err, flag.ErrHelp) 201 | }) 202 | t.Run("help flag with subcommand and other flags", func(t *testing.T) { 203 | t.Parallel() 204 | s := newTestState() 205 | 206 | err := Parse(s.root, []string{"add", "--help", "--dry-run"}) 207 | require.Error(t, err) 208 | require.ErrorIs(t, err, flag.ErrHelp) 209 | }) 210 | t.Run("unknown subcommand", func(t *testing.T) { 211 | t.Parallel() 212 | s := newTestState() 213 | 214 | err := Parse(s.root, []string{"unknown"}) 215 | require.Error(t, err) 216 | require.Contains(t, err.Error(), "unknown command") 217 | }) 218 | t.Run("flags at multiple levels", func(t *testing.T) { 219 | t.Parallel() 220 | s := newTestState() 221 | 222 | err := Parse(s.root, []string{"add", "--dry-run", "item1", "--verbose"}) 223 | require.NoError(t, err) 224 | cmd := getCommand(t, s.root) 225 | 226 | assert.Equal(t, s.add, cmd) 227 | assert.True(t, GetFlag[bool](s.root.state, "dry-run")) 228 | assert.True(t, GetFlag[bool](s.root.state, "verbose")) 229 | }) 230 | t.Run("nested subcommand and root flag", func(t *testing.T) { 231 | t.Parallel() 232 | s := newTestState() 233 | 234 | err := Parse(s.root, []string{"--verbose", "nested", "sub", "--echo", "hello"}) 235 | require.NoError(t, err) 236 | cmd := getCommand(t, s.root) 237 | 238 | assert.Equal(t, s.sub, cmd) 239 | assert.Equal(t, "hello", GetFlag[string](s.root.state, "echo")) 240 | assert.True(t, GetFlag[bool](s.root.state, "verbose")) 241 | }) 242 | t.Run("nested subcommand with mixed flags", func(t *testing.T) { 243 | t.Parallel() 244 | s := newTestState() 245 | 246 | err := Parse(s.root, []string{"nested", "sub", "--echo", "hello", "--verbose"}) 247 | require.NoError(t, err) 248 | cmd := getCommand(t, s.root) 249 | 250 | assert.Equal(t, s.sub, cmd) 251 | assert.Equal(t, "hello", GetFlag[string](s.root.state, "echo")) 252 | assert.True(t, GetFlag[bool](s.root.state, "verbose")) 253 | }) 254 | t.Run("end of options delimiter", func(t *testing.T) { 255 | t.Parallel() 256 | s := newTestState() 257 | 258 | err := Parse(s.root, []string{"--verbose", "--", "nested", "sub", "--echo", "hello"}) 259 | require.NoError(t, err) 260 | cmd := getCommand(t, s.root) 261 | 262 | assert.Equal(t, s.root, cmd) 263 | assert.Equal(t, []string{"nested", "sub", "--echo", "hello"}, s.root.state.Args) 264 | assert.True(t, GetFlag[bool](s.root.state, "verbose")) 265 | }) 266 | t.Run("flags and args", func(t *testing.T) { 267 | t.Parallel() 268 | s := newTestState() 269 | 270 | err := Parse(s.root, []string{"add", "item1", "--dry-run", "item2"}) 271 | require.NoError(t, err) 272 | cmd := getCommand(t, s.root) 273 | 274 | assert.Equal(t, s.add, cmd) 275 | assert.True(t, GetFlag[bool](s.root.state, "dry-run")) 276 | assert.Equal(t, []string{"item1", "item2"}, s.root.state.Args) 277 | }) 278 | t.Run("nested subcommand with flags and args", func(t *testing.T) { 279 | t.Parallel() 280 | s := newTestState() 281 | 282 | err := Parse(s.root, []string{"nested", "sub", "--echo", "hello", "world"}) 283 | require.NoError(t, err) 284 | cmd := getCommand(t, s.root) 285 | 286 | assert.Equal(t, s.sub, cmd) 287 | assert.Equal(t, "hello", GetFlag[string](s.root.state, "echo")) 288 | assert.Equal(t, []string{"world"}, s.root.state.Args) 289 | }) 290 | t.Run("subcommand flags not available in parent", func(t *testing.T) { 291 | t.Parallel() 292 | s := newTestState() 293 | 294 | err := Parse(s.root, []string{"--dry-run"}) 295 | require.Error(t, err) 296 | require.ErrorContains(t, err, "flag provided but not defined") 297 | }) 298 | t.Run("parent flags inherited in subcommand", func(t *testing.T) { 299 | t.Parallel() 300 | s := newTestState() 301 | 302 | err := Parse(s.root, []string{"nested", "sub", "--force"}) 303 | require.NoError(t, err) 304 | cmd := getCommand(t, s.root) 305 | 306 | assert.Equal(t, s.sub, cmd) 307 | assert.True(t, GetFlag[bool](s.root.state, "force")) 308 | }) 309 | t.Run("unrelated subcommand flags not inherited in other subcommands", func(t *testing.T) { 310 | t.Parallel() 311 | s := newTestState() 312 | 313 | err := Parse(s.root, []string{"nested", "sub", "--dry-run"}) 314 | require.Error(t, err) 315 | require.ErrorContains(t, err, "flag provided but not defined") 316 | }) 317 | t.Run("empty name in subcommand", func(t *testing.T) { 318 | t.Parallel() 319 | s := newTestState() 320 | s.sub.Name = "" 321 | 322 | err := Parse(s.root, nil) 323 | require.Error(t, err) 324 | require.ErrorContains(t, err, `subcommand in path [todo, nested] has no name`) 325 | }) 326 | t.Run("required flag", func(t *testing.T) { 327 | t.Parallel() 328 | { 329 | s := newTestState() 330 | err := Parse(s.root, []string{"nested", "hello"}) 331 | require.Error(t, err) 332 | require.ErrorContains(t, err, `command "todo nested hello": required flags "-mandatory-flag, -another-mandatory-flag" not set`) 333 | } 334 | { 335 | // Correct type - true 336 | s := newTestState() 337 | err := Parse(s.root, []string{"nested", "hello", "--mandatory-flag=true", "--another-mandatory-flag", "some-value"}) 338 | require.NoError(t, err) 339 | cmd := getCommand(t, s.root) 340 | 341 | assert.Equal(t, s.hello, cmd) 342 | require.True(t, GetFlag[bool](s.root.state, "mandatory-flag")) 343 | } 344 | { 345 | // Correct type - false 346 | s := newTestState() 347 | err := Parse(s.root, []string{"nested", "hello", "--mandatory-flag=false", "--another-mandatory-flag=some-value"}) 348 | require.NoError(t, err) 349 | cmd := s.root.terminal() 350 | assert.Equal(t, s.hello, cmd) 351 | require.False(t, GetFlag[bool](s.root.state, "mandatory-flag")) 352 | } 353 | { 354 | // Incorrect type 355 | s := newTestState() 356 | err := Parse(s.root, []string{"nested", "hello", "--mandatory-flag=not-a-bool"}) 357 | require.Error(t, err) 358 | require.ErrorContains(t, err, `command "todo nested hello": invalid boolean value "not-a-bool" for -mandatory-flag: parse error`) 359 | } 360 | }) 361 | t.Run("unknown required flag set by cli author", func(t *testing.T) { 362 | t.Parallel() 363 | cmd := &Command{ 364 | Name: "root", 365 | FlagsMetadata: []FlagMetadata{ 366 | {Name: "some-other-flag", Required: true}, 367 | }, 368 | } 369 | err := Parse(cmd, nil) 370 | require.Error(t, err) 371 | // TODO(mf): consider improving this error message so it's obvious that a "required" flag 372 | // was set by the cli author but not registered in the flag set 373 | require.ErrorContains(t, err, `command "root": internal error: required flag -some-other-flag not found in flag set`) 374 | }) 375 | t.Run("space in command name", func(t *testing.T) { 376 | t.Parallel() 377 | cmd := &Command{ 378 | Name: "root", 379 | SubCommands: []*Command{ 380 | {Name: "sub command"}, 381 | }, 382 | } 383 | err := Parse(cmd, nil) 384 | require.Error(t, err) 385 | require.ErrorContains(t, err, `failed to parse: command ["root", "sub command"]: name must start with a letter and contain only letters, numbers, dashes (-) or underscores (_)`) 386 | }) 387 | t.Run("dash in command name", func(t *testing.T) { 388 | t.Parallel() 389 | cmd := &Command{ 390 | Name: "root", 391 | Exec: func(ctx context.Context, s *State) error { return nil }, 392 | SubCommands: []*Command{ 393 | {Name: "sub-command"}, 394 | }, 395 | } 396 | err := Parse(cmd, nil) 397 | require.NoError(t, err) 398 | }) 399 | } 400 | 401 | func getCommand(t *testing.T, c *Command) *Command { 402 | require.NotNil(t, c) 403 | require.NotNil(t, c.state) 404 | require.NotEmpty(t, c.state.path) 405 | terminal := c.terminal() 406 | require.NotNil(t, terminal) 407 | return terminal 408 | } 409 | -------------------------------------------------------------------------------- /pkg/suggest/suggest.go: -------------------------------------------------------------------------------- 1 | package suggest 2 | 3 | import ( 4 | "sort" 5 | "strings" 6 | ) 7 | 8 | // threshold is the minimum similarity score required for a string to be considered similar. 9 | const threshold = 0.5 10 | 11 | // FindSimilar returns a list of similar strings to the target string from a list of candidates. 12 | func FindSimilar(target string, candidates []string, maxResults int) []string { 13 | // Early returns for invalid inputs 14 | if target == "" || maxResults <= 0 { 15 | return []string{} 16 | } 17 | 18 | suggestions := make([]struct { 19 | name string 20 | score float64 21 | }, 0, len(candidates)) 22 | 23 | // Calculate similarity scores 24 | for _, name := range candidates { 25 | score := calculateSimilarity(target, name) 26 | if score > threshold { // Only include reasonably similar commands 27 | suggestions = append(suggestions, struct { 28 | name string 29 | score float64 30 | }{name, score}) 31 | } 32 | } 33 | 34 | sort.Slice(suggestions, func(i, j int) bool { 35 | if suggestions[i].score == suggestions[j].score { 36 | return suggestions[i].name < suggestions[j].name 37 | } 38 | return suggestions[i].score > suggestions[j].score 39 | }) 40 | 41 | // Get top N suggestions 42 | result := make([]string, 0, maxResults) 43 | for i := 0; i < len(suggestions) && i < maxResults; i++ { 44 | result = append(result, suggestions[i].name) 45 | } 46 | 47 | return result 48 | } 49 | 50 | func calculateSimilarity(a, b string) float64 { 51 | a = strings.ToLower(a) 52 | b = strings.ToLower(b) 53 | 54 | // Perfect match 55 | if a == b { 56 | return 1.0 57 | } 58 | // Prefix match bonus 59 | if strings.HasPrefix(b, a) { 60 | return 0.9 61 | } 62 | // Calculate Levenshtein distance 63 | distance := levenshteinDistance(a, b) 64 | maxLen := float64(max(len(a), len(b))) 65 | 66 | // Convert distance to similarity score (0 to 1) 67 | similarity := 1.0 - float64(distance)/maxLen 68 | 69 | return similarity 70 | } 71 | 72 | func levenshteinDistance(a, b string) int { 73 | if len(a) == 0 { 74 | return len(b) 75 | } 76 | if len(b) == 0 { 77 | return len(a) 78 | } 79 | 80 | matrix := make([][]int, len(a)+1) 81 | for i := range matrix { 82 | matrix[i] = make([]int, len(b)+1) 83 | } 84 | 85 | for i := 0; i <= len(a); i++ { 86 | matrix[i][0] = i 87 | } 88 | for j := 0; j <= len(b); j++ { 89 | matrix[0][j] = j 90 | } 91 | 92 | for i := 1; i <= len(a); i++ { 93 | for j := 1; j <= len(b); j++ { 94 | cost := 1 95 | if a[i-1] == b[j-1] { 96 | cost = 0 97 | } 98 | matrix[i][j] = min( 99 | matrix[i-1][j]+1, // deletion 100 | min(matrix[i][j-1]+1, // insertion 101 | matrix[i-1][j-1]+cost)) // substitution 102 | } 103 | } 104 | 105 | return matrix[len(a)][len(b)] 106 | } 107 | -------------------------------------------------------------------------------- /pkg/suggest/suggest_test.go: -------------------------------------------------------------------------------- 1 | package suggest 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestFindSimilar(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | target string 13 | candidates []string 14 | maxResults int 15 | expected []string 16 | }{ 17 | { 18 | name: "exact match", 19 | target: "hello", 20 | candidates: []string{"hello", "world", "help"}, 21 | maxResults: 2, 22 | expected: []string{"hello", "help"}, 23 | }, 24 | { 25 | name: "empty target", 26 | target: "", 27 | candidates: []string{"hello", "world"}, 28 | maxResults: 2, 29 | expected: []string{}, 30 | }, 31 | { 32 | name: "no matches", 33 | target: "xyz", 34 | candidates: []string{"hello", "world"}, 35 | maxResults: 2, 36 | expected: []string{}, 37 | }, 38 | { 39 | name: "invalid max results", 40 | target: "hello", 41 | candidates: []string{"hello", "world"}, 42 | maxResults: -1, 43 | expected: []string{}, 44 | }, 45 | } 46 | 47 | for _, tt := range tests { 48 | t.Run(tt.name, func(t *testing.T) { 49 | result := FindSimilar(tt.target, tt.candidates, tt.maxResults) 50 | assert.Equal(t, tt.expected, result) 51 | }) 52 | } 53 | } 54 | 55 | func TestCalculateSimilarity(t *testing.T) { 56 | t.Parallel() 57 | 58 | tests := []struct { 59 | name string 60 | a string 61 | b string 62 | expected float64 63 | }{ 64 | { 65 | name: "perfect match", 66 | a: "hello", 67 | b: "hello", 68 | expected: 1.0, 69 | }, 70 | { 71 | name: "perfect match with different case", 72 | a: "Hello", 73 | b: "hello", 74 | expected: 1.0, 75 | }, 76 | { 77 | name: "prefix match", 78 | a: "hel", 79 | b: "hello", 80 | expected: 0.9, 81 | }, 82 | { 83 | name: "one character difference", 84 | a: "hello", 85 | b: "hello1", 86 | expected: 0.9, // prefix match case 87 | }, 88 | { 89 | name: "completely different strings", 90 | a: "hello", 91 | b: "world", 92 | expected: 0.2, // Based on Levenshtein distance of 4 with max length 5 93 | }, 94 | { 95 | name: "empty strings", 96 | a: "", 97 | b: "", 98 | expected: 1.0, 99 | }, 100 | { 101 | name: "one empty string", 102 | a: "hello", 103 | b: "", 104 | expected: 0.0, 105 | }, 106 | } 107 | 108 | for _, tt := range tests { 109 | t.Run(tt.name, func(t *testing.T) { 110 | result := calculateSimilarity(tt.a, tt.b) 111 | assert.InDelta(t, tt.expected, result, 0.001, "similarity mismatch for %q and %q", tt.a, tt.b) 112 | }) 113 | } 114 | } 115 | 116 | func TestLevenshteinDistance(t *testing.T) { 117 | t.Parallel() 118 | 119 | tests := []struct { 120 | name string 121 | a string 122 | b string 123 | expected int 124 | }{ 125 | { 126 | name: "identical strings", 127 | a: "hello", 128 | b: "hello", 129 | expected: 0, 130 | }, 131 | { 132 | name: "one character difference", 133 | a: "hello", 134 | b: "hallo", 135 | expected: 1, 136 | }, 137 | { 138 | name: "addition", 139 | a: "hello", 140 | b: "hello1", 141 | expected: 1, 142 | }, 143 | { 144 | name: "deletion", 145 | a: "hello", 146 | b: "hell", 147 | expected: 1, 148 | }, 149 | { 150 | name: "empty first string", 151 | a: "", 152 | b: "hello", 153 | expected: 5, 154 | }, 155 | { 156 | name: "empty second string", 157 | a: "hello", 158 | b: "", 159 | expected: 5, 160 | }, 161 | { 162 | name: "both empty strings", 163 | a: "", 164 | b: "", 165 | expected: 0, 166 | }, 167 | { 168 | name: "completely different strings", 169 | a: "hello", 170 | b: "world", 171 | expected: 4, 172 | }, 173 | } 174 | 175 | for _, tt := range tests { 176 | t.Run(tt.name, func(t *testing.T) { 177 | result := levenshteinDistance(tt.a, tt.b) 178 | assert.Equal(t, tt.expected, result, "distance mismatch for %q and %q", tt.a, tt.b) 179 | }) 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /pkg/textutil/textutil.go: -------------------------------------------------------------------------------- 1 | package textutil 2 | 3 | import "strings" 4 | 5 | func Wrap(text string, width int) []string { 6 | words := strings.Fields(text) 7 | var ( 8 | lines []string 9 | currentLine []string 10 | currentLength int 11 | ) 12 | for _, word := range words { 13 | if currentLength+len(word)+1 > width { 14 | if len(currentLine) > 0 { 15 | lines = append(lines, strings.Join(currentLine, " ")) 16 | currentLine = []string{word} 17 | currentLength = len(word) 18 | } else { 19 | lines = append(lines, word) 20 | } 21 | } else { 22 | currentLine = append(currentLine, word) 23 | if currentLength == 0 { 24 | currentLength = len(word) 25 | } else { 26 | currentLength += len(word) + 1 27 | } 28 | } 29 | } 30 | if len(currentLine) > 0 { 31 | lines = append(lines, strings.Join(currentLine, " ")) 32 | } 33 | return lines 34 | } 35 | -------------------------------------------------------------------------------- /pkg/textutil/textutil_test.go: -------------------------------------------------------------------------------- 1 | package textutil 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestWrapText(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | text string 13 | width int 14 | expected []string 15 | }{ 16 | { 17 | name: "simple wrap", 18 | text: "hello world", 19 | width: 5, 20 | expected: []string{"hello", "world"}, 21 | }, 22 | { 23 | name: "no wrap needed", 24 | text: "hello", 25 | width: 10, 26 | expected: []string{"hello"}, 27 | }, 28 | { 29 | name: "multiple wraps", 30 | text: "this is a long text that needs wrapping", 31 | width: 10, 32 | expected: []string{"this is a", "long text", "that needs", "wrapping"}, 33 | }, 34 | { 35 | name: "empty string", 36 | text: "", 37 | width: 10, 38 | expected: nil, 39 | }, 40 | { 41 | name: "single word longer than width", 42 | text: "supercalifragilistic", 43 | width: 10, 44 | expected: []string{"supercalifragilistic"}, 45 | }, 46 | { 47 | name: "multiple spaces", 48 | text: "hello world", 49 | width: 20, 50 | expected: []string{"hello world"}, 51 | }, 52 | } 53 | 54 | for _, tt := range tests { 55 | t.Run(tt.name, func(t *testing.T) { 56 | result := Wrap(tt.text, tt.width) 57 | assert.EqualValues(t, tt.expected, result, "wrapped text mismatch for input %q with width %d", tt.text, tt.width) 58 | }) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /run.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "os" 9 | ) 10 | 11 | // RunOptions specifies options for running a command. 12 | type RunOptions struct { 13 | // Stdin, Stdout, and Stderr are the standard input, output, and error streams for the command. 14 | // If any of these are nil, the command will use the default streams ([os.Stdin], [os.Stdout], 15 | // and [os.Stderr], respectively). 16 | Stdin io.Reader 17 | Stdout, Stderr io.Writer 18 | } 19 | 20 | // Run executes the current command. It returns an error if the command has not been parsed or if 21 | // the command has no execution function. 22 | // 23 | // The options parameter may be nil, in which case default values are used. See [RunOptions] for 24 | // more details. 25 | func Run(ctx context.Context, root *Command, options *RunOptions) error { 26 | if root == nil { 27 | return errors.New("root command is nil") 28 | } 29 | if root.state == nil || len(root.state.path) == 0 { 30 | return errors.New("command has not been parsed") 31 | } 32 | cmd := root.terminal() 33 | if cmd == nil { 34 | // This should never happen, but if it does, it's likely a bug in the Parse function. 35 | return errors.New("no terminal command found") 36 | } 37 | 38 | options = checkAndSetRunOptions(options) 39 | updateState(root.state, options) 40 | 41 | return run(ctx, cmd, root.state) 42 | } 43 | 44 | func run(ctx context.Context, cmd *Command, state *State) (retErr error) { 45 | defer func() { 46 | if r := recover(); r != nil { 47 | switch err := r.(type) { 48 | case error: 49 | retErr = fmt.Errorf("internal: %v", err) 50 | default: 51 | retErr = fmt.Errorf("recover: %v", r) 52 | } 53 | } 54 | }() 55 | return cmd.Exec(ctx, state) 56 | } 57 | 58 | func updateState(s *State, opt *RunOptions) { 59 | if s.Stdin == nil { 60 | s.Stdin = opt.Stdin 61 | } 62 | if s.Stdout == nil { 63 | s.Stdout = opt.Stdout 64 | } 65 | if s.Stderr == nil { 66 | s.Stderr = opt.Stderr 67 | } 68 | } 69 | 70 | func checkAndSetRunOptions(opt *RunOptions) *RunOptions { 71 | if opt == nil { 72 | opt = &RunOptions{} 73 | } 74 | if opt.Stdin == nil { 75 | opt.Stdin = os.Stdin 76 | } 77 | if opt.Stdout == nil { 78 | opt.Stdout = os.Stdout 79 | } 80 | if opt.Stderr == nil { 81 | opt.Stderr = os.Stderr 82 | } 83 | return opt 84 | } 85 | -------------------------------------------------------------------------------- /run_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "flag" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestRun(t *testing.T) { 13 | t.Parallel() 14 | 15 | t.Run("print version", func(t *testing.T) { 16 | t.Parallel() 17 | 18 | root := &Command{ 19 | Name: "printer", 20 | Usage: "printer [flags] [command]", 21 | SubCommands: []*Command{ 22 | { 23 | Name: "version", 24 | Usage: "show version", 25 | Exec: func(ctx context.Context, s *State) error { 26 | _, _ = s.Stdout.Write([]byte("1.0.0\n")) 27 | return nil 28 | }, 29 | }, 30 | }, 31 | Exec: func(ctx context.Context, s *State) error { return nil }, 32 | } 33 | err := Parse(root, []string{"version"}) 34 | require.NoError(t, err) 35 | 36 | output := bytes.NewBuffer(nil) 37 | require.NoError(t, err) 38 | err = Run(context.Background(), root, &RunOptions{Stdout: output}) 39 | require.NoError(t, err) 40 | require.Equal(t, "1.0.0\n", output.String()) 41 | }) 42 | 43 | t.Run("parse and run", func(t *testing.T) { 44 | t.Parallel() 45 | var count int 46 | 47 | root := &Command{ 48 | Name: "count", 49 | Usage: "count [flags] [command]", 50 | Flags: FlagsFunc(func(f *flag.FlagSet) { 51 | f.Bool("dry-run", false, "dry run") 52 | }), 53 | Exec: func(ctx context.Context, s *State) error { 54 | if !GetFlag[bool](s, "dry-run") { 55 | count++ 56 | } 57 | return nil 58 | }, 59 | } 60 | err := Parse(root, nil) 61 | require.NoError(t, err) 62 | // Run the command 3 times 63 | for i := 0; i < 3; i++ { 64 | err := Run(context.Background(), root, nil) 65 | require.NoError(t, err) 66 | } 67 | require.Equal(t, 3, count) 68 | // Run with dry-run flag 69 | err = Parse(root, []string{"--dry-run"}) 70 | require.NoError(t, err) 71 | err = Run(context.Background(), root, nil) 72 | require.NoError(t, err) 73 | require.Equal(t, 3, count) 74 | }) 75 | t.Run("typo suggestion", func(t *testing.T) { 76 | t.Parallel() 77 | 78 | root := &Command{ 79 | Name: "count", 80 | Usage: "count [flags] [command]", 81 | SubCommands: []*Command{ 82 | { 83 | Name: "version", 84 | Usage: "show version", 85 | Exec: func(ctx context.Context, s *State) error { 86 | _, _ = s.Stdout.Write([]byte("1.0.0\n")) 87 | return nil 88 | }, 89 | }, 90 | }, 91 | Exec: func(ctx context.Context, s *State) error { return nil }, 92 | } 93 | 94 | err := Parse(root, []string{"verzion"}) 95 | require.Error(t, err) 96 | require.Contains(t, err.Error(), `unknown command "verzion". Did you mean one of these?`) 97 | require.Contains(t, err.Error(), ` version`) 98 | }) 99 | } 100 | -------------------------------------------------------------------------------- /state.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | // State holds command information during Exec function execution, allowing child commands to access 10 | // parent flags. Use [GetFlag] to get flag values across the command hierarchy. 11 | type State struct { 12 | // Args contains the remaining arguments after flag parsing. 13 | Args []string 14 | 15 | // Standard I/O streams. 16 | Stdin io.Reader 17 | Stdout, Stderr io.Writer 18 | 19 | // path is the command hierarchy from the root command to the current command. The root command 20 | // is the first element in the path, and the terminal command is the last element. 21 | path []*Command 22 | } 23 | 24 | // GetFlag retrieves a flag value by name from the command hierarchy. It first checks the current 25 | // command's flags, then walks up through parent commands. 26 | // 27 | // If the flag doesn't exist or if the type doesn't match the requested type T an error will be 28 | // raised in the Run function. This is an internal error and should never happen in normal usage. 29 | // This ensures flag-related programming errors are caught early during development. 30 | // 31 | // verbose := GetFlag[bool](state, "verbose") 32 | // count := GetFlag[int](state, "count") 33 | // path := GetFlag[string](state, "path") 34 | func GetFlag[T any](s *State, name string) T { 35 | // Try to find the flag in each command's flag set, starting from the current command 36 | for i := len(s.path) - 1; i >= 0; i-- { 37 | cmd := s.path[i] 38 | if cmd.Flags == nil { 39 | continue 40 | } 41 | 42 | if f := cmd.Flags.Lookup(name); f != nil { 43 | if getter, ok := f.Value.(flag.Getter); ok { 44 | value := getter.Get() 45 | if v, ok := value.(T); ok { 46 | return v 47 | } 48 | err := fmt.Errorf("type mismatch for flag %q in command %q: registered %T, requested %T", 49 | formatFlagName(name), 50 | getCommandPath(s.path), 51 | value, 52 | *new(T), 53 | ) 54 | // Flag exists but type doesn't match - this is an internal error 55 | panic(err) 56 | } 57 | } 58 | } 59 | 60 | // If flag not found anywhere in hierarchy, panic with helpful message 61 | err := fmt.Errorf("flag %q not found in command %q flag set", 62 | formatFlagName(name), 63 | getCommandPath(s.path), 64 | ) 65 | panic(err) 66 | } 67 | -------------------------------------------------------------------------------- /state_test.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "flag" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestGetFlag(t *testing.T) { 12 | t.Parallel() 13 | 14 | t.Run("flag not found", func(t *testing.T) { 15 | cmd := &Command{ 16 | Name: "root", 17 | Flags: flag.NewFlagSet("root", flag.ContinueOnError), 18 | } 19 | state := &State{ 20 | path: []*Command{cmd}, 21 | } 22 | defer func() { 23 | r := recover() 24 | require.NotNil(t, r) 25 | err, ok := r.(error) 26 | require.True(t, ok) 27 | assert.ErrorContains(t, err, `flag "-version" not found in command "root" flag set`) 28 | }() 29 | // Panic because author tried to access a flag that doesn't exist in any of the commands 30 | _ = GetFlag[string](state, "version") 31 | }) 32 | t.Run("flag type mismatch", func(t *testing.T) { 33 | cmd := &Command{ 34 | Name: "root", 35 | Flags: FlagsFunc(func(f *flag.FlagSet) { f.String("version", "1.0.0", "show version") }), 36 | } 37 | state := &State{ 38 | path: []*Command{cmd}, 39 | } 40 | defer func() { 41 | r := recover() 42 | require.NotNil(t, r) 43 | err, ok := r.(error) 44 | require.True(t, ok) 45 | assert.ErrorContains(t, err, `type mismatch for flag "-version" in command "root": registered string, requested int`) 46 | }() 47 | // Panic because author tried to access a registered flag with the wrong type 48 | _ = GetFlag[int](state, "version") 49 | }) 50 | } 51 | -------------------------------------------------------------------------------- /usage.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "cmp" 5 | "flag" 6 | "fmt" 7 | "slices" 8 | "strings" 9 | 10 | "github.com/mfridman/cli/pkg/textutil" 11 | ) 12 | 13 | // DefaultUsage returns the default usage string for the command hierarchy. It is used when the 14 | // command does not provide a custom usage function. The usage string includes the command's short 15 | // help, usage pattern, available subcommands, and flags. 16 | func DefaultUsage(root *Command) string { 17 | if root == nil { 18 | return "" 19 | } 20 | 21 | // Get terminal command from state 22 | terminalCmd := root.terminal() 23 | 24 | var b strings.Builder 25 | 26 | if terminalCmd.UsageFunc != nil { 27 | return terminalCmd.UsageFunc(terminalCmd) 28 | } 29 | 30 | if terminalCmd.ShortHelp != "" { 31 | b.WriteString(terminalCmd.ShortHelp) 32 | b.WriteString("\n\n") 33 | } 34 | 35 | b.WriteString("Usage:\n") 36 | if terminalCmd.Usage != "" { 37 | b.WriteString(" " + terminalCmd.Usage + "\n") 38 | } else { 39 | usage := terminalCmd.Name 40 | if root.state != nil && len(root.state.path) > 0 { 41 | usage = getCommandPath(root.state.path) 42 | } 43 | if terminalCmd.Flags != nil { 44 | usage += " [flags]" 45 | } 46 | if len(terminalCmd.SubCommands) > 0 { 47 | usage += " " 48 | } 49 | b.WriteString(" " + usage + "\n") 50 | } 51 | b.WriteString("\n") 52 | 53 | if len(terminalCmd.SubCommands) > 0 { 54 | b.WriteString("Available Commands:\n") 55 | sortedCommands := slices.Clone(terminalCmd.SubCommands) 56 | slices.SortFunc(sortedCommands, func(a, b *Command) int { 57 | return cmp.Compare(a.Name, b.Name) 58 | }) 59 | 60 | maxNameLen := 0 61 | for _, sub := range sortedCommands { 62 | if len(sub.Name) > maxNameLen { 63 | maxNameLen = len(sub.Name) 64 | } 65 | } 66 | 67 | nameWidth := maxNameLen + 4 68 | wrapWidth := 80 - nameWidth 69 | 70 | for _, sub := range sortedCommands { 71 | if sub.ShortHelp == "" { 72 | fmt.Fprintf(&b, " %s\n", sub.Name) 73 | continue 74 | } 75 | 76 | lines := textutil.Wrap(sub.ShortHelp, wrapWidth) 77 | padding := strings.Repeat(" ", maxNameLen-len(sub.Name)+4) 78 | fmt.Fprintf(&b, " %s%s%s\n", sub.Name, padding, lines[0]) 79 | 80 | indentPadding := strings.Repeat(" ", nameWidth+2) 81 | for _, line := range lines[1:] { 82 | fmt.Fprintf(&b, "%s%s\n", indentPadding, line) 83 | } 84 | } 85 | b.WriteString("\n") 86 | } 87 | 88 | var flags []flagInfo 89 | if root.state != nil && len(root.state.path) > 0 { 90 | for i, cmd := range root.state.path { 91 | if cmd.Flags == nil { 92 | continue 93 | } 94 | isGlobal := i < len(root.state.path)-1 95 | cmd.Flags.VisitAll(func(f *flag.Flag) { 96 | flags = append(flags, flagInfo{ 97 | name: "-" + f.Name, 98 | usage: f.Usage, 99 | defval: f.DefValue, 100 | global: isGlobal, 101 | }) 102 | }) 103 | } 104 | } 105 | 106 | if len(flags) > 0 { 107 | slices.SortFunc(flags, func(a, b flagInfo) int { 108 | return cmp.Compare(a.name, b.name) 109 | }) 110 | 111 | maxFlagLen := 0 112 | for _, f := range flags { 113 | if len(f.name) > maxFlagLen { 114 | maxFlagLen = len(f.name) 115 | } 116 | } 117 | 118 | hasLocal := false 119 | hasGlobal := false 120 | for _, f := range flags { 121 | if f.global { 122 | hasGlobal = true 123 | } else { 124 | hasLocal = true 125 | } 126 | } 127 | 128 | if hasLocal { 129 | b.WriteString("Flags:\n") 130 | writeFlagSection(&b, flags, maxFlagLen, false) 131 | b.WriteString("\n") 132 | } 133 | 134 | if hasGlobal { 135 | b.WriteString("Global Flags:\n") 136 | writeFlagSection(&b, flags, maxFlagLen, true) 137 | b.WriteString("\n") 138 | } 139 | } 140 | 141 | if len(terminalCmd.SubCommands) > 0 { 142 | cmdName := terminalCmd.Name 143 | if root.state != nil && len(root.state.path) > 0 { 144 | cmdName = getCommandPath(root.state.path) 145 | } 146 | fmt.Fprintf(&b, "Use \"%s [command] --help\" for more information about a command.\n", cmdName) 147 | } 148 | 149 | return strings.TrimRight(b.String(), "\n") 150 | } 151 | 152 | // writeFlagSection handles the formatting of flag descriptions 153 | func writeFlagSection(b *strings.Builder, flags []flagInfo, maxLen int, global bool) { 154 | nameWidth := maxLen + 4 155 | wrapWidth := 80 - nameWidth 156 | 157 | for _, f := range flags { 158 | if f.global != global { 159 | continue 160 | } 161 | 162 | description := f.usage 163 | if f.defval != "" { 164 | description += fmt.Sprintf(" (default: %s)", f.defval) 165 | } 166 | 167 | lines := textutil.Wrap(description, wrapWidth) 168 | padding := strings.Repeat(" ", maxLen-len(f.name)+4) 169 | fmt.Fprintf(b, " %s%s%s\n", f.name, padding, lines[0]) 170 | 171 | indentPadding := strings.Repeat(" ", nameWidth+2) 172 | for _, line := range lines[1:] { 173 | fmt.Fprintf(b, "%s%s\n", indentPadding, line) 174 | } 175 | } 176 | } 177 | 178 | type flagInfo struct { 179 | name string 180 | usage string 181 | defval string 182 | global bool 183 | } 184 | --------------------------------------------------------------------------------