├── .github ├── CODE-OF-CONDUCT.md ├── CONTRIBUTING.md └── workflows │ ├── codeql-analysis.yml │ ├── lint.yml │ └── test.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── README.md ├── example_gh_test.go ├── gh.go ├── gh_test.go ├── go.mod ├── go.sum ├── internal ├── git │ ├── git.go │ ├── git_test.go │ ├── remote.go │ ├── remote_test.go │ ├── url.go │ └── url_test.go ├── set │ ├── string_set.go │ └── string_set_test.go ├── testutils │ └── config_stub.go └── yamlmap │ ├── yaml_map.go │ └── yaml_map_test.go └── pkg ├── api ├── cache.go ├── cache_test.go ├── client_options.go ├── client_options_test.go ├── errors.go ├── errors_test.go ├── graphql_client.go ├── graphql_client_test.go ├── http_client.go ├── http_client_test.go ├── log_formatter.go ├── rest_client.go └── rest_client_test.go ├── asciisanitizer ├── sanitizer.go └── sanitizer_test.go ├── auth ├── auth.go └── auth_test.go ├── browser ├── browser.go └── browser_test.go ├── config ├── config.go ├── config_test.go └── errors.go ├── jq ├── jq.go └── jq_test.go ├── jsonpretty ├── format.go └── format_test.go ├── markdown ├── markdown.go └── markdown_test.go ├── prompter ├── mock.go ├── prompter.go └── prompter_test.go ├── repository ├── repository.go └── repository_test.go ├── ssh ├── ssh.go └── ssh_test.go ├── tableprinter ├── table.go └── table_test.go ├── template ├── template.go └── template_test.go ├── term ├── console.go ├── console_windows.go ├── env.go └── env_test.go ├── text ├── text.go └── text_test.go └── x ├── color ├── accessibility.go ├── accessibility_test.go └── color.go ├── markdown ├── accessibility.go ├── accessibility_test.go └── markdown.go └── x.go /.github/CODE-OF-CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | opensource@github.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available 126 | at [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | Hi! Thanks for your interest in contributing to the GitHub CLI Module! 4 | 5 | We accept pull requests for bug fixes and features where we've discussed the approach in an issue and given the go-ahead for a community member to work on it. We'd also love to hear about ideas for new features as issues. 6 | 7 | Please do: 8 | 9 | * Check existing issues to verify that the [bug][bug issues] or [feature request][feature request issues] has not already been submitted. 10 | * Open an issue if things aren't working as expected. 11 | * Open an issue to propose a significant change. 12 | * Open a pull request to fix a bug. 13 | * Open a pull request to fix documentation. 14 | * Open a pull request for any issue labelled [`help wanted`][hw] or [`good first issue`][gfi]. 15 | 16 | Please avoid: 17 | 18 | * Opening pull requests for issues marked `needs-design`, `needs-investigation`, or `blocked`. 19 | * Opening pull requests for any issue marked `core`. These issues require additional context from 20 | the core CLI team at GitHub and any external pull requests will not be accepted. 21 | 22 | ## Submitting a pull request 23 | 24 | 1. Create a new branch: `git checkout -b my-branch-name` 25 | 1. Make your change, add tests, and ensure tests pass 26 | 1. Submit a pull request: `gh pr create --web` 27 | 28 | Contributions to this project are [released][legal] to the public under the [project's open source license][license]. 29 | 30 | Please note that this project adheres to a [Contributor Code of Conduct][code-of-conduct]. By participating in this project you agree to abide by its terms. 31 | 32 | ## Resources 33 | 34 | - [How to Contribute to Open Source][] 35 | - [Using Pull Requests][] 36 | - [GitHub Help][] 37 | 38 | 39 | [bug issues]: https://github.com/cli/go-gh/issues?q=is%3Aopen+is%3Aissue+label%3Abug 40 | [feature request issues]: https://github.com/cli/go-gh/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement 41 | [hw]: https://github.com/cli/go-gh/labels/help%20wanted 42 | [gfi]: https://github.com/cli/go-gh/labels/good%20first%20issue 43 | [legal]: https://docs.github.com/en/free-pro-team@latest/github/site-policy/github-terms-of-service#6-contributions-under-repository-license 44 | [license]: ../LICENSE 45 | [code-of-conduct]: ./CODE-OF-CONDUCT.md 46 | [How to Contribute to Open Source]: https://opensource.guide/how-to-contribute/ 47 | [Using Pull Requests]: https://docs.github.com/en/free-pro-team@latest/github/collaborating-with-issues-and-pull-requests/about-pull-requests 48 | [GitHub Help]: https://docs.github.com/ 49 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | name: Code Scanning 2 | on: 3 | push: 4 | branches: [trunk] 5 | pull_request: 6 | branches: [trunk] 7 | schedule: 8 | - cron: "0 0 * * 0" 9 | permissions: 10 | actions: read 11 | contents: read 12 | security-events: write 13 | jobs: 14 | codeql: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v4 20 | 21 | - name: Initialize CodeQL 22 | uses: github/codeql-action/init@v2 23 | with: 24 | languages: go 25 | queries: security-and-quality 26 | 27 | - name: Perform CodeQL Analysis 28 | uses: github/codeql-action/analyze@v2 29 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: [push, pull_request] 3 | permissions: 4 | contents: read 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout repository 11 | uses: actions/checkout@v4 12 | 13 | - name: Set up Go 14 | uses: actions/setup-go@v5 15 | with: 16 | go-version-file: go.mod 17 | 18 | - name: Check dependencies 19 | run: | 20 | go mod tidy 21 | git diff --exit-code go.mod 22 | 23 | - name: Lint 24 | uses: golangci/golangci-lint-action@v6 25 | with: 26 | version: v1.64 27 | problem-matchers: true 28 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: [push, pull_request] 3 | permissions: 4 | contents: read 5 | jobs: 6 | test: 7 | strategy: 8 | fail-fast: false 9 | matrix: 10 | os: [ubuntu-latest, windows-latest, macos-latest] 11 | 12 | runs-on: ${{ matrix.os }} 13 | 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v4 17 | 18 | - name: Set up Go 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version-file: go.mod 22 | 23 | - name: Run tests 24 | run: go test -v ./... 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | - gofmt 4 | - godot 5 | 6 | linters-settings: 7 | godot: 8 | # comments to be checked: `declarations`, `toplevel`, or `all` 9 | scope: declarations 10 | # check that each sentence starts with a capital letter 11 | capital: true 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 GitHub Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go library for the GitHub CLI 2 | 3 | `go-gh` is a collection of Go modules to make authoring [GitHub CLI extensions][extensions] easier. 4 | 5 | Modules from this library will obey GitHub CLI conventions by default: 6 | 7 | - [`repository.Current()`](https://pkg.go.dev/github.com/cli/go-gh/v2/pkg/repository#current) respects the value of the `GH_REPO` environment variable and reads from git remote configuration as fallback. 8 | 9 | - GitHub API requests will be authenticated using the same mechanism as `gh`, i.e. using the values of `GH_TOKEN` and `GH_HOST` environment variables and falling back to the user's stored OAuth token. 10 | 11 | - [Terminal capabilities](https://pkg.go.dev/github.com/cli/go-gh/v2/pkg/term) are determined by taking environment variables `GH_FORCE_TTY`, `NO_COLOR`, `CLICOLOR`, etc. into account. 12 | 13 | - Generating [table](https://pkg.go.dev/github.com/cli/go-gh/v2/pkg/tableprinter) or [Go template](https://pkg.go.dev/github.com/cli/go-gh/pkg/template) output uses the same engine as gh. 14 | 15 | - The [`browser`](https://pkg.go.dev/github.com/cli/go-gh/v2/pkg/browser) module activates the user's preferred web browser. 16 | 17 | ## Usage 18 | 19 | See the full `go-gh` [reference documentation](https://pkg.go.dev/github.com/cli/go-gh/v2) for more information 20 | 21 | ```golang 22 | package main 23 | 24 | import ( 25 | "fmt" 26 | "log" 27 | "github.com/cli/go-gh/v2" 28 | "github.com/cli/go-gh/v2/pkg/api" 29 | ) 30 | 31 | func main() { 32 | // These examples assume `gh` is installed and has been authenticated. 33 | 34 | // Shell out to a gh command and read its output. 35 | issueList, _, err := gh.Exec("issue", "list", "--repo", "cli/cli", "--limit", "5") 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | fmt.Println(issueList.String()) 40 | 41 | // Use an API client to retrieve repository tags. 42 | client, err := api.DefaultRESTClient() 43 | if err != nil { 44 | log.Fatal(err) 45 | } 46 | response := []struct{ 47 | Name string 48 | }{} 49 | err = client.Get("repos/cli/cli/tags", &response) 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | fmt.Println(response) 54 | } 55 | ``` 56 | 57 | See [examples][] for more demonstrations of usage. 58 | 59 | ## Contributing 60 | 61 | If anything feels off, or if you feel that some functionality is missing, please check out our [contributing docs][contributing]. There you will find instructions for sharing your feedback and for submitting pull requests to the project. Thank you! 62 | 63 | [extensions]: https://docs.github.com/en/github-cli/github-cli/creating-github-cli-extensions 64 | [examples]: ./example_gh_test.go 65 | [contributing]: ./.github/CONTRIBUTING.md 66 | -------------------------------------------------------------------------------- /gh.go: -------------------------------------------------------------------------------- 1 | // Package gh is a library for CLI Go applications to help interface with the gh CLI tool, 2 | // and the GitHub API. 3 | // 4 | // Note that the examples in this package assume gh and git are installed. They do not run in 5 | // the Go Playground used by pkg.go.dev. 6 | package gh 7 | 8 | import ( 9 | "bytes" 10 | "context" 11 | "fmt" 12 | "io" 13 | "os" 14 | "os/exec" 15 | 16 | "github.com/cli/safeexec" 17 | ) 18 | 19 | // Exec invokes a gh command in a subprocess and captures the output and error streams. 20 | func Exec(args ...string) (stdout, stderr bytes.Buffer, err error) { 21 | ghExe, err := Path() 22 | if err != nil { 23 | return 24 | } 25 | err = run(context.Background(), ghExe, nil, nil, &stdout, &stderr, args) 26 | return 27 | } 28 | 29 | // ExecContext invokes a gh command in a subprocess and captures the output and error streams. 30 | func ExecContext(ctx context.Context, args ...string) (stdout, stderr bytes.Buffer, err error) { 31 | ghExe, err := Path() 32 | if err != nil { 33 | return 34 | } 35 | err = run(ctx, ghExe, nil, nil, &stdout, &stderr, args) 36 | return 37 | } 38 | 39 | // Exec invokes a gh command in a subprocess with its stdin, stdout, and stderr streams connected to 40 | // those of the parent process. This is suitable for running gh commands with interactive prompts. 41 | func ExecInteractive(ctx context.Context, args ...string) error { 42 | ghExe, err := Path() 43 | if err != nil { 44 | return err 45 | } 46 | return run(ctx, ghExe, nil, os.Stdin, os.Stdout, os.Stderr, args) 47 | } 48 | 49 | // Path searches for an executable named "gh" in the directories named by the PATH environment variable. 50 | // If the executable is found the result is an absolute path. 51 | func Path() (string, error) { 52 | if ghExe := os.Getenv("GH_PATH"); ghExe != "" { 53 | return ghExe, nil 54 | } 55 | return safeexec.LookPath("gh") 56 | } 57 | 58 | func run(ctx context.Context, ghExe string, env []string, stdin io.Reader, stdout, stderr io.Writer, args []string) error { 59 | cmd := exec.CommandContext(ctx, ghExe, args...) 60 | cmd.Stdin = stdin 61 | cmd.Stdout = stdout 62 | cmd.Stderr = stderr 63 | if env != nil { 64 | cmd.Env = env 65 | } 66 | if err := cmd.Run(); err != nil { 67 | return fmt.Errorf("gh execution failed: %w", err) 68 | } 69 | return nil 70 | } 71 | -------------------------------------------------------------------------------- /gh_test.go: -------------------------------------------------------------------------------- 1 | package gh 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "os" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestHelperProcess(t *testing.T) { 15 | if os.Getenv("GH_WANT_HELPER_PROCESS") != "1" { 16 | return 17 | } 18 | if err := func(args []string) error { 19 | if args[len(args)-1] == "error" { 20 | return fmt.Errorf("process exited with error") 21 | } 22 | fmt.Fprintf(os.Stdout, "%v", args) 23 | return nil 24 | }(os.Args[3:]); err != nil { 25 | fmt.Fprint(os.Stderr, err) 26 | os.Exit(1) 27 | } 28 | os.Exit(0) 29 | } 30 | 31 | func TestHelperProcessLongRunning(t *testing.T) { 32 | if os.Getenv("GH_WANT_HELPER_PROCESS") != "1" { 33 | return 34 | } 35 | args := os.Args[3:] 36 | fmt.Fprintf(os.Stdout, "%v", args) 37 | fmt.Fprint(os.Stderr, "going to sleep...") 38 | time.Sleep(10 * time.Second) 39 | fmt.Fprint(os.Stderr, "...going to exit") 40 | os.Exit(0) 41 | } 42 | 43 | func TestRun(t *testing.T) { 44 | var stdout, stderr bytes.Buffer 45 | err := run(context.TODO(), os.Args[0], []string{"GH_WANT_HELPER_PROCESS=1"}, nil, &stdout, &stderr, 46 | []string{"-test.run=TestHelperProcess", "--", "gh", "issue", "list"}) 47 | assert.NoError(t, err) 48 | assert.Equal(t, "[gh issue list]", stdout.String()) 49 | assert.Equal(t, "", stderr.String()) 50 | } 51 | 52 | func TestRunError(t *testing.T) { 53 | var stdout, stderr bytes.Buffer 54 | err := run(context.TODO(), os.Args[0], []string{"GH_WANT_HELPER_PROCESS=1"}, nil, &stdout, &stderr, 55 | []string{"-test.run=TestHelperProcess", "--", "gh", "error"}) 56 | assert.EqualError(t, err, "gh execution failed: exit status 1") 57 | assert.Equal(t, "", stdout.String()) 58 | assert.Equal(t, "process exited with error", stderr.String()) 59 | } 60 | 61 | func TestRunInteractiveContextCanceled(t *testing.T) { 62 | // pass current time to ensure that deadline has already passed 63 | ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 64 | cancel() 65 | err := run(ctx, os.Args[0], []string{"GH_WANT_HELPER_PROCESS=1"}, nil, nil, nil, 66 | []string{"-test.run=TestHelperProcessLongRunning", "--", "gh", "issue", "list"}) 67 | assert.EqualError(t, err, "gh execution failed: context deadline exceeded") 68 | } 69 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cli/go-gh/v2 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/AlecAivazis/survey/v2 v2.3.7 7 | github.com/MakeNowJust/heredoc v1.0.0 8 | github.com/Masterminds/sprig/v3 v3.3.0 9 | github.com/alecthomas/chroma/v2 v2.14.0 10 | github.com/charmbracelet/glamour v0.9.2-0.20250319212134-549f544650e3 11 | github.com/charmbracelet/lipgloss v1.1.1-0.20250319133953-166f707985bc 12 | github.com/cli/browser v1.3.0 13 | github.com/cli/safeexec v1.0.0 14 | github.com/cli/shurcooL-graphql v0.0.4 15 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 16 | github.com/henvic/httpretty v0.0.6 17 | github.com/itchyny/gojq v0.12.15 18 | github.com/leaanthony/go-ansi-parser v1.6.1 19 | github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d 20 | github.com/muesli/reflow v0.3.0 21 | github.com/muesli/termenv v0.16.0 22 | github.com/stretchr/testify v1.7.0 23 | github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e 24 | golang.org/x/sys v0.31.0 25 | golang.org/x/term v0.30.0 26 | golang.org/x/text v0.23.0 27 | gopkg.in/h2non/gock.v1 v1.1.2 28 | gopkg.in/yaml.v3 v3.0.1 29 | ) 30 | 31 | require ( 32 | dario.cat/mergo v1.0.1 // indirect 33 | github.com/Masterminds/goutils v1.1.1 // indirect 34 | github.com/Masterminds/semver/v3 v3.3.0 // indirect 35 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 36 | github.com/aymerick/douceur v0.2.0 // indirect 37 | github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect 38 | github.com/charmbracelet/x/ansi v0.8.0 // indirect 39 | github.com/charmbracelet/x/cellbuf v0.0.13 // indirect 40 | github.com/charmbracelet/x/term v0.2.1 // indirect 41 | github.com/davecgh/go-spew v1.1.1 // indirect 42 | github.com/dlclark/regexp2 v1.11.0 // indirect 43 | github.com/google/uuid v1.6.0 // indirect 44 | github.com/gorilla/css v1.0.1 // indirect 45 | github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect 46 | github.com/huandu/xstrings v1.5.0 // indirect 47 | github.com/itchyny/timefmt-go v0.1.5 // indirect 48 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect 49 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 50 | github.com/mattn/go-colorable v0.1.13 // indirect 51 | github.com/mattn/go-isatty v0.0.20 // indirect 52 | github.com/mattn/go-runewidth v0.0.16 // indirect 53 | github.com/microcosm-cc/bluemonday v1.0.27 // indirect 54 | github.com/mitchellh/copystructure v1.2.0 // indirect 55 | github.com/mitchellh/reflectwalk v1.0.2 // indirect 56 | github.com/pmezard/go-difflib v1.0.0 // indirect 57 | github.com/rivo/uniseg v0.4.7 // indirect 58 | github.com/shopspring/decimal v1.4.0 // indirect 59 | github.com/spf13/cast v1.7.0 // indirect 60 | github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect 61 | github.com/yuin/goldmark v1.7.8 // indirect 62 | github.com/yuin/goldmark-emoji v1.0.5 // indirect 63 | golang.org/x/crypto v0.35.0 // indirect 64 | golang.org/x/net v0.36.0 // indirect 65 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect 66 | ) 67 | -------------------------------------------------------------------------------- /internal/git/git.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "os/exec" 7 | 8 | "github.com/cli/safeexec" 9 | ) 10 | 11 | func Exec(args ...string) (stdOut, stdErr bytes.Buffer, err error) { 12 | path, err := path() 13 | if err != nil { 14 | err = fmt.Errorf("could not find git executable in PATH. error: %w", err) 15 | return 16 | } 17 | return run(path, nil, args...) 18 | } 19 | 20 | func path() (string, error) { 21 | return safeexec.LookPath("git") 22 | } 23 | 24 | func run(path string, env []string, args ...string) (stdOut, stdErr bytes.Buffer, err error) { 25 | cmd := exec.Command(path, args...) 26 | cmd.Stdout = &stdOut 27 | cmd.Stderr = &stdErr 28 | if env != nil { 29 | cmd.Env = env 30 | } 31 | err = cmd.Run() 32 | if err != nil { 33 | err = fmt.Errorf("failed to run git: %s. error: %w", stdErr.String(), err) 34 | return 35 | } 36 | return 37 | } 38 | -------------------------------------------------------------------------------- /internal/git/git_test.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestHelperProcess(t *testing.T) { 12 | if os.Getenv("GH_WANT_HELPER_PROCESS") != "1" { 13 | return 14 | } 15 | if err := func(args []string) error { 16 | if args[len(args)-1] == "error" { 17 | return fmt.Errorf("process exited with error") 18 | } 19 | fmt.Fprintf(os.Stdout, "%v", args) 20 | return nil 21 | }(os.Args[3:]); err != nil { 22 | fmt.Fprint(os.Stderr, err) 23 | os.Exit(1) 24 | } 25 | os.Exit(0) 26 | } 27 | 28 | func TestRun(t *testing.T) { 29 | stdOut, stdErr, err := run(os.Args[0], 30 | []string{"GH_WANT_HELPER_PROCESS=1"}, 31 | "-test.run=TestHelperProcess", "--", "git", "status") 32 | assert.NoError(t, err) 33 | assert.Equal(t, "[git status]", stdOut.String()) 34 | assert.Equal(t, "", stdErr.String()) 35 | } 36 | 37 | func TestRunError(t *testing.T) { 38 | stdOut, stdErr, err := run(os.Args[0], 39 | []string{"GH_WANT_HELPER_PROCESS=1"}, 40 | "-test.run=TestHelperProcess", "--", "git", "status", "error") 41 | assert.EqualError(t, err, "failed to run git: process exited with error. error: exit status 1") 42 | assert.Equal(t, "", stdOut.String()) 43 | assert.Equal(t, "process exited with error", stdErr.String()) 44 | } 45 | -------------------------------------------------------------------------------- /internal/git/remote.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "net/url" 5 | "regexp" 6 | "sort" 7 | "strings" 8 | ) 9 | 10 | var remoteRE = regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`) 11 | 12 | type RemoteSet []*Remote 13 | 14 | type Remote struct { 15 | Name string 16 | FetchURL *url.URL 17 | PushURL *url.URL 18 | Resolved string 19 | Host string 20 | Owner string 21 | Repo string 22 | } 23 | 24 | func (r RemoteSet) Len() int { return len(r) } 25 | func (r RemoteSet) Swap(i, j int) { r[i], r[j] = r[j], r[i] } 26 | func (r RemoteSet) Less(i, j int) bool { 27 | return remoteNameSortScore(r[i].Name) > remoteNameSortScore(r[j].Name) 28 | } 29 | 30 | func remoteNameSortScore(name string) int { 31 | switch strings.ToLower(name) { 32 | case "upstream": 33 | return 3 34 | case "github": 35 | return 2 36 | case "origin": 37 | return 1 38 | default: 39 | return 0 40 | } 41 | } 42 | 43 | func Remotes() (RemoteSet, error) { 44 | list, err := listRemotes() 45 | if err != nil { 46 | return nil, err 47 | } 48 | remotes := parseRemotes(list) 49 | setResolvedRemotes(remotes) 50 | sort.Sort(remotes) 51 | return remotes, nil 52 | } 53 | 54 | // Filter remotes by given hostnames, maintains original order. 55 | func (rs RemoteSet) FilterByHosts(hosts []string) RemoteSet { 56 | filtered := make(RemoteSet, 0) 57 | for _, remote := range rs { 58 | for _, host := range hosts { 59 | if strings.EqualFold(remote.Host, host) { 60 | filtered = append(filtered, remote) 61 | break 62 | } 63 | } 64 | } 65 | return filtered 66 | } 67 | 68 | func listRemotes() ([]string, error) { 69 | stdOut, _, err := Exec("remote", "-v") 70 | if err != nil { 71 | return nil, err 72 | } 73 | return toLines(stdOut.String()), nil 74 | } 75 | 76 | func parseRemotes(gitRemotes []string) RemoteSet { 77 | remotes := RemoteSet{} 78 | for _, r := range gitRemotes { 79 | match := remoteRE.FindStringSubmatch(r) 80 | if match == nil { 81 | continue 82 | } 83 | name := strings.TrimSpace(match[1]) 84 | urlStr := strings.TrimSpace(match[2]) 85 | urlType := strings.TrimSpace(match[3]) 86 | 87 | url, err := ParseURL(urlStr) 88 | if err != nil { 89 | continue 90 | } 91 | host, owner, repo, _ := RepoInfoFromURL(url) 92 | 93 | var rem *Remote 94 | if len(remotes) > 0 { 95 | rem = remotes[len(remotes)-1] 96 | if name != rem.Name { 97 | rem = nil 98 | } 99 | } 100 | if rem == nil { 101 | rem = &Remote{Name: name} 102 | remotes = append(remotes, rem) 103 | } 104 | 105 | switch urlType { 106 | case "fetch": 107 | rem.FetchURL = url 108 | rem.Host = host 109 | rem.Owner = owner 110 | rem.Repo = repo 111 | case "push": 112 | rem.PushURL = url 113 | if rem.Host == "" { 114 | rem.Host = host 115 | } 116 | if rem.Owner == "" { 117 | rem.Owner = owner 118 | } 119 | if rem.Repo == "" { 120 | rem.Repo = repo 121 | } 122 | } 123 | } 124 | return remotes 125 | } 126 | 127 | func setResolvedRemotes(remotes RemoteSet) { 128 | stdOut, _, err := Exec("config", "--get-regexp", `^remote\..*\.gh-resolved$`) 129 | if err != nil { 130 | return 131 | } 132 | for _, l := range toLines(stdOut.String()) { 133 | parts := strings.SplitN(l, " ", 2) 134 | if len(parts) < 2 { 135 | continue 136 | } 137 | rp := strings.SplitN(parts[0], ".", 3) 138 | if len(rp) < 2 { 139 | continue 140 | } 141 | name := rp[1] 142 | for _, r := range remotes { 143 | if r.Name == name { 144 | r.Resolved = parts[1] 145 | break 146 | } 147 | } 148 | } 149 | } 150 | 151 | func toLines(output string) []string { 152 | lines := strings.TrimSuffix(output, "\n") 153 | return strings.Split(lines, "\n") 154 | } 155 | -------------------------------------------------------------------------------- /internal/git/remote_test.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestRemotes(t *testing.T) { 12 | tempDir := t.TempDir() 13 | oldWd, _ := os.Getwd() 14 | assert.NoError(t, os.Chdir(tempDir)) 15 | t.Cleanup(func() { _ = os.Chdir(oldWd) }) 16 | _, _, err := Exec("init", "--quiet") 17 | assert.NoError(t, err) 18 | gitDir := filepath.Join(tempDir, ".git") 19 | remoteFile := filepath.Join(gitDir, "config") 20 | remotes := ` 21 | [remote "origin"] 22 | url = git@example.com:monalisa/origin.git 23 | [remote "test"] 24 | url = git://github.com/hubot/test.git 25 | gh-resolved = other 26 | [remote "upstream"] 27 | url = https://github.com/monalisa/upstream.git 28 | gh-resolved = base 29 | [remote "github"] 30 | url = git@github.com:hubot/github.git 31 | ` 32 | f, err := os.OpenFile(remoteFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0755) 33 | assert.NoError(t, err) 34 | _, err = f.Write([]byte(remotes)) 35 | assert.NoError(t, err) 36 | err = f.Close() 37 | assert.NoError(t, err) 38 | rs, err := Remotes() 39 | assert.NoError(t, err) 40 | assert.Equal(t, 4, len(rs)) 41 | assert.Equal(t, "upstream", rs[0].Name) 42 | assert.Equal(t, "base", rs[0].Resolved) 43 | assert.Equal(t, "github", rs[1].Name) 44 | assert.Equal(t, "", rs[1].Resolved) 45 | assert.Equal(t, "origin", rs[2].Name) 46 | assert.Equal(t, "", rs[2].Resolved) 47 | assert.Equal(t, "test", rs[3].Name) 48 | assert.Equal(t, "other", rs[3].Resolved) 49 | } 50 | 51 | func TestParseRemotes(t *testing.T) { 52 | remoteList := []string{ 53 | "mona\tgit@github.com:monalisa/myfork.git (fetch)", 54 | "origin\thttps://github.com/monalisa/octo-cat.git (fetch)", 55 | "origin\thttps://github.com/monalisa/octo-cat-push.git (push)", 56 | "upstream\thttps://example.com/nowhere.git (fetch)", 57 | "upstream\thttps://github.com/hubot/tools (push)", 58 | "zardoz\thttps://example.com/zed.git (push)", 59 | "koke\tgit://github.com/koke/grit.git (fetch)", 60 | "koke\tgit://github.com/koke/grit.git (push)", 61 | } 62 | 63 | r := parseRemotes(remoteList) 64 | assert.Equal(t, 5, len(r)) 65 | 66 | assert.Equal(t, "mona", r[0].Name) 67 | assert.Equal(t, "ssh://git@github.com/monalisa/myfork.git", r[0].FetchURL.String()) 68 | assert.Nil(t, r[0].PushURL) 69 | assert.Equal(t, "github.com", r[0].Host) 70 | assert.Equal(t, "monalisa", r[0].Owner) 71 | assert.Equal(t, "myfork", r[0].Repo) 72 | 73 | assert.Equal(t, "origin", r[1].Name) 74 | assert.Equal(t, "/monalisa/octo-cat.git", r[1].FetchURL.Path) 75 | assert.Equal(t, "/monalisa/octo-cat-push.git", r[1].PushURL.Path) 76 | assert.Equal(t, "github.com", r[1].Host) 77 | assert.Equal(t, "monalisa", r[1].Owner) 78 | assert.Equal(t, "octo-cat", r[1].Repo) 79 | 80 | assert.Equal(t, "upstream", r[2].Name) 81 | assert.Equal(t, "example.com", r[2].FetchURL.Host) 82 | assert.Equal(t, "github.com", r[2].PushURL.Host) 83 | assert.Equal(t, "github.com", r[2].Host) 84 | assert.Equal(t, "hubot", r[2].Owner) 85 | assert.Equal(t, "tools", r[2].Repo) 86 | 87 | assert.Equal(t, "zardoz", r[3].Name) 88 | assert.Nil(t, r[3].FetchURL) 89 | assert.Equal(t, "https://example.com/zed.git", r[3].PushURL.String()) 90 | assert.Equal(t, "", r[3].Host) 91 | assert.Equal(t, "", r[3].Owner) 92 | assert.Equal(t, "", r[3].Repo) 93 | 94 | assert.Equal(t, "koke", r[4].Name) 95 | assert.Equal(t, "/koke/grit.git", r[4].FetchURL.Path) 96 | assert.Equal(t, "/koke/grit.git", r[4].PushURL.Path) 97 | assert.Equal(t, "github.com", r[4].Host) 98 | assert.Equal(t, "koke", r[4].Owner) 99 | assert.Equal(t, "grit", r[4].Repo) 100 | } 101 | -------------------------------------------------------------------------------- /internal/git/url.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "strings" 7 | ) 8 | 9 | func IsURL(u string) bool { 10 | return strings.HasPrefix(u, "git@") || isSupportedProtocol(u) 11 | } 12 | 13 | func isSupportedProtocol(u string) bool { 14 | return strings.HasPrefix(u, "ssh:") || 15 | strings.HasPrefix(u, "git+ssh:") || 16 | strings.HasPrefix(u, "git:") || 17 | strings.HasPrefix(u, "http:") || 18 | strings.HasPrefix(u, "git+https:") || 19 | strings.HasPrefix(u, "https:") 20 | } 21 | 22 | func isPossibleProtocol(u string) bool { 23 | return isSupportedProtocol(u) || 24 | strings.HasPrefix(u, "ftp:") || 25 | strings.HasPrefix(u, "ftps:") || 26 | strings.HasPrefix(u, "file:") 27 | } 28 | 29 | // ParseURL normalizes git remote urls. 30 | func ParseURL(rawURL string) (u *url.URL, err error) { 31 | if !isPossibleProtocol(rawURL) && 32 | strings.ContainsRune(rawURL, ':') && 33 | // Not a Windows path. 34 | !strings.ContainsRune(rawURL, '\\') { 35 | // Support scp-like syntax for ssh protocol. 36 | rawURL = "ssh://" + strings.Replace(rawURL, ":", "/", 1) 37 | } 38 | 39 | u, err = url.Parse(rawURL) 40 | if err != nil { 41 | return 42 | } 43 | 44 | if u.Scheme == "git+ssh" { 45 | u.Scheme = "ssh" 46 | } 47 | 48 | if u.Scheme == "git+https" { 49 | u.Scheme = "https" 50 | } 51 | 52 | if u.Scheme != "ssh" { 53 | return 54 | } 55 | 56 | if strings.HasPrefix(u.Path, "//") { 57 | u.Path = strings.TrimPrefix(u.Path, "/") 58 | } 59 | 60 | if idx := strings.Index(u.Host, ":"); idx >= 0 { 61 | u.Host = u.Host[0:idx] 62 | } 63 | 64 | return 65 | } 66 | 67 | // Extract GitHub repository information from a git remote URL. 68 | func RepoInfoFromURL(u *url.URL) (host string, owner string, name string, err error) { 69 | if u.Hostname() == "" { 70 | return "", "", "", fmt.Errorf("no hostname detected") 71 | } 72 | 73 | parts := strings.SplitN(strings.Trim(u.Path, "/"), "/", 3) 74 | if len(parts) != 2 { 75 | return "", "", "", fmt.Errorf("invalid path: %s", u.Path) 76 | } 77 | 78 | return normalizeHostname(u.Hostname()), parts[0], strings.TrimSuffix(parts[1], ".git"), nil 79 | } 80 | 81 | func normalizeHostname(h string) string { 82 | return strings.ToLower(strings.TrimPrefix(h, "www.")) 83 | } 84 | -------------------------------------------------------------------------------- /internal/git/url_test.go: -------------------------------------------------------------------------------- 1 | package git 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestIsURL(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | url string 14 | want bool 15 | }{ 16 | { 17 | name: "scp-like", 18 | url: "git@example.com:owner/repo", 19 | want: true, 20 | }, 21 | { 22 | name: "scp-like with no user", 23 | url: "example.com:owner/repo", 24 | want: false, 25 | }, 26 | { 27 | name: "ssh", 28 | url: "ssh://git@example.com/owner/repo", 29 | want: true, 30 | }, 31 | { 32 | name: "git", 33 | url: "git://example.com/owner/repo", 34 | want: true, 35 | }, 36 | { 37 | name: "git with extension", 38 | url: "git://example.com/owner/repo.git", 39 | want: true, 40 | }, 41 | { 42 | name: "git+ssh", 43 | url: "git+ssh://git@example.com/owner/repo.git", 44 | want: true, 45 | }, 46 | { 47 | name: "git+https", 48 | url: "git+https://example.com/owner/repo.git", 49 | want: true, 50 | }, 51 | { 52 | name: "http", 53 | url: "http://example.com/owner/repo.git", 54 | want: true, 55 | }, 56 | { 57 | name: "https", 58 | url: "https://example.com/owner/repo.git", 59 | want: true, 60 | }, 61 | { 62 | name: "no protocol", 63 | url: "example.com/owner/repo", 64 | want: false, 65 | }, 66 | } 67 | for _, tt := range tests { 68 | t.Run(tt.name, func(t *testing.T) { 69 | assert.Equal(t, tt.want, IsURL(tt.url)) 70 | }) 71 | } 72 | } 73 | 74 | func TestParseURL(t *testing.T) { 75 | type url struct { 76 | Scheme string 77 | User string 78 | Host string 79 | Path string 80 | } 81 | 82 | tests := []struct { 83 | name string 84 | url string 85 | want url 86 | wantErr bool 87 | }{ 88 | { 89 | name: "HTTPS", 90 | url: "https://example.com/owner/repo.git", 91 | want: url{ 92 | Scheme: "https", 93 | User: "", 94 | Host: "example.com", 95 | Path: "/owner/repo.git", 96 | }, 97 | }, 98 | { 99 | name: "HTTP", 100 | url: "http://example.com/owner/repo.git", 101 | want: url{ 102 | Scheme: "http", 103 | User: "", 104 | Host: "example.com", 105 | Path: "/owner/repo.git", 106 | }, 107 | }, 108 | { 109 | name: "git", 110 | url: "git://example.com/owner/repo.git", 111 | want: url{ 112 | Scheme: "git", 113 | User: "", 114 | Host: "example.com", 115 | Path: "/owner/repo.git", 116 | }, 117 | }, 118 | { 119 | name: "ssh", 120 | url: "ssh://git@example.com/owner/repo.git", 121 | want: url{ 122 | Scheme: "ssh", 123 | User: "git", 124 | Host: "example.com", 125 | Path: "/owner/repo.git", 126 | }, 127 | }, 128 | { 129 | name: "ssh with port", 130 | url: "ssh://git@example.com:443/owner/repo.git", 131 | want: url{ 132 | Scheme: "ssh", 133 | User: "git", 134 | Host: "example.com", 135 | Path: "/owner/repo.git", 136 | }, 137 | }, 138 | { 139 | name: "git+ssh", 140 | url: "git+ssh://example.com/owner/repo.git", 141 | want: url{ 142 | Scheme: "ssh", 143 | User: "", 144 | Host: "example.com", 145 | Path: "/owner/repo.git", 146 | }, 147 | }, 148 | { 149 | name: "git+https", 150 | url: "git+https://example.com/owner/repo.git", 151 | want: url{ 152 | Scheme: "https", 153 | User: "", 154 | Host: "example.com", 155 | Path: "/owner/repo.git", 156 | }, 157 | }, 158 | { 159 | name: "scp-like", 160 | url: "git@example.com:owner/repo.git", 161 | want: url{ 162 | Scheme: "ssh", 163 | User: "git", 164 | Host: "example.com", 165 | Path: "/owner/repo.git", 166 | }, 167 | }, 168 | { 169 | name: "scp-like, leading slash", 170 | url: "git@example.com:/owner/repo.git", 171 | want: url{ 172 | Scheme: "ssh", 173 | User: "git", 174 | Host: "example.com", 175 | Path: "/owner/repo.git", 176 | }, 177 | }, 178 | { 179 | name: "file protocol", 180 | url: "file:///example.com/owner/repo.git", 181 | want: url{ 182 | Scheme: "file", 183 | User: "", 184 | Host: "", 185 | Path: "/example.com/owner/repo.git", 186 | }, 187 | }, 188 | { 189 | name: "file path", 190 | url: "/example.com/owner/repo.git", 191 | want: url{ 192 | Scheme: "", 193 | User: "", 194 | Host: "", 195 | Path: "/example.com/owner/repo.git", 196 | }, 197 | }, 198 | { 199 | name: "Windows file path", 200 | url: "C:\\example.com\\owner\\repo.git", 201 | want: url{ 202 | Scheme: "c", 203 | User: "", 204 | Host: "", 205 | Path: "", 206 | }, 207 | }, 208 | } 209 | for _, tt := range tests { 210 | t.Run(tt.name, func(t *testing.T) { 211 | u, err := ParseURL(tt.url) 212 | if tt.wantErr { 213 | assert.Error(t, err) 214 | return 215 | } 216 | assert.NoError(t, err) 217 | assert.Equal(t, tt.want.Scheme, u.Scheme) 218 | assert.Equal(t, tt.want.User, u.User.Username()) 219 | assert.Equal(t, tt.want.Host, u.Host) 220 | assert.Equal(t, tt.want.Path, u.Path) 221 | }) 222 | } 223 | } 224 | 225 | func TestRepoInfoFromURL(t *testing.T) { 226 | tests := []struct { 227 | name string 228 | input string 229 | wantHost string 230 | wantOwner string 231 | wantRepo string 232 | wantErr bool 233 | wantErrMsg string 234 | }{ 235 | { 236 | name: "github.com URL", 237 | input: "https://github.com/monalisa/octo-cat.git", 238 | wantHost: "github.com", 239 | wantOwner: "monalisa", 240 | wantRepo: "octo-cat", 241 | }, 242 | { 243 | name: "github.com URL with trailing slash", 244 | input: "https://github.com/monalisa/octo-cat/", 245 | wantHost: "github.com", 246 | wantOwner: "monalisa", 247 | wantRepo: "octo-cat", 248 | }, 249 | { 250 | name: "www.github.com URL", 251 | input: "http://www.GITHUB.com/monalisa/octo-cat.git", 252 | wantHost: "github.com", 253 | wantOwner: "monalisa", 254 | wantRepo: "octo-cat", 255 | }, 256 | { 257 | name: "too many path components", 258 | input: "https://github.com/monalisa/octo-cat/pulls", 259 | wantErr: true, 260 | wantErrMsg: "invalid path: /monalisa/octo-cat/pulls", 261 | }, 262 | { 263 | name: "non-GitHub hostname", 264 | input: "https://example.com/one/two", 265 | wantHost: "example.com", 266 | wantOwner: "one", 267 | wantRepo: "two", 268 | }, 269 | { 270 | name: "filesystem path", 271 | input: "/path/to/file", 272 | wantErr: true, 273 | wantErrMsg: "no hostname detected", 274 | }, 275 | { 276 | name: "filesystem path with scheme", 277 | input: "file:///path/to/file", 278 | wantErr: true, 279 | wantErrMsg: "no hostname detected", 280 | }, 281 | { 282 | name: "github.com SSH URL", 283 | input: "ssh://github.com/monalisa/octo-cat.git", 284 | wantHost: "github.com", 285 | wantOwner: "monalisa", 286 | wantRepo: "octo-cat", 287 | }, 288 | { 289 | name: "github.com HTTPS+SSH URL", 290 | input: "https+ssh://github.com/monalisa/octo-cat.git", 291 | wantHost: "github.com", 292 | wantOwner: "monalisa", 293 | wantRepo: "octo-cat", 294 | }, 295 | { 296 | name: "github.com git URL", 297 | input: "git://github.com/monalisa/octo-cat.git", 298 | wantHost: "github.com", 299 | wantOwner: "monalisa", 300 | wantRepo: "octo-cat", 301 | }, 302 | } 303 | 304 | for _, tt := range tests { 305 | t.Run(tt.name, func(t *testing.T) { 306 | u, err := url.Parse(tt.input) 307 | assert.NoError(t, err) 308 | host, owner, repo, err := RepoInfoFromURL(u) 309 | if tt.wantErr { 310 | assert.EqualError(t, err, tt.wantErrMsg) 311 | return 312 | } 313 | assert.NoError(t, err) 314 | assert.Equal(t, tt.wantHost, host) 315 | assert.Equal(t, tt.wantOwner, owner) 316 | assert.Equal(t, tt.wantRepo, repo) 317 | }) 318 | } 319 | } 320 | -------------------------------------------------------------------------------- /internal/set/string_set.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | var exists = struct{}{} 4 | 5 | type stringSet struct { 6 | v []string 7 | m map[string]struct{} 8 | } 9 | 10 | func NewStringSet() *stringSet { 11 | s := &stringSet{} 12 | s.m = make(map[string]struct{}) 13 | s.v = []string{} 14 | return s 15 | } 16 | 17 | func (s *stringSet) Add(value string) { 18 | if s.Contains(value) { 19 | return 20 | } 21 | s.m[value] = exists 22 | s.v = append(s.v, value) 23 | } 24 | 25 | func (s *stringSet) AddValues(values []string) { 26 | for _, v := range values { 27 | s.Add(v) 28 | } 29 | } 30 | 31 | func (s *stringSet) Remove(value string) { 32 | if !s.Contains(value) { 33 | return 34 | } 35 | delete(s.m, value) 36 | s.v = sliceWithout(s.v, value) 37 | } 38 | 39 | func sliceWithout(s []string, v string) []string { 40 | idx := -1 41 | for i, item := range s { 42 | if item == v { 43 | idx = i 44 | break 45 | } 46 | } 47 | if idx < 0 { 48 | return s 49 | } 50 | return append(s[:idx], s[idx+1:]...) 51 | } 52 | 53 | func (s *stringSet) RemoveValues(values []string) { 54 | for _, v := range values { 55 | s.Remove(v) 56 | } 57 | } 58 | 59 | func (s *stringSet) Contains(value string) bool { 60 | _, c := s.m[value] 61 | return c 62 | } 63 | 64 | func (s *stringSet) Len() int { 65 | return len(s.m) 66 | } 67 | 68 | func (s *stringSet) ToSlice() []string { 69 | return s.v 70 | } 71 | -------------------------------------------------------------------------------- /internal/set/string_set_test.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_StringSlice_ToSlice(t *testing.T) { 10 | s := NewStringSet() 11 | s.Add("one") 12 | s.Add("two") 13 | s.Add("three") 14 | s.Add("two") 15 | assert.Equal(t, []string{"one", "two", "three"}, s.ToSlice()) 16 | } 17 | 18 | func Test_StringSlice_Remove(t *testing.T) { 19 | s := NewStringSet() 20 | s.Add("one") 21 | s.Add("two") 22 | s.Add("three") 23 | s.Remove("two") 24 | assert.Equal(t, []string{"one", "three"}, s.ToSlice()) 25 | assert.False(t, s.Contains("two")) 26 | assert.Equal(t, 2, s.Len()) 27 | } 28 | -------------------------------------------------------------------------------- /internal/testutils/config_stub.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cli/go-gh/v2/pkg/config" 7 | ) 8 | 9 | // StubConfig replaces the config.Read function with a function that returns a config object 10 | // created from the given config string. It also sets up a cleanup function that restores the 11 | // original config.Read function. 12 | func StubConfig(t *testing.T, cfgStr string) { 13 | t.Helper() 14 | old := config.Read 15 | config.Read = func(_ *config.Config) (*config.Config, error) { 16 | return config.ReadFromString(cfgStr), nil 17 | } 18 | t.Cleanup(func() { 19 | config.Read = old 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /internal/yamlmap/yaml_map.go: -------------------------------------------------------------------------------- 1 | // Package yamlmap is a wrapper of gopkg.in/yaml.v3 for interacting 2 | // with yaml data as if it were a map. 3 | package yamlmap 4 | 5 | import ( 6 | "errors" 7 | 8 | "gopkg.in/yaml.v3" 9 | ) 10 | 11 | const ( 12 | modified = "modifed" 13 | ) 14 | 15 | type Map struct { 16 | *yaml.Node 17 | } 18 | 19 | var ErrNotFound = errors.New("not found") 20 | var ErrInvalidYaml = errors.New("invalid yaml") 21 | var ErrInvalidFormat = errors.New("invalid format") 22 | 23 | func StringValue(value string) *Map { 24 | return &Map{&yaml.Node{ 25 | Kind: yaml.ScalarNode, 26 | Tag: "!!str", 27 | Value: value, 28 | }} 29 | } 30 | 31 | func MapValue() *Map { 32 | return &Map{&yaml.Node{ 33 | Kind: yaml.MappingNode, 34 | Tag: "!!map", 35 | }} 36 | } 37 | 38 | func NullValue() *Map { 39 | return &Map{&yaml.Node{ 40 | Kind: yaml.ScalarNode, 41 | Tag: "!!null", 42 | }} 43 | } 44 | 45 | func Unmarshal(data []byte) (*Map, error) { 46 | var root yaml.Node 47 | err := yaml.Unmarshal(data, &root) 48 | if err != nil { 49 | return nil, ErrInvalidYaml 50 | } 51 | if len(root.Content) == 0 { 52 | return MapValue(), nil 53 | } 54 | if root.Content[0].Kind != yaml.MappingNode { 55 | return nil, ErrInvalidFormat 56 | } 57 | return &Map{root.Content[0]}, nil 58 | } 59 | 60 | func Marshal(m *Map) ([]byte, error) { 61 | return yaml.Marshal(m.Node) 62 | } 63 | 64 | func (m *Map) AddEntry(key string, value *Map) { 65 | keyNode := &yaml.Node{ 66 | Kind: yaml.ScalarNode, 67 | Tag: "!!str", 68 | Value: key, 69 | } 70 | m.Content = append(m.Content, keyNode, value.Node) 71 | m.SetModified() 72 | } 73 | 74 | func (m *Map) Empty() bool { 75 | return len(m.Content) == 0 76 | } 77 | 78 | func (m *Map) FindEntry(key string) (*Map, error) { 79 | // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. 80 | // When iterating over the content slice we only want to compare the keys of the yamlMap. 81 | for i, v := range m.Content { 82 | if i%2 != 0 { 83 | continue 84 | } 85 | if v.Value == key { 86 | if i+1 < len(m.Content) { 87 | return &Map{m.Content[i+1]}, nil 88 | } 89 | } 90 | } 91 | return nil, ErrNotFound 92 | } 93 | 94 | func (m *Map) Keys() []string { 95 | // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. 96 | // When iterating over the content slice we only want to select the keys of the yamlMap. 97 | keys := []string{} 98 | for i, v := range m.Content { 99 | if i%2 != 0 { 100 | continue 101 | } 102 | keys = append(keys, v.Value) 103 | } 104 | return keys 105 | } 106 | 107 | func (m *Map) RemoveEntry(key string) error { 108 | // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. 109 | // When iterating over the content slice we only want to compare the keys of the yamlMap. 110 | // If we find they key to remove, remove the key and its value from the content slice. 111 | found, skipNext := false, false 112 | newContent := []*yaml.Node{} 113 | for i, v := range m.Content { 114 | if skipNext { 115 | skipNext = false 116 | continue 117 | } 118 | if i%2 != 0 || v.Value != key { 119 | newContent = append(newContent, v) 120 | } else { 121 | found = true 122 | skipNext = true 123 | m.SetModified() 124 | } 125 | } 126 | if !found { 127 | return ErrNotFound 128 | } 129 | m.Content = newContent 130 | return nil 131 | } 132 | 133 | func (m *Map) SetEntry(key string, value *Map) { 134 | // Note: The content slice of a yamlMap looks like [key1, value1, key2, value2, ...]. 135 | // When iterating over the content slice we only want to compare the keys of the yamlMap. 136 | // If we find they key to set, set the next item in the content slice to the new value. 137 | m.SetModified() 138 | for i, v := range m.Content { 139 | if i%2 != 0 || v.Value != key { 140 | continue 141 | } 142 | if v.Value == key { 143 | if i+1 < len(m.Content) { 144 | m.Content[i+1] = value.Node 145 | return 146 | } 147 | } 148 | } 149 | m.AddEntry(key, value) 150 | } 151 | 152 | // Note: This is a hack to introduce the concept of modified/unmodified 153 | // on top of gopkg.in/yaml.v3. This works by setting the Value property 154 | // of a MappingNode to a specific value and then later checking if the 155 | // node's Value property is that specific value. When a MappingNode gets 156 | // output as a string the Value property is not used, thus changing it 157 | // has no impact for our purposes. 158 | func (m *Map) SetModified() { 159 | // Can not mark a non-mapping node as modified 160 | if m.Node.Kind != yaml.MappingNode && m.Node.Tag == "!!null" { 161 | m.Node.Kind = yaml.MappingNode 162 | m.Node.Tag = "!!map" 163 | } 164 | if m.Node.Kind == yaml.MappingNode { 165 | m.Node.Value = modified 166 | } 167 | } 168 | 169 | // Traverse map using BFS to set all nodes as unmodified. 170 | func (m *Map) SetUnmodified() { 171 | i := 0 172 | queue := []*yaml.Node{m.Node} 173 | for { 174 | if i > (len(queue) - 1) { 175 | break 176 | } 177 | q := queue[i] 178 | i = i + 1 179 | if q.Kind != yaml.MappingNode { 180 | continue 181 | } 182 | q.Value = "" 183 | queue = append(queue, q.Content...) 184 | } 185 | } 186 | 187 | // Traverse map using BFS to searach for any nodes that have been modified. 188 | func (m *Map) IsModified() bool { 189 | i := 0 190 | queue := []*yaml.Node{m.Node} 191 | for { 192 | if i > (len(queue) - 1) { 193 | break 194 | } 195 | q := queue[i] 196 | i = i + 1 197 | if q.Kind != yaml.MappingNode { 198 | continue 199 | } 200 | if q.Value == modified { 201 | return true 202 | } 203 | queue = append(queue, q.Content...) 204 | } 205 | return false 206 | } 207 | 208 | func (m *Map) String() string { 209 | data, err := Marshal(m) 210 | if err != nil { 211 | return "" 212 | } 213 | return string(data) 214 | } 215 | -------------------------------------------------------------------------------- /internal/yamlmap/yaml_map_test.go: -------------------------------------------------------------------------------- 1 | package yamlmap 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestMapAddEntry(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | key string 13 | value string 14 | wantValue string 15 | wantLength int 16 | }{ 17 | { 18 | name: "add entry with key that is not present", 19 | key: "notPresent", 20 | value: "test1", 21 | wantValue: "test1", 22 | wantLength: 10, 23 | }, 24 | { 25 | name: "add entry with key that is already present", 26 | key: "erroneous", 27 | value: "test2", 28 | wantValue: "same", 29 | wantLength: 10, 30 | }, 31 | } 32 | 33 | for _, tt := range tests { 34 | m := testMap() 35 | t.Run(tt.name, func(t *testing.T) { 36 | m.AddEntry(tt.key, StringValue(tt.value)) 37 | entry, err := m.FindEntry(tt.key) 38 | assert.NoError(t, err) 39 | assert.Equal(t, tt.wantValue, entry.Value) 40 | assert.Equal(t, tt.wantLength, len(m.Content)) 41 | assert.True(t, m.IsModified()) 42 | }) 43 | } 44 | } 45 | 46 | func TestMapEmpty(t *testing.T) { 47 | m := blankMap() 48 | assert.Equal(t, true, m.Empty()) 49 | m.AddEntry("test", StringValue("test")) 50 | assert.Equal(t, false, m.Empty()) 51 | } 52 | 53 | func TestMapFindEntry(t *testing.T) { 54 | tests := []struct { 55 | name string 56 | key string 57 | output string 58 | wantErr bool 59 | }{ 60 | { 61 | name: "find key", 62 | key: "valid", 63 | output: "present", 64 | }, 65 | { 66 | name: "find key that is not present", 67 | key: "invalid", 68 | wantErr: true, 69 | }, 70 | { 71 | name: "find key with blank value", 72 | key: "blank", 73 | output: "", 74 | }, 75 | { 76 | name: "find key that has same content as a value", 77 | key: "same", 78 | output: "logical", 79 | }, 80 | } 81 | 82 | for _, tt := range tests { 83 | m := testMap() 84 | t.Run(tt.name, func(t *testing.T) { 85 | out, err := m.FindEntry(tt.key) 86 | if tt.wantErr { 87 | assert.EqualError(t, err, "not found") 88 | assert.False(t, m.IsModified()) 89 | return 90 | } 91 | assert.NoError(t, err) 92 | assert.Equal(t, tt.output, out.Value) 93 | assert.False(t, m.IsModified()) 94 | }) 95 | } 96 | } 97 | 98 | func TestMapFindEntryModified(t *testing.T) { 99 | m := testMap() 100 | entry, err := m.FindEntry("valid") 101 | assert.NoError(t, err) 102 | assert.Equal(t, "present", entry.Value) 103 | entry.Value = "test" 104 | assert.Equal(t, "test", entry.Value) 105 | entry2, err := m.FindEntry("valid") 106 | assert.NoError(t, err) 107 | assert.Equal(t, "test", entry2.Value) 108 | } 109 | 110 | func TestMapKeys(t *testing.T) { 111 | tests := []struct { 112 | name string 113 | m *Map 114 | wantKeys []string 115 | }{ 116 | { 117 | name: "keys for full map", 118 | m: testMap(), 119 | wantKeys: []string{"valid", "erroneous", "blank", "same"}, 120 | }, 121 | { 122 | name: "keys for empty map", 123 | m: blankMap(), 124 | wantKeys: []string{}, 125 | }, 126 | } 127 | for _, tt := range tests { 128 | t.Run(tt.name, func(t *testing.T) { 129 | keys := tt.m.Keys() 130 | assert.Equal(t, tt.wantKeys, keys) 131 | assert.False(t, tt.m.IsModified()) 132 | }) 133 | } 134 | } 135 | 136 | func TestMapRemoveEntry(t *testing.T) { 137 | tests := []struct { 138 | name string 139 | key string 140 | wantLength int 141 | wantErr bool 142 | }{ 143 | { 144 | name: "remove key", 145 | key: "erroneous", 146 | wantLength: 6, 147 | }, 148 | { 149 | name: "remove key that is not present", 150 | key: "invalid", 151 | wantLength: 8, 152 | wantErr: true, 153 | }, 154 | { 155 | name: "remove key that has same content as a value", 156 | key: "same", 157 | wantLength: 6, 158 | }, 159 | } 160 | 161 | for _, tt := range tests { 162 | m := testMap() 163 | t.Run(tt.name, func(t *testing.T) { 164 | err := m.RemoveEntry(tt.key) 165 | if tt.wantErr { 166 | assert.EqualError(t, err, "not found") 167 | assert.False(t, m.IsModified()) 168 | } else { 169 | assert.NoError(t, err) 170 | assert.True(t, m.IsModified()) 171 | } 172 | assert.Equal(t, tt.wantLength, len(m.Content)) 173 | _, err = m.FindEntry(tt.key) 174 | assert.EqualError(t, err, "not found") 175 | }) 176 | } 177 | } 178 | 179 | func TestMapSetEntry(t *testing.T) { 180 | tests := []struct { 181 | name string 182 | key string 183 | value *Map 184 | wantLength int 185 | }{ 186 | { 187 | name: "sets key that is not present", 188 | key: "not", 189 | value: StringValue("present"), 190 | wantLength: 10, 191 | }, 192 | { 193 | name: "sets key that is present", 194 | key: "erroneous", 195 | value: StringValue("not same"), 196 | wantLength: 8, 197 | }, 198 | } 199 | for _, tt := range tests { 200 | m := testMap() 201 | t.Run(tt.name, func(t *testing.T) { 202 | m.SetEntry(tt.key, tt.value) 203 | assert.True(t, m.IsModified()) 204 | assert.Equal(t, tt.wantLength, len(m.Content)) 205 | e, err := m.FindEntry(tt.key) 206 | assert.NoError(t, err) 207 | assert.Equal(t, tt.value.Value, e.Value) 208 | }) 209 | } 210 | } 211 | 212 | func TestUnmarshal(t *testing.T) { 213 | tests := []struct { 214 | name string 215 | data []byte 216 | wantErr string 217 | wantEmpty bool 218 | }{ 219 | { 220 | name: "valid yaml", 221 | data: []byte(`{test: "data"}`), 222 | }, 223 | { 224 | name: "empty yaml", 225 | data: []byte(``), 226 | wantEmpty: true, 227 | }, 228 | { 229 | name: "invalid yaml", 230 | data: []byte(`{test: `), 231 | wantErr: "invalid yaml", 232 | }, 233 | { 234 | name: "invalid format", 235 | data: []byte(`data`), 236 | wantErr: "invalid format", 237 | }, 238 | } 239 | for _, tt := range tests { 240 | t.Run(tt.name, func(t *testing.T) { 241 | m, err := Unmarshal(tt.data) 242 | if tt.wantErr != "" { 243 | assert.EqualError(t, err, tt.wantErr) 244 | assert.Nil(t, m) 245 | return 246 | } 247 | assert.NoError(t, err) 248 | assert.Equal(t, tt.wantEmpty, m.Empty()) 249 | }) 250 | } 251 | } 252 | 253 | func testMap() *Map { 254 | var data = ` 255 | valid: present 256 | erroneous: same 257 | blank: 258 | same: logical 259 | ` 260 | m, _ := Unmarshal([]byte(data)) 261 | return m 262 | } 263 | 264 | func blankMap() *Map { 265 | return MapValue() 266 | } 267 | -------------------------------------------------------------------------------- /pkg/api/cache.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "crypto/sha256" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "os" 12 | "path/filepath" 13 | "strings" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | type cache struct { 19 | dir string 20 | ttl time.Duration 21 | } 22 | 23 | type cacheRoundTripper struct { 24 | fs fileStorage 25 | rt http.RoundTripper 26 | } 27 | 28 | type fileStorage struct { 29 | dir string 30 | ttl time.Duration 31 | mu *sync.RWMutex 32 | } 33 | 34 | type readCloser struct { 35 | io.Reader 36 | io.Closer 37 | } 38 | 39 | func isCacheableRequest(req *http.Request) bool { 40 | if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") { 41 | return true 42 | } 43 | 44 | if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") { 45 | return true 46 | } 47 | 48 | return false 49 | } 50 | 51 | func isCacheableResponse(res *http.Response) bool { 52 | return res.StatusCode < 500 && res.StatusCode != 403 53 | } 54 | 55 | func cacheKey(req *http.Request) (string, error) { 56 | h := sha256.New() 57 | fmt.Fprintf(h, "%s:", req.Method) 58 | fmt.Fprintf(h, "%s:", req.URL.String()) 59 | fmt.Fprintf(h, "%s:", req.Header.Get("Accept")) 60 | fmt.Fprintf(h, "%s:", req.Header.Get("Authorization")) 61 | 62 | if req.Body != nil { 63 | var bodyCopy io.ReadCloser 64 | req.Body, bodyCopy = copyStream(req.Body) 65 | defer bodyCopy.Close() 66 | if _, err := io.Copy(h, bodyCopy); err != nil { 67 | return "", err 68 | } 69 | } 70 | 71 | digest := h.Sum(nil) 72 | return fmt.Sprintf("%x", digest), nil 73 | } 74 | 75 | func (c cache) RoundTripper(rt http.RoundTripper) http.RoundTripper { 76 | fs := fileStorage{ 77 | dir: c.dir, 78 | ttl: c.ttl, 79 | mu: &sync.RWMutex{}, 80 | } 81 | return cacheRoundTripper{fs: fs, rt: rt} 82 | } 83 | 84 | func (crt cacheRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 85 | reqDir, reqTTL := requestCacheOptions(req) 86 | 87 | if crt.fs.ttl == 0 && reqTTL == 0 { 88 | return crt.rt.RoundTrip(req) 89 | } 90 | 91 | if !isCacheableRequest(req) { 92 | return crt.rt.RoundTrip(req) 93 | } 94 | 95 | origDir := crt.fs.dir 96 | if reqDir != "" { 97 | crt.fs.dir = reqDir 98 | } 99 | origTTL := crt.fs.ttl 100 | if reqTTL != 0 { 101 | crt.fs.ttl = reqTTL 102 | } 103 | 104 | key, keyErr := cacheKey(req) 105 | if keyErr == nil { 106 | if res, err := crt.fs.read(key); err == nil { 107 | res.Request = req 108 | return res, nil 109 | } 110 | } 111 | 112 | res, err := crt.rt.RoundTrip(req) 113 | if err == nil && keyErr == nil && isCacheableResponse(res) { 114 | _ = crt.fs.store(key, res) 115 | } 116 | 117 | crt.fs.dir = origDir 118 | crt.fs.ttl = origTTL 119 | 120 | return res, err 121 | } 122 | 123 | // Allow an individual request to override cache options. 124 | func requestCacheOptions(req *http.Request) (string, time.Duration) { 125 | var dur time.Duration 126 | dir := req.Header.Get("X-GH-CACHE-DIR") 127 | ttl := req.Header.Get("X-GH-CACHE-TTL") 128 | if ttl != "" { 129 | dur, _ = time.ParseDuration(ttl) 130 | } 131 | return dir, dur 132 | } 133 | 134 | func (fs *fileStorage) filePath(key string) string { 135 | if len(key) >= 6 { 136 | return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:]) 137 | } 138 | return filepath.Join(fs.dir, key) 139 | } 140 | 141 | func (fs *fileStorage) read(key string) (*http.Response, error) { 142 | cacheFile := fs.filePath(key) 143 | 144 | fs.mu.RLock() 145 | defer fs.mu.RUnlock() 146 | 147 | f, err := os.Open(cacheFile) 148 | if err != nil { 149 | return nil, err 150 | } 151 | defer f.Close() 152 | 153 | stat, err := f.Stat() 154 | if err != nil { 155 | return nil, err 156 | } 157 | 158 | age := time.Since(stat.ModTime()) 159 | if age > fs.ttl { 160 | return nil, errors.New("cache expired") 161 | } 162 | 163 | body := &bytes.Buffer{} 164 | _, err = io.Copy(body, f) 165 | if err != nil { 166 | return nil, err 167 | } 168 | 169 | res, err := http.ReadResponse(bufio.NewReader(body), nil) 170 | return res, err 171 | } 172 | 173 | func (fs *fileStorage) store(key string, res *http.Response) (storeErr error) { 174 | cacheFile := fs.filePath(key) 175 | 176 | fs.mu.Lock() 177 | defer fs.mu.Unlock() 178 | 179 | if storeErr = os.MkdirAll(filepath.Dir(cacheFile), 0755); storeErr != nil { 180 | return 181 | } 182 | 183 | var f *os.File 184 | if f, storeErr = os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); storeErr != nil { 185 | return 186 | } 187 | 188 | defer func() { 189 | if err := f.Close(); storeErr == nil && err != nil { 190 | storeErr = err 191 | } 192 | }() 193 | 194 | var origBody io.ReadCloser 195 | if res.Body != nil { 196 | origBody, res.Body = copyStream(res.Body) 197 | defer res.Body.Close() 198 | } 199 | 200 | storeErr = res.Write(f) 201 | if origBody != nil { 202 | res.Body = origBody 203 | } 204 | 205 | return 206 | } 207 | 208 | func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) { 209 | b := &bytes.Buffer{} 210 | nr := io.TeeReader(r, b) 211 | return io.NopCloser(b), &readCloser{ 212 | Reader: nr, 213 | Closer: r, 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /pkg/api/cache_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "path/filepath" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestCacheResponse(t *testing.T) { 16 | counter := 0 17 | fakeHTTP := tripper{ 18 | roundTrip: func(req *http.Request) (*http.Response, error) { 19 | counter += 1 20 | body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) 21 | status := 200 22 | if req.URL.Path == "/error" { 23 | status = 500 24 | } 25 | return &http.Response{ 26 | StatusCode: status, 27 | Body: io.NopCloser(bytes.NewBufferString(body)), 28 | }, nil 29 | }, 30 | } 31 | 32 | cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache") 33 | 34 | httpClient, err := NewHTTPClient( 35 | ClientOptions{ 36 | Host: "github.com", 37 | AuthToken: "token", 38 | Transport: fakeHTTP, 39 | EnableCache: true, 40 | CacheDir: cacheDir, 41 | LogIgnoreEnv: true, 42 | }, 43 | ) 44 | assert.NoError(t, err) 45 | 46 | do := func(method, url string, body io.Reader) (string, error) { 47 | req, err := http.NewRequest(method, url, body) 48 | if err != nil { 49 | return "", err 50 | } 51 | res, err := httpClient.Do(req) 52 | if err != nil { 53 | return "", err 54 | } 55 | defer res.Body.Close() 56 | resBody, err := io.ReadAll(res.Body) 57 | if err != nil { 58 | err = fmt.Errorf("ReadAll: %w", err) 59 | } 60 | return string(resBody), err 61 | } 62 | 63 | var res string 64 | 65 | res, err = do("GET", "http://example.com/path", nil) 66 | assert.NoError(t, err) 67 | assert.Equal(t, "1: GET http://example.com/path", res) 68 | res, err = do("GET", "http://example.com/path", nil) 69 | assert.NoError(t, err) 70 | assert.Equal(t, "1: GET http://example.com/path", res) 71 | 72 | res, err = do("GET", "http://example.com/path2", nil) 73 | assert.NoError(t, err) 74 | assert.Equal(t, "2: GET http://example.com/path2", res) 75 | 76 | res, err = do("POST", "http://example.com/path2", nil) 77 | assert.NoError(t, err) 78 | assert.Equal(t, "3: POST http://example.com/path2", res) 79 | 80 | res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) 81 | assert.NoError(t, err) 82 | assert.Equal(t, "4: POST http://example.com/graphql", res) 83 | res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) 84 | assert.NoError(t, err) 85 | assert.Equal(t, "4: POST http://example.com/graphql", res) 86 | 87 | res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`)) 88 | assert.NoError(t, err) 89 | assert.Equal(t, "5: POST http://example.com/graphql", res) 90 | 91 | res, err = do("GET", "http://example.com/error", nil) 92 | assert.NoError(t, err) 93 | assert.Equal(t, "6: GET http://example.com/error", res) 94 | res, err = do("GET", "http://example.com/error", nil) 95 | assert.NoError(t, err) 96 | assert.Equal(t, "7: GET http://example.com/error", res) 97 | } 98 | 99 | func TestCacheResponseRequestCacheOptions(t *testing.T) { 100 | counter := 0 101 | fakeHTTP := tripper{ 102 | roundTrip: func(req *http.Request) (*http.Response, error) { 103 | counter += 1 104 | body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String()) 105 | status := 200 106 | if req.URL.Path == "/error" { 107 | status = 500 108 | } 109 | return &http.Response{ 110 | StatusCode: status, 111 | Body: io.NopCloser(bytes.NewBufferString(body)), 112 | }, nil 113 | }, 114 | } 115 | 116 | cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache") 117 | 118 | httpClient, err := NewHTTPClient( 119 | ClientOptions{ 120 | Host: "github.com", 121 | AuthToken: "token", 122 | Transport: fakeHTTP, 123 | EnableCache: false, 124 | CacheDir: cacheDir, 125 | LogIgnoreEnv: true, 126 | }, 127 | ) 128 | assert.NoError(t, err) 129 | 130 | do := func(method, url string, body io.Reader) (string, error) { 131 | req, err := http.NewRequest(method, url, body) 132 | if err != nil { 133 | return "", err 134 | } 135 | req.Header.Set("X-GH-CACHE-DIR", cacheDir) 136 | req.Header.Set("X-GH-CACHE-TTL", "1h") 137 | res, err := httpClient.Do(req) 138 | if err != nil { 139 | return "", err 140 | } 141 | defer res.Body.Close() 142 | resBody, err := io.ReadAll(res.Body) 143 | if err != nil { 144 | err = fmt.Errorf("ReadAll: %w", err) 145 | } 146 | return string(resBody), err 147 | } 148 | 149 | var res string 150 | 151 | res, err = do("GET", "http://example.com/path", nil) 152 | assert.NoError(t, err) 153 | assert.Equal(t, "1: GET http://example.com/path", res) 154 | res, err = do("GET", "http://example.com/path", nil) 155 | assert.NoError(t, err) 156 | assert.Equal(t, "1: GET http://example.com/path", res) 157 | 158 | res, err = do("GET", "http://example.com/path2", nil) 159 | assert.NoError(t, err) 160 | assert.Equal(t, "2: GET http://example.com/path2", res) 161 | 162 | res, err = do("POST", "http://example.com/path2", nil) 163 | assert.NoError(t, err) 164 | assert.Equal(t, "3: POST http://example.com/path2", res) 165 | 166 | res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) 167 | assert.NoError(t, err) 168 | assert.Equal(t, "4: POST http://example.com/graphql", res) 169 | res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`)) 170 | assert.NoError(t, err) 171 | assert.Equal(t, "4: POST http://example.com/graphql", res) 172 | 173 | res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`)) 174 | assert.NoError(t, err) 175 | assert.Equal(t, "5: POST http://example.com/graphql", res) 176 | 177 | res, err = do("GET", "http://example.com/error", nil) 178 | assert.NoError(t, err) 179 | assert.Equal(t, "6: GET http://example.com/error", res) 180 | res, err = do("GET", "http://example.com/error", nil) 181 | assert.NoError(t, err) 182 | assert.Equal(t, "7: GET http://example.com/error", res) 183 | } 184 | 185 | func TestRequestCacheOptions(t *testing.T) { 186 | req, err := http.NewRequest("GET", "some/url", nil) 187 | assert.NoError(t, err) 188 | req.Header.Set("X-GH-CACHE-DIR", "some/dir/path") 189 | req.Header.Set("X-GH-CACHE-TTL", "1h") 190 | dir, ttl := requestCacheOptions(req) 191 | assert.Equal(t, dir, "some/dir/path") 192 | assert.Equal(t, ttl, time.Hour) 193 | } 194 | -------------------------------------------------------------------------------- /pkg/api/client_options.go: -------------------------------------------------------------------------------- 1 | // Package api is a set of types for interacting with the GitHub API. 2 | package api 3 | 4 | import ( 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/cli/go-gh/v2/pkg/auth" 11 | "github.com/cli/go-gh/v2/pkg/config" 12 | ) 13 | 14 | // ClientOptions holds available options to configure API clients. 15 | type ClientOptions struct { 16 | // AuthToken is the authorization token that will be used 17 | // to authenticate against API endpoints. 18 | AuthToken string 19 | 20 | // CacheDir is the directory to use for cached API requests. 21 | // Default is the same directory that gh uses for caching. 22 | CacheDir string 23 | 24 | // CacheTTL is the time that cached API requests are valid for. 25 | // Default is 24 hours. 26 | CacheTTL time.Duration 27 | 28 | // EnableCache specifies if API requests will be cached or not. 29 | // Default is no caching. 30 | EnableCache bool 31 | 32 | // Headers are the headers that will be sent with every API request. 33 | // Default headers set are Accept, Content-Type, Time-Zone, and User-Agent. 34 | // Default headers will be overridden by keys specified in Headers. 35 | Headers map[string]string 36 | 37 | // Host is the default host that API requests will be sent to. 38 | Host string 39 | 40 | // Log specifies a writer to write API request logs to. Default is to respect the GH_DEBUG environment 41 | // variable, and no logging otherwise. 42 | Log io.Writer 43 | 44 | // LogIgnoreEnv disables respecting the GH_DEBUG environment variable. This can be useful in test mode 45 | // or when the extension already offers its own controls for logging to the user. 46 | LogIgnoreEnv bool 47 | 48 | // LogColorize enables colorized logging to Log for display in a terminal. 49 | // Default is no coloring. 50 | LogColorize bool 51 | 52 | // LogVerboseHTTP enables logging HTTP headers and bodies to Log. 53 | // Default is only logging request URLs and response statuses. 54 | LogVerboseHTTP bool 55 | 56 | // SkipDefaultHeaders disables setting of the default headers. 57 | SkipDefaultHeaders bool 58 | 59 | // Timeout specifies a time limit for each API request. 60 | // Default is no timeout. 61 | Timeout time.Duration 62 | 63 | // Transport specifies the mechanism by which individual API requests are made. 64 | // If both Transport and UnixDomainSocket are specified then Transport takes 65 | // precedence. Due to this behavior any value set for Transport needs to manually 66 | // handle routing to UnixDomainSocket if necessary. Generally, setting Transport 67 | // should be reserved for testing purposes. 68 | // Default is http.DefaultTransport. 69 | Transport http.RoundTripper 70 | 71 | // UnixDomainSocket specifies the Unix domain socket address by which individual 72 | // API requests will be routed. If specifed, this will form the base of the API 73 | // request transport chain. 74 | // Default is no socket address. 75 | UnixDomainSocket string 76 | } 77 | 78 | func optionsNeedResolution(opts ClientOptions) bool { 79 | if opts.Host == "" { 80 | return true 81 | } 82 | if opts.AuthToken == "" { 83 | return true 84 | } 85 | if opts.UnixDomainSocket == "" && opts.Transport == nil { 86 | return true 87 | } 88 | return false 89 | } 90 | 91 | func resolveOptions(opts ClientOptions) (ClientOptions, error) { 92 | cfg, _ := config.Read(nil) 93 | if opts.Host == "" { 94 | opts.Host, _ = auth.DefaultHost() 95 | } 96 | if opts.AuthToken == "" { 97 | opts.AuthToken, _ = auth.TokenForHost(opts.Host) 98 | if opts.AuthToken == "" { 99 | return ClientOptions{}, fmt.Errorf("authentication token not found for host %s", opts.Host) 100 | } 101 | } 102 | if opts.UnixDomainSocket == "" && cfg != nil { 103 | opts.UnixDomainSocket, _ = cfg.Get([]string{"http_unix_socket"}) 104 | } 105 | return opts, nil 106 | } 107 | -------------------------------------------------------------------------------- /pkg/api/client_options_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/cli/go-gh/v2/internal/testutils" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestResolveOptions(t *testing.T) { 12 | testutils.StubConfig(t, testConfigWithSocket()) 13 | 14 | tests := []struct { 15 | name string 16 | opts ClientOptions 17 | wantAuthToken string 18 | wantHost string 19 | wantSocket string 20 | }{ 21 | { 22 | name: "honors consumer provided ClientOptions", 23 | opts: ClientOptions{ 24 | Host: "test.com", 25 | AuthToken: "token_from_opts", 26 | UnixDomainSocket: "socket_from_opts", 27 | }, 28 | wantAuthToken: "token_from_opts", 29 | wantHost: "test.com", 30 | wantSocket: "socket_from_opts", 31 | }, 32 | { 33 | name: "uses config values if there are no consumer provided ClientOptions", 34 | opts: ClientOptions{}, 35 | wantAuthToken: "token", 36 | wantHost: "github.com", 37 | wantSocket: "socket", 38 | }, 39 | } 40 | 41 | for _, tt := range tests { 42 | t.Run(tt.name, func(t *testing.T) { 43 | opts, err := resolveOptions(tt.opts) 44 | assert.NoError(t, err) 45 | assert.Equal(t, tt.wantHost, opts.Host) 46 | assert.Equal(t, tt.wantAuthToken, opts.AuthToken) 47 | assert.Equal(t, tt.wantSocket, opts.UnixDomainSocket) 48 | }) 49 | } 50 | } 51 | 52 | func TestOptionsNeedResolution(t *testing.T) { 53 | tests := []struct { 54 | name string 55 | opts ClientOptions 56 | out bool 57 | }{ 58 | { 59 | name: "Host, AuthToken, and UnixDomainSocket specified", 60 | opts: ClientOptions{ 61 | Host: "test.com", 62 | AuthToken: "token", 63 | UnixDomainSocket: "socket", 64 | }, 65 | out: false, 66 | }, 67 | { 68 | name: "Host, AuthToken, and Transport specified", 69 | opts: ClientOptions{ 70 | Host: "test.com", 71 | AuthToken: "token", 72 | Transport: http.DefaultTransport, 73 | }, 74 | out: false, 75 | }, 76 | { 77 | name: "Host, and AuthToken specified", 78 | opts: ClientOptions{ 79 | Host: "test.com", 80 | AuthToken: "token", 81 | }, 82 | out: true, 83 | }, 84 | { 85 | name: "Host, and UnixDomainSocket specified", 86 | opts: ClientOptions{ 87 | Host: "test.com", 88 | UnixDomainSocket: "socket", 89 | }, 90 | out: true, 91 | }, 92 | { 93 | name: "Host, and Transport specified", 94 | opts: ClientOptions{ 95 | Host: "test.com", 96 | Transport: http.DefaultTransport, 97 | }, 98 | out: true, 99 | }, 100 | { 101 | name: "AuthToken, and UnixDomainSocket specified", 102 | opts: ClientOptions{ 103 | AuthToken: "token", 104 | UnixDomainSocket: "socket", 105 | }, 106 | out: true, 107 | }, 108 | { 109 | name: "AuthToken, and Transport specified", 110 | opts: ClientOptions{ 111 | AuthToken: "token", 112 | Transport: http.DefaultTransport, 113 | }, 114 | out: true, 115 | }, 116 | { 117 | name: "Host specified", 118 | opts: ClientOptions{ 119 | Host: "test.com", 120 | }, 121 | out: true, 122 | }, 123 | { 124 | name: "AuthToken specified", 125 | opts: ClientOptions{ 126 | AuthToken: "token", 127 | }, 128 | out: true, 129 | }, 130 | { 131 | name: "UnixDomainSocket specified", 132 | opts: ClientOptions{ 133 | UnixDomainSocket: "socket", 134 | }, 135 | out: true, 136 | }, 137 | { 138 | name: "Transport specified", 139 | opts: ClientOptions{ 140 | Transport: http.DefaultTransport, 141 | }, 142 | out: true, 143 | }, 144 | } 145 | 146 | for _, tt := range tests { 147 | t.Run(tt.name, func(t *testing.T) { 148 | assert.Equal(t, tt.out, optionsNeedResolution(tt.opts)) 149 | }) 150 | } 151 | } 152 | 153 | func testConfig() string { 154 | return ` 155 | hosts: 156 | github.com: 157 | user: user1 158 | oauth_token: abc123 159 | git_protocol: ssh 160 | ` 161 | } 162 | 163 | func testConfigWithSocket() string { 164 | return ` 165 | http_unix_socket: socket 166 | hosts: 167 | github.com: 168 | user: user1 169 | oauth_token: token 170 | git_protocol: ssh 171 | ` 172 | } 173 | -------------------------------------------------------------------------------- /pkg/api/errors.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/url" 9 | "strings" 10 | ) 11 | 12 | // HTTPError represents an error response from the GitHub API. 13 | type HTTPError struct { 14 | Errors []HTTPErrorItem 15 | Headers http.Header 16 | Message string 17 | RequestURL *url.URL 18 | StatusCode int 19 | } 20 | 21 | // HTTPErrorItem stores additional information about an error response 22 | // returned from the GitHub API. 23 | type HTTPErrorItem struct { 24 | Code string 25 | Field string 26 | Message string 27 | Resource string 28 | } 29 | 30 | // Allow HTTPError to satisfy error interface. 31 | func (err *HTTPError) Error() string { 32 | if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 { 33 | return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1]) 34 | } else if err.Message != "" { 35 | return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL) 36 | } 37 | return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL) 38 | } 39 | 40 | // GraphQLError represents an error response from GitHub GraphQL API. 41 | type GraphQLError struct { 42 | Errors []GraphQLErrorItem 43 | } 44 | 45 | // GraphQLErrorItem stores additional information about an error response 46 | // returned from the GitHub GraphQL API. 47 | type GraphQLErrorItem struct { 48 | Message string 49 | Locations []struct { 50 | Line int 51 | Column int 52 | } 53 | Path []interface{} 54 | Extensions map[string]interface{} 55 | Type string 56 | } 57 | 58 | // Allow GraphQLError to satisfy error interface. 59 | func (gr *GraphQLError) Error() string { 60 | errorMessages := make([]string, 0, len(gr.Errors)) 61 | for _, e := range gr.Errors { 62 | msg := e.Message 63 | if p := e.pathString(); p != "" { 64 | msg = fmt.Sprintf("%s (%s)", msg, p) 65 | } 66 | errorMessages = append(errorMessages, msg) 67 | } 68 | return fmt.Sprintf("GraphQL: %s", strings.Join(errorMessages, ", ")) 69 | } 70 | 71 | // Match determines if the GraphQLError is about a specific type on a specific path. 72 | // If the path argument ends with a ".", it will match all its subpaths. 73 | func (gr *GraphQLError) Match(expectType, expectPath string) bool { 74 | for _, e := range gr.Errors { 75 | if e.Type != expectType || !matchPath(e.pathString(), expectPath) { 76 | return false 77 | } 78 | } 79 | return true 80 | } 81 | 82 | func (ge GraphQLErrorItem) pathString() string { 83 | var res strings.Builder 84 | for i, v := range ge.Path { 85 | if i > 0 { 86 | res.WriteRune('.') 87 | } 88 | fmt.Fprintf(&res, "%v", v) 89 | } 90 | return res.String() 91 | } 92 | 93 | func matchPath(p, expect string) bool { 94 | if strings.HasSuffix(expect, ".") { 95 | return strings.HasPrefix(p, expect) || p == strings.TrimSuffix(expect, ".") 96 | } 97 | return p == expect 98 | } 99 | 100 | // HandleHTTPError parses a http.Response into a HTTPError. 101 | func HandleHTTPError(resp *http.Response) error { 102 | httpError := &HTTPError{ 103 | Headers: resp.Header, 104 | RequestURL: resp.Request.URL, 105 | StatusCode: resp.StatusCode, 106 | } 107 | 108 | if !jsonTypeRE.MatchString(resp.Header.Get(contentType)) { 109 | httpError.Message = resp.Status 110 | return httpError 111 | } 112 | 113 | body, err := io.ReadAll(resp.Body) 114 | if err != nil { 115 | httpError.Message = err.Error() 116 | return httpError 117 | } 118 | 119 | var parsedBody struct { 120 | Message string `json:"message"` 121 | Errors []json.RawMessage 122 | } 123 | if err := json.Unmarshal(body, &parsedBody); err != nil { 124 | return httpError 125 | } 126 | 127 | var messages []string 128 | if parsedBody.Message != "" { 129 | messages = append(messages, parsedBody.Message) 130 | } 131 | for _, raw := range parsedBody.Errors { 132 | switch raw[0] { 133 | case '"': 134 | var errString string 135 | _ = json.Unmarshal(raw, &errString) 136 | messages = append(messages, errString) 137 | httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString}) 138 | case '{': 139 | var errInfo HTTPErrorItem 140 | _ = json.Unmarshal(raw, &errInfo) 141 | msg := errInfo.Message 142 | if errInfo.Code != "" && errInfo.Code != "custom" { 143 | msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code)) 144 | } 145 | if msg != "" { 146 | messages = append(messages, msg) 147 | } 148 | httpError.Errors = append(httpError.Errors, errInfo) 149 | } 150 | } 151 | httpError.Message = strings.Join(messages, "\n") 152 | 153 | return httpError 154 | } 155 | 156 | // Convert common error codes to human readable messages 157 | // See https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors for more details. 158 | func errorCodeToMessage(code string) string { 159 | switch code { 160 | case "missing", "missing_field": 161 | return "is missing" 162 | case "invalid", "unprocessable": 163 | return "is invalid" 164 | case "already_exists": 165 | return "already exists" 166 | default: 167 | return code 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /pkg/api/errors_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGraphQLErrorMatch(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | error GraphQLError 13 | kind string 14 | path string 15 | wantMatch bool 16 | }{ 17 | { 18 | name: "matches path and type", 19 | error: GraphQLError{Errors: []GraphQLErrorItem{ 20 | {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, 21 | }}, 22 | kind: "NOT_FOUND", 23 | path: "repository.issue", 24 | wantMatch: true, 25 | }, 26 | { 27 | name: "matches base path and type", 28 | error: GraphQLError{Errors: []GraphQLErrorItem{ 29 | {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, 30 | }}, 31 | kind: "NOT_FOUND", 32 | path: "repository.", 33 | wantMatch: true, 34 | }, 35 | { 36 | name: "does not match path but matches type", 37 | error: GraphQLError{Errors: []GraphQLErrorItem{ 38 | {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, 39 | }}, 40 | kind: "NOT_FOUND", 41 | path: "label.title", 42 | wantMatch: false, 43 | }, 44 | { 45 | name: "matches path but not type", 46 | error: GraphQLError{Errors: []GraphQLErrorItem{ 47 | {Path: []interface{}{"repository", "issue"}, Type: "NOT_FOUND"}, 48 | }}, 49 | kind: "UNKNOWN", 50 | path: "repository.issue", 51 | wantMatch: false, 52 | }, 53 | } 54 | 55 | for _, tt := range tests { 56 | t.Run(tt.name, func(t *testing.T) { 57 | assert.Equal(t, tt.wantMatch, tt.error.Match(tt.kind, tt.path)) 58 | }) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /pkg/api/graphql_client.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "strings" 12 | 13 | "github.com/cli/go-gh/v2/pkg/auth" 14 | graphql "github.com/cli/shurcooL-graphql" 15 | ) 16 | 17 | // GraphQLClient wraps methods for the different types of 18 | // API requests that are supported by the server. 19 | type GraphQLClient struct { 20 | client *graphql.Client 21 | host string 22 | httpClient *http.Client 23 | } 24 | 25 | func DefaultGraphQLClient() (*GraphQLClient, error) { 26 | return NewGraphQLClient(ClientOptions{}) 27 | } 28 | 29 | // GraphQLClient builds a client to send requests to GitHub GraphQL API endpoints. 30 | // As part of the configuration a hostname, auth token, default set of headers, 31 | // and unix domain socket are resolved from the gh environment configuration. 32 | // These behaviors can be overridden using the opts argument. 33 | func NewGraphQLClient(opts ClientOptions) (*GraphQLClient, error) { 34 | if optionsNeedResolution(opts) { 35 | var err error 36 | opts, err = resolveOptions(opts) 37 | if err != nil { 38 | return nil, err 39 | } 40 | } 41 | 42 | httpClient, err := NewHTTPClient(opts) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | endpoint := graphQLEndpoint(opts.Host) 48 | 49 | return &GraphQLClient{ 50 | client: graphql.NewClient(endpoint, httpClient), 51 | host: endpoint, 52 | httpClient: httpClient, 53 | }, nil 54 | } 55 | 56 | // DoWithContext executes a GraphQL query request. 57 | // The response is populated into the response argument. 58 | func (c *GraphQLClient) DoWithContext(ctx context.Context, query string, variables map[string]interface{}, response interface{}) error { 59 | reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables}) 60 | if err != nil { 61 | return err 62 | } 63 | 64 | req, err := http.NewRequestWithContext(ctx, "POST", c.host, bytes.NewBuffer(reqBody)) 65 | if err != nil { 66 | return err 67 | } 68 | 69 | resp, err := c.httpClient.Do(req) 70 | if err != nil { 71 | return err 72 | } 73 | defer resp.Body.Close() 74 | 75 | success := resp.StatusCode >= 200 && resp.StatusCode < 300 76 | if !success { 77 | return HandleHTTPError(resp) 78 | } 79 | 80 | if resp.StatusCode == http.StatusNoContent { 81 | return nil 82 | } 83 | 84 | body, err := io.ReadAll(resp.Body) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | gr := graphQLResponse{Data: response} 90 | err = json.Unmarshal(body, &gr) 91 | if err != nil { 92 | return err 93 | } 94 | 95 | if len(gr.Errors) > 0 { 96 | return &GraphQLError{Errors: gr.Errors} 97 | } 98 | 99 | return nil 100 | } 101 | 102 | // Do wraps DoWithContext using context.Background. 103 | func (c *GraphQLClient) Do(query string, variables map[string]interface{}, response interface{}) error { 104 | return c.DoWithContext(context.Background(), query, variables, response) 105 | } 106 | 107 | // MutateWithContext executes a GraphQL mutation request. 108 | // The mutation string is derived from the mutation argument, and the 109 | // response is populated into it. 110 | // The mutation argument should be a pointer to struct that corresponds 111 | // to the GitHub GraphQL schema. 112 | // Provided input will be set as a variable named input. 113 | func (c *GraphQLClient) MutateWithContext(ctx context.Context, name string, m interface{}, variables map[string]interface{}) error { 114 | err := c.client.MutateNamed(ctx, name, m, variables) 115 | var graphQLErrs graphql.Errors 116 | if err != nil && errors.As(err, &graphQLErrs) { 117 | items := make([]GraphQLErrorItem, len(graphQLErrs)) 118 | for i, e := range graphQLErrs { 119 | items[i] = GraphQLErrorItem{ 120 | Message: e.Message, 121 | Locations: e.Locations, 122 | Path: e.Path, 123 | Extensions: e.Extensions, 124 | Type: e.Type, 125 | } 126 | } 127 | err = &GraphQLError{items} 128 | } 129 | return err 130 | } 131 | 132 | // Mutate wraps MutateWithContext using context.Background. 133 | func (c *GraphQLClient) Mutate(name string, m interface{}, variables map[string]interface{}) error { 134 | return c.MutateWithContext(context.Background(), name, m, variables) 135 | } 136 | 137 | // QueryWithContext executes a GraphQL query request, 138 | // The query string is derived from the query argument, and the 139 | // response is populated into it. 140 | // The query argument should be a pointer to struct that corresponds 141 | // to the GitHub GraphQL schema. 142 | func (c *GraphQLClient) QueryWithContext(ctx context.Context, name string, q interface{}, variables map[string]interface{}) error { 143 | err := c.client.QueryNamed(ctx, name, q, variables) 144 | var graphQLErrs graphql.Errors 145 | if err != nil && errors.As(err, &graphQLErrs) { 146 | items := make([]GraphQLErrorItem, len(graphQLErrs)) 147 | for i, e := range graphQLErrs { 148 | items[i] = GraphQLErrorItem{ 149 | Message: e.Message, 150 | Locations: e.Locations, 151 | Path: e.Path, 152 | Extensions: e.Extensions, 153 | Type: e.Type, 154 | } 155 | } 156 | err = &GraphQLError{items} 157 | } 158 | return err 159 | } 160 | 161 | // Query wraps QueryWithContext using context.Background. 162 | func (c *GraphQLClient) Query(name string, q interface{}, variables map[string]interface{}) error { 163 | return c.QueryWithContext(context.Background(), name, q, variables) 164 | } 165 | 166 | type graphQLResponse struct { 167 | Data interface{} 168 | Errors []GraphQLErrorItem 169 | } 170 | 171 | func graphQLEndpoint(host string) string { 172 | if isGarage(host) { 173 | return fmt.Sprintf("https://%s/api/graphql", host) 174 | } 175 | host = auth.NormalizeHostname(host) 176 | if auth.IsEnterprise(host) { 177 | return fmt.Sprintf("https://%s/api/graphql", host) 178 | } 179 | if strings.EqualFold(host, localhost) { 180 | return fmt.Sprintf("http://api.%s/graphql", host) 181 | } 182 | return fmt.Sprintf("https://api.%s/graphql", host) 183 | } 184 | -------------------------------------------------------------------------------- /pkg/api/http_client.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "net/http" 8 | "os" 9 | "regexp" 10 | "runtime/debug" 11 | "strings" 12 | "time" 13 | 14 | "github.com/cli/go-gh/v2/pkg/asciisanitizer" 15 | "github.com/cli/go-gh/v2/pkg/config" 16 | "github.com/cli/go-gh/v2/pkg/term" 17 | "github.com/henvic/httpretty" 18 | "github.com/thlib/go-timezone-local/tzlocal" 19 | "golang.org/x/text/transform" 20 | ) 21 | 22 | const ( 23 | accept = "Accept" 24 | authorization = "Authorization" 25 | contentType = "Content-Type" 26 | github = "github.com" 27 | jsonContentType = "application/json; charset=utf-8" 28 | localhost = "github.localhost" 29 | modulePath = "github.com/cli/go-gh" 30 | timeZone = "Time-Zone" 31 | userAgent = "User-Agent" 32 | ) 33 | 34 | var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) 35 | 36 | func DefaultHTTPClient() (*http.Client, error) { 37 | return NewHTTPClient(ClientOptions{}) 38 | } 39 | 40 | // HTTPClient builds a client that can be passed to another library. 41 | // As part of the configuration a hostname, auth token, default set of headers, 42 | // and unix domain socket are resolved from the gh environment configuration. 43 | // These behaviors can be overridden using the opts argument. In this instance 44 | // providing opts.Host will not change the destination of your request as it is 45 | // the responsibility of the consumer to configure this. However, if opts.Host 46 | // does not match the request host, the auth token will not be added to the headers. 47 | // This is to protect against the case where tokens could be sent to an arbitrary 48 | // host. 49 | func NewHTTPClient(opts ClientOptions) (*http.Client, error) { 50 | if optionsNeedResolution(opts) { 51 | var err error 52 | opts, err = resolveOptions(opts) 53 | if err != nil { 54 | return nil, err 55 | } 56 | } 57 | 58 | transport := http.DefaultTransport 59 | 60 | if opts.UnixDomainSocket != "" { 61 | transport = newUnixDomainSocketRoundTripper(opts.UnixDomainSocket) 62 | } 63 | 64 | if opts.Transport != nil { 65 | transport = opts.Transport 66 | } 67 | 68 | transport = newSanitizerRoundTripper(transport) 69 | 70 | if opts.CacheDir == "" { 71 | opts.CacheDir = config.CacheDir() 72 | } 73 | if opts.EnableCache && opts.CacheTTL == 0 { 74 | opts.CacheTTL = time.Hour * 24 75 | } 76 | c := cache{dir: opts.CacheDir, ttl: opts.CacheTTL} 77 | transport = c.RoundTripper(transport) 78 | 79 | if opts.Log == nil && !opts.LogIgnoreEnv { 80 | ghDebug := os.Getenv("GH_DEBUG") 81 | switch ghDebug { 82 | case "", "0", "false", "no": 83 | // no logging 84 | default: 85 | opts.Log = os.Stderr 86 | opts.LogColorize = !term.IsColorDisabled() && term.IsTerminal(os.Stderr) 87 | opts.LogVerboseHTTP = strings.Contains(ghDebug, "api") 88 | } 89 | } 90 | 91 | if opts.Log != nil { 92 | logger := &httpretty.Logger{ 93 | Time: true, 94 | TLS: false, 95 | Colors: opts.LogColorize, 96 | RequestHeader: opts.LogVerboseHTTP, 97 | RequestBody: opts.LogVerboseHTTP, 98 | ResponseHeader: opts.LogVerboseHTTP, 99 | ResponseBody: opts.LogVerboseHTTP, 100 | Formatters: []httpretty.Formatter{&jsonFormatter{colorize: opts.LogColorize}}, 101 | MaxResponseBody: 100000, 102 | } 103 | logger.SetOutput(opts.Log) 104 | logger.SetBodyFilter(func(h http.Header) (skip bool, err error) { 105 | return !inspectableMIMEType(h.Get(contentType)), nil 106 | }) 107 | transport = logger.RoundTripper(transport) 108 | } 109 | 110 | if opts.Headers == nil { 111 | opts.Headers = map[string]string{} 112 | } 113 | if !opts.SkipDefaultHeaders { 114 | resolveHeaders(opts.Headers) 115 | } 116 | transport = newHeaderRoundTripper(opts.Host, opts.AuthToken, opts.Headers, transport) 117 | 118 | return &http.Client{Transport: transport, Timeout: opts.Timeout}, nil 119 | } 120 | 121 | func inspectableMIMEType(t string) bool { 122 | return strings.HasPrefix(t, "text/") || 123 | strings.HasPrefix(t, "application/x-www-form-urlencoded") || 124 | jsonTypeRE.MatchString(t) 125 | } 126 | 127 | func isSameDomain(requestHost, domain string) bool { 128 | requestHost = strings.ToLower(requestHost) 129 | domain = strings.ToLower(domain) 130 | return (requestHost == domain) || strings.HasSuffix(requestHost, "."+domain) 131 | } 132 | 133 | func isGarage(host string) bool { 134 | return strings.EqualFold(host, "garage.github.com") 135 | } 136 | 137 | type headerRoundTripper struct { 138 | headers map[string]string 139 | host string 140 | rt http.RoundTripper 141 | } 142 | 143 | func resolveHeaders(headers map[string]string) { 144 | if _, ok := headers[contentType]; !ok { 145 | headers[contentType] = jsonContentType 146 | } 147 | if _, ok := headers[userAgent]; !ok { 148 | headers[userAgent] = "go-gh" 149 | info, ok := debug.ReadBuildInfo() 150 | if ok { 151 | for _, dep := range info.Deps { 152 | if dep.Path == modulePath { 153 | headers[userAgent] += fmt.Sprintf(" %s", dep.Version) 154 | break 155 | } 156 | } 157 | } 158 | } 159 | if _, ok := headers[timeZone]; !ok { 160 | tz := currentTimeZone() 161 | if tz != "" { 162 | headers[timeZone] = tz 163 | } 164 | } 165 | if _, ok := headers[accept]; !ok { 166 | // Preview for PullRequest.mergeStateStatus. 167 | a := "application/vnd.github.merge-info-preview+json" 168 | // Preview for visibility when RESTing repos into an org. 169 | a += ", application/vnd.github.nebula-preview" 170 | headers[accept] = a 171 | } 172 | } 173 | 174 | func newHeaderRoundTripper(host string, authToken string, headers map[string]string, rt http.RoundTripper) http.RoundTripper { 175 | if _, ok := headers[authorization]; !ok && authToken != "" { 176 | headers[authorization] = fmt.Sprintf("token %s", authToken) 177 | } 178 | if len(headers) == 0 { 179 | return rt 180 | } 181 | return headerRoundTripper{host: host, headers: headers, rt: rt} 182 | } 183 | 184 | func (hrt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 185 | for k, v := range hrt.headers { 186 | // If the authorization header has been set and the request 187 | // host is not in the same domain that was specified in the ClientOptions 188 | // then do not add the authorization header to the request. 189 | if k == authorization && !isSameDomain(req.URL.Hostname(), hrt.host) { 190 | continue 191 | } 192 | 193 | // If the header is already set in the request, don't overwrite it. 194 | if req.Header.Get(k) == "" { 195 | req.Header.Set(k, v) 196 | } 197 | } 198 | 199 | return hrt.rt.RoundTrip(req) 200 | } 201 | 202 | func newUnixDomainSocketRoundTripper(socketPath string) http.RoundTripper { 203 | dial := func(network, addr string) (net.Conn, error) { 204 | return net.Dial("unix", socketPath) 205 | } 206 | 207 | return &http.Transport{ 208 | Dial: dial, 209 | DialTLS: dial, 210 | DisableKeepAlives: true, 211 | } 212 | } 213 | 214 | type sanitizerRoundTripper struct { 215 | rt http.RoundTripper 216 | } 217 | 218 | func newSanitizerRoundTripper(rt http.RoundTripper) http.RoundTripper { 219 | return sanitizerRoundTripper{rt: rt} 220 | } 221 | 222 | func (srt sanitizerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 223 | resp, err := srt.rt.RoundTrip(req) 224 | if err != nil || !jsonTypeRE.MatchString(resp.Header.Get(contentType)) { 225 | return resp, err 226 | } 227 | sanitizedReadCloser := struct { 228 | io.Reader 229 | io.Closer 230 | }{ 231 | Reader: transform.NewReader(resp.Body, &asciisanitizer.Sanitizer{JSON: true}), 232 | Closer: resp.Body, 233 | } 234 | resp.Body = sanitizedReadCloser 235 | return resp, err 236 | } 237 | 238 | func currentTimeZone() string { 239 | tz, err := tzlocal.RuntimeTZ() 240 | if err != nil { 241 | return "" 242 | } 243 | return tz 244 | } 245 | -------------------------------------------------------------------------------- /pkg/api/http_client_test.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/cli/go-gh/v2/internal/testutils" 12 | "github.com/stretchr/testify/assert" 13 | "gopkg.in/h2non/gock.v1" 14 | ) 15 | 16 | func TestHTTPClient(t *testing.T) { 17 | testutils.StubConfig(t, testConfig()) 18 | t.Cleanup(gock.Off) 19 | 20 | gock.New("https://api.github.com"). 21 | Get("/some/test/path"). 22 | MatchHeader("Authorization", "token abc123"). 23 | Reply(200). 24 | JSON(`{"message": "success"}`) 25 | 26 | client, err := DefaultHTTPClient() 27 | assert.NoError(t, err) 28 | 29 | res, err := client.Get("https://api.github.com/some/test/path") 30 | assert.NoError(t, err) 31 | assert.True(t, gock.IsDone(), printPendingMocks(gock.Pending())) 32 | assert.Equal(t, 200, res.StatusCode) 33 | } 34 | 35 | func TestNewHTTPClient(t *testing.T) { 36 | reflectHTTP := tripper{ 37 | roundTrip: func(req *http.Request) (*http.Response, error) { 38 | header := req.Header.Clone() 39 | body := "{}" 40 | return &http.Response{ 41 | StatusCode: 200, 42 | Header: header, 43 | Body: io.NopCloser(bytes.NewBufferString(body)), 44 | }, nil 45 | }, 46 | } 47 | 48 | tests := []struct { 49 | name string 50 | enableLog bool 51 | log *bytes.Buffer 52 | host string 53 | headers map[string]string 54 | skipHeaders bool 55 | wantHeaders http.Header 56 | }{ 57 | { 58 | name: "sets default headers", 59 | wantHeaders: defaultHeaders(), 60 | }, 61 | { 62 | name: "allows overriding default headers", 63 | headers: map[string]string{ 64 | authorization: "token new_token", 65 | accept: "application/vnd.github.test-preview", 66 | }, 67 | wantHeaders: func() http.Header { 68 | h := defaultHeaders() 69 | h.Set(authorization, "token new_token") 70 | h.Set(accept, "application/vnd.github.test-preview") 71 | return h 72 | }(), 73 | }, 74 | { 75 | name: "allows setting custom headers", 76 | headers: map[string]string{ 77 | "custom": "testing", 78 | }, 79 | wantHeaders: func() http.Header { 80 | h := defaultHeaders() 81 | h.Set("custom", "testing") 82 | return h 83 | }(), 84 | }, 85 | { 86 | name: "allows setting logger", 87 | enableLog: true, 88 | log: &bytes.Buffer{}, 89 | wantHeaders: defaultHeaders(), 90 | }, 91 | { 92 | name: "does not add an authorization header for non-matching host", 93 | host: "notauthorized.com", 94 | wantHeaders: func() http.Header { 95 | h := defaultHeaders() 96 | h.Del(authorization) 97 | return h 98 | }(), 99 | }, 100 | { 101 | name: "does not add an authorization header for non-matching host subdomain", 102 | host: "test.company", 103 | wantHeaders: func() http.Header { 104 | h := defaultHeaders() 105 | h.Del(authorization) 106 | return h 107 | }(), 108 | }, 109 | { 110 | name: "adds an authorization header for a matching host", 111 | host: "test.com", 112 | wantHeaders: defaultHeaders(), 113 | }, 114 | { 115 | name: "adds an authorization header if hosts match but differ in case", 116 | host: "TeSt.CoM", 117 | wantHeaders: defaultHeaders(), 118 | }, 119 | { 120 | name: "skips default headers", 121 | skipHeaders: true, 122 | wantHeaders: func() http.Header { 123 | h := defaultHeaders() 124 | h.Del(accept) 125 | h.Del(contentType) 126 | h.Del(timeZone) 127 | h.Del(userAgent) 128 | return h 129 | }(), 130 | }, 131 | } 132 | 133 | for _, tt := range tests { 134 | t.Run(tt.name, func(t *testing.T) { 135 | if tt.host == "" { 136 | tt.host = "test.com" 137 | } 138 | opts := ClientOptions{ 139 | Host: tt.host, 140 | AuthToken: "oauth_token", 141 | Headers: tt.headers, 142 | SkipDefaultHeaders: tt.skipHeaders, 143 | Transport: reflectHTTP, 144 | LogIgnoreEnv: true, 145 | } 146 | if tt.enableLog { 147 | opts.Log = tt.log 148 | } 149 | client, _ := NewHTTPClient(opts) 150 | res, err := client.Get("https://test.com") 151 | assert.NoError(t, err) 152 | assert.Equal(t, tt.wantHeaders, res.Header) 153 | if tt.enableLog { 154 | assert.NotEmpty(t, tt.log) 155 | } 156 | }) 157 | } 158 | } 159 | 160 | type tripper struct { 161 | roundTrip func(*http.Request) (*http.Response, error) 162 | } 163 | 164 | func (tr tripper) RoundTrip(req *http.Request) (*http.Response, error) { 165 | return tr.roundTrip(req) 166 | } 167 | 168 | func defaultHeaders() http.Header { 169 | h := http.Header{} 170 | a := "application/vnd.github.merge-info-preview+json" 171 | a += ", application/vnd.github.nebula-preview" 172 | h.Set(contentType, jsonContentType) 173 | h.Set(userAgent, "go-gh") 174 | h.Set(authorization, fmt.Sprintf("token %s", "oauth_token")) 175 | h.Set(timeZone, currentTimeZone()) 176 | h.Set(accept, a) 177 | return h 178 | } 179 | 180 | func printPendingMocks(mocks []gock.Mock) string { 181 | paths := []string{} 182 | for _, mock := range mocks { 183 | paths = append(paths, mock.Request().URLStruct.String()) 184 | } 185 | return fmt.Sprintf("%d unmatched mocks: %s", len(paths), strings.Join(paths, ", ")) 186 | } 187 | -------------------------------------------------------------------------------- /pkg/api/log_formatter.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "strings" 9 | 10 | "github.com/cli/go-gh/v2/pkg/jsonpretty" 11 | ) 12 | 13 | type graphqlBody struct { 14 | Query string `json:"query"` 15 | OperationName string `json:"operationName"` 16 | Variables json.RawMessage `json:"variables"` 17 | } 18 | 19 | // jsonFormatter is a httpretty.Formatter that prettifies JSON payloads and GraphQL queries. 20 | type jsonFormatter struct { 21 | colorize bool 22 | } 23 | 24 | func (f *jsonFormatter) Format(w io.Writer, src []byte) error { 25 | var graphqlQuery graphqlBody 26 | // TODO: find more precise way to detect a GraphQL query from the JSON payload alone 27 | if err := json.Unmarshal(src, &graphqlQuery); err == nil && graphqlQuery.Query != "" && len(graphqlQuery.Variables) > 0 { 28 | colorHighlight := "\x1b[35;1m" 29 | colorReset := "\x1b[m" 30 | if !f.colorize { 31 | colorHighlight = "" 32 | colorReset = "" 33 | } 34 | if _, err := fmt.Fprintf(w, "%sGraphQL query:%s\n%s\n", colorHighlight, colorReset, strings.ReplaceAll(strings.TrimSpace(graphqlQuery.Query), "\t", " ")); err != nil { 35 | return err 36 | } 37 | if _, err := fmt.Fprintf(w, "%sGraphQL variables:%s %s\n", colorHighlight, colorReset, string(graphqlQuery.Variables)); err != nil { 38 | return err 39 | } 40 | return nil 41 | } 42 | return jsonpretty.Format(w, bytes.NewReader(src), " ", f.colorize) 43 | } 44 | 45 | func (f *jsonFormatter) Match(t string) bool { 46 | return jsonTypeRE.MatchString(t) 47 | } 48 | -------------------------------------------------------------------------------- /pkg/api/rest_client.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strings" 10 | 11 | "github.com/cli/go-gh/v2/pkg/auth" 12 | ) 13 | 14 | // RESTClient wraps methods for the different types of 15 | // API requests that are supported by the server. 16 | type RESTClient struct { 17 | client *http.Client 18 | host string 19 | } 20 | 21 | func DefaultRESTClient() (*RESTClient, error) { 22 | return NewRESTClient(ClientOptions{}) 23 | } 24 | 25 | // RESTClient builds a client to send requests to GitHub REST API endpoints. 26 | // As part of the configuration a hostname, auth token, default set of headers, 27 | // and unix domain socket are resolved from the gh environment configuration. 28 | // These behaviors can be overridden using the opts argument. 29 | func NewRESTClient(opts ClientOptions) (*RESTClient, error) { 30 | if optionsNeedResolution(opts) { 31 | var err error 32 | opts, err = resolveOptions(opts) 33 | if err != nil { 34 | return nil, err 35 | } 36 | } 37 | 38 | client, err := NewHTTPClient(opts) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | return &RESTClient{ 44 | client: client, 45 | host: opts.Host, 46 | }, nil 47 | } 48 | 49 | // RequestWithContext issues a request with type specified by method to the 50 | // specified path with the specified body. 51 | // The response is returned rather than being populated 52 | // into a response argument. 53 | func (c *RESTClient) RequestWithContext(ctx context.Context, method string, path string, body io.Reader) (*http.Response, error) { 54 | url := restURL(c.host, path) 55 | req, err := http.NewRequestWithContext(ctx, method, url, body) 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | resp, err := c.client.Do(req) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | success := resp.StatusCode >= 200 && resp.StatusCode < 300 66 | if !success { 67 | defer resp.Body.Close() 68 | return nil, HandleHTTPError(resp) 69 | } 70 | 71 | return resp, err 72 | } 73 | 74 | // Request wraps RequestWithContext with context.Background. 75 | func (c *RESTClient) Request(method string, path string, body io.Reader) (*http.Response, error) { 76 | return c.RequestWithContext(context.Background(), method, path, body) 77 | } 78 | 79 | // DoWithContext issues a request with type specified by method to the 80 | // specified path with the specified body. 81 | // The response is populated into the response argument. 82 | func (c *RESTClient) DoWithContext(ctx context.Context, method string, path string, body io.Reader, response interface{}) error { 83 | url := restURL(c.host, path) 84 | req, err := http.NewRequestWithContext(ctx, method, url, body) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | resp, err := c.client.Do(req) 90 | if err != nil { 91 | return err 92 | } 93 | 94 | success := resp.StatusCode >= 200 && resp.StatusCode < 300 95 | if !success { 96 | defer resp.Body.Close() 97 | return HandleHTTPError(resp) 98 | } 99 | 100 | if resp.StatusCode == http.StatusNoContent { 101 | return nil 102 | } 103 | defer resp.Body.Close() 104 | 105 | b, err := io.ReadAll(resp.Body) 106 | if err != nil { 107 | return err 108 | } 109 | 110 | err = json.Unmarshal(b, &response) 111 | if err != nil { 112 | return err 113 | } 114 | 115 | return nil 116 | } 117 | 118 | // Do wraps DoWithContext with context.Background. 119 | func (c *RESTClient) Do(method string, path string, body io.Reader, response interface{}) error { 120 | return c.DoWithContext(context.Background(), method, path, body, response) 121 | } 122 | 123 | // Delete issues a DELETE request to the specified path. 124 | // The response is populated into the response argument. 125 | func (c *RESTClient) Delete(path string, resp interface{}) error { 126 | return c.Do(http.MethodDelete, path, nil, resp) 127 | } 128 | 129 | // Get issues a GET request to the specified path. 130 | // The response is populated into the response argument. 131 | func (c *RESTClient) Get(path string, resp interface{}) error { 132 | return c.Do(http.MethodGet, path, nil, resp) 133 | } 134 | 135 | // Patch issues a PATCH request to the specified path with the specified body. 136 | // The response is populated into the response argument. 137 | func (c *RESTClient) Patch(path string, body io.Reader, resp interface{}) error { 138 | return c.Do(http.MethodPatch, path, body, resp) 139 | } 140 | 141 | // Post issues a POST request to the specified path with the specified body. 142 | // The response is populated into the response argument. 143 | func (c *RESTClient) Post(path string, body io.Reader, resp interface{}) error { 144 | return c.Do(http.MethodPost, path, body, resp) 145 | } 146 | 147 | // Put issues a PUT request to the specified path with the specified body. 148 | // The response is populated into the response argument. 149 | func (c *RESTClient) Put(path string, body io.Reader, resp interface{}) error { 150 | return c.Do(http.MethodPut, path, body, resp) 151 | } 152 | 153 | func restURL(hostname string, pathOrURL string) string { 154 | if strings.HasPrefix(pathOrURL, "https://") || strings.HasPrefix(pathOrURL, "http://") { 155 | return pathOrURL 156 | } 157 | return restPrefix(hostname) + pathOrURL 158 | } 159 | 160 | func restPrefix(hostname string) string { 161 | if isGarage(hostname) { 162 | return fmt.Sprintf("https://%s/api/v3/", hostname) 163 | } 164 | hostname = auth.NormalizeHostname(hostname) 165 | if auth.IsEnterprise(hostname) { 166 | return fmt.Sprintf("https://%s/api/v3/", hostname) 167 | } 168 | if strings.EqualFold(hostname, localhost) { 169 | return fmt.Sprintf("http://api.%s/", hostname) 170 | } 171 | return fmt.Sprintf("https://api.%s/", hostname) 172 | } 173 | -------------------------------------------------------------------------------- /pkg/asciisanitizer/sanitizer.go: -------------------------------------------------------------------------------- 1 | // Package asciisanitizer implements an ASCII control character sanitizer for UTF-8 strings. 2 | // It will transform ASCII control codes into equivalent inert characters that are safe for display in the terminal. 3 | // Without sanitization these ASCII control characters will be interpreted by the terminal. 4 | // This behaviour can be used maliciously as an attack vector, especially the ASCII control characters \x1B and \x9B. 5 | package asciisanitizer 6 | 7 | import ( 8 | "bytes" 9 | "errors" 10 | "strings" 11 | "unicode" 12 | "unicode/utf8" 13 | 14 | "golang.org/x/text/transform" 15 | ) 16 | 17 | // Sanitizer implements transform.Transformer interface. 18 | type Sanitizer struct { 19 | // JSON tells the Sanitizer to replace strings that will be transformed 20 | // into control characters when the string is marshaled to JSON. Set to 21 | // true if the string being sanitized represents JSON formatted data. 22 | JSON bool 23 | addEscape bool 24 | } 25 | 26 | // Transform uses a sliding window algorithm to detect C0 and C1 control characters as they are read and replaces 27 | // them with equivalent inert characters. Bytes that are not part of a control character are not modified. 28 | func (t *Sanitizer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { 29 | transfer := func(write, read []byte) error { 30 | readLength := len(read) 31 | writeLength := len(write) 32 | if writeLength > len(dst) { 33 | return transform.ErrShortDst 34 | } 35 | copy(dst, write) 36 | nDst += writeLength 37 | dst = dst[writeLength:] 38 | nSrc += readLength 39 | src = src[readLength:] 40 | return nil 41 | } 42 | 43 | for len(src) > 0 { 44 | // When sanitizing JSON strings make sure that we have 6 bytes if available. 45 | if t.JSON && len(src) < 6 && !atEOF { 46 | err = transform.ErrShortSrc 47 | return 48 | } 49 | r, size := utf8.DecodeRune(src) 50 | if r == utf8.RuneError && size < 2 { 51 | if !atEOF { 52 | err = transform.ErrShortSrc 53 | return 54 | } else { 55 | err = errors.New("invalid UTF-8 string") 56 | return 57 | } 58 | } 59 | // Replace C0 and C1 control characters. 60 | if unicode.IsControl(r) { 61 | if repl, found := mapControlToCaret(r); found { 62 | err = transfer(repl, src[:size]) 63 | if err != nil { 64 | return 65 | } 66 | continue 67 | } 68 | } 69 | // Replace JSON C0 and C1 control characters. 70 | if t.JSON && len(src) >= 6 { 71 | if repl, found := mapJSONControlToCaret(src[:6]); found { 72 | if t.addEscape { 73 | // Add an escape character when necessary to prevent creating 74 | // invalid JSON with our replacements. 75 | repl = append([]byte{'\\'}, repl...) 76 | t.addEscape = false 77 | } 78 | err = transfer(repl, src[:6]) 79 | if err != nil { 80 | return 81 | } 82 | continue 83 | } 84 | } 85 | err = transfer(src[:size], src[:size]) 86 | if err != nil { 87 | return 88 | } 89 | if t.JSON { 90 | if r == '\\' { 91 | t.addEscape = !t.addEscape 92 | } else { 93 | t.addEscape = false 94 | } 95 | } 96 | } 97 | return 98 | } 99 | 100 | // Reset resets the state and allows the Sanitizer to be reused. 101 | func (t *Sanitizer) Reset() { 102 | t.addEscape = false 103 | } 104 | 105 | // mapControlToCaret maps C0 and C1 control characters to their caret notation. 106 | func mapControlToCaret(r rune) ([]byte, bool) { 107 | //\t (09), \n (10), \v (11), \r (13) are safe C0 characters and are not sanitized. 108 | m := map[rune]string{ 109 | 0: `^@`, 110 | 1: `^A`, 111 | 2: `^B`, 112 | 3: `^C`, 113 | 4: `^D`, 114 | 5: `^E`, 115 | 6: `^F`, 116 | 7: `^G`, 117 | 8: `^H`, 118 | 12: `^L`, 119 | 14: `^N`, 120 | 15: `^O`, 121 | 16: `^P`, 122 | 17: `^Q`, 123 | 18: `^R`, 124 | 19: `^S`, 125 | 20: `^T`, 126 | 21: `^U`, 127 | 22: `^V`, 128 | 23: `^W`, 129 | 24: `^X`, 130 | 25: `^Y`, 131 | 26: `^Z`, 132 | 27: `^[`, 133 | 28: `^\\`, 134 | 29: `^]`, 135 | 30: `^^`, 136 | 31: `^_`, 137 | 128: `^@`, 138 | 129: `^A`, 139 | 130: `^B`, 140 | 131: `^C`, 141 | 132: `^D`, 142 | 133: `^E`, 143 | 134: `^F`, 144 | 135: `^G`, 145 | 136: `^H`, 146 | 137: `^I`, 147 | 138: `^J`, 148 | 139: `^K`, 149 | 140: `^L`, 150 | 141: `^M`, 151 | 142: `^N`, 152 | 143: `^O`, 153 | 144: `^P`, 154 | 145: `^Q`, 155 | 146: `^R`, 156 | 147: `^S`, 157 | 148: `^T`, 158 | 149: `^U`, 159 | 150: `^V`, 160 | 151: `^W`, 161 | 152: `^X`, 162 | 153: `^Y`, 163 | 154: `^Z`, 164 | 155: `^[`, 165 | 156: `^\\`, 166 | 157: `^]`, 167 | 158: `^^`, 168 | 159: `^_`, 169 | } 170 | if c, ok := m[r]; ok { 171 | return []byte(c), true 172 | } 173 | return nil, false 174 | } 175 | 176 | // mapJSONControlToCaret maps JSON C0 and C1 control characters to their caret notation. 177 | // JSON control characters are six byte strings, representing a unicode code point, 178 | // ranging from \u0000 to \u001F and \u0080 to \u009F. 179 | func mapJSONControlToCaret(b []byte) ([]byte, bool) { 180 | if len(b) != 6 { 181 | return nil, false 182 | } 183 | if !bytes.HasPrefix(b, []byte(`\u00`)) { 184 | return nil, false 185 | } 186 | //\t (\u0009), \n (\u000a), \v (\u000b), \r (\u000d) are safe C0 characters and are not sanitized. 187 | m := map[string]string{ 188 | `\u0000`: `^@`, 189 | `\u0001`: `^A`, 190 | `\u0002`: `^B`, 191 | `\u0003`: `^C`, 192 | `\u0004`: `^D`, 193 | `\u0005`: `^E`, 194 | `\u0006`: `^F`, 195 | `\u0007`: `^G`, 196 | `\u0008`: `^H`, 197 | `\u000c`: `^L`, 198 | `\u000e`: `^N`, 199 | `\u000f`: `^O`, 200 | `\u0010`: `^P`, 201 | `\u0011`: `^Q`, 202 | `\u0012`: `^R`, 203 | `\u0013`: `^S`, 204 | `\u0014`: `^T`, 205 | `\u0015`: `^U`, 206 | `\u0016`: `^V`, 207 | `\u0017`: `^W`, 208 | `\u0018`: `^X`, 209 | `\u0019`: `^Y`, 210 | `\u001a`: `^Z`, 211 | `\u001b`: `^[`, 212 | `\u001c`: `^\\`, 213 | `\u001d`: `^]`, 214 | `\u001e`: `^^`, 215 | `\u001f`: `^_`, 216 | `\u0080`: `^@`, 217 | `\u0081`: `^A`, 218 | `\u0082`: `^B`, 219 | `\u0083`: `^C`, 220 | `\u0084`: `^D`, 221 | `\u0085`: `^E`, 222 | `\u0086`: `^F`, 223 | `\u0087`: `^G`, 224 | `\u0088`: `^H`, 225 | `\u0089`: `^I`, 226 | `\u008a`: `^J`, 227 | `\u008b`: `^K`, 228 | `\u008c`: `^L`, 229 | `\u008d`: `^M`, 230 | `\u008e`: `^N`, 231 | `\u008f`: `^O`, 232 | `\u0090`: `^P`, 233 | `\u0091`: `^Q`, 234 | `\u0092`: `^R`, 235 | `\u0093`: `^S`, 236 | `\u0094`: `^T`, 237 | `\u0095`: `^U`, 238 | `\u0096`: `^V`, 239 | `\u0097`: `^W`, 240 | `\u0098`: `^X`, 241 | `\u0099`: `^Y`, 242 | `\u009a`: `^Z`, 243 | `\u009b`: `^[`, 244 | `\u009c`: `^\\`, 245 | `\u009d`: `^]`, 246 | `\u009e`: `^^`, 247 | `\u009f`: `^_`, 248 | } 249 | if c, ok := m[strings.ToLower(string(b))]; ok { 250 | return []byte(c), true 251 | } 252 | return nil, false 253 | } 254 | -------------------------------------------------------------------------------- /pkg/asciisanitizer/sanitizer_test.go: -------------------------------------------------------------------------------- 1 | package asciisanitizer 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "testing/iotest" 7 | 8 | "github.com/stretchr/testify/require" 9 | "golang.org/x/text/transform" 10 | ) 11 | 12 | func TestSanitizerTransform(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | json bool 16 | input string 17 | want string 18 | }{ 19 | { 20 | name: "No control characters", 21 | input: "The quick brown fox jumped over the lazy dog", 22 | want: "The quick brown fox jumped over the lazy dog", 23 | }, 24 | { 25 | name: "JSON sanitization maintains valid JSON", 26 | json: true, 27 | input: `\u001B \\u001B \\\u001B \\\\u001B \\u001B\\u001B`, 28 | want: `^[ \\^[ \\^[ \\\\^[ \\^[\\^[`, 29 | }, 30 | { 31 | name: "JSON C0 control character", 32 | json: true, 33 | input: `0\u0000`, 34 | want: "0^@", 35 | }, 36 | { 37 | name: "JSON C0 control characters", 38 | json: true, 39 | input: `0\u0000 1\u0001 2\u0002 3\u0003 4\u0004 5\u0005 6\u0006 7\u0007 8\u0008 9\u0009 ` + 40 | `A\u000a B\u000b C\u000c D\u000d E\u000e F\u000f ` + 41 | `10\u0010 11\u0011 12\u0012 13\u0013 14\u0014 15\u0015 16\u0016 17\u0017 18\u0018 19\u0019 ` + 42 | `1A\u001a 1B\u001b 1C\u001c 1D\u001d 1E\u001e 1F\u001f`, 43 | want: `0^@ 1^A 2^B 3^C 4^D 5^E 6^F 7^G 8^H 9\u0009 ` + 44 | `A\u000a B\u000b C^L D\u000d E^N F^O ` + 45 | `10^P 11^Q 12^R 13^S 14^T 15^U 16^V 17^W 18^X 19^Y ` + 46 | `1A^Z 1B^[ 1C^\\ 1D^] 1E^^ 1F^_`, 47 | }, 48 | { 49 | name: "JSON C1 control characters", 50 | json: true, 51 | input: `80\u0080 81\u0081 82\u0082 83\u0083 84\u0084 85\u0085 86\u0086 87\u0087 88\u0088 89\u0089 ` + 52 | `8A\u008a 8B\u008b 8C\u008c 8D\u008d 8E\u008e 8F\u008f ` + 53 | `90\u0090 91\u0091 92\u0092 93\u0093 94\u0094 95\u0095 96\u0096 97\u0097 98\u0098 99\u0099 ` + 54 | `9A\u009a 9B\u009b 9C\u009c 9D\u009d 9E\u009e 9F\u009f`, 55 | want: `80^@ 81^A 82^B 83^C 84^D 85^E 86^F 87^G 88^H 89^I ` + 56 | `8A^J 8B^K 8C^L 8D^M 8E^N 8F^O ` + 57 | `90^P 91^Q 92^R 93^S 94^T 95^U 96^V 97^W 98^X 99^Y ` + 58 | `9A^Z 9B^[ 9C^\\ 9D^] 9E^^ 9F^_`, 59 | }, 60 | { 61 | name: "C0 control character", 62 | input: "0\x00", 63 | want: "0^@", 64 | }, 65 | { 66 | name: "C0 control characters", 67 | input: "0\x00 1\x01 2\x02 3\x03 4\x04 5\x05 6\x06 7\x07 8\x08 9\x09 " + 68 | "A\x0A B\x0B C\x0C D\x0D E\x0E F\x0F " + 69 | "10\x10 11\x11 12\x12 13\x13 14\x14 15\x15 16\x16 17\x17 18\x18 19\x19 " + 70 | "1A\x1A 1B\x1B 1C\x1C 1D\x1D 1E\x1E 1F\x1F", 71 | want: "0^@ 1^A 2^B 3^C 4^D 5^E 6^F 7^G 8^H 9\t " + 72 | "A\n B\v C^L D\r E^N F^O " + 73 | "10^P 11^Q 12^R 13^S 14^T 15^U 16^V 17^W 18^X 19^Y " + 74 | "1A^Z 1B^[ 1C^\\\\ 1D^] 1E^^ 1F^_", 75 | }, 76 | { 77 | name: "C1 control character", 78 | input: "80\xC2\x80", 79 | want: "80^@", 80 | }, 81 | { 82 | name: "C1 control characters", 83 | input: "80\xC2\x80 81\xC2\x81 82\xC2\x82 83\xC2\x83 84\xC2\x84 85\xC2\x85 86\xC2\x86 87\xC2\x87 88\xC2\x88 89\xC2\x89 " + 84 | "8A\xC2\x8A 8B\xC2\x8B 8C\xC2\x8C 8D\xC2\x8D 8E\xC2\x8E 8F\xC2\x8F " + 85 | "90\xC2\x90 91\xC2\x91 92\xC2\x92 93\xC2\x93 94\xC2\x94 95\xC2\x95 96\xC2\x96 97\xC2\x97 98\xC2\x98 99\xC2\x99 " + 86 | "9A\xC2\x9A 9B\xC2\x9B 9C\xC2\x9C 9D\xC2\x9D 9E\xC2\x9E 9F\xC2\x9F", 87 | want: "80^@ 81^A 82^B 83^C 84^D 85^E 86^F 87^G 88^H 89^I " + 88 | "8A^J 8B^K 8C^L 8D^M 8E^N 8F^O " + 89 | "90^P 91^Q 92^R 93^S 94^T 95^U 96^V 97^W 98^X 99^Y " + 90 | "9A^Z 9B^[ 9C^\\\\ 9D^] 9E^^ 9F^_", 91 | }, 92 | } 93 | for _, tt := range tests { 94 | t.Run(tt.name, func(t *testing.T) { 95 | sanitizer := &Sanitizer{JSON: tt.json} 96 | reader := bytes.NewReader([]byte(tt.input)) 97 | transformReader := transform.NewReader(reader, sanitizer) 98 | err := iotest.TestReader(transformReader, []byte(tt.want)) 99 | require.NoError(t, err) 100 | }) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /pkg/auth/auth.go: -------------------------------------------------------------------------------- 1 | // Package auth is a set of functions for retrieving authentication tokens 2 | // and authenticated hosts. 3 | package auth 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "os/exec" 9 | "strings" 10 | 11 | "github.com/cli/go-gh/v2/internal/set" 12 | "github.com/cli/go-gh/v2/pkg/config" 13 | "github.com/cli/safeexec" 14 | ) 15 | 16 | const ( 17 | codespaces = "CODESPACES" 18 | defaultSource = "default" 19 | ghEnterpriseToken = "GH_ENTERPRISE_TOKEN" 20 | ghHost = "GH_HOST" 21 | ghToken = "GH_TOKEN" 22 | github = "github.com" 23 | githubEnterpriseToken = "GITHUB_ENTERPRISE_TOKEN" 24 | githubToken = "GITHUB_TOKEN" 25 | hostsKey = "hosts" 26 | localhost = "github.localhost" 27 | oauthToken = "oauth_token" 28 | tenancyHost = "ghe.com" // TenancyHost is the domain suffix of a tenancy GitHub instance. 29 | ) 30 | 31 | // TokenForHost retrieves an authentication token and the source of that token for the specified 32 | // host. The source can be either an environment variable, configuration file, or the system 33 | // keyring. In the latter case, this shells out to "gh auth token" to obtain the token. 34 | // 35 | // Returns "", "default" if no applicable token is found. 36 | func TokenForHost(host string) (string, string) { 37 | if token, source := TokenFromEnvOrConfig(host); token != "" { 38 | return token, source 39 | } 40 | 41 | ghExe := os.Getenv("GH_PATH") 42 | if ghExe == "" { 43 | ghExe, _ = safeexec.LookPath("gh") 44 | } 45 | 46 | if ghExe != "" { 47 | if token, source := tokenFromGh(ghExe, host); token != "" { 48 | return token, source 49 | } 50 | } 51 | 52 | return "", defaultSource 53 | } 54 | 55 | // TokenFromEnvOrConfig retrieves an authentication token from environment variables or the config 56 | // file as fallback, but does not support reading the token from system keyring. Most consumers 57 | // should use TokenForHost. 58 | func TokenFromEnvOrConfig(host string) (string, string) { 59 | cfg, _ := config.Read(nil) 60 | return tokenForHost(cfg, host) 61 | } 62 | 63 | func tokenForHost(cfg *config.Config, host string) (string, string) { 64 | normalizedHost := NormalizeHostname(host) 65 | // This code is currently the exact opposite of IsEnterprise. However, we have chosen 66 | // to write it separately, directly in line, because it is much clearer in the exact 67 | // scenarios that we expect to use GH_TOKEN and GITHUB_TOKEN. 68 | if normalizedHost == github || IsTenancy(normalizedHost) || normalizedHost == localhost { 69 | if token := os.Getenv(ghToken); token != "" { 70 | return token, ghToken 71 | } 72 | 73 | if token := os.Getenv(githubToken); token != "" { 74 | return token, githubToken 75 | } 76 | } else { 77 | if token := os.Getenv(ghEnterpriseToken); token != "" { 78 | return token, ghEnterpriseToken 79 | } 80 | 81 | if token := os.Getenv(githubEnterpriseToken); token != "" { 82 | return token, githubEnterpriseToken 83 | } 84 | } 85 | 86 | // If config is nil, something has failed much earlier and it's probably 87 | // more correct to panic because we don't expect to support anything 88 | // where the config isn't available, but that would be a breaking change, 89 | // so it's worth thinking about carefully, if we wanted to rework this. 90 | if cfg == nil { 91 | return "", defaultSource 92 | } 93 | 94 | token, err := cfg.Get([]string{hostsKey, normalizedHost, oauthToken}) 95 | if err != nil { 96 | return "", defaultSource 97 | } 98 | 99 | return token, oauthToken 100 | } 101 | 102 | func tokenFromGh(path string, host string) (string, string) { 103 | cmd := exec.Command(path, "auth", "token", "--secure-storage", "--hostname", host) 104 | result, err := cmd.Output() 105 | if err != nil { 106 | return "", "gh" 107 | } 108 | return strings.TrimSpace(string(result)), "gh" 109 | } 110 | 111 | // KnownHosts retrieves a list of hosts that have corresponding 112 | // authentication tokens, either from environment variables 113 | // or from the configuration file. 114 | // Returns an empty string slice if no hosts are found. 115 | func KnownHosts() []string { 116 | cfg, _ := config.Read(nil) 117 | return knownHosts(cfg) 118 | } 119 | 120 | func knownHosts(cfg *config.Config) []string { 121 | hosts := set.NewStringSet() 122 | if host := os.Getenv(ghHost); host != "" { 123 | hosts.Add(host) 124 | } 125 | if token, _ := tokenForHost(cfg, github); token != "" { 126 | hosts.Add(github) 127 | } 128 | if cfg != nil { 129 | keys, err := cfg.Keys([]string{hostsKey}) 130 | if err == nil { 131 | hosts.AddValues(keys) 132 | } 133 | } 134 | return hosts.ToSlice() 135 | } 136 | 137 | // DefaultHost retrieves an authenticated host and the source of host. 138 | // The source can be either an environment variable or from the 139 | // configuration file. 140 | // Returns "github.com", "default" if no viable host is found. 141 | func DefaultHost() (string, string) { 142 | cfg, _ := config.Read(nil) 143 | return defaultHost(cfg) 144 | } 145 | 146 | func defaultHost(cfg *config.Config) (string, string) { 147 | if host := os.Getenv(ghHost); host != "" { 148 | return host, ghHost 149 | } 150 | if cfg != nil { 151 | keys, err := cfg.Keys([]string{hostsKey}) 152 | if err == nil && len(keys) == 1 { 153 | return keys[0], hostsKey 154 | } 155 | } 156 | return github, defaultSource 157 | } 158 | 159 | // IsEnterprise determines if a provided host is a GitHub Enterprise Server instance, 160 | // rather than GitHub.com, a tenancy GitHub instance, or github.localhost. 161 | func IsEnterprise(host string) bool { 162 | // Note that if you are making changes here, you should also consider making the equivalent 163 | // in tokenForHost, which is the exact opposite of this function. 164 | normalizedHost := NormalizeHostname(host) 165 | return normalizedHost != github && normalizedHost != localhost && !IsTenancy(normalizedHost) 166 | } 167 | 168 | // IsTenancy determines if a provided host is a tenancy GitHub instance, 169 | // rather than GitHub.com or a GitHub Enterprise Server instance. 170 | func IsTenancy(host string) bool { 171 | normalizedHost := NormalizeHostname(host) 172 | return strings.HasSuffix(normalizedHost, "."+tenancyHost) 173 | } 174 | 175 | // NormalizeHostname ensures the host matches the values used throughout 176 | // the rest of the codebase with respect to hostnames. These are github, 177 | // localhost, and tenancyHost. 178 | func NormalizeHostname(host string) string { 179 | hostname := strings.ToLower(host) 180 | if strings.HasSuffix(hostname, "."+github) { 181 | return github 182 | } 183 | if strings.HasSuffix(hostname, "."+localhost) { 184 | return localhost 185 | } 186 | // This has been copied over from the cli/cli NormalizeHostname function 187 | // to ensure compatible behaviour but we don't fully understand when or 188 | // why it would be useful here. We can't see what harm will come of 189 | // duplicating the logic. 190 | if before, found := cutSuffix(hostname, "."+tenancyHost); found { 191 | idx := strings.LastIndex(before, ".") 192 | return fmt.Sprintf("%s.%s", before[idx+1:], tenancyHost) 193 | } 194 | return hostname 195 | } 196 | 197 | // Backport strings.CutSuffix from Go 1.20. 198 | func cutSuffix(s, suffix string) (string, bool) { 199 | if !strings.HasSuffix(s, suffix) { 200 | return s, false 201 | } 202 | return s[:len(s)-len(suffix)], true 203 | } 204 | -------------------------------------------------------------------------------- /pkg/browser/browser.go: -------------------------------------------------------------------------------- 1 | // Package browser facilitates opening of URLs in a web browser. 2 | package browser 3 | 4 | import ( 5 | "fmt" 6 | "io" 7 | "net/url" 8 | "os" 9 | "os/exec" 10 | 11 | cliBrowser "github.com/cli/browser" 12 | "github.com/cli/go-gh/v2/pkg/config" 13 | "github.com/cli/safeexec" 14 | "github.com/google/shlex" 15 | ) 16 | 17 | // Browser represents a web browser that can be used to open up URLs. 18 | type Browser struct { 19 | launcher string 20 | stderr io.Writer 21 | stdout io.Writer 22 | } 23 | 24 | // New initializes a Browser. If a launcher is not specified 25 | // one is determined based on environment variables or from the 26 | // configuration file. 27 | // The order of precedence for determining a launcher is: 28 | // - Specified launcher; 29 | // - GH_BROWSER environment variable; 30 | // - browser option from configuration file; 31 | // - BROWSER environment variable. 32 | func New(launcher string, stdout, stderr io.Writer) *Browser { 33 | if launcher == "" { 34 | launcher = resolveLauncher() 35 | } 36 | b := &Browser{ 37 | launcher: launcher, 38 | stderr: stderr, 39 | stdout: stdout, 40 | } 41 | return b 42 | } 43 | 44 | // Browse opens the launcher and navigates to the specified URL. 45 | func (b *Browser) Browse(url string) error { 46 | return b.browse(url, nil) 47 | } 48 | 49 | func (b *Browser) browse(url string, env []string) error { 50 | // Ensure the URL is supported including the scheme, 51 | // overwrite `url` for use within the function. 52 | urlParsed, err := isPossibleProtocol(url) 53 | if err != nil { 54 | return err 55 | } 56 | 57 | url = urlParsed.String() 58 | 59 | // Use default `gh` browsing module for opening URL if not customized. 60 | if b.launcher == "" { 61 | return cliBrowser.OpenURL(url) 62 | } 63 | 64 | launcherArgs, err := shlex.Split(b.launcher) 65 | if err != nil { 66 | return err 67 | } 68 | launcherExe, err := safeexec.LookPath(launcherArgs[0]) 69 | if err != nil { 70 | return err 71 | } 72 | args := append(launcherArgs[1:], url) 73 | cmd := exec.Command(launcherExe, args...) 74 | cmd.Stdout = b.stdout 75 | cmd.Stderr = b.stderr 76 | if env != nil { 77 | cmd.Env = env 78 | } 79 | return cmd.Run() 80 | } 81 | 82 | func resolveLauncher() string { 83 | if ghBrowser := os.Getenv("GH_BROWSER"); ghBrowser != "" { 84 | return ghBrowser 85 | } 86 | cfg, err := config.Read(nil) 87 | if err == nil { 88 | if cfgBrowser, _ := cfg.Get([]string{"browser"}); cfgBrowser != "" { 89 | return cfgBrowser 90 | } 91 | } 92 | return os.Getenv("BROWSER") 93 | } 94 | 95 | func isSupportedScheme(scheme string) bool { 96 | switch scheme { 97 | case "http", "https", "vscode", "vscode-insiders": 98 | return true 99 | default: 100 | return false 101 | } 102 | } 103 | 104 | func isPossibleProtocol(u string) (*url.URL, error) { 105 | // Parse URL for known supported schemes before handling unknown cases. 106 | urlParsed, err := url.Parse(u) 107 | if err != nil { 108 | return nil, fmt.Errorf("opening unparsable URL is unsupported: %s", u) 109 | } 110 | 111 | if isSupportedScheme(urlParsed.Scheme) { 112 | return urlParsed, nil 113 | } 114 | 115 | // Disallow any unrecognized URL schemes if explicitly present. 116 | if urlParsed.Scheme != "" { 117 | return nil, fmt.Errorf("opening unsupport URL scheme: %s", u) 118 | } 119 | 120 | // Disallow URLs that match existing files or directories on the filesystem 121 | // as these could be executables or executed by the launcher browser due to 122 | // the file extension and/or associated application. 123 | // 124 | // Symlinks should not be resolved in order to avoid broken links or other 125 | // vulnerabilities trying to resolve them. 126 | if fileInfo, _ := os.Lstat(u); fileInfo != nil { 127 | return nil, fmt.Errorf("opening files or directories is unsupported: %s", u) 128 | } 129 | 130 | // Disallow URLs that match executables found in the user path. 131 | exec, _ := safeexec.LookPath(u) 132 | if exec != "" { 133 | return nil, fmt.Errorf("opening executables is unsupported: %s", u) 134 | } 135 | 136 | // Otherwise, assume HTTP URL using `https` to ensure secure browsing. 137 | urlParsed.Scheme = "https" 138 | return urlParsed, nil 139 | } 140 | -------------------------------------------------------------------------------- /pkg/config/errors.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // InvalidConfigFileError represents an error when trying to read a config file. 8 | type InvalidConfigFileError struct { 9 | Path string 10 | Err error 11 | } 12 | 13 | // Allow InvalidConfigFileError to satisfy error interface. 14 | func (e *InvalidConfigFileError) Error() string { 15 | return fmt.Sprintf("invalid config file %s: %s", e.Path, e.Err) 16 | } 17 | 18 | // Allow InvalidConfigFileError to be unwrapped. 19 | func (e *InvalidConfigFileError) Unwrap() error { 20 | return e.Err 21 | } 22 | 23 | // KeyNotFoundError represents an error when trying to find a config key 24 | // that does not exist. 25 | type KeyNotFoundError struct { 26 | Key string 27 | } 28 | 29 | // Allow KeyNotFoundError to satisfy error interface. 30 | func (e *KeyNotFoundError) Error() string { 31 | return fmt.Sprintf("could not find key %q", e.Key) 32 | } 33 | -------------------------------------------------------------------------------- /pkg/jq/jq.go: -------------------------------------------------------------------------------- 1 | // Package jq facilitates processing of JSON strings using jq expressions. 2 | package jq 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "math" 11 | "os" 12 | "strconv" 13 | "strings" 14 | 15 | "github.com/cli/go-gh/v2/pkg/jsonpretty" 16 | "github.com/itchyny/gojq" 17 | ) 18 | 19 | // Evaluate a jq expression against an input and write it to an output. 20 | // Any top-level scalar values produced by the jq expression are written out 21 | // directly, as raw values and not as JSON scalars, similar to how jq --raw 22 | // works. 23 | func Evaluate(input io.Reader, output io.Writer, expr string) error { 24 | return EvaluateFormatted(input, output, expr, "", false) 25 | } 26 | 27 | // Evaluate a jq expression against an input and write it to an output, 28 | // optionally with indentation and colorization. Any top-level scalar values 29 | // produced by the jq expression are written out directly, as raw values and not 30 | // as JSON scalars, similar to how jq --raw works. 31 | func EvaluateFormatted(input io.Reader, output io.Writer, expr string, indent string, colorize bool) error { 32 | query, err := gojq.Parse(expr) 33 | if err != nil { 34 | var e *gojq.ParseError 35 | if errors.As(err, &e) { 36 | str, line, column := getLineColumn(expr, e.Offset-len(e.Token)) 37 | return fmt.Errorf( 38 | "failed to parse jq expression (line %d, column %d)\n %s\n %*c %w", 39 | line, column, str, column, '^', err, 40 | ) 41 | } 42 | return err 43 | } 44 | 45 | code, err := gojq.Compile( 46 | query, 47 | gojq.WithEnvironLoader(func() []string { 48 | return os.Environ() 49 | })) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | jsonData, err := io.ReadAll(input) 55 | if err != nil { 56 | return err 57 | } 58 | 59 | var responseData interface{} 60 | err = json.Unmarshal(jsonData, &responseData) 61 | if err != nil { 62 | return err 63 | } 64 | 65 | enc := prettyEncoder{ 66 | w: output, 67 | indent: indent, 68 | colorize: colorize, 69 | } 70 | 71 | iter := code.Run(responseData) 72 | for { 73 | v, ok := iter.Next() 74 | if !ok { 75 | break 76 | } 77 | if err, isErr := v.(error); isErr { 78 | var e *gojq.HaltError 79 | if errors.As(err, &e) && e.Value() == nil { 80 | break 81 | } 82 | return err 83 | } 84 | if text, e := jsonScalarToString(v); e == nil { 85 | _, err := fmt.Fprintln(output, text) 86 | if err != nil { 87 | return err 88 | } 89 | } else { 90 | if err = enc.Encode(v); err != nil { 91 | return err 92 | } 93 | } 94 | } 95 | 96 | return nil 97 | } 98 | 99 | func jsonScalarToString(input interface{}) (string, error) { 100 | switch tt := input.(type) { 101 | case string: 102 | return tt, nil 103 | case float64: 104 | if math.Trunc(tt) == tt { 105 | return strconv.FormatFloat(tt, 'f', 0, 64), nil 106 | } else { 107 | return strconv.FormatFloat(tt, 'f', 2, 64), nil 108 | } 109 | case nil: 110 | return "", nil 111 | case bool: 112 | return fmt.Sprintf("%v", tt), nil 113 | default: 114 | return "", fmt.Errorf("cannot convert type to string: %v", tt) 115 | } 116 | } 117 | 118 | type prettyEncoder struct { 119 | w io.Writer 120 | indent string 121 | colorize bool 122 | } 123 | 124 | func (p prettyEncoder) Encode(v any) error { 125 | var b []byte 126 | var err error 127 | if p.indent == "" { 128 | b, err = json.Marshal(v) 129 | } else { 130 | b, err = json.MarshalIndent(v, "", p.indent) 131 | } 132 | if err != nil { 133 | return err 134 | } 135 | if !p.colorize { 136 | if _, err := p.w.Write(b); err != nil { 137 | return err 138 | } 139 | if _, err := p.w.Write([]byte{'\n'}); err != nil { 140 | return err 141 | } 142 | return nil 143 | } 144 | return jsonpretty.Format(p.w, bytes.NewReader(b), p.indent, true) 145 | } 146 | 147 | func getLineColumn(expr string, offset int) (string, int, int) { 148 | for line := 1; ; line++ { 149 | index := strings.Index(expr, "\n") 150 | if index < 0 { 151 | return expr, line, offset + 1 152 | } 153 | if index >= offset { 154 | return expr[:index], line, offset + 1 155 | } 156 | expr = expr[index+1:] 157 | offset -= index + 1 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /pkg/jq/jq_test.go: -------------------------------------------------------------------------------- 1 | package jq 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/MakeNowJust/heredoc" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestEvaluateFormatted(t *testing.T) { 14 | t.Setenv("CODE", "code_c") 15 | type args struct { 16 | json io.Reader 17 | expr string 18 | indent string 19 | colorize bool 20 | } 21 | tests := []struct { 22 | name string 23 | args args 24 | wantW string 25 | wantErr bool 26 | wantErrMsg string 27 | }{ 28 | { 29 | name: "simple", 30 | args: args{ 31 | json: strings.NewReader(`{"name":"Mona", "arms":8}`), 32 | expr: `.name`, 33 | indent: "", 34 | colorize: false, 35 | }, 36 | wantW: "Mona\n", 37 | }, 38 | { 39 | name: "multiple queries", 40 | args: args{ 41 | json: strings.NewReader(`{"name":"Mona", "arms":8}`), 42 | expr: `.name,.arms`, 43 | indent: "", 44 | colorize: false, 45 | }, 46 | wantW: "Mona\n8\n", 47 | }, 48 | { 49 | name: "object as JSON", 50 | args: args{ 51 | json: strings.NewReader(`{"user":{"login":"monalisa"}}`), 52 | expr: `.user`, 53 | indent: "", 54 | colorize: false, 55 | }, 56 | wantW: "{\"login\":\"monalisa\"}\n", 57 | }, 58 | { 59 | name: "object as JSON, indented", 60 | args: args{ 61 | json: strings.NewReader(`{"user":{"login":"monalisa"}}`), 62 | expr: `.user`, 63 | indent: " ", 64 | colorize: false, 65 | }, 66 | wantW: "{\n \"login\": \"monalisa\"\n}\n", 67 | }, 68 | { 69 | name: "object as JSON, indented & colorized", 70 | args: args{ 71 | json: strings.NewReader(`{"user":{"login":"monalisa"}}`), 72 | expr: `.user`, 73 | indent: " ", 74 | colorize: true, 75 | }, 76 | wantW: "\x1b[1;38m{\x1b[m\n" + 77 | " \x1b[1;34m\"login\"\x1b[m\x1b[1;38m:\x1b[m" + 78 | " \x1b[32m\"monalisa\"\x1b[m\n" + 79 | "\x1b[1;38m}\x1b[m\n", 80 | }, 81 | { 82 | name: "empty array", 83 | args: args{ 84 | json: strings.NewReader(`[]`), 85 | expr: `., [], unique`, 86 | indent: "", 87 | colorize: false, 88 | }, 89 | wantW: "[]\n[]\n[]\n", 90 | }, 91 | { 92 | name: "empty array, colorized", 93 | args: args{ 94 | json: strings.NewReader(`[]`), 95 | expr: `.`, 96 | indent: "", 97 | colorize: true, 98 | }, 99 | wantW: "\x1b[1;38m[\x1b[m\x1b[1;38m]\x1b[m\n", 100 | }, 101 | { 102 | name: "complex", 103 | args: args{ 104 | json: strings.NewReader(heredoc.Doc(`[ 105 | { 106 | "title": "First title", 107 | "labels": [{"name":"bug"}, {"name":"help wanted"}] 108 | }, 109 | { 110 | "title": "Second but not last", 111 | "labels": [] 112 | }, 113 | { 114 | "title": "Alas, tis' the end", 115 | "labels": [{}, {"name":"feature"}] 116 | } 117 | ]`)), 118 | expr: `.[] | [.title,(.labels | map(.name) | join(","))] | @tsv`, 119 | indent: "", 120 | colorize: false, 121 | }, 122 | wantW: heredoc.Doc(` 123 | First title bug,help wanted 124 | Second but not last 125 | Alas, tis' the end ,feature 126 | `), 127 | }, 128 | { 129 | name: "with env var", 130 | args: args{ 131 | json: strings.NewReader(heredoc.Doc(`[ 132 | { 133 | "title": "code_a", 134 | "labels": [{"name":"bug"}, {"name":"help wanted"}] 135 | }, 136 | { 137 | "title": "code_b", 138 | "labels": [] 139 | }, 140 | { 141 | "title": "code_c", 142 | "labels": [{}, {"name":"feature"}] 143 | } 144 | ]`)), 145 | expr: `.[] | select(.title == env.CODE) | .labels`, 146 | indent: " ", 147 | colorize: false, 148 | }, 149 | wantW: "[\n {},\n {\n \"name\": \"feature\"\n }\n]\n", 150 | }, 151 | { 152 | name: "mixing scalars, arrays and objects", 153 | args: args{ 154 | json: strings.NewReader(heredoc.Doc(`[ 155 | "foo", 156 | true, 157 | 42, 158 | [17, 23], 159 | {"foo": "bar"} 160 | ]`)), 161 | expr: `.[]`, 162 | indent: " ", 163 | colorize: true, 164 | }, 165 | wantW: "foo\ntrue\n42\n" + 166 | "\x1b[1;38m[\x1b[m\n" + 167 | " 17\x1b[1;38m,\x1b[m\n" + 168 | " 23\n" + 169 | "\x1b[1;38m]\x1b[m\n" + 170 | "\x1b[1;38m{\x1b[m\n" + 171 | " \x1b[1;34m\"foo\"\x1b[m\x1b[1;38m:\x1b[m" + 172 | " \x1b[32m\"bar\"\x1b[m\n" + 173 | "\x1b[1;38m}\x1b[m\n", 174 | }, 175 | { 176 | name: "halt function", 177 | args: args{ 178 | json: strings.NewReader("{}"), 179 | expr: `1,halt,2`, 180 | }, 181 | wantW: "1\n", 182 | }, 183 | { 184 | name: "halt_error function", 185 | args: args{ 186 | json: strings.NewReader("{}"), 187 | expr: `1,halt_error,2`, 188 | }, 189 | wantW: "1\n", 190 | wantErr: true, 191 | wantErrMsg: "halt error: {}", 192 | }, 193 | { 194 | name: "invalid one-line query", 195 | args: args{ 196 | json: strings.NewReader("{}"), 197 | expr: `[1,2,,3]`, 198 | }, 199 | wantErr: true, 200 | wantErrMsg: `failed to parse jq expression (line 1, column 6) 201 | [1,2,,3] 202 | ^ unexpected token ","`, 203 | }, 204 | { 205 | name: "invalid multi-line query", 206 | args: args{ 207 | json: strings.NewReader("{}"), 208 | expr: `[ 209 | 1,,2 210 | ,3]`, 211 | }, 212 | wantErr: true, 213 | wantErrMsg: `failed to parse jq expression (line 2, column 5) 214 | 1,,2 215 | ^ unexpected token ","`, 216 | }, 217 | { 218 | name: "invalid unterminated query", 219 | args: args{ 220 | json: strings.NewReader("{}"), 221 | expr: `[1,`, 222 | }, 223 | wantErr: true, 224 | wantErrMsg: `failed to parse jq expression (line 1, column 4) 225 | [1, 226 | ^ unexpected EOF`, 227 | }, 228 | } 229 | for _, tt := range tests { 230 | t.Run(tt.name, func(t *testing.T) { 231 | w := &bytes.Buffer{} 232 | err := EvaluateFormatted(tt.args.json, w, tt.args.expr, tt.args.indent, tt.args.colorize) 233 | if tt.wantErr { 234 | assert.Error(t, err) 235 | assert.EqualError(t, err, tt.wantErrMsg) 236 | return 237 | } 238 | assert.NoError(t, err) 239 | assert.Equal(t, tt.wantW, w.String()) 240 | }) 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /pkg/jsonpretty/format.go: -------------------------------------------------------------------------------- 1 | // Package jsonpretty implements a terminal pretty-printer for JSON. 2 | package jsonpretty 3 | 4 | import ( 5 | "bytes" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "strings" 10 | ) 11 | 12 | const ( 13 | colorDelim = "\x1b[1;38m" // bright white 14 | colorKey = "\x1b[1;34m" // bright blue 15 | colorNull = "\x1b[36m" // cyan 16 | colorString = "\x1b[32m" // green 17 | colorBool = "\x1b[33m" // yellow 18 | colorReset = "\x1b[m" 19 | ) 20 | 21 | // Format reads JSON from r and writes a prettified version of it to w. 22 | func Format(w io.Writer, r io.Reader, indent string, colorize bool) error { 23 | dec := json.NewDecoder(r) 24 | dec.UseNumber() 25 | 26 | c := func(ansi string) string { 27 | if !colorize { 28 | return "" 29 | } 30 | return ansi 31 | } 32 | 33 | var idx int 34 | var stack []json.Delim 35 | 36 | for { 37 | t, err := dec.Token() 38 | if err == io.EOF { 39 | break 40 | } 41 | if err != nil { 42 | return err 43 | } 44 | 45 | switch tt := t.(type) { 46 | case json.Delim: 47 | switch tt { 48 | case '{', '[': 49 | stack = append(stack, tt) 50 | idx = 0 51 | if _, err := fmt.Fprint(w, c(colorDelim), tt, c(colorReset)); err != nil { 52 | return err 53 | } 54 | if dec.More() { 55 | if _, err := fmt.Fprint(w, "\n", strings.Repeat(indent, len(stack))); err != nil { 56 | return err 57 | } 58 | } 59 | continue 60 | case '}', ']': 61 | stack = stack[:len(stack)-1] 62 | idx = 0 63 | if _, err := fmt.Fprint(w, c(colorDelim), tt, c(colorReset)); err != nil { 64 | return err 65 | } 66 | } 67 | default: 68 | b, err := marshalJSON(tt) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | isKey := len(stack) > 0 && stack[len(stack)-1] == '{' && idx%2 == 0 74 | idx++ 75 | 76 | var color string 77 | if isKey { 78 | color = colorKey 79 | } else if tt == nil { 80 | color = colorNull 81 | } else { 82 | switch t.(type) { 83 | case string: 84 | color = colorString 85 | case bool: 86 | color = colorBool 87 | } 88 | } 89 | 90 | if color != "" { 91 | if _, err := fmt.Fprint(w, c(color)); err != nil { 92 | return err 93 | } 94 | } 95 | if _, err := w.Write(b); err != nil { 96 | return err 97 | } 98 | if color != "" { 99 | if _, err := fmt.Fprint(w, c(colorReset)); err != nil { 100 | return err 101 | } 102 | } 103 | 104 | if isKey { 105 | if _, err := fmt.Fprint(w, c(colorDelim), ":", c(colorReset), " "); err != nil { 106 | return err 107 | } 108 | continue 109 | } 110 | } 111 | 112 | if dec.More() { 113 | if _, err := fmt.Fprint(w, c(colorDelim), ",", c(colorReset), "\n", strings.Repeat(indent, len(stack))); err != nil { 114 | return err 115 | } 116 | } else if len(stack) > 0 { 117 | if _, err := fmt.Fprint(w, "\n", strings.Repeat(indent, len(stack)-1)); err != nil { 118 | return err 119 | } 120 | } else { 121 | if _, err := fmt.Fprint(w, "\n"); err != nil { 122 | return err 123 | } 124 | } 125 | } 126 | 127 | return nil 128 | } 129 | 130 | // marshalJSON works like json.Marshal, but with HTML-escaping disabled. 131 | func marshalJSON(v interface{}) ([]byte, error) { 132 | buf := bytes.Buffer{} 133 | enc := json.NewEncoder(&buf) 134 | enc.SetEscapeHTML(false) 135 | if err := enc.Encode(v); err != nil { 136 | return nil, err 137 | } 138 | bb := buf.Bytes() 139 | // omit trailing newline added by json.Encoder 140 | if len(bb) > 0 && bb[len(bb)-1] == '\n' { 141 | return bb[:len(bb)-1], nil 142 | } 143 | return bb, nil 144 | } 145 | -------------------------------------------------------------------------------- /pkg/jsonpretty/format_test.go: -------------------------------------------------------------------------------- 1 | package jsonpretty 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "testing" 7 | ) 8 | 9 | func TestWrite(t *testing.T) { 10 | type args struct { 11 | r io.Reader 12 | indent string 13 | colorize bool 14 | } 15 | tests := []struct { 16 | name string 17 | args args 18 | wantW string 19 | wantErr bool 20 | }{ 21 | { 22 | name: "blank", 23 | args: args{ 24 | r: bytes.NewBufferString(``), 25 | indent: "", 26 | colorize: true, 27 | }, 28 | wantW: "", 29 | wantErr: false, 30 | }, 31 | { 32 | name: "empty object", 33 | args: args{ 34 | r: bytes.NewBufferString(`{}`), 35 | indent: "", 36 | colorize: true, 37 | }, 38 | wantW: "\x1b[1;38m{\x1b[m\x1b[1;38m}\x1b[m\n", 39 | wantErr: false, 40 | }, 41 | { 42 | name: "nested object", 43 | args: args{ 44 | r: bytes.NewBufferString(`{"hash":{"a":1,"b":2},"array":[3,4]}`), 45 | indent: "\t", 46 | colorize: true, 47 | }, 48 | wantW: "\x1b[1;38m{\x1b[m\n\t\x1b[1;34m\"hash\"\x1b[m\x1b[1;38m:\x1b[m " + 49 | "\x1b[1;38m{\x1b[m\n\t\t\x1b[1;34m\"a\"\x1b[m\x1b[1;38m:\x1b[m 1\x1b[1;38m,\x1b[m\n\t\t\x1b[1;34m\"b\"\x1b[m\x1b[1;38m:\x1b[m 2\n\t\x1b[1;38m}\x1b[m\x1b[1;38m,\x1b[m" + 50 | "\n\t\x1b[1;34m\"array\"\x1b[m\x1b[1;38m:\x1b[m \x1b[1;38m[\x1b[m\n\t\t3\x1b[1;38m,\x1b[m\n\t\t4\n\t\x1b[1;38m]\x1b[m\n\x1b[1;38m}\x1b[m\n", 51 | wantErr: false, 52 | }, 53 | { 54 | name: "no color", 55 | args: args{ 56 | r: bytes.NewBufferString(`{"hash":{"a":1,"b":2},"array":[3,4]}`), 57 | indent: "\t", 58 | colorize: false, 59 | }, 60 | wantW: "{\n\t\"hash\": {\n\t\t\"a\": 1,\n\t\t\"b\": 2\n\t},\n\t\"array\": [\n\t\t3,\n\t\t4\n\t]\n}\n", 61 | wantErr: false, 62 | }, 63 | { 64 | name: "string", 65 | args: args{ 66 | r: bytes.NewBufferString(`"foo"`), 67 | indent: "", 68 | colorize: true, 69 | }, 70 | wantW: "\x1b[32m\"foo\"\x1b[m\n", 71 | wantErr: false, 72 | }, 73 | { 74 | name: "error", 75 | args: args{ 76 | r: bytes.NewBufferString(`{{`), 77 | indent: "", 78 | colorize: true, 79 | }, 80 | wantW: "\x1b[1;38m{\x1b[m\n", 81 | wantErr: true, 82 | }, 83 | } 84 | for _, tt := range tests { 85 | t.Run(tt.name, func(t *testing.T) { 86 | w := &bytes.Buffer{} 87 | if err := Format(w, tt.args.r, tt.args.indent, tt.args.colorize); (err != nil) != tt.wantErr { 88 | t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr) 89 | return 90 | } 91 | if w.String() != tt.wantW { 92 | t.Errorf("got: %q, want: %q", w.String(), tt.wantW) 93 | } 94 | }) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /pkg/markdown/markdown.go: -------------------------------------------------------------------------------- 1 | // Package markdown facilitates rendering markdown in the terminal. 2 | package markdown 3 | 4 | import ( 5 | "os" 6 | "strings" 7 | 8 | "github.com/charmbracelet/glamour" 9 | xcolor "github.com/cli/go-gh/v2/pkg/x/color" 10 | xmarkdown "github.com/cli/go-gh/v2/pkg/x/markdown" 11 | ) 12 | 13 | // WithoutIndentation is a rendering option that removes indentation from the markdown rendering. 14 | func WithoutIndentation() glamour.TermRendererOption { 15 | overrides := []byte(` 16 | { 17 | "document": { 18 | "margin": 0 19 | }, 20 | "code_block": { 21 | "margin": 0 22 | } 23 | }`) 24 | 25 | return glamour.WithStylesFromJSONBytes(overrides) 26 | } 27 | 28 | // WithoutWrap is a rendering option that set the character limit for soft wraping the markdown rendering. 29 | func WithWrap(w int) glamour.TermRendererOption { 30 | return glamour.WithWordWrap(w) 31 | } 32 | 33 | // WithTheme is a rendering option that sets the theme to use while rendering the markdown. 34 | // It can be used in conjunction with [term.Theme]. 35 | // If the environment variable GLAMOUR_STYLE is set, it will take precedence over the provided theme. 36 | func WithTheme(theme string) glamour.TermRendererOption { 37 | style := os.Getenv("GLAMOUR_STYLE") 38 | if style == "" || style == "auto" { 39 | switch theme { 40 | case "light", "dark": 41 | if xcolor.IsAccessibleColorsEnabled() { 42 | return glamour.WithOptions( 43 | glamour.WithStyles(xmarkdown.AccessibleStyleConfig(theme)), 44 | glamour.WithChromaFormatter("terminal16"), 45 | ) 46 | } 47 | style = theme 48 | default: 49 | style = "notty" 50 | } 51 | } 52 | return glamour.WithStylePath(style) 53 | } 54 | 55 | // WithBaseURL is a rendering option that sets the base URL to use when rendering relative URLs. 56 | func WithBaseURL(u string) glamour.TermRendererOption { 57 | return glamour.WithBaseURL(u) 58 | } 59 | 60 | // Render the markdown string according to the specified rendering options. 61 | // By default emoji are rendered and new lines are preserved. 62 | func Render(text string, opts ...glamour.TermRendererOption) (string, error) { 63 | // Glamour rendering preserves carriage return characters in code blocks, but 64 | // we need to ensure that no such characters are present in the output. 65 | text = strings.ReplaceAll(text, "\r\n", "\n") 66 | opts = append(opts, glamour.WithEmoji(), glamour.WithPreservedNewLines(), glamour.WithTableWrap(false)) 67 | tr, err := glamour.NewTermRenderer(opts...) 68 | if err != nil { 69 | return "", err 70 | } 71 | return tr.Render(text) 72 | } 73 | -------------------------------------------------------------------------------- /pkg/prompter/mock.go: -------------------------------------------------------------------------------- 1 | package prompter 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // PrompterMock provides stubbed out methods for prompting the user for 12 | // use in tests. PrompterMock has a superset of the methods on Prompter 13 | // so they both can satisfy the same interface. 14 | // 15 | // A basic example of how PrompterMock can be used: 16 | // 17 | // type ConfirmPrompter interface { 18 | // Confirm(string, bool) (bool, error) 19 | // } 20 | // 21 | // func PlayGame(prompter ConfirmPrompter) (int, error) { 22 | // confirm, err := prompter.Confirm("Shall we play a game", true) 23 | // if err != nil { 24 | // return 0, err 25 | // } 26 | // if confirm { 27 | // return 1, nil 28 | // } 29 | // return 2, nil 30 | // } 31 | // 32 | // func TestPlayGame(t *testing.T) { 33 | // expectedOutcome := 1 34 | // mock := NewMock(t) 35 | // mock.RegisterConfirm("Shall we play a game", func(prompt string, defaultValue bool) (bool, error) { 36 | // return true, nil 37 | // }) 38 | // outcome, err := PlayGame(mock) 39 | // if err != nil { 40 | // t.Fatalf("unexpected error: %v", err) 41 | // } 42 | // if outcome != expectedOutcome { 43 | // t.Errorf("expected %q, got %q", expectedOutcome, outcome) 44 | // } 45 | // } 46 | type PrompterMock struct { 47 | t *testing.T 48 | selectStubs []selectStub 49 | multiSelectStubs []multiSelectStub 50 | inputStubs []inputStub 51 | passwordStubs []passwordStub 52 | confirmStubs []confirmStub 53 | } 54 | 55 | type selectStub struct { 56 | prompt string 57 | expectedOptions []string 58 | fn func(string, string, []string) (int, error) 59 | } 60 | 61 | type multiSelectStub struct { 62 | prompt string 63 | expectedOptions []string 64 | fn func(string, []string, []string) ([]int, error) 65 | } 66 | 67 | type inputStub struct { 68 | prompt string 69 | fn func(string, string) (string, error) 70 | } 71 | 72 | type passwordStub struct { 73 | prompt string 74 | fn func(string) (string, error) 75 | } 76 | 77 | type confirmStub struct { 78 | Prompt string 79 | Fn func(string, bool) (bool, error) 80 | } 81 | 82 | // NewMock instantiates a new PrompterMock. 83 | func NewMock(t *testing.T) *PrompterMock { 84 | m := &PrompterMock{ 85 | t: t, 86 | selectStubs: []selectStub{}, 87 | multiSelectStubs: []multiSelectStub{}, 88 | inputStubs: []inputStub{}, 89 | passwordStubs: []passwordStub{}, 90 | confirmStubs: []confirmStub{}, 91 | } 92 | t.Cleanup(m.verify) 93 | return m 94 | } 95 | 96 | // Select prompts the user to select an option from a list of options. 97 | func (m *PrompterMock) Select(prompt, defaultValue string, options []string) (int, error) { 98 | var s selectStub 99 | if len(m.selectStubs) == 0 { 100 | return -1, noSuchPromptErr(prompt) 101 | } 102 | s = m.selectStubs[0] 103 | m.selectStubs = m.selectStubs[1:len(m.selectStubs)] 104 | if s.prompt != prompt { 105 | return -1, noSuchPromptErr(prompt) 106 | } 107 | assertOptions(m.t, s.expectedOptions, options) 108 | return s.fn(prompt, defaultValue, options) 109 | } 110 | 111 | // MultiSelect prompts the user to select multiple options from a list of options. 112 | func (m *PrompterMock) MultiSelect(prompt string, defaultValues, options []string) ([]int, error) { 113 | var s multiSelectStub 114 | if len(m.multiSelectStubs) == 0 { 115 | return []int{}, noSuchPromptErr(prompt) 116 | } 117 | s = m.multiSelectStubs[0] 118 | m.multiSelectStubs = m.multiSelectStubs[1:len(m.multiSelectStubs)] 119 | if s.prompt != prompt { 120 | return []int{}, noSuchPromptErr(prompt) 121 | } 122 | assertOptions(m.t, s.expectedOptions, options) 123 | return s.fn(prompt, defaultValues, options) 124 | } 125 | 126 | // Input prompts the user to input a single-line string. 127 | func (m *PrompterMock) Input(prompt, defaultValue string) (string, error) { 128 | var s inputStub 129 | if len(m.inputStubs) == 0 { 130 | return "", noSuchPromptErr(prompt) 131 | } 132 | s = m.inputStubs[0] 133 | m.inputStubs = m.inputStubs[1:len(m.inputStubs)] 134 | if s.prompt != prompt { 135 | return "", noSuchPromptErr(prompt) 136 | } 137 | return s.fn(prompt, defaultValue) 138 | } 139 | 140 | // Password prompts the user to input a single-line string without echoing the input. 141 | func (m *PrompterMock) Password(prompt string) (string, error) { 142 | var s passwordStub 143 | if len(m.passwordStubs) == 0 { 144 | return "", noSuchPromptErr(prompt) 145 | } 146 | s = m.passwordStubs[0] 147 | m.passwordStubs = m.passwordStubs[1:len(m.passwordStubs)] 148 | if s.prompt != prompt { 149 | return "", noSuchPromptErr(prompt) 150 | } 151 | return s.fn(prompt) 152 | } 153 | 154 | // Confirm prompts the user to confirm a yes/no question. 155 | func (m *PrompterMock) Confirm(prompt string, defaultValue bool) (bool, error) { 156 | var s confirmStub 157 | if len(m.confirmStubs) == 0 { 158 | return false, noSuchPromptErr(prompt) 159 | } 160 | s = m.confirmStubs[0] 161 | m.confirmStubs = m.confirmStubs[1:len(m.confirmStubs)] 162 | if s.Prompt != prompt { 163 | return false, noSuchPromptErr(prompt) 164 | } 165 | return s.Fn(prompt, defaultValue) 166 | } 167 | 168 | // RegisterSelect records that a Select prompt should be called. 169 | func (m *PrompterMock) RegisterSelect(prompt string, opts []string, stub func(_, _ string, _ []string) (int, error)) { 170 | m.selectStubs = append(m.selectStubs, selectStub{ 171 | prompt: prompt, 172 | expectedOptions: opts, 173 | fn: stub}) 174 | } 175 | 176 | // RegisterMultiSelect records that a MultiSelect prompt should be called. 177 | func (m *PrompterMock) RegisterMultiSelect(prompt string, d, opts []string, stub func(_ string, _, _ []string) ([]int, error)) { 178 | m.multiSelectStubs = append(m.multiSelectStubs, multiSelectStub{ 179 | prompt: prompt, 180 | expectedOptions: opts, 181 | fn: stub}) 182 | } 183 | 184 | // RegisterInput records that an Input prompt should be called. 185 | func (m *PrompterMock) RegisterInput(prompt string, stub func(_, _ string) (string, error)) { 186 | m.inputStubs = append(m.inputStubs, inputStub{prompt: prompt, fn: stub}) 187 | } 188 | 189 | // RegisterPassword records that a Password prompt should be called. 190 | func (m *PrompterMock) RegisterPassword(prompt string, stub func(string) (string, error)) { 191 | m.passwordStubs = append(m.passwordStubs, passwordStub{prompt: prompt, fn: stub}) 192 | } 193 | 194 | // RegisterConfirm records that a Confirm prompt should be called. 195 | func (m *PrompterMock) RegisterConfirm(prompt string, stub func(_ string, _ bool) (bool, error)) { 196 | m.confirmStubs = append(m.confirmStubs, confirmStub{Prompt: prompt, Fn: stub}) 197 | } 198 | 199 | func (m *PrompterMock) verify() { 200 | errs := []string{} 201 | if len(m.selectStubs) > 0 { 202 | errs = append(errs, "MultiSelect") 203 | } 204 | if len(m.multiSelectStubs) > 0 { 205 | errs = append(errs, "Select") 206 | } 207 | if len(m.inputStubs) > 0 { 208 | errs = append(errs, "Input") 209 | } 210 | if len(m.passwordStubs) > 0 { 211 | errs = append(errs, "Password") 212 | } 213 | if len(m.confirmStubs) > 0 { 214 | errs = append(errs, "Confirm") 215 | } 216 | if len(errs) > 0 { 217 | m.t.Helper() 218 | m.t.Errorf("%d unmatched calls to %s", len(errs), strings.Join(errs, ",")) 219 | } 220 | } 221 | 222 | func noSuchPromptErr(prompt string) error { 223 | return fmt.Errorf("no such prompt '%s'", prompt) 224 | } 225 | 226 | func assertOptions(t *testing.T, expected, actual []string) { 227 | assert.Equal(t, expected, actual) 228 | } 229 | -------------------------------------------------------------------------------- /pkg/prompter/prompter.go: -------------------------------------------------------------------------------- 1 | // Package prompter provides various methods for prompting the user with 2 | // questions for input. 3 | package prompter 4 | 5 | import ( 6 | "fmt" 7 | "io" 8 | "strings" 9 | 10 | "github.com/AlecAivazis/survey/v2" 11 | "github.com/cli/go-gh/v2/pkg/text" 12 | ) 13 | 14 | // Prompter provides methods for prompting the user. 15 | type Prompter struct { 16 | stdin FileReader 17 | stdout FileWriter 18 | stderr FileWriter 19 | } 20 | 21 | // FileWriter provides a minimal writable interface for stdout and stderr. 22 | type FileWriter interface { 23 | io.Writer 24 | Fd() uintptr 25 | } 26 | 27 | // FileReader provides a minimal readable interface for stdin. 28 | type FileReader interface { 29 | io.Reader 30 | Fd() uintptr 31 | } 32 | 33 | // New instantiates a new Prompter. 34 | func New(stdin FileReader, stdout FileWriter, stderr FileWriter) *Prompter { 35 | return &Prompter{ 36 | stdin: stdin, 37 | stdout: stdout, 38 | stderr: stderr, 39 | } 40 | } 41 | 42 | // Select prompts the user to select an option from a list of options. 43 | func (p *Prompter) Select(prompt, defaultValue string, options []string) (int, error) { 44 | var result int 45 | q := &survey.Select{ 46 | Message: prompt, 47 | Options: options, 48 | PageSize: 20, 49 | Filter: latinMatchingFilter, 50 | } 51 | if defaultValue != "" { 52 | for _, o := range options { 53 | if o == defaultValue { 54 | q.Default = defaultValue 55 | break 56 | } 57 | } 58 | } 59 | err := p.ask(q, &result) 60 | return result, err 61 | } 62 | 63 | // MultiSelect prompts the user to select multiple options from a list of options. 64 | func (p *Prompter) MultiSelect(prompt string, defaultValues, options []string) ([]int, error) { 65 | var result []int 66 | q := &survey.MultiSelect{ 67 | Message: prompt, 68 | Options: options, 69 | PageSize: 20, 70 | Filter: latinMatchingFilter, 71 | } 72 | if len(defaultValues) > 0 { 73 | validatedDefault := []string{} 74 | for _, x := range defaultValues { 75 | for _, y := range options { 76 | if x == y { 77 | validatedDefault = append(validatedDefault, x) 78 | } 79 | } 80 | } 81 | q.Default = validatedDefault 82 | } 83 | err := p.ask(q, &result) 84 | return result, err 85 | } 86 | 87 | // Input prompts the user to input a single-line string. 88 | func (p *Prompter) Input(prompt, defaultValue string) (string, error) { 89 | var result string 90 | err := p.ask(&survey.Input{ 91 | Message: prompt, 92 | Default: defaultValue, 93 | }, &result) 94 | return result, err 95 | } 96 | 97 | // Password prompts the user to input a single-line string without echoing the input. 98 | func (p *Prompter) Password(prompt string) (string, error) { 99 | var result string 100 | err := p.ask(&survey.Password{ 101 | Message: prompt, 102 | }, &result) 103 | return result, err 104 | } 105 | 106 | // Confirm prompts the user to confirm a yes/no question. 107 | func (p *Prompter) Confirm(prompt string, defaultValue bool) (bool, error) { 108 | var result bool 109 | err := p.ask(&survey.Confirm{ 110 | Message: prompt, 111 | Default: defaultValue, 112 | }, &result) 113 | return result, err 114 | } 115 | 116 | func (p *Prompter) ask(q survey.Prompt, response interface{}, opts ...survey.AskOpt) error { 117 | opts = append(opts, survey.WithStdio(p.stdin, p.stdout, p.stderr)) 118 | err := survey.AskOne(q, response, opts...) 119 | if err == nil { 120 | return nil 121 | } 122 | return fmt.Errorf("could not prompt: %w", err) 123 | } 124 | 125 | // latinMatchingFilter returns whether the value matches the input filter. 126 | // The strings are compared normalized in case. 127 | // The filter's diactritics are kept as-is, but the value's are normalized, 128 | // so that a missing diactritic in the filter still returns a result. 129 | func latinMatchingFilter(filter, value string, index int) bool { 130 | filter = strings.ToLower(filter) 131 | value = strings.ToLower(value) 132 | // include this option if it matches. 133 | return strings.Contains(value, filter) || strings.Contains(text.RemoveDiacritics(value), filter) 134 | } 135 | -------------------------------------------------------------------------------- /pkg/prompter/prompter_test.go: -------------------------------------------------------------------------------- 1 | package prompter 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "os" 7 | "testing" 8 | 9 | "github.com/cli/go-gh/v2/pkg/term" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func ExamplePrompter() { 14 | term := term.FromEnv() 15 | in, ok := term.In().(*os.File) 16 | if !ok { 17 | log.Fatal("error casting to file") 18 | } 19 | out, ok := term.Out().(*os.File) 20 | if !ok { 21 | log.Fatal("error casting to file") 22 | } 23 | errOut, ok := term.ErrOut().(*os.File) 24 | if !ok { 25 | log.Fatal("error casting to file") 26 | } 27 | prompter := New(in, out, errOut) 28 | response, err := prompter.Confirm("Shall we play a game", true) 29 | if err != nil { 30 | log.Fatal(err) 31 | } 32 | fmt.Println(response) 33 | } 34 | 35 | func TestLatinMatchingFilter(t *testing.T) { 36 | tests := []struct { 37 | name string 38 | filter string 39 | value string 40 | want bool 41 | }{ 42 | { 43 | name: "exact match no diacritics", 44 | filter: "Mikelis", 45 | value: "Mikelis", 46 | want: true, 47 | }, 48 | { 49 | name: "exact match no diacritics", 50 | filter: "Mikelis", 51 | value: "Mikelis", 52 | want: true, 53 | }, 54 | { 55 | name: "exact match diacritics", 56 | filter: "Miķelis", 57 | value: "Miķelis", 58 | want: true, 59 | }, 60 | { 61 | name: "partial match diacritics", 62 | filter: "Miķe", 63 | value: "Miķelis", 64 | want: true, 65 | }, 66 | { 67 | name: "exact match diacritics in value", 68 | filter: "Mikelis", 69 | value: "Miķelis", 70 | want: true, 71 | }, 72 | { 73 | name: "partial match diacritics in filter", 74 | filter: "Miķe", 75 | value: "Miķelis", 76 | want: true, 77 | }, 78 | { 79 | name: "no match when removing diacritics in filter", 80 | filter: "Mielis", 81 | value: "Mikelis", 82 | want: false, 83 | }, 84 | { 85 | name: "no match when removing diacritics in value", 86 | filter: "Mikelis", 87 | value: "Mielis", 88 | want: false, 89 | }, 90 | { 91 | name: "no match diacritics in filter", 92 | filter: "Miķelis", 93 | value: "Mikelis", 94 | want: false, 95 | }, 96 | } 97 | for _, tt := range tests { 98 | t.Run(tt.name, func(t *testing.T) { 99 | assert.Equal(t, latinMatchingFilter(tt.filter, tt.value, 0), tt.want) 100 | }) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /pkg/repository/repository.go: -------------------------------------------------------------------------------- 1 | // Package repository is a set of types and functions for modeling and 2 | // interacting with GitHub repositories. 3 | package repository 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "os" 9 | "strings" 10 | 11 | "github.com/cli/go-gh/v2/internal/git" 12 | "github.com/cli/go-gh/v2/pkg/auth" 13 | "github.com/cli/go-gh/v2/pkg/ssh" 14 | ) 15 | 16 | // Repository holds information representing a GitHub repository. 17 | type Repository struct { 18 | Host string 19 | Name string 20 | Owner string 21 | } 22 | 23 | // Parse extracts the repository information from the following 24 | // string formats: "OWNER/REPO", "HOST/OWNER/REPO", and a full URL. 25 | // If the format does not specify a host, use the config to determine a host. 26 | func Parse(s string) (Repository, error) { 27 | var r Repository 28 | 29 | if git.IsURL(s) { 30 | u, err := git.ParseURL(s) 31 | if err != nil { 32 | return r, err 33 | } 34 | 35 | host, owner, name, err := git.RepoInfoFromURL(u) 36 | if err != nil { 37 | return r, err 38 | } 39 | 40 | r.Host = host 41 | r.Name = name 42 | r.Owner = owner 43 | 44 | return r, nil 45 | } 46 | 47 | parts := strings.SplitN(s, "/", 4) 48 | for _, p := range parts { 49 | if len(p) == 0 { 50 | return r, fmt.Errorf(`expected the "[HOST/]OWNER/REPO" format, got %q`, s) 51 | } 52 | } 53 | 54 | switch len(parts) { 55 | case 3: 56 | r.Host = parts[0] 57 | r.Owner = parts[1] 58 | r.Name = parts[2] 59 | return r, nil 60 | case 2: 61 | r.Host, _ = auth.DefaultHost() 62 | r.Owner = parts[0] 63 | r.Name = parts[1] 64 | return r, nil 65 | default: 66 | return r, fmt.Errorf(`expected the "[HOST/]OWNER/REPO" format, got %q`, s) 67 | } 68 | } 69 | 70 | // Parse extracts the repository information from the following 71 | // string formats: "OWNER/REPO", "HOST/OWNER/REPO", and a full URL. 72 | // If the format does not specify a host, use the host provided. 73 | func ParseWithHost(s, host string) (Repository, error) { 74 | var r Repository 75 | 76 | if git.IsURL(s) { 77 | u, err := git.ParseURL(s) 78 | if err != nil { 79 | return r, err 80 | } 81 | 82 | host, owner, name, err := git.RepoInfoFromURL(u) 83 | if err != nil { 84 | return r, err 85 | } 86 | 87 | r.Host = host 88 | r.Owner = owner 89 | r.Name = name 90 | 91 | return r, nil 92 | } 93 | 94 | parts := strings.SplitN(s, "/", 4) 95 | for _, p := range parts { 96 | if len(p) == 0 { 97 | return r, fmt.Errorf(`expected the "[HOST/]OWNER/REPO" format, got %q`, s) 98 | } 99 | } 100 | 101 | switch len(parts) { 102 | case 3: 103 | r.Host = parts[0] 104 | r.Owner = parts[1] 105 | r.Name = parts[2] 106 | return r, nil 107 | case 2: 108 | r.Host = host 109 | r.Owner = parts[0] 110 | r.Name = parts[1] 111 | return r, nil 112 | default: 113 | return r, fmt.Errorf(`expected the "[HOST/]OWNER/REPO" format, got %q`, s) 114 | } 115 | } 116 | 117 | // Current uses git remotes to determine the GitHub repository 118 | // the current directory is tracking. 119 | func Current() (Repository, error) { 120 | var r Repository 121 | 122 | override := os.Getenv("GH_REPO") 123 | if override != "" { 124 | return Parse(override) 125 | } 126 | 127 | remotes, err := git.Remotes() 128 | if err != nil { 129 | return r, err 130 | } 131 | if len(remotes) == 0 { 132 | return r, errors.New("unable to determine current repository, no git remotes configured for this repository") 133 | } 134 | 135 | translator := ssh.NewTranslator() 136 | for _, r := range remotes { 137 | if r.FetchURL != nil { 138 | r.FetchURL = translator.Translate(r.FetchURL) 139 | } 140 | if r.PushURL != nil { 141 | r.PushURL = translator.Translate(r.PushURL) 142 | } 143 | } 144 | 145 | hosts := auth.KnownHosts() 146 | 147 | filteredRemotes := remotes.FilterByHosts(hosts) 148 | if len(filteredRemotes) == 0 { 149 | return r, errors.New("unable to determine current repository, none of the git remotes configured for this repository point to a known GitHub host") 150 | } 151 | 152 | rem := filteredRemotes[0] 153 | r.Host = rem.Host 154 | r.Owner = rem.Owner 155 | r.Name = rem.Repo 156 | 157 | return r, nil 158 | } 159 | -------------------------------------------------------------------------------- /pkg/repository/repository_test.go: -------------------------------------------------------------------------------- 1 | package repository 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cli/go-gh/v2/internal/testutils" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestParse(t *testing.T) { 11 | testutils.StubConfig(t, "") 12 | 13 | tests := []struct { 14 | name string 15 | input string 16 | hostOverride string 17 | wantOwner string 18 | wantName string 19 | wantHost string 20 | wantErr string 21 | }{ 22 | { 23 | name: "OWNER/REPO combo", 24 | input: "OWNER/REPO", 25 | wantHost: "github.com", 26 | wantOwner: "OWNER", 27 | wantName: "REPO", 28 | }, 29 | { 30 | name: "too few elements", 31 | input: "OWNER", 32 | wantErr: `expected the "[HOST/]OWNER/REPO" format, got "OWNER"`, 33 | }, 34 | { 35 | name: "too many elements", 36 | input: "a/b/c/d", 37 | wantErr: `expected the "[HOST/]OWNER/REPO" format, got "a/b/c/d"`, 38 | }, 39 | { 40 | name: "blank value", 41 | input: "a/", 42 | wantErr: `expected the "[HOST/]OWNER/REPO" format, got "a/"`, 43 | }, 44 | { 45 | name: "with hostname", 46 | input: "example.org/OWNER/REPO", 47 | wantHost: "example.org", 48 | wantOwner: "OWNER", 49 | wantName: "REPO", 50 | }, 51 | { 52 | name: "full URL", 53 | input: "https://example.org/OWNER/REPO.git", 54 | wantHost: "example.org", 55 | wantOwner: "OWNER", 56 | wantName: "REPO", 57 | }, 58 | { 59 | name: "SSH URL", 60 | input: "git@example.org:OWNER/REPO.git", 61 | wantHost: "example.org", 62 | wantOwner: "OWNER", 63 | wantName: "REPO", 64 | }, 65 | { 66 | name: "OWNER/REPO with default host override", 67 | input: "OWNER/REPO", 68 | hostOverride: "override.com", 69 | wantHost: "override.com", 70 | wantOwner: "OWNER", 71 | wantName: "REPO", 72 | }, 73 | { 74 | name: "HOST/OWNER/REPO with default host override", 75 | input: "example.com/OWNER/REPO", 76 | hostOverride: "override.com", 77 | wantHost: "example.com", 78 | wantOwner: "OWNER", 79 | wantName: "REPO", 80 | }, 81 | } 82 | for _, tt := range tests { 83 | t.Run(tt.name, func(t *testing.T) { 84 | t.Setenv("GH_CONFIG_DIR", "nonexistant") 85 | if tt.hostOverride != "" { 86 | t.Setenv("GH_HOST", tt.hostOverride) 87 | } 88 | r, err := Parse(tt.input) 89 | if tt.wantErr != "" { 90 | assert.EqualError(t, err, tt.wantErr) 91 | return 92 | } 93 | assert.NoError(t, err) 94 | assert.Equal(t, tt.wantHost, r.Host) 95 | assert.Equal(t, tt.wantOwner, r.Owner) 96 | assert.Equal(t, tt.wantName, r.Name) 97 | }) 98 | } 99 | } 100 | 101 | func TestParse_hostFromConfig(t *testing.T) { 102 | var cfgStr = ` 103 | hosts: 104 | enterprise.com: 105 | user: user2 106 | oauth_token: yyyyyyyyyyyyyyyyyyyy 107 | git_protocol: https 108 | ` 109 | testutils.StubConfig(t, cfgStr) 110 | r, err := Parse("OWNER/REPO") 111 | assert.NoError(t, err) 112 | assert.Equal(t, "enterprise.com", r.Host) 113 | assert.Equal(t, "OWNER", r.Owner) 114 | assert.Equal(t, "REPO", r.Name) 115 | } 116 | 117 | func TestParseWithHost(t *testing.T) { 118 | tests := []struct { 119 | name string 120 | input string 121 | host string 122 | wantOwner string 123 | wantName string 124 | wantHost string 125 | wantErr string 126 | }{ 127 | { 128 | name: "OWNER/REPO combo", 129 | input: "OWNER/REPO", 130 | host: "github.com", 131 | wantHost: "github.com", 132 | wantOwner: "OWNER", 133 | wantName: "REPO", 134 | }, 135 | { 136 | name: "too few elements", 137 | input: "OWNER", 138 | host: "github.com", 139 | wantErr: `expected the "[HOST/]OWNER/REPO" format, got "OWNER"`, 140 | }, 141 | { 142 | name: "too many elements", 143 | input: "a/b/c/d", 144 | host: "github.com", 145 | wantErr: `expected the "[HOST/]OWNER/REPO" format, got "a/b/c/d"`, 146 | }, 147 | { 148 | name: "blank value", 149 | input: "a/", 150 | host: "github.com", 151 | wantErr: `expected the "[HOST/]OWNER/REPO" format, got "a/"`, 152 | }, 153 | { 154 | name: "with hostname", 155 | input: "example.org/OWNER/REPO", 156 | host: "github.com", 157 | wantHost: "example.org", 158 | wantOwner: "OWNER", 159 | wantName: "REPO", 160 | }, 161 | { 162 | name: "full URL", 163 | input: "https://example.org/OWNER/REPO.git", 164 | host: "github.com", 165 | wantHost: "example.org", 166 | wantOwner: "OWNER", 167 | wantName: "REPO", 168 | }, 169 | { 170 | name: "SSH URL", 171 | input: "git@example.org:OWNER/REPO.git", 172 | host: "github.com", 173 | wantHost: "example.org", 174 | wantOwner: "OWNER", 175 | wantName: "REPO", 176 | }, 177 | } 178 | for _, tt := range tests { 179 | t.Run(tt.name, func(t *testing.T) { 180 | r, err := ParseWithHost(tt.input, tt.host) 181 | if tt.wantErr != "" { 182 | assert.EqualError(t, err, tt.wantErr) 183 | return 184 | } 185 | assert.NoError(t, err) 186 | assert.Equal(t, tt.wantHost, r.Host) 187 | assert.Equal(t, tt.wantOwner, r.Owner) 188 | assert.Equal(t, tt.wantName, r.Name) 189 | }) 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /pkg/ssh/ssh.go: -------------------------------------------------------------------------------- 1 | // Package ssh resolves local SSH hostname aliases. 2 | package ssh 3 | 4 | import ( 5 | "bufio" 6 | "net/url" 7 | "os/exec" 8 | "strings" 9 | "sync" 10 | 11 | "github.com/cli/safeexec" 12 | ) 13 | 14 | type Translator struct { 15 | cacheMap map[string]string 16 | cacheMu sync.RWMutex 17 | sshPath string 18 | sshPathErr error 19 | sshPathMu sync.Mutex 20 | 21 | lookPath func(string) (string, error) 22 | newCommand func(string, ...string) *exec.Cmd 23 | } 24 | 25 | // NewTranslator initializes a new Translator instance. 26 | func NewTranslator() *Translator { 27 | return &Translator{} 28 | } 29 | 30 | // Translate applies applicable SSH hostname aliases to the specified URL and returns the resulting URL. 31 | func (t *Translator) Translate(u *url.URL) *url.URL { 32 | if u.Scheme != "ssh" { 33 | return u 34 | } 35 | resolvedHost, err := t.resolve(u.Hostname()) 36 | if err != nil { 37 | return u 38 | } 39 | if strings.EqualFold(resolvedHost, "ssh.github.com") { 40 | resolvedHost = "github.com" 41 | } 42 | newURL, _ := url.Parse(u.String()) 43 | newURL.Host = resolvedHost 44 | return newURL 45 | } 46 | 47 | func (t *Translator) resolve(hostname string) (string, error) { 48 | t.cacheMu.RLock() 49 | cached, cacheFound := t.cacheMap[strings.ToLower(hostname)] 50 | t.cacheMu.RUnlock() 51 | if cacheFound { 52 | return cached, nil 53 | } 54 | 55 | var sshPath string 56 | t.sshPathMu.Lock() 57 | if t.sshPath == "" && t.sshPathErr == nil { 58 | lookPath := t.lookPath 59 | if lookPath == nil { 60 | lookPath = safeexec.LookPath 61 | } 62 | t.sshPath, t.sshPathErr = lookPath("ssh") 63 | } 64 | if t.sshPathErr != nil { 65 | defer t.sshPathMu.Unlock() 66 | return t.sshPath, t.sshPathErr 67 | } 68 | sshPath = t.sshPath 69 | t.sshPathMu.Unlock() 70 | 71 | t.cacheMu.Lock() 72 | defer t.cacheMu.Unlock() 73 | 74 | newCommand := t.newCommand 75 | if newCommand == nil { 76 | newCommand = exec.Command 77 | } 78 | sshCmd := newCommand(sshPath, "-G", hostname) 79 | stdout, err := sshCmd.StdoutPipe() 80 | if err != nil { 81 | return "", err 82 | } 83 | 84 | if err := sshCmd.Start(); err != nil { 85 | return "", err 86 | } 87 | 88 | var resolvedHost string 89 | s := bufio.NewScanner(stdout) 90 | for s.Scan() { 91 | line := s.Text() 92 | parts := strings.SplitN(line, " ", 2) 93 | if len(parts) == 2 && parts[0] == "hostname" { 94 | resolvedHost = parts[1] 95 | } 96 | } 97 | 98 | err = sshCmd.Wait() 99 | if err != nil || resolvedHost == "" { 100 | // handle failures by returning the original hostname unchanged 101 | resolvedHost = hostname 102 | } 103 | 104 | if t.cacheMap == nil { 105 | t.cacheMap = map[string]string{} 106 | } 107 | t.cacheMap[strings.ToLower(hostname)] = resolvedHost 108 | return resolvedHost, nil 109 | } 110 | -------------------------------------------------------------------------------- /pkg/ssh/ssh_test.go: -------------------------------------------------------------------------------- 1 | package ssh 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/url" 7 | "os" 8 | "os/exec" 9 | "testing" 10 | 11 | "github.com/MakeNowJust/heredoc" 12 | "github.com/cli/safeexec" 13 | ) 14 | 15 | func TestTranslator(t *testing.T) { 16 | if _, err := safeexec.LookPath("ssh"); err != nil { 17 | t.Skip("no ssh found on system") 18 | } 19 | 20 | tests := []struct { 21 | name string 22 | sshConfig string 23 | arg string 24 | want string 25 | }{ 26 | { 27 | name: "translate SSH URL", 28 | sshConfig: heredoc.Doc(` 29 | Host github-* 30 | Hostname github.com 31 | `), 32 | arg: "ssh://git@github-foo/owner/repo.git", 33 | want: "ssh://git@github.com/owner/repo.git", 34 | }, 35 | { 36 | name: "does not translate HTTPS URL", 37 | sshConfig: heredoc.Doc(` 38 | Host github-* 39 | Hostname github.com 40 | `), 41 | arg: "https://github-foo/owner/repo.git", 42 | want: "https://github-foo/owner/repo.git", 43 | }, 44 | { 45 | name: "treats ssh.github.com as github.com", 46 | sshConfig: heredoc.Doc(` 47 | Host github.com 48 | Hostname ssh.github.com 49 | `), 50 | arg: "ssh://git@github.com/owner/repo.git", 51 | want: "ssh://git@github.com/owner/repo.git", 52 | }, 53 | } 54 | for _, tt := range tests { 55 | t.Run(tt.name, func(t *testing.T) { 56 | f, err := os.CreateTemp("", "ssh-config.*") 57 | if err != nil { 58 | t.Fatalf("error creating file: %v", err) 59 | } 60 | _, err = f.WriteString(tt.sshConfig) 61 | _ = f.Close() 62 | if err != nil { 63 | t.Fatalf("error writing ssh config: %v", err) 64 | } 65 | 66 | tr := &Translator{ 67 | newCommand: func(exe string, args ...string) *exec.Cmd { 68 | args = append([]string{"-F", f.Name()}, args...) 69 | return exec.Command(exe, args...) 70 | }, 71 | } 72 | u, err := url.Parse(tt.arg) 73 | if err != nil { 74 | t.Fatalf("error parsing URL: %v", err) 75 | } 76 | res := tr.Translate(u) 77 | if got := res.String(); got != tt.want { 78 | t.Errorf("expected %q, got %q", tt.want, got) 79 | } 80 | }) 81 | } 82 | } 83 | 84 | func TestHelperProcess(t *testing.T) { 85 | if os.Getenv("GH_WANT_HELPER_PROCESS") != "1" { 86 | return 87 | } 88 | if err := func(args []string) error { 89 | if len(args) < 3 || args[2] == "error" { 90 | return errors.New("fatal") 91 | } 92 | if args[2] == "empty.io" { 93 | return nil 94 | } 95 | fmt.Fprintf(os.Stdout, "hostname %s\n", args[2]) 96 | return nil 97 | }(os.Args[3:]); err != nil { 98 | fmt.Fprintln(os.Stderr, err) 99 | os.Exit(1) 100 | } 101 | os.Exit(0) 102 | } 103 | 104 | func TestTranslator_caching(t *testing.T) { 105 | countLookPath := 0 106 | countNewCommand := 0 107 | tr := &Translator{ 108 | lookPath: func(s string) (string, error) { 109 | countLookPath++ 110 | return "/path/to/ssh", nil 111 | }, 112 | newCommand: func(exe string, args ...string) *exec.Cmd { 113 | args = append([]string{"-test.run=TestHelperProcess", "--", exe}, args...) 114 | c := exec.Command(os.Args[0], args...) 115 | c.Env = []string{"GH_WANT_HELPER_PROCESS=1"} 116 | countNewCommand++ 117 | return c 118 | }, 119 | } 120 | 121 | tests := []struct { 122 | input string 123 | result string 124 | }{ 125 | { 126 | input: "ssh://github1.com/owner/repo.git", 127 | result: "github1.com", 128 | }, 129 | { 130 | input: "ssh://github2.com/owner/repo.git", 131 | result: "github2.com", 132 | }, 133 | { 134 | input: "ssh://empty.io/owner/repo.git", 135 | result: "empty.io", 136 | }, 137 | { 138 | input: "ssh://error/owner/repo.git", 139 | result: "error", 140 | }, 141 | } 142 | for _, tt := range tests { 143 | t.Run(tt.input, func(t *testing.T) { 144 | u, err := url.Parse(tt.input) 145 | if err != nil { 146 | t.Fatalf("error parsing URL: %v", err) 147 | } 148 | if res := tr.Translate(u); res.Host != tt.result { 149 | t.Errorf("expected github.com, got: %q", res.Host) 150 | } 151 | if res := tr.Translate(u); res.Host != tt.result { 152 | t.Errorf("expected github.com, got: %q (second call)", res.Host) 153 | } 154 | }) 155 | } 156 | 157 | if countLookPath != 1 { 158 | t.Errorf("expected lookPath to happen 1 time; actual: %d", countLookPath) 159 | } 160 | if countNewCommand != len(tests) { 161 | t.Errorf("expected ssh command to shell out %d times; actual: %d", len(tests), countNewCommand) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /pkg/tableprinter/table.go: -------------------------------------------------------------------------------- 1 | // Package tableprinter facilitates rendering column-formatted data to a terminal and TSV-formatted data to 2 | // a script or a file. It is suitable for presenting tabular data in a human-readable format that is 3 | // guaranteed to fit within the given viewport, while at the same time offering the same data in a 4 | // machine-readable format for scripts. 5 | package tableprinter 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | 11 | "github.com/cli/go-gh/v2/pkg/text" 12 | ) 13 | 14 | type fieldOption func(*tableField) 15 | 16 | type TablePrinter interface { 17 | AddHeader([]string, ...fieldOption) 18 | AddField(string, ...fieldOption) 19 | EndRow() 20 | Render() error 21 | } 22 | 23 | // WithTruncate overrides the truncation function for the field. The function should transform a string 24 | // argument into a string that fits within the given display width. The default behavior is to truncate the 25 | // value by adding "..." in the end. The truncation function will be called before padding and coloring. 26 | // Pass nil to disable truncation for this value. 27 | func WithTruncate(fn func(int, string) string) fieldOption { 28 | return func(f *tableField) { 29 | f.truncateFunc = fn 30 | } 31 | } 32 | 33 | // WithPadding overrides the padding function for the field. The function should transform a string argument 34 | // into a string that is padded to fit within the given display width. The default behavior is to pad fields 35 | // with spaces except for the last field. The padding function will be called after truncation and before coloring. 36 | // Pass nil to disable padding for this value. 37 | func WithPadding(fn func(int, string) string) fieldOption { 38 | return func(f *tableField) { 39 | f.paddingFunc = fn 40 | } 41 | } 42 | 43 | // WithColor sets the color function for the field. The function should transform a string value by wrapping 44 | // it in ANSI escape codes. The color function will not be used if the table was initialized in non-terminal mode. 45 | // The color function will be called before truncation and padding. 46 | func WithColor(fn func(string) string) fieldOption { 47 | return func(f *tableField) { 48 | f.colorFunc = fn 49 | } 50 | } 51 | 52 | // New initializes a table printer with terminal mode and terminal width. When terminal mode is enabled, the 53 | // output will be human-readable, column-formatted to fit available width, and rendered with color support. 54 | // In non-terminal mode, the output is tab-separated and all truncation of values is disabled. 55 | func New(w io.Writer, isTTY bool, maxWidth int) TablePrinter { 56 | if isTTY { 57 | return &ttyTablePrinter{ 58 | out: w, 59 | maxWidth: maxWidth, 60 | } 61 | } 62 | 63 | return &tsvTablePrinter{ 64 | out: w, 65 | } 66 | } 67 | 68 | type tableField struct { 69 | text string 70 | truncateFunc func(int, string) string 71 | paddingFunc func(int, string) string 72 | colorFunc func(string) string 73 | } 74 | 75 | type ttyTablePrinter struct { 76 | out io.Writer 77 | maxWidth int 78 | hasHeaders bool 79 | rows [][]tableField 80 | } 81 | 82 | func (t *ttyTablePrinter) AddHeader(columns []string, opts ...fieldOption) { 83 | if t.hasHeaders { 84 | return 85 | } 86 | 87 | t.hasHeaders = true 88 | for _, column := range columns { 89 | t.AddField(column, opts...) 90 | } 91 | t.EndRow() 92 | } 93 | 94 | func (t *ttyTablePrinter) AddField(s string, opts ...fieldOption) { 95 | if t.rows == nil { 96 | t.rows = make([][]tableField, 1) 97 | } 98 | rowI := len(t.rows) - 1 99 | field := tableField{ 100 | text: s, 101 | truncateFunc: text.Truncate, 102 | } 103 | for _, opt := range opts { 104 | opt(&field) 105 | } 106 | t.rows[rowI] = append(t.rows[rowI], field) 107 | } 108 | 109 | func (t *ttyTablePrinter) EndRow() { 110 | t.rows = append(t.rows, []tableField{}) 111 | } 112 | 113 | func (t *ttyTablePrinter) Render() error { 114 | if len(t.rows) == 0 { 115 | return nil 116 | } 117 | 118 | delim := " " 119 | numCols := len(t.rows[0]) 120 | colWidths := t.calculateColumnWidths(len(delim)) 121 | 122 | for _, row := range t.rows { 123 | for col, field := range row { 124 | if col > 0 { 125 | _, err := fmt.Fprint(t.out, delim) 126 | if err != nil { 127 | return err 128 | } 129 | } 130 | truncVal := field.text 131 | if field.truncateFunc != nil { 132 | truncVal = field.truncateFunc(colWidths[col], field.text) 133 | } 134 | if field.paddingFunc != nil { 135 | truncVal = field.paddingFunc(colWidths[col], truncVal) 136 | } else if col < numCols-1 { 137 | truncVal = text.PadRight(colWidths[col], truncVal) 138 | } 139 | if field.colorFunc != nil { 140 | truncVal = field.colorFunc(truncVal) 141 | } 142 | _, err := fmt.Fprint(t.out, truncVal) 143 | if err != nil { 144 | return err 145 | } 146 | } 147 | if len(row) > 0 { 148 | _, err := fmt.Fprint(t.out, "\n") 149 | if err != nil { 150 | return err 151 | } 152 | } 153 | } 154 | return nil 155 | } 156 | 157 | func (t *ttyTablePrinter) calculateColumnWidths(delimSize int) []int { 158 | numCols := len(t.rows[0]) 159 | maxColWidths := make([]int, numCols) 160 | colWidths := make([]int, numCols) 161 | 162 | for _, row := range t.rows { 163 | for col, field := range row { 164 | w := text.DisplayWidth(field.text) 165 | if w > maxColWidths[col] { 166 | maxColWidths[col] = w 167 | } 168 | // if this field has disabled truncating, ensure that the column is wide enough 169 | if field.truncateFunc == nil && w > colWidths[col] { 170 | colWidths[col] = w 171 | } 172 | } 173 | } 174 | 175 | availWidth := func() int { 176 | setWidths := 0 177 | for col := 0; col < numCols; col++ { 178 | setWidths += colWidths[col] 179 | } 180 | return t.maxWidth - delimSize*(numCols-1) - setWidths 181 | } 182 | numFixedCols := func() int { 183 | fixedCols := 0 184 | for col := 0; col < numCols; col++ { 185 | if colWidths[col] > 0 { 186 | fixedCols++ 187 | } 188 | } 189 | return fixedCols 190 | } 191 | 192 | // set the widths of short columns 193 | if w := availWidth(); w > 0 { 194 | if numFlexColumns := numCols - numFixedCols(); numFlexColumns > 0 { 195 | perColumn := w / numFlexColumns 196 | for col := 0; col < numCols; col++ { 197 | if max := maxColWidths[col]; max < perColumn { 198 | colWidths[col] = max 199 | } 200 | } 201 | } 202 | } 203 | 204 | // truncate long columns to the remaining available width 205 | if numFlexColumns := numCols - numFixedCols(); numFlexColumns > 0 { 206 | perColumn := availWidth() / numFlexColumns 207 | for col := 0; col < numCols; col++ { 208 | if colWidths[col] == 0 { 209 | if max := maxColWidths[col]; max < perColumn { 210 | colWidths[col] = max 211 | } else if perColumn > 0 { 212 | colWidths[col] = perColumn 213 | } 214 | } 215 | } 216 | } 217 | 218 | // add the remainder to truncated columns 219 | if w := availWidth(); w > 0 { 220 | for col := 0; col < numCols; col++ { 221 | d := maxColWidths[col] - colWidths[col] 222 | toAdd := w 223 | if d < toAdd { 224 | toAdd = d 225 | } 226 | colWidths[col] += toAdd 227 | w -= toAdd 228 | if w <= 0 { 229 | break 230 | } 231 | } 232 | } 233 | 234 | return colWidths 235 | } 236 | 237 | type tsvTablePrinter struct { 238 | out io.Writer 239 | currentCol int 240 | } 241 | 242 | func (t *tsvTablePrinter) AddHeader(_ []string, _ ...fieldOption) {} 243 | 244 | func (t *tsvTablePrinter) AddField(text string, _ ...fieldOption) { 245 | if t.currentCol > 0 { 246 | fmt.Fprint(t.out, "\t") 247 | } 248 | fmt.Fprint(t.out, text) 249 | t.currentCol++ 250 | } 251 | 252 | func (t *tsvTablePrinter) EndRow() { 253 | fmt.Fprint(t.out, "\n") 254 | t.currentCol = 0 255 | } 256 | 257 | func (t *tsvTablePrinter) Render() error { 258 | return nil 259 | } 260 | -------------------------------------------------------------------------------- /pkg/tableprinter/table_test.go: -------------------------------------------------------------------------------- 1 | package tableprinter 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "log" 7 | "os" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/MakeNowJust/heredoc" 12 | ) 13 | 14 | func ExampleTablePrinter() { 15 | // information about the terminal can be obtained using the [pkg/term] package 16 | isTTY := true 17 | termWidth := 14 18 | red := func(s string) string { 19 | return "\x1b[31m" + s + "\x1b[m" 20 | } 21 | 22 | t := New(os.Stdout, isTTY, termWidth) 23 | t.AddField("9", WithTruncate(nil)) 24 | t.AddField("hello") 25 | t.EndRow() 26 | t.AddField("10", WithTruncate(nil)) 27 | t.AddField("long description", WithColor(red)) 28 | t.EndRow() 29 | if err := t.Render(); err != nil { 30 | log.Fatal(err) 31 | } 32 | // stdout now contains: 33 | // 9 hello 34 | // 10 long de... 35 | } 36 | 37 | func Test_ttyTablePrinter_autoTruncate(t *testing.T) { 38 | buf := bytes.Buffer{} 39 | tp := New(&buf, true, 5) 40 | 41 | tp.AddField("1") 42 | tp.AddField("hello") 43 | tp.EndRow() 44 | tp.AddField("2") 45 | tp.AddField("world") 46 | tp.EndRow() 47 | 48 | err := tp.Render() 49 | if err != nil { 50 | t.Fatalf("unexpected error: %v", err) 51 | } 52 | 53 | expected := "1 he\n2 wo\n" 54 | if buf.String() != expected { 55 | t.Errorf("expected: %q, got: %q", expected, buf.String()) 56 | } 57 | } 58 | 59 | func Test_ttyTablePrinter_WithTruncate(t *testing.T) { 60 | buf := bytes.Buffer{} 61 | tp := New(&buf, true, 15) 62 | 63 | tp.AddField("long SHA", WithTruncate(nil)) 64 | tp.AddField("hello") 65 | tp.EndRow() 66 | tp.AddField("another SHA", WithTruncate(nil)) 67 | tp.AddField("world") 68 | tp.EndRow() 69 | 70 | err := tp.Render() 71 | if err != nil { 72 | t.Fatalf("unexpected error: %v", err) 73 | } 74 | 75 | expected := "long SHA he\nanother SHA wo\n" 76 | if buf.String() != expected { 77 | t.Errorf("expected: %q, got: %q", expected, buf.String()) 78 | } 79 | } 80 | 81 | func Test_ttyTablePrinter_AddHeader(t *testing.T) { 82 | buf := bytes.Buffer{} 83 | tp := New(&buf, true, 80) 84 | 85 | tp.AddHeader([]string{"ONE", "TWO", "THREE"}, WithColor(func(s string) string { 86 | return fmt.Sprintf("\x1b[4m%s\x1b[m", s) 87 | })) 88 | // Subsequent calls to AddHeader are ignored. 89 | tp.AddHeader([]string{"SHOULD", "NOT", "EXIST"}) 90 | 91 | tp.AddField("hello") 92 | tp.AddField("beautiful") 93 | tp.AddField("people") 94 | tp.EndRow() 95 | 96 | err := tp.Render() 97 | if err != nil { 98 | t.Fatalf("unexpected error: %v", err) 99 | } 100 | 101 | expected := heredoc.Docf(` 102 | %[1]s[4mONE %[1]s[m %[1]s[4mTWO %[1]s[m %[1]s[4mTHREE%[1]s[m 103 | hello beautiful people 104 | `, "\x1b") 105 | if buf.String() != expected { 106 | t.Errorf("expected: %q, got: %q", expected, buf.String()) 107 | } 108 | } 109 | 110 | func Test_ttyTablePrinter_WithPadding(t *testing.T) { 111 | buf := bytes.Buffer{} 112 | tp := New(&buf, true, 80) 113 | 114 | // Center the headers. 115 | tp.AddHeader([]string{"A", "B", "C"}, WithPadding(func(width int, s string) string { 116 | left := (width - len(s)) / 2 117 | return strings.Repeat(" ", left) + s + strings.Repeat(" ", width-left-len(s)) 118 | })) 119 | 120 | tp.AddField("hello") 121 | tp.AddField("beautiful") 122 | tp.AddField("people") 123 | tp.EndRow() 124 | 125 | err := tp.Render() 126 | if err != nil { 127 | t.Fatalf("unexpected error: %v", err) 128 | } 129 | 130 | expected := heredoc.Doc(` 131 | A B C 132 | hello beautiful people 133 | `) 134 | if buf.String() != expected { 135 | t.Errorf("expected: %q, got: %q", expected, buf.String()) 136 | } 137 | } 138 | 139 | func Test_tsvTablePrinter(t *testing.T) { 140 | buf := bytes.Buffer{} 141 | tp := New(&buf, false, 0) 142 | 143 | tp.AddField("1") 144 | tp.AddField("hello") 145 | tp.EndRow() 146 | tp.AddField("2") 147 | tp.AddField("world") 148 | tp.EndRow() 149 | 150 | err := tp.Render() 151 | if err != nil { 152 | t.Fatalf("unexpected error: %v", err) 153 | } 154 | 155 | expected := "1\thello\n2\tworld\n" 156 | if buf.String() != expected { 157 | t.Errorf("expected: %q, got: %q", expected, buf.String()) 158 | } 159 | } 160 | 161 | func Test_tsvTablePrinter_AddHeader(t *testing.T) { 162 | buf := bytes.Buffer{} 163 | tp := New(&buf, false, 0) 164 | 165 | // Headers are not output in TSV output. 166 | tp.AddHeader([]string{"ONE", "TWO", "THREE"}) 167 | 168 | tp.AddField("hello") 169 | tp.AddField("beautiful") 170 | tp.AddField("people") 171 | tp.EndRow() 172 | tp.AddField("1") 173 | tp.AddField("2") 174 | tp.AddField("3") 175 | tp.EndRow() 176 | 177 | err := tp.Render() 178 | if err != nil { 179 | t.Fatalf("unexpected error: %v", err) 180 | } 181 | 182 | expected := "hello\tbeautiful\tpeople\n1\t2\t3\n" 183 | if buf.String() != expected { 184 | t.Errorf("expected: %q, got: %q", expected, buf.String()) 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /pkg/term/console.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package term 5 | 6 | import ( 7 | "errors" 8 | "os" 9 | ) 10 | 11 | func enableVirtualTerminalProcessing(f *os.File) error { 12 | return errors.New("not implemented") 13 | } 14 | 15 | func openTTY() (*os.File, error) { 16 | return os.Open("/dev/tty") 17 | } 18 | -------------------------------------------------------------------------------- /pkg/term/console_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package term 5 | 6 | import ( 7 | "os" 8 | 9 | "golang.org/x/sys/windows" 10 | ) 11 | 12 | func enableVirtualTerminalProcessing(f *os.File) error { 13 | stdout := windows.Handle(f.Fd()) 14 | 15 | var originalMode uint32 16 | windows.GetConsoleMode(stdout, &originalMode) 17 | return windows.SetConsoleMode(stdout, originalMode|windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING) 18 | } 19 | 20 | func openTTY() (*os.File, error) { 21 | return os.Open("CONOUT$") 22 | } 23 | -------------------------------------------------------------------------------- /pkg/term/env.go: -------------------------------------------------------------------------------- 1 | // Package term provides information about the terminal that the current process is connected to (if any), 2 | // for example measuring the dimensions of the terminal and inspecting whether it's safe to output color. 3 | package term 4 | 5 | import ( 6 | "io" 7 | "os" 8 | "strconv" 9 | "strings" 10 | 11 | "github.com/muesli/termenv" 12 | "golang.org/x/term" 13 | ) 14 | 15 | // Term represents information about the terminal that a process is connected to. 16 | type Term struct { 17 | in *os.File 18 | out *os.File 19 | errOut *os.File 20 | isTTY bool 21 | colorEnabled bool 22 | is256enabled bool 23 | hasTrueColor bool 24 | width int 25 | widthPercent int 26 | } 27 | 28 | // FromEnv initializes a Term from [os.Stdout] and environment variables: 29 | // - GH_FORCE_TTY 30 | // - NO_COLOR 31 | // - CLICOLOR 32 | // - CLICOLOR_FORCE 33 | // - TERM 34 | // - COLORTERM 35 | func FromEnv() Term { 36 | var stdoutIsTTY bool 37 | var isColorEnabled bool 38 | var termWidthOverride int 39 | var termWidthPercentage int 40 | 41 | spec := os.Getenv("GH_FORCE_TTY") 42 | if spec != "" { 43 | stdoutIsTTY = true 44 | isColorEnabled = !IsColorDisabled() 45 | 46 | if w, err := strconv.Atoi(spec); err == nil { 47 | termWidthOverride = w 48 | } else if strings.HasSuffix(spec, "%") { 49 | if p, err := strconv.Atoi(spec[:len(spec)-1]); err == nil { 50 | termWidthPercentage = p 51 | } 52 | } 53 | } else { 54 | stdoutIsTTY = IsTerminal(os.Stdout) 55 | isColorEnabled = IsColorForced() || (!IsColorDisabled() && stdoutIsTTY) 56 | } 57 | 58 | isVirtualTerminal := false 59 | if stdoutIsTTY { 60 | if err := enableVirtualTerminalProcessing(os.Stdout); err == nil { 61 | isVirtualTerminal = true 62 | } 63 | } 64 | 65 | return Term{ 66 | in: os.Stdin, 67 | out: os.Stdout, 68 | errOut: os.Stderr, 69 | isTTY: stdoutIsTTY, 70 | colorEnabled: isColorEnabled, 71 | is256enabled: isVirtualTerminal || is256ColorSupported(), 72 | hasTrueColor: isVirtualTerminal || isTrueColorSupported(), 73 | width: termWidthOverride, 74 | widthPercent: termWidthPercentage, 75 | } 76 | } 77 | 78 | // In is the reader reading from standard input. 79 | func (t Term) In() io.Reader { 80 | return t.in 81 | } 82 | 83 | // Out is the writer writing to standard output. 84 | func (t Term) Out() io.Writer { 85 | return t.out 86 | } 87 | 88 | // ErrOut is the writer writing to standard error. 89 | func (t Term) ErrOut() io.Writer { 90 | return t.errOut 91 | } 92 | 93 | // IsTerminalOutput returns true if standard output is connected to a terminal. 94 | func (t Term) IsTerminalOutput() bool { 95 | return t.isTTY 96 | } 97 | 98 | // IsColorEnabled reports whether it's safe to output ANSI color sequences, depending on IsTerminalOutput 99 | // and environment variables. 100 | func (t Term) IsColorEnabled() bool { 101 | return t.colorEnabled 102 | } 103 | 104 | // Is256ColorSupported reports whether the terminal advertises ANSI 256 color codes. 105 | func (t Term) Is256ColorSupported() bool { 106 | return t.is256enabled 107 | } 108 | 109 | // IsTrueColorSupported reports whether the terminal advertises support for ANSI true color sequences. 110 | func (t Term) IsTrueColorSupported() bool { 111 | return t.hasTrueColor 112 | } 113 | 114 | // Size returns the width and height of the terminal that the current process is attached to. 115 | // In case of errors, the numeric values returned are -1. 116 | func (t Term) Size() (int, int, error) { 117 | if t.width > 0 { 118 | return t.width, -1, nil 119 | } 120 | 121 | ttyOut := t.out 122 | if ttyOut == nil || !IsTerminal(ttyOut) { 123 | if f, err := openTTY(); err == nil { 124 | defer f.Close() 125 | ttyOut = f 126 | } else { 127 | return -1, -1, err 128 | } 129 | } 130 | 131 | width, height, err := terminalSize(ttyOut) 132 | if err == nil && t.widthPercent > 0 { 133 | return int(float64(width) * float64(t.widthPercent) / 100), height, nil 134 | } 135 | 136 | return width, height, err 137 | } 138 | 139 | // Theme returns the theme of the terminal by analyzing the background color of the terminal. 140 | func (t Term) Theme() string { 141 | if !t.IsColorEnabled() { 142 | return "none" 143 | } 144 | if termenv.HasDarkBackground() { 145 | return "dark" 146 | } 147 | return "light" 148 | } 149 | 150 | // IsTerminal reports whether a file descriptor is connected to a terminal. 151 | func IsTerminal(f *os.File) bool { 152 | return term.IsTerminal(int(f.Fd())) 153 | } 154 | 155 | func terminalSize(f *os.File) (int, int, error) { 156 | return term.GetSize(int(f.Fd())) 157 | } 158 | 159 | // IsColorDisabled returns true if environment variables NO_COLOR or CLICOLOR prohibit usage of color codes 160 | // in terminal output. 161 | func IsColorDisabled() bool { 162 | return os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" 163 | } 164 | 165 | // IsColorForced returns true if environment variable CLICOLOR_FORCE is set to force colored terminal output. 166 | func IsColorForced() bool { 167 | return os.Getenv("CLICOLOR_FORCE") != "" && os.Getenv("CLICOLOR_FORCE") != "0" 168 | } 169 | 170 | func is256ColorSupported() bool { 171 | return isTrueColorSupported() || 172 | strings.Contains(os.Getenv("TERM"), "256") || 173 | strings.Contains(os.Getenv("COLORTERM"), "256") 174 | } 175 | 176 | func isTrueColorSupported() bool { 177 | term := os.Getenv("TERM") 178 | colorterm := os.Getenv("COLORTERM") 179 | 180 | return strings.Contains(term, "24bit") || 181 | strings.Contains(term, "truecolor") || 182 | strings.Contains(colorterm, "24bit") || 183 | strings.Contains(colorterm, "truecolor") 184 | } 185 | -------------------------------------------------------------------------------- /pkg/term/env_test.go: -------------------------------------------------------------------------------- 1 | // Package term provides information about the terminal that the current process is connected to (if any), 2 | // for example measuring the dimensions of the terminal and inspecting whether it's safe to output color. 3 | package term 4 | 5 | import ( 6 | "testing" 7 | ) 8 | 9 | func TestFromEnv(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | env map[string]string 13 | wantTerminal bool 14 | wantColor bool 15 | want256Color bool 16 | wantTrueColor bool 17 | }{ 18 | { 19 | name: "default", 20 | env: map[string]string{ 21 | "GH_FORCE_TTY": "", 22 | "CLICOLOR": "", 23 | "CLICOLOR_FORCE": "", 24 | "NO_COLOR": "", 25 | "TERM": "", 26 | "COLORTERM": "", 27 | }, 28 | wantTerminal: false, 29 | wantColor: false, 30 | want256Color: false, 31 | wantTrueColor: false, 32 | }, 33 | { 34 | name: "force color", 35 | env: map[string]string{ 36 | "GH_FORCE_TTY": "", 37 | "CLICOLOR": "", 38 | "CLICOLOR_FORCE": "1", 39 | "NO_COLOR": "", 40 | "TERM": "", 41 | "COLORTERM": "", 42 | }, 43 | wantTerminal: false, 44 | wantColor: true, 45 | want256Color: false, 46 | wantTrueColor: false, 47 | }, 48 | { 49 | name: "force tty", 50 | env: map[string]string{ 51 | "GH_FORCE_TTY": "true", 52 | "CLICOLOR": "", 53 | "CLICOLOR_FORCE": "", 54 | "NO_COLOR": "", 55 | "TERM": "", 56 | "COLORTERM": "", 57 | }, 58 | wantTerminal: true, 59 | wantColor: true, 60 | want256Color: false, 61 | wantTrueColor: false, 62 | }, 63 | { 64 | name: "has 256-color support", 65 | env: map[string]string{ 66 | "GH_FORCE_TTY": "true", 67 | "CLICOLOR": "", 68 | "CLICOLOR_FORCE": "", 69 | "NO_COLOR": "", 70 | "TERM": "256-color", 71 | "COLORTERM": "", 72 | }, 73 | wantTerminal: true, 74 | wantColor: true, 75 | want256Color: true, 76 | wantTrueColor: false, 77 | }, 78 | { 79 | name: "has truecolor support", 80 | env: map[string]string{ 81 | "GH_FORCE_TTY": "true", 82 | "CLICOLOR": "", 83 | "CLICOLOR_FORCE": "", 84 | "NO_COLOR": "", 85 | "TERM": "truecolor", 86 | "COLORTERM": "", 87 | }, 88 | wantTerminal: true, 89 | wantColor: true, 90 | want256Color: true, 91 | wantTrueColor: true, 92 | }, 93 | } 94 | for _, tt := range tests { 95 | t.Run(tt.name, func(t *testing.T) { 96 | for key, value := range tt.env { 97 | t.Setenv(key, value) 98 | } 99 | terminal := FromEnv() 100 | if got := terminal.IsTerminalOutput(); got != tt.wantTerminal { 101 | t.Errorf("expected terminal %v, got %v", tt.wantTerminal, got) 102 | } 103 | if got := terminal.IsColorEnabled(); got != tt.wantColor { 104 | t.Errorf("expected color %v, got %v", tt.wantColor, got) 105 | } 106 | if got := terminal.Is256ColorSupported(); got != tt.want256Color { 107 | t.Errorf("expected 256-color %v, got %v", tt.want256Color, got) 108 | } 109 | if got := terminal.IsTrueColorSupported(); got != tt.wantTrueColor { 110 | t.Errorf("expected truecolor %v, got %v", tt.wantTrueColor, got) 111 | } 112 | }) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /pkg/text/text.go: -------------------------------------------------------------------------------- 1 | // Package text is a set of utility functions for text processing and outputting to the terminal. 2 | package text 3 | 4 | import ( 5 | "fmt" 6 | "regexp" 7 | "strings" 8 | "time" 9 | "unicode" 10 | 11 | "github.com/charmbracelet/lipgloss" 12 | "github.com/muesli/reflow/truncate" 13 | "golang.org/x/text/runes" 14 | "golang.org/x/text/transform" 15 | "golang.org/x/text/unicode/norm" 16 | ) 17 | 18 | const ( 19 | ellipsis = "..." 20 | minWidthForEllipsis = len(ellipsis) + 2 21 | ) 22 | 23 | var indentRE = regexp.MustCompile(`(?m)^`) 24 | 25 | // Indent returns a copy of the string s with indent prefixed to it, will apply indent 26 | // to each line of the string. 27 | func Indent(s, indent string) string { 28 | if len(strings.TrimSpace(s)) == 0 { 29 | return s 30 | } 31 | return indentRE.ReplaceAllLiteralString(s, indent) 32 | } 33 | 34 | // DisplayWidth calculates what the rendered width of string s will be. 35 | func DisplayWidth(s string) int { 36 | return lipgloss.Width(s) 37 | } 38 | 39 | // Truncate returns a copy of the string s that has been shortened to fit the maximum display width. 40 | func Truncate(maxWidth int, s string) string { 41 | w := DisplayWidth(s) 42 | if w <= maxWidth { 43 | return s 44 | } 45 | tail := "" 46 | if maxWidth >= minWidthForEllipsis { 47 | tail = ellipsis 48 | } 49 | r := truncate.StringWithTail(s, uint(maxWidth), tail) 50 | if DisplayWidth(r) < maxWidth { 51 | r += " " 52 | } 53 | return r 54 | } 55 | 56 | // PadRight returns a copy of the string s that has been padded on the right with whitespace to fit 57 | // the maximum display width. 58 | func PadRight(maxWidth int, s string) string { 59 | if padWidth := maxWidth - DisplayWidth(s); padWidth > 0 { 60 | s += strings.Repeat(" ", padWidth) 61 | } 62 | return s 63 | } 64 | 65 | // Pluralize returns a concatenated string with num and the plural form of thing if necessary. 66 | func Pluralize(num int, thing string) string { 67 | if num == 1 { 68 | return fmt.Sprintf("%d %s", num, thing) 69 | } 70 | return fmt.Sprintf("%d %ss", num, thing) 71 | } 72 | 73 | func fmtDuration(amount int, unit string) string { 74 | return fmt.Sprintf("about %s ago", Pluralize(amount, unit)) 75 | } 76 | 77 | // RelativeTimeAgo returns a human readable string of the time duration between a and b that is estimated 78 | // to the nearest unit of time. 79 | func RelativeTimeAgo(a, b time.Time) string { 80 | ago := a.Sub(b) 81 | 82 | if ago < time.Minute { 83 | return "less than a minute ago" 84 | } 85 | if ago < time.Hour { 86 | return fmtDuration(int(ago.Minutes()), "minute") 87 | } 88 | if ago < 24*time.Hour { 89 | return fmtDuration(int(ago.Hours()), "hour") 90 | } 91 | if ago < 30*24*time.Hour { 92 | return fmtDuration(int(ago.Hours())/24, "day") 93 | } 94 | if ago < 365*24*time.Hour { 95 | return fmtDuration(int(ago.Hours())/24/30, "month") 96 | } 97 | 98 | return fmtDuration(int(ago.Hours()/24/365), "year") 99 | } 100 | 101 | // RemoveDiacritics returns the input value without "diacritics", or accent marks. 102 | func RemoveDiacritics(value string) string { 103 | // Mn = "Mark, nonspacing" unicode character category 104 | removeMnTransfomer := runes.Remove(runes.In(unicode.Mn)) 105 | 106 | // 1. Decompose the text into characters and diacritical marks 107 | // 2. Remove the diacriticals marks 108 | // 3. Recompose the text 109 | t := transform.Chain(norm.NFD, removeMnTransfomer, norm.NFC) 110 | normalized, _, err := transform.String(t, value) 111 | if err != nil { 112 | return value 113 | } 114 | return normalized 115 | } 116 | -------------------------------------------------------------------------------- /pkg/x/color/accessibility.go: -------------------------------------------------------------------------------- 1 | package color 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/cli/go-gh/v2/pkg/config" 7 | ) 8 | 9 | const ( 10 | // AccessibleColorsEnv is the name of the environment variable to enable accessibile color features. 11 | AccessibleColorsEnv = "GH_ACCESSIBLE_COLORS" 12 | 13 | // AccessibleColorsSetting is the name of the `gh config` setting to enable accessibile color features. 14 | AccessibleColorsSetting = "accessible_colors" 15 | ) 16 | 17 | // IsAccessibleColorsEnabled returns true if accessible colors are enabled via environment variable 18 | // or configuration setting with the environment variable having higher precedence. 19 | // 20 | // If the environment variable is set, then any value other than the following will enable accessible colors: 21 | // empty, "0", "false", or "no". 22 | func IsAccessibleColorsEnabled() bool { 23 | // Environment variable only has the highest precedence if actually set. 24 | if envVar, set := os.LookupEnv(AccessibleColorsEnv); set { 25 | switch envVar { 26 | case "", "0", "false", "no": 27 | return false 28 | default: 29 | return true 30 | } 31 | } 32 | 33 | // We are not handling errors because we don't want to fail if the config is not 34 | // read. Instead, we assume an empty configuration is equivalent to "disabled". 35 | cfg, _ := config.Read(nil) 36 | accessibleConfigValue, _ := cfg.Get([]string{AccessibleColorsSetting}) 37 | 38 | return accessibleConfigValue == "enabled" 39 | } 40 | -------------------------------------------------------------------------------- /pkg/x/color/accessibility_test.go: -------------------------------------------------------------------------------- 1 | package color 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/MakeNowJust/heredoc" 7 | "github.com/cli/go-gh/v2/internal/testutils" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestIsAccessibleColorsEnabled(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | env map[string]string 15 | cfgStr string 16 | wantOut bool 17 | }{ 18 | { 19 | name: "When the accessibility configuration and env var are both unset, it should return false", 20 | cfgStr: "", 21 | wantOut: false, 22 | }, 23 | { 24 | name: "When the accessibility configuration is unset but the env var is set to something truthy (not '0' or 'false'), it should return true", 25 | env: map[string]string{ 26 | "GH_ACCESSIBLE_COLORS": "1", 27 | }, 28 | cfgStr: "", 29 | wantOut: true, 30 | }, 31 | { 32 | name: "When the accessibility configuration is unset and the env var returns '0', it should return false", 33 | env: map[string]string{ 34 | "GH_ACCESSIBLE_COLORS": "0", 35 | }, 36 | cfgStr: "", 37 | wantOut: false, 38 | }, 39 | { 40 | name: "When the accessibility configuration is unset and the env var returns 'false', it should return false", 41 | env: map[string]string{ 42 | "GH_ACCESSIBLE_COLORS": "false", 43 | }, 44 | cfgStr: "", 45 | wantOut: false, 46 | }, 47 | { 48 | name: "When the accessibility configuration is unset and the env var returns '', it should return false", 49 | env: map[string]string{ 50 | "GH_ACCESSIBLE_COLORS": "", 51 | }, 52 | cfgStr: "", 53 | wantOut: false, 54 | }, 55 | { 56 | name: "When the accessibility configuration is set to enabled and the env var is unset, it should return true", 57 | cfgStr: accessibilityEnabledConfig(), 58 | wantOut: true, 59 | }, 60 | { 61 | name: "When the accessibility configuration is set to disabled and the env var is unset, it should return false", 62 | cfgStr: accessibilityDisabledConfig(), 63 | wantOut: false, 64 | }, 65 | { 66 | name: "When the accessibility configuration is set to disabled and the env var is set to something truthy (not '0' or 'false'), it should return true", 67 | env: map[string]string{ 68 | "GH_ACCESSIBLE_COLORS": "true", 69 | }, 70 | cfgStr: accessibilityDisabledConfig(), 71 | wantOut: true, 72 | }, 73 | { 74 | name: "When the accessibility configuration is set to enabled and the env var is set to '0', it should return false", 75 | env: map[string]string{ 76 | "GH_ACCESSIBLE_COLORS": "0", 77 | }, 78 | cfgStr: accessibilityEnabledConfig(), 79 | wantOut: false, 80 | }, 81 | { 82 | name: "When the accessibility configuration is set to enabled and the env var is set to 'false', it should return false", 83 | env: map[string]string{ 84 | "GH_ACCESSIBLE_COLORS": "false", 85 | }, 86 | cfgStr: accessibilityEnabledConfig(), 87 | wantOut: false, 88 | }, 89 | { 90 | name: "When the accessibility configuration is set to enabled and the env var is set to '', it should return false", 91 | env: map[string]string{ 92 | "GH_ACCESSIBLE_COLORS": "", 93 | }, 94 | cfgStr: accessibilityEnabledConfig(), 95 | wantOut: false, 96 | }, 97 | } 98 | for _, tt := range tests { 99 | t.Run(tt.name, func(t *testing.T) { 100 | for k, v := range tt.env { 101 | t.Setenv(k, v) 102 | } 103 | testutils.StubConfig(t, tt.cfgStr) 104 | assert.Equal(t, tt.wantOut, IsAccessibleColorsEnabled()) 105 | }) 106 | } 107 | } 108 | 109 | func accessibilityEnabledConfig() string { 110 | return heredoc.Doc(` 111 | accessible_colors: enabled 112 | `) 113 | } 114 | 115 | func accessibilityDisabledConfig() string { 116 | return heredoc.Doc(` 117 | accessible_colors: disabled 118 | `) 119 | } 120 | -------------------------------------------------------------------------------- /pkg/x/color/color.go: -------------------------------------------------------------------------------- 1 | // Package color handles experimental GitHub CLI user experiences focused on color rendering concerns such as accessibility and color roles. 2 | // 3 | // Note this is an experimental package where the API is subject to change. 4 | package color 5 | -------------------------------------------------------------------------------- /pkg/x/markdown/accessibility.go: -------------------------------------------------------------------------------- 1 | package markdown 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | 7 | "github.com/charmbracelet/glamour/ansi" 8 | "github.com/charmbracelet/glamour/styles" 9 | ) 10 | 11 | // glamourStyleColor represents color codes used to customize glamour style elements. 12 | type glamourStyleColor int 13 | 14 | // Do not change the order of the following glamour color constants, 15 | // which matches 4-bit colors with their respective color codes. 16 | const ( 17 | black glamourStyleColor = iota 18 | red 19 | green 20 | yellow 21 | blue 22 | magenta 23 | cyan 24 | white 25 | brightBlack 26 | brightRed 27 | brightGreen 28 | brightYellow 29 | brightBlue 30 | brightMagenta 31 | brightCyan 32 | brightWhite 33 | ) 34 | 35 | func (gsc glamourStyleColor) code() *string { 36 | s := strconv.Itoa(int(gsc)) 37 | return &s 38 | } 39 | 40 | func parseGlamourStyleColor(code string) (glamourStyleColor, error) { 41 | switch code { 42 | case "0": 43 | return black, nil 44 | case "1": 45 | return red, nil 46 | case "2": 47 | return green, nil 48 | case "3": 49 | return yellow, nil 50 | case "4": 51 | return blue, nil 52 | case "5": 53 | return magenta, nil 54 | case "6": 55 | return cyan, nil 56 | case "7": 57 | return white, nil 58 | case "8": 59 | return brightBlack, nil 60 | case "9": 61 | return brightRed, nil 62 | case "10": 63 | return brightGreen, nil 64 | case "11": 65 | return brightYellow, nil 66 | case "12": 67 | return brightBlue, nil 68 | case "13": 69 | return brightMagenta, nil 70 | case "14": 71 | return brightCyan, nil 72 | case "15": 73 | return brightWhite, nil 74 | default: 75 | return 0, fmt.Errorf("invalid color code: %s", code) 76 | } 77 | } 78 | 79 | func AccessibleStyleConfig(theme string) ansi.StyleConfig { 80 | switch theme { 81 | case "light": 82 | return accessibleLightStyleConfig() 83 | case "dark": 84 | return accessibleDarkStyleConfig() 85 | default: 86 | return ansi.StyleConfig{} 87 | } 88 | } 89 | 90 | func accessibleDarkStyleConfig() ansi.StyleConfig { 91 | cfg := styles.DarkStyleConfig 92 | 93 | // Text color 94 | cfg.Document.StylePrimitive.Color = white.code() 95 | 96 | // Link colors 97 | cfg.Link.Color = brightCyan.code() 98 | cfg.LinkText.Color = brightCyan.code() 99 | 100 | // Heading colors 101 | cfg.Heading.StylePrimitive.Color = brightMagenta.code() 102 | cfg.H1.StylePrimitive.Color = brightWhite.code() 103 | cfg.H1.StylePrimitive.BackgroundColor = brightBlue.code() 104 | cfg.H6.StylePrimitive.Color = brightMagenta.code() 105 | 106 | // Code colors 107 | cfg.Code.BackgroundColor = brightWhite.code() 108 | cfg.Code.Color = red.code() 109 | 110 | // Image colors 111 | cfg.Image.Color = brightMagenta.code() 112 | cfg.ImageText.Color = brightMagenta.code() 113 | 114 | // Horizontal rule colors 115 | cfg.HorizontalRule.Color = white.code() 116 | 117 | // Code block colors 118 | // Unsetting StyleBlock color until we understand what it does versus Chroma style 119 | cfg.CodeBlock.StyleBlock.StylePrimitive.Color = nil 120 | 121 | return cfg 122 | } 123 | 124 | func accessibleLightStyleConfig() ansi.StyleConfig { 125 | cfg := styles.LightStyleConfig 126 | 127 | // Text color 128 | cfg.Document.StylePrimitive.Color = black.code() 129 | 130 | // Link colors 131 | cfg.Link.Color = brightBlue.code() 132 | cfg.LinkText.Color = brightBlue.code() 133 | 134 | // Heading colors 135 | cfg.Heading.StylePrimitive.Color = magenta.code() 136 | cfg.H1.StylePrimitive.Color = brightWhite.code() 137 | cfg.H1.StylePrimitive.BackgroundColor = blue.code() 138 | 139 | // Code colors 140 | cfg.Code.BackgroundColor = brightWhite.code() 141 | cfg.Code.Color = red.code() 142 | 143 | // Image colors 144 | cfg.Image.Color = magenta.code() 145 | cfg.ImageText.Color = magenta.code() 146 | 147 | // Horizontal rule colors 148 | cfg.HorizontalRule.Color = white.code() 149 | 150 | // Code block colors 151 | // Unsetting StyleBlock color until we understand what it does versus Chroma style 152 | cfg.CodeBlock.StyleBlock.StylePrimitive.Color = nil 153 | 154 | return cfg 155 | } 156 | -------------------------------------------------------------------------------- /pkg/x/markdown/accessibility_test.go: -------------------------------------------------------------------------------- 1 | package markdown 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/charmbracelet/glamour/ansi" 8 | "github.com/charmbracelet/glamour/styles" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // TestGlamourStyleColors ensures that the resulting string color codes match the expected values. 13 | func TestGlamourStyleColors(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | c glamourStyleColor 17 | want string 18 | }{ 19 | { 20 | name: "black", 21 | c: black, 22 | want: "0", 23 | }, 24 | { 25 | name: "red", 26 | c: red, 27 | want: "1", 28 | }, 29 | { 30 | name: "green", 31 | c: green, 32 | want: "2", 33 | }, 34 | { 35 | name: "yellow", 36 | c: yellow, 37 | want: "3", 38 | }, 39 | { 40 | name: "blue", 41 | c: blue, 42 | want: "4", 43 | }, 44 | { 45 | name: "magenta", 46 | c: magenta, 47 | want: "5", 48 | }, 49 | { 50 | name: "cyan", 51 | c: cyan, 52 | want: "6", 53 | }, 54 | { 55 | name: "white", 56 | c: white, 57 | want: "7", 58 | }, 59 | { 60 | name: "bright black", 61 | c: brightBlack, 62 | want: "8", 63 | }, 64 | { 65 | name: "bright red", 66 | c: brightRed, 67 | want: "9", 68 | }, 69 | { 70 | name: "bright green", 71 | c: brightGreen, 72 | want: "10", 73 | }, 74 | { 75 | name: "bright yellow", 76 | c: brightYellow, 77 | want: "11", 78 | }, 79 | { 80 | name: "bright blue", 81 | c: brightBlue, 82 | want: "12", 83 | }, 84 | { 85 | name: "bright magenta", 86 | c: brightMagenta, 87 | want: "13", 88 | }, 89 | { 90 | name: "bright cyan", 91 | c: brightCyan, 92 | want: "14", 93 | }, 94 | { 95 | name: "bright white", 96 | c: brightWhite, 97 | want: "15", 98 | }, 99 | } 100 | for _, tt := range tests { 101 | t.Run(tt.name, func(t *testing.T) { 102 | t.Parallel() 103 | 104 | assert.Equal(t, tt.want, *tt.c.code()) 105 | }) 106 | } 107 | } 108 | 109 | func TestAccessibleStyleConfig(t *testing.T) { 110 | tests := []struct { 111 | name string 112 | theme string 113 | want ansi.StyleConfig 114 | }{ 115 | { 116 | name: "light", 117 | theme: "light", 118 | want: accessibleLightStyleConfig(), 119 | }, 120 | { 121 | name: "dark", 122 | theme: "dark", 123 | want: accessibleDarkStyleConfig(), 124 | }, 125 | { 126 | name: "fallback", 127 | theme: "foo", 128 | want: ansi.StyleConfig{}, 129 | }, 130 | } 131 | for _, tt := range tests { 132 | t.Run(tt.name, func(t *testing.T) { 133 | t.Parallel() 134 | 135 | assert.Equal(t, tt.want, AccessibleStyleConfig(tt.theme)) 136 | }) 137 | } 138 | } 139 | 140 | func TestAccessibleDarkStyleConfig(t *testing.T) { 141 | cfg := accessibleDarkStyleConfig() 142 | assert.Equal(t, white.code(), cfg.Document.StylePrimitive.Color) 143 | assert.Equal(t, brightCyan.code(), cfg.Link.Color) 144 | assert.Equal(t, brightMagenta.code(), cfg.Heading.StylePrimitive.Color) 145 | assert.Equal(t, brightWhite.code(), cfg.H1.StylePrimitive.Color) 146 | assert.Equal(t, brightBlue.code(), cfg.H1.StylePrimitive.BackgroundColor) 147 | assert.Equal(t, brightWhite.code(), cfg.Code.BackgroundColor) 148 | assert.Equal(t, red.code(), cfg.Code.Color) 149 | assert.Equal(t, brightMagenta.code(), cfg.Image.Color) 150 | assert.Equal(t, white.code(), cfg.HorizontalRule.Color) 151 | 152 | // Test that we haven't changed the original style 153 | assert.Equal(t, styles.DarkStyleConfig.H2, cfg.H2) 154 | } 155 | 156 | func TestAccessibleDarkStyleConfigIs4Bit(t *testing.T) { 157 | t.Parallel() 158 | 159 | cfg := accessibleDarkStyleConfig() 160 | validateColors(t, reflect.ValueOf(cfg), "StyleConfig") 161 | } 162 | 163 | func TestAccessibleLightStyleConfig(t *testing.T) { 164 | t.Parallel() 165 | 166 | cfg := accessibleLightStyleConfig() 167 | assert.Equal(t, black.code(), cfg.Document.StylePrimitive.Color) 168 | assert.Equal(t, brightBlue.code(), cfg.Link.Color) 169 | assert.Equal(t, magenta.code(), cfg.Heading.StylePrimitive.Color) 170 | assert.Equal(t, brightWhite.code(), cfg.H1.StylePrimitive.Color) 171 | assert.Equal(t, blue.code(), cfg.H1.StylePrimitive.BackgroundColor) 172 | assert.Equal(t, brightWhite.code(), cfg.Code.BackgroundColor) 173 | assert.Equal(t, red.code(), cfg.Code.Color) 174 | assert.Equal(t, magenta.code(), cfg.Image.Color) 175 | assert.Equal(t, white.code(), cfg.HorizontalRule.Color) 176 | 177 | // Test that we haven't changed the original style 178 | assert.Equal(t, styles.LightStyleConfig.H2, cfg.H2) 179 | } 180 | 181 | func TestAccessibleLightStyleConfigIs4Bit(t *testing.T) { 182 | t.Parallel() 183 | 184 | cfg := accessibleLightStyleConfig() 185 | validateColors(t, reflect.ValueOf(cfg), "StyleConfig") 186 | } 187 | 188 | // Walk every field in the StyleConfig struct, checking that the Color and 189 | // BackgroundColor fields are valid 4-bit colors. 190 | // 191 | // This test skips Chroma fields because their Color fields are RGB hex values 192 | // that are downsampled to 4-bit colors unlike Glamour, which are 8-bit colors. 193 | // For more information, https://github.com/alecthomas/chroma/blob/0bf0e9f9ae2a81d463afe769cce01ff821bee3ba/formatters/tty_indexed.go#L32-L44 194 | func validateColors(t *testing.T, v reflect.Value, path string) { 195 | if v.Kind() == reflect.Ptr { 196 | if v.IsNil() { 197 | return 198 | } 199 | v = v.Elem() 200 | } 201 | 202 | switch v.Kind() { 203 | case reflect.Struct: 204 | for i := range v.NumField() { 205 | field := v.Field(i) 206 | fieldType := v.Type().Field(i) 207 | 208 | // Construct path for better error reporting 209 | fieldPath := path + "." + fieldType.Name 210 | 211 | // Ensure we only check Glamour "Color" and "BackgroundColor" 212 | if fieldType.Name == "Chroma" { 213 | continue 214 | } else if (fieldType.Name == "Color" || fieldType.Name == "BackgroundColor") && 215 | fieldType.Type.Kind() == reflect.Ptr && fieldType.Type.Elem().Kind() == reflect.String { 216 | 217 | if field.IsNil() { 218 | continue 219 | } 220 | color := field.Elem().String() 221 | _, err := parseGlamourStyleColor(color) 222 | assert.NoError(t, err, "Failed to parse color '%s' in %s", color, fieldPath) 223 | } else { 224 | // Recurse into nested structs 225 | validateColors(t, field, fieldPath) 226 | } 227 | } 228 | case reflect.Slice: 229 | // Handle slices of structs 230 | for i := range v.Len() { 231 | validateColors(t, v.Index(i), path+"[]") 232 | } 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /pkg/x/markdown/markdown.go: -------------------------------------------------------------------------------- 1 | // Package markdown handles experimental GitHub CLI user experiences focused on markdown rendering concerns such as accessibility. 2 | // 3 | // Note this is an experimental package where the API is subject to change. 4 | package markdown 5 | -------------------------------------------------------------------------------- /pkg/x/x.go: -------------------------------------------------------------------------------- 1 | // Package x is a collection of experimental features for use within the GitHub CLI and CLI extensions. 2 | // 3 | // Anything contained is subject to change without notice until it is considered stable enough to be promoted. 4 | package x 5 | --------------------------------------------------------------------------------