├── go.mod
├── .github
└── workflows
│ └── go.yml
├── go.sum
├── LICENSE
├── release
├── release.go
└── asset
│ ├── download_test.go
│ └── download.go
├── README.md
├── upgrade.go
└── checksum
├── checksum.go
└── checksum_test.go
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/getsavvyinc/upgrade-cli
2 |
3 | go 1.21.6
4 |
5 | require (
6 | github.com/hashicorp/go-version v1.6.0
7 | github.com/stretchr/testify v1.8.4
8 | )
9 |
10 | require (
11 | github.com/davecgh/go-spew v1.1.1 // indirect
12 | github.com/pmezard/go-difflib v1.0.0 // indirect
13 | gopkg.in/yaml.v3 v3.0.1 // indirect
14 | )
15 |
16 | retract v0.7.0 // missing fallback for arm64 -> all
17 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: Go
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "*" ]
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 |
18 | - name: Set up Go
19 | uses: actions/setup-go@v5
20 | with:
21 | go-version-file: './go.mod'
22 |
23 | - name: Set CI ENV
24 | run: export ENV=CI
25 |
26 | - name: Test
27 | run: go test -v -count=1 ./...
28 |
29 | - name: vet
30 | run: go vet ./...
31 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
4 | github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
7 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
8 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
9 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
10 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
11 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
12 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 savvy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/release/release.go:
--------------------------------------------------------------------------------
1 | package release
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "net/http"
8 | )
9 |
10 | type Asset struct {
11 | Name string `json:"name"`
12 | BrowserDownloadURL string `json:"browser_download_url"`
13 | }
14 |
15 | // Info holds information about a release.
16 | type Info struct {
17 | TagName string `json:"tag_name"`
18 | Assets []Asset `json:"assets"`
19 | }
20 |
21 | type Getter interface {
22 | GetLatestRelease(ctx context.Context) (*Info, error)
23 | }
24 |
25 | type githubReleaseGetter struct {
26 | repo, owner string
27 | }
28 |
29 | var _ Getter = (*githubReleaseGetter)(nil)
30 |
31 | func NewReleaseGetter(repo, owner string) *githubReleaseGetter {
32 | return &githubReleaseGetter{
33 | repo: repo,
34 | owner: owner,
35 | }
36 | }
37 |
38 | func (g *githubReleaseGetter) GetLatestRelease(ctx context.Context) (*Info, error) {
39 | url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", g.owner, g.repo)
40 | return getLatestRelease(ctx, url)
41 | }
42 |
43 | // getLatestRelease fetches the latest release from GitHub.
44 | func getLatestRelease(ctx context.Context, url string) (*Info, error) {
45 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
46 | if err != nil {
47 | return nil, err
48 | }
49 | resp, err := http.DefaultClient.Do(req)
50 | if err != nil {
51 | return nil, err
52 | }
53 | defer resp.Body.Close()
54 |
55 | var release Info
56 | if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
57 | return nil, err
58 | }
59 | return &release, nil
60 | }
61 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UpgradeCLI
2 |
3 |
6 |
7 |
8 | UpgradeCLI makes it easy to add an `upgrade` command to your cli.
9 |
10 | UpgradeCLI was built to implement the `upgrade` command for [Savvy's](https://getsavvy.so) OSS [CLI](https://github.com/getsavvyinc/savvy-cli).
11 |
12 | > Savvy's CLI helps developers create and share high quality runbooks right from the terminal.
13 |
14 | ## Install
15 |
16 | ```sh
17 | go get github.com/getsavvyinc/upgrade-cli
18 |
19 | ```
20 |
21 | ## Usage
22 |
23 | ```go
24 | package cmd
25 |
26 | import (
27 | "context"
28 | "os"
29 |
30 | "github.com/getsavvyinc/savvy-cli/config"
31 | "github.com/getsavvyinc/savvy-cli/display"
32 | "github.com/getsavvyinc/upgrade-cli"
33 | "github.com/spf13/cobra"
34 | )
35 |
36 | const owner = "getsavvyinc"
37 | const repo = "savvy-cli"
38 |
39 | // upgradeCmd represents the upgrade command
40 | var upgradeCmd = &cobra.Command{
41 | Use: "upgrade",
42 | Short: "upgrade savvy to the latest version",
43 | Long: `upgrade savvy to the latest version`,
44 | Run: func(cmd *cobra.Command, args []string) {
45 | executablePath, err := os.Executable()
46 | if err != nil {
47 | display.Error(err)
48 | os.Exit(1)
49 | }
50 | version := config.Version()
51 |
52 | upgrader := upgrade.NewUpgrader(owner, repo, executablePath)
53 |
54 | if ok, err := upgrader.IsNewVersionAvailable(context.Background(), version); err != nil {
55 | display.Error(err)
56 | return
57 | } else if !ok {
58 | display.Info("Savvy is already up to date")
59 | return
60 | }
61 |
62 | display.Info("Upgrading savvy...")
63 | if err := upgrader.Upgrade(context.Background(), version); err != nil {
64 | display.Error(err)
65 | os.Exit(1)
66 | } else {
67 | display.Success("Savvy has been upgraded to the latest version")
68 | }
69 | },
70 | }
71 |
72 | func init() {
73 | rootCmd.AddCommand(upgradeCmd)
74 | }
75 | ```
76 |
77 | ## Requirements
78 |
79 | > `upgrade-cli` is fully compatible with releases generated using [goreleaser](https://github.com/goreleaser/goreleaser).
80 |
81 | `upgrade-cli` makes the following assumptions about Relase Assets.
82 |
83 | * The checksum file has a `checksums.txt` suffix
84 | * The checksum file format matches the example below:
85 |
86 | ```sh
87 | 6796a0fb64d0c78b2de5410a94749a3bfb77291747c1835fbd427e8bf00f6af3 savvy_darwin_arm64
88 | 3853c410eeee629f71a981844975700b2925ac7582bf5559c384c391be8abbcb savvy_darwin_x86_64
89 | 00637eae6cf7588d990d64113a02caca831ea5391ef6f66c88db2dfa576ca6bd savvy_linux_arm64
90 | 1e9c98dbb0f54ee06119d957fa140b42780aa330d11208ad0a21c2a06832eca3 savvy_linux_i386
91 | 3040ff4c07dda6c7ff65f9476b57277b14a72d0b33381b35aa8810df3e1785ea savvy_linux_x86_64
92 | ```
93 | * The URL to download a binary asset for a particular $os, $arch ends with `$os_$arch`
94 |
95 | ## Contributing
96 |
97 | All contributions are welcome - bug reports, pull requests and ideas for improving the package.
98 |
99 | 1. Join the `#upgrade-cli` channel on [Discord](https://getsavvy.so/discord)
100 | 2. Open an [issue on GitHub](https://github.com/getsavvyinc/upgrade-cli/issues/new) to reports bugs or feature requests
101 | 3. Please follow a ["fork and pull request"](https://docs.github.com/en/get-started/exploring-projects-on-github/contributing-to-a-project) workflow for submitting changes to the repository.
102 |
--------------------------------------------------------------------------------
/release/asset/download_test.go:
--------------------------------------------------------------------------------
1 | package asset
2 |
3 | import (
4 | "context"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/getsavvyinc/upgrade-cli/release"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | // downloadData is the content of the file that is downloaded in the tests.
15 | // It's sha256 hash is: 88fd602a930bc7c0bb78c385f3cb70e976a0cdc3517020be32f19aae8c8eba17
16 | // NOTE: There is no newline at the end of the file.
17 | const downloadData = `#!/bin/sh
18 |
19 | echo "Hello, World!"`
20 |
21 | const downloadDataChecksum = "88fd602a930bc7c0bb78c385f3cb70e976a0cdc3517020be32f19aae8c8eba17"
22 |
23 | func setupTestServer(t *testing.T, handler http.Handler) *httptest.Server {
24 | srv := httptest.NewServer(handler)
25 | defer t.Cleanup(srv.Close)
26 | return srv
27 | }
28 |
29 | func downloadDataHandler(w http.ResponseWriter, r *http.Request) {
30 | w.Header().Set("Content-Type", "application/octet-stream")
31 | w.WriteHeader(200)
32 | io.WriteString(w, downloadData)
33 | }
34 |
35 | func shouldNeverBeCalled(t *testing.T) http.Handler {
36 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37 | t.Errorf("unexpected URL: %s", r.URL.Path)
38 | })
39 | }
40 |
41 | func TestAssetDownloader(t *testing.T) {
42 | const executablePath = "savvy"
43 | t.Run("TryDownloadingMissingAsset", func(t *testing.T) {
44 | srv := setupTestServer(t, shouldNeverBeCalled(t))
45 | ctx := context.Background()
46 | downloader := NewAssetDownloader(executablePath)
47 | asset, cleanupFn, err := downloader.DownloadAsset(ctx, []release.Asset{
48 | {BrowserDownloadURL: srv.URL + "/nonexistent"},
49 | })
50 | assert.ErrorIs(t, err, ErrNoAsset)
51 | assert.Nil(t, asset)
52 | assert.Nil(t, cleanupFn)
53 | })
54 | t.Run("EnsureDownloadedDoesntChangeContent", func(t *testing.T) {
55 | srv := setupTestServer(t, http.HandlerFunc(downloadDataHandler))
56 | ctx := context.Background()
57 | downloader := NewAssetDownloader(executablePath, WithOS("os"), WithArch("arch"))
58 | asset, cleanupFn, err := downloader.DownloadAsset(ctx, []release.Asset{
59 | {BrowserDownloadURL: srv.URL + "/download_os_arch"},
60 | })
61 | assert.NoError(t, err)
62 | assert.NotNil(t, asset)
63 | assert.NotNil(t, cleanupFn)
64 |
65 | assert.Equal(t, downloadDataChecksum, asset.Checksum)
66 | t.Run("VerifyCleanup", func(t *testing.T) {
67 | tmpFile := asset.DownloadedBinaryFilePath
68 | assert.FileExists(t, tmpFile)
69 | assert.NoError(t, cleanupFn())
70 | assert.NoFileExists(t, tmpFile)
71 | })
72 | })
73 | t.Run("VerifyFallback", func(t *testing.T) {
74 | srv := setupTestServer(t, http.HandlerFunc(downloadDataHandler))
75 | ctx := context.Background()
76 | t.Run("DownloadFailsWithoutFallback", func(t *testing.T) {
77 | downloader := NewAssetDownloader(executablePath, WithOS("os"), WithArch("amd64"))
78 | asset, cleanupFn, err := downloader.DownloadAsset(ctx, []release.Asset{
79 | {BrowserDownloadURL: srv.URL + "/download_os_x86_64"},
80 | })
81 | assert.ErrorIs(t, err, ErrNoAsset)
82 | assert.Nil(t, asset)
83 | assert.Nil(t, cleanupFn)
84 | })
85 | t.Run("DownloadSucceedsWithFallback", func(t *testing.T) {
86 | downloader := NewAssetDownloader(executablePath,
87 | WithOS("os"),
88 | WithArch("amd64"),
89 | WithLookupArchFallback(
90 | map[string][]string{"amd64": {"all", "x86_64"}},
91 | ))
92 | asset, cleanupFn, err := downloader.DownloadAsset(ctx, []release.Asset{
93 | {BrowserDownloadURL: srv.URL + "/download_os_x86_64"},
94 | })
95 | assert.NoError(t, err)
96 | assert.NotNil(t, asset)
97 | assert.NotNil(t, cleanupFn)
98 | assert.Equal(t, downloadDataChecksum, asset.Checksum)
99 | })
100 | })
101 | }
102 |
--------------------------------------------------------------------------------
/upgrade.go:
--------------------------------------------------------------------------------
1 | package upgrade
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "os"
8 | "path/filepath"
9 |
10 | "github.com/getsavvyinc/upgrade-cli/checksum"
11 | "github.com/getsavvyinc/upgrade-cli/release"
12 | "github.com/getsavvyinc/upgrade-cli/release/asset"
13 | "github.com/hashicorp/go-version"
14 | )
15 |
16 | type Upgrader interface {
17 | IsNewVersionAvailable(ctx context.Context, currentVersion string) (bool, error)
18 | // Upgrade upgrades the current binary to the latest version.
19 | Upgrade(ctx context.Context, currentVersion string) error
20 | }
21 |
22 | type upgrader struct {
23 | executablePath string
24 | repo string
25 | owner string
26 | releaseGetter release.Getter
27 | assetDownloader asset.Downloader
28 | checksumDownloader checksum.Downloader
29 | checksumValidator checksum.CheckSumValidator
30 | }
31 |
32 | var _ Upgrader = (*upgrader)(nil)
33 |
34 | type Opt func(*upgrader)
35 |
36 | func WithAssetDownloader(d asset.Downloader) Opt {
37 | return func(u *upgrader) {
38 | u.assetDownloader = d
39 | }
40 | }
41 |
42 | func WithCheckSumDownloader(c checksum.Downloader) Opt {
43 | return func(u *upgrader) {
44 | u.checksumDownloader = c
45 | }
46 | }
47 |
48 | func WithCheckSumValidator(c checksum.CheckSumValidator) Opt {
49 | return func(u *upgrader) {
50 | u.checksumValidator = c
51 | }
52 | }
53 |
54 | func NewUpgrader(owner string, repo string, executablePath string, opts ...Opt) Upgrader {
55 | u := &upgrader{
56 | repo: repo,
57 | owner: owner,
58 | executablePath: executablePath,
59 | releaseGetter: release.NewReleaseGetter(repo, owner),
60 | assetDownloader: asset.NewAssetDownloader(executablePath, asset.WithLookupArchFallback(map[string][]string{
61 | "amd64": {"x86_64"},
62 | "386": {"i86", "all"},
63 | "arm64": {"all"},
64 | })),
65 | checksumDownloader: checksum.NewCheckSumDownloader(),
66 | checksumValidator: checksum.NewCheckSumValidator(),
67 | }
68 | for _, opt := range opts {
69 | opt(u)
70 | }
71 | return u
72 | }
73 |
74 | var ErrInvalidCheckSum = errors.New("invalid checksum")
75 |
76 | func (u *upgrader) IsNewVersionAvailable(ctx context.Context, currentVersion string) (bool, error) {
77 | curr, err := version.NewVersion(currentVersion)
78 | if err != nil {
79 | return false, fmt.Errorf("failed to parse current version: %s with err %w", currentVersion, err)
80 | }
81 |
82 | releaseInfo, err := u.releaseGetter.GetLatestRelease(ctx)
83 | if err != nil {
84 | return false, err
85 | }
86 |
87 | latest, err := version.NewVersion(releaseInfo.TagName)
88 | if err != nil {
89 | return false, fmt.Errorf("failed to parse latest version: %s with err %w", releaseInfo.TagName, err)
90 | }
91 |
92 | return latest.GreaterThan(curr), nil
93 | }
94 |
95 | func (u *upgrader) Upgrade(ctx context.Context, currentVersion string) error {
96 | curr, err := version.NewVersion(currentVersion)
97 | if err != nil {
98 | return err
99 | }
100 |
101 | releaseInfo, err := u.releaseGetter.GetLatestRelease(ctx)
102 | if err != nil {
103 | return err
104 | }
105 |
106 | latest, err := version.NewVersion(releaseInfo.TagName)
107 | if err != nil {
108 | return err
109 | }
110 |
111 | if latest.LessThanOrEqual(curr) {
112 | return nil
113 | }
114 |
115 | // from the releaseInfo, download the binary for the architecture
116 |
117 | downloadInfo, cleanup, err := u.assetDownloader.DownloadAsset(ctx, releaseInfo.Assets)
118 | if err != nil {
119 | return err
120 | }
121 |
122 | if cleanup != nil {
123 | defer cleanup()
124 | }
125 |
126 | // download the checksum file
127 | checksumInfo, err := u.checksumDownloader.Download(ctx, releaseInfo.Assets)
128 | if err != nil {
129 | return err
130 | }
131 |
132 | executableName := filepath.Base(u.executablePath)
133 | // verify the checksum
134 | if !u.checksumValidator.IsCheckSumValid(ctx, executableName, checksumInfo, downloadInfo.Checksum) {
135 | return ErrInvalidCheckSum
136 | }
137 |
138 | if err := replaceBinary(downloadInfo.DownloadedBinaryFilePath, u.executablePath); err != nil {
139 | return fmt.Errorf("failed to replace binary: %w", err)
140 | }
141 |
142 | return nil
143 | }
144 |
145 | // replaceBinary replaces the current executable with the downloaded update.
146 | func replaceBinary(tmpFilePath, currentBinaryPath string) error {
147 | // Replace the current binary with the new binary
148 | if err := os.Rename(tmpFilePath, currentBinaryPath); err != nil {
149 | return fmt.Errorf("failed to replace binary: %w", err)
150 | }
151 |
152 | return nil
153 | }
154 |
--------------------------------------------------------------------------------
/release/asset/download.go:
--------------------------------------------------------------------------------
1 | package asset
2 |
3 | import (
4 | "context"
5 | "crypto/sha256"
6 | "encoding/hex"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "os"
12 | "path/filepath"
13 | "runtime"
14 | "strings"
15 |
16 | "github.com/getsavvyinc/upgrade-cli/release"
17 | )
18 |
19 | type cleanupFn func() error
20 |
21 | type Downloader interface {
22 | DownloadAsset(ctx context.Context, ReleaseAssets []release.Asset) (*Info, cleanupFn, error)
23 | }
24 |
25 | type Info struct {
26 | Checksum string
27 | DownloadedBinaryFilePath string
28 | }
29 |
30 | type downloader struct {
31 | os string
32 | arch string
33 | lookupArchFallback map[string][]string
34 | executablePath string
35 | }
36 |
37 | var _ Downloader = (*downloader)(nil)
38 |
39 | type AssetDownloadOpt func(*downloader)
40 |
41 | func WithOS(os string) AssetDownloadOpt {
42 | return func(d *downloader) {
43 | d.os = os
44 | }
45 | }
46 |
47 | func WithArch(arch string) AssetDownloadOpt {
48 | return func(d *downloader) {
49 | d.arch = arch
50 | }
51 | }
52 |
53 | func WithLookupArchFallback(lookupArchFallback map[string][]string) AssetDownloadOpt {
54 | return func(d *downloader) {
55 | d.lookupArchFallback = lookupArchFallback
56 | }
57 | }
58 |
59 | func NewAssetDownloader(executablePath string, opts ...AssetDownloadOpt) Downloader {
60 | d := &downloader{
61 | os: runtime.GOOS,
62 | arch: runtime.GOARCH,
63 | executablePath: executablePath,
64 | }
65 | for _, opt := range opts {
66 | opt(d)
67 | }
68 | return d
69 | }
70 |
71 | var ErrNoAsset = errors.New("no asset found")
72 |
73 | func (d *downloader) DownloadAsset(ctx context.Context, assets []release.Asset) (*Info, cleanupFn, error) {
74 | // iterate through the assets and find the one that matches the os and arch
75 | suffix := d.os + "_" + d.arch
76 | asset, found := d.assetForSuffix(assets, suffix)
77 | if found {
78 | return d.downloadAsset(ctx, asset.BrowserDownloadURL)
79 | }
80 | // if asset not found, try a fallback. e.g amd64 -> x86_64
81 | if len(d.lookupArchFallback) == 0 {
82 | return nil, nil, fmt.Errorf("%w: os:%s arch:%s", ErrNoAsset, d.os, d.arch)
83 | }
84 |
85 | fallbackArchs, ok := d.lookupArchFallback[d.arch]
86 | if !ok {
87 | return nil, nil, fmt.Errorf("%w: os:%s arch:%s", ErrNoAsset, d.os, d.arch)
88 | }
89 |
90 | // Try to find an asset for each fallback architecture
91 | for _, fallbackArch := range fallbackArchs {
92 | fallbackSuffix := d.os + "_" + fallbackArch
93 | asset, found = d.assetForSuffix(assets, fallbackSuffix)
94 | if found {
95 | return d.downloadAsset(ctx, asset.BrowserDownloadURL)
96 | }
97 | }
98 | return nil, nil, fmt.Errorf("%w: os:%s arch:%s", ErrNoAsset, d.os, d.arch)
99 | }
100 |
101 | func (d *downloader) assetForSuffix(assets []release.Asset, suffix string) (release.Asset, bool) {
102 | for _, asset := range assets {
103 | if strings.HasSuffix(asset.BrowserDownloadURL, suffix) {
104 | return asset, true
105 | }
106 | }
107 | return release.Asset{}, false
108 | }
109 |
110 | func (d *downloader) downloadAsset(ctx context.Context, url string) (*Info, cleanupFn, error) {
111 | // Download the file
112 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
113 | if err != nil {
114 | return nil, nil, err
115 | }
116 |
117 | resp, err := http.DefaultClient.Do(req)
118 | if err != nil {
119 | return nil, nil, err
120 | }
121 | defer resp.Body.Close()
122 |
123 | // Create a temporary file in the same directory as the executable
124 | // Doing so avoids issues where the downloaded file is on a different filesystem/mount point from the executable.
125 | executable, executableDir := filepath.Base(d.executablePath), filepath.Dir(d.executablePath)
126 | tmpFile, err := os.CreateTemp(executableDir, executable)
127 | if err != nil {
128 | return nil, nil, err
129 | }
130 | defer tmpFile.Close()
131 |
132 | cleanupFn := func() error {
133 | return os.Remove(tmpFile.Name())
134 | }
135 |
136 | // sha256 checksum
137 | hasher := sha256.New()
138 |
139 | // Write the response body to the temporary file and hasher
140 | rd := io.TeeReader(resp.Body, hasher)
141 | _, err = io.Copy(tmpFile, rd)
142 | if err != nil {
143 | cleanupFn()
144 | return nil, nil, err
145 | }
146 |
147 | // Ensure the downloaded file has executable permissions
148 | if err := os.Chmod(tmpFile.Name(), 0755); err != nil {
149 | cleanupFn()
150 | return nil, nil, err
151 | }
152 |
153 | return &Info{
154 | Checksum: hex.EncodeToString(hasher.Sum(nil)),
155 | DownloadedBinaryFilePath: tmpFile.Name(),
156 | }, cleanupFn, nil
157 | }
158 |
--------------------------------------------------------------------------------
/checksum/checksum.go:
--------------------------------------------------------------------------------
1 | package checksum
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "errors"
7 | "fmt"
8 | "net/http"
9 | "runtime"
10 | "strings"
11 |
12 | "github.com/getsavvyinc/upgrade-cli/release"
13 | )
14 |
15 | type Downloader interface {
16 | Download(ctx context.Context, assets []release.Asset) (*Info, error)
17 | }
18 |
19 | type Info struct {
20 | // keyed on $binary_os_$arch
21 | Checksums map[string]string
22 | }
23 |
24 | type checksumDownloader struct {
25 | assetSuffix string
26 | }
27 |
28 | type DownloadOpt func(*checksumDownloader)
29 |
30 | func WithAssetSuffix(suffix string) DownloadOpt {
31 | return func(c *checksumDownloader) {
32 | c.assetSuffix = suffix
33 | }
34 | }
35 |
36 | func NewCheckSumDownloader(opts ...DownloadOpt) Downloader {
37 | d := &checksumDownloader{
38 | assetSuffix: "checksums.txt",
39 | }
40 | for _, opt := range opts {
41 | opt(d)
42 | }
43 | return d
44 | }
45 |
46 | var ErrNoCheckSumAsset = errors.New("no checksum asset found")
47 |
48 | func (c *checksumDownloader) Download(ctx context.Context, assets []release.Asset) (*Info, error) {
49 | // iterate through the assets and find the one that matches the os and arch
50 | for _, asset := range assets {
51 | if strings.HasSuffix(asset.BrowserDownloadURL, c.assetSuffix) {
52 | checksums, err := downloadCheckSum(ctx, asset.BrowserDownloadURL)
53 | if err != nil {
54 | return nil, err
55 | }
56 | return checksums, nil
57 | }
58 | }
59 | return nil, ErrNoCheckSumAsset
60 | }
61 |
62 | var ErrInvalidChecksumFile = errors.New("invalid checksum file")
63 |
64 | func downloadCheckSum(ctx context.Context, url string) (*Info, error) {
65 | // download the checksum file
66 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
67 | if err != nil {
68 | return nil, err
69 | }
70 |
71 | resp, err := http.DefaultClient.Do(req)
72 | if err != nil {
73 | return nil, err
74 | }
75 | defer resp.Body.Close()
76 |
77 | checksums := make(map[string]string)
78 |
79 | scanner := bufio.NewScanner(resp.Body)
80 | // parse the file and return the checksums
81 | for scanner.Scan() {
82 | line := scanner.Text()
83 | // parse the line and extract the checksum
84 | line = strings.TrimSpace(line)
85 | // there maybe one or more blank spaces between the checksum and the file name
86 | parts := strings.Fields(line)
87 | // parts[0] is the checksum, parts[1] is the file name
88 | if len(parts) != 2 {
89 | return nil, fmt.Errorf("%w: checksum file is malformed", ErrInvalidChecksumFile)
90 | }
91 | checksums[parts[1]] = parts[0]
92 | }
93 |
94 | if len(checksums) == 0 {
95 | return nil, fmt.Errorf("%w: checksum file is empty", ErrInvalidChecksumFile)
96 | }
97 | return &Info{Checksums: checksums}, nil
98 | }
99 |
100 | type CheckSumValidator interface {
101 | IsCheckSumValid(ctx context.Context, binary string, checksums *Info, downloadedChecksum string) bool
102 | }
103 |
104 | type validator struct {
105 | os string
106 | arch string
107 | }
108 |
109 | // String maps arch to string.
110 | //
111 | // String maps 386 to i386 and amd64 to x86_64 for consistency across linux and darwin.
112 |
113 | type ValidatorOption func(*validator)
114 |
115 | func WithOS(os string) ValidatorOption {
116 | return func(v *validator) {
117 | v.os = os
118 | }
119 | }
120 |
121 | var fallbackArchMap = map[string][]string{
122 | "amd64": {"x86_64"},
123 | "386": {"i386", "all"},
124 | "arm64": {"all"},
125 | }
126 |
127 | func WithArch(a string) ValidatorOption {
128 | return func(v *validator) {
129 | v.arch = strings.ToLower(a)
130 | }
131 | }
132 |
133 | func NewCheckSumValidator(opts ...ValidatorOption) CheckSumValidator {
134 | v := &validator{
135 | os: runtime.GOOS,
136 | arch: strings.ToLower(runtime.GOARCH),
137 | }
138 |
139 | for _, opt := range opts {
140 | opt(v)
141 | }
142 | return v
143 | }
144 |
145 | func (v *validator) IsCheckSumValid(ctx context.Context, binary string, info *Info, downloadedChecksum string) bool {
146 |
147 | key := fmt.Sprintf("%s_%s_%s", binary, v.os, v.arch)
148 | expectedChecksum, ok := info.Checksums[key]
149 | if !ok {
150 | return v.tryFallbackArch(binary, info, downloadedChecksum)
151 | }
152 | return expectedChecksum == downloadedChecksum
153 | }
154 |
155 | func (v *validator) tryFallbackArch(binary string, info *Info, downloadedChecksum string) bool {
156 | archs, ok := fallbackArchMap[v.arch]
157 | if !ok {
158 | return false
159 | }
160 |
161 | for _, arch := range archs {
162 | key := fmt.Sprintf("%s_%s_%s", binary, v.os, arch)
163 | expectedChecksum, ok := info.Checksums[key]
164 | if ok {
165 | return expectedChecksum == downloadedChecksum
166 | }
167 | }
168 | return false
169 | }
170 |
--------------------------------------------------------------------------------
/checksum/checksum_test.go:
--------------------------------------------------------------------------------
1 | package checksum
2 |
3 | import (
4 | "context"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "strings"
9 | "testing"
10 |
11 | "github.com/getsavvyinc/upgrade-cli/release"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | // checksumData is a sample checksum file for testing
16 | // It contains the checksums, one or more spaces and binary_os_arch pairs
17 | // NOTE: we intentionally have an extra space at the beg of each line
18 | const checksumData = ` checksum_savvy_darwin_arm64 savvy_darwin_arm64
19 | checksum_savvy_darwin_x86_64 savvy_darwin_x86_64
20 | checksum_savvy_linux_arm64 savvy_linux_arm64
21 | checksum_savvy_linux_i386 savvy_linux_i386
22 | checksum_savvy_linux_x86_64 savvy_linux_x86_64
23 | `
24 |
25 | const malformedChecksumData = `6796a0fb64d0c78b2de5410a94749a 3bfb77291747c1835fbd427e8bf00f6af3 savvy_darwin_arm64
26 | `
27 |
28 | func setupTestServer(t *testing.T, handler http.Handler) *httptest.Server {
29 | srv := httptest.NewServer(handler)
30 | defer t.Cleanup(srv.Close)
31 | return srv
32 | }
33 |
34 | func checkSumDataHandler(t *testing.T) http.Handler {
35 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36 | w.Header().Set("Content-Type", "application/octet-stream")
37 | w.WriteHeader(200)
38 | if r.URL.Path == "/checksums.txt" {
39 | io.WriteString(w, checksumData)
40 | return
41 | }
42 | if r.URL.Path == "/empty_checksums.txt" {
43 | io.WriteString(w, "")
44 | return
45 | }
46 | if r.URL.Path == "/malformed_checksums.txt" {
47 | io.WriteString(w, malformedChecksumData)
48 | return
49 | }
50 | // DownloadCheckSum should only be called with routes that end with checksums.txt
51 | t.Errorf("unexpected URL: %s", r.URL.Path)
52 | })
53 | }
54 |
55 | func TestDownloadCheckSum(t *testing.T) {
56 | srv := setupTestServer(t, checkSumDataHandler(t))
57 | ctx := context.Background()
58 | testSuffix := "checksums.txt"
59 | t.Run("ValidCheckSumFile", func(t *testing.T) {
60 | checksumURL := srv.URL + "/checksums.txt"
61 | downloader := NewCheckSumDownloader(WithAssetSuffix(testSuffix))
62 | checksums, err := downloader.Download(ctx, []release.Asset{
63 | {BrowserDownloadURL: checksumURL},
64 | {BrowserDownloadURL: srv.URL + "/malformed_path.txt"},
65 | })
66 | assert.NoError(t, err)
67 | assert.NotNil(t, checksums)
68 | assert.NotEmpty(t, checksums.Checksums)
69 | for k, v := range checksums.Checksums {
70 | assert.NotEmpty(t, k)
71 | assert.NotEmpty(t, v)
72 | assert.Equal(t, strings.Join([]string{"checksum", k}, "_"), v)
73 | }
74 | })
75 | t.Run("InvalidCheckSumFile", func(t *testing.T) {
76 | testCases := []struct {
77 | name string
78 | url string
79 | }{
80 | {
81 | name: "MalformedChecksumFile",
82 | url: srv.URL + "/malformed_checksums.txt",
83 | },
84 | {
85 | name: "EmptyCheckSumFile",
86 | url: srv.URL + "/empty_checksums.txt",
87 | },
88 | }
89 | for _, tc := range testCases {
90 | t.Run(tc.name, func(t *testing.T) {
91 | downloader := NewCheckSumDownloader(WithAssetSuffix(testSuffix))
92 | checksums, err := downloader.Download(ctx, []release.Asset{
93 | {BrowserDownloadURL: tc.url},
94 | })
95 | assert.Error(t, err)
96 | assert.Nil(t, checksums)
97 | assert.ErrorIs(t, err, ErrInvalidChecksumFile)
98 | })
99 | }
100 | })
101 | t.Run("NoCheckSumAsset", func(t *testing.T) {
102 | downloader := NewCheckSumDownloader(WithAssetSuffix(testSuffix))
103 | checksums, err := downloader.Download(ctx, []release.Asset{
104 | {BrowserDownloadURL: srv.URL + "/savvy_darwin_arm64"},
105 | })
106 | assert.Error(t, err)
107 | assert.Nil(t, checksums)
108 | assert.ErrorIs(t, err, ErrNoCheckSumAsset)
109 | })
110 | }
111 |
112 | func TestCheckSumValidator(t *testing.T) {
113 | binary := "savvy"
114 | const checksum = "checksum"
115 | checksumInfo := &Info{
116 | Checksums: map[string]string{
117 | binary + "_darwin_x86_64": checksum,
118 | binary + "_linux_x86_64": checksum,
119 | binary + "_linux_i386": checksum,
120 | },
121 | }
122 |
123 | testCases := []struct {
124 | name string
125 | downloadedChecksum string
126 | isValid bool
127 | os string
128 | arch string
129 | binary string
130 | }{
131 | {
132 | name: "ValidChecksums",
133 | downloadedChecksum: checksum,
134 | os: "linux",
135 | arch: "x86_64",
136 | isValid: true,
137 | binary: binary,
138 | },
139 | {
140 | name: "ValidChecksumsWithAmd64",
141 | downloadedChecksum: checksum,
142 | os: "linux",
143 | arch: "amd64",
144 | isValid: true,
145 | binary: binary,
146 | },
147 | {
148 | name: "ValidChecksumsWith386",
149 | downloadedChecksum: checksum,
150 | os: "linux",
151 | arch: "386",
152 | isValid: true,
153 | binary: binary,
154 | },
155 | {
156 | name: "InvalidChecksums",
157 | downloadedChecksum: "invalid_checksum",
158 | os: "darwin",
159 | arch: "x86_64",
160 | isValid: false,
161 | binary: binary,
162 | },
163 | {
164 | name: "InvalidOS",
165 | downloadedChecksum: checksum,
166 | os: "windows",
167 | arch: "x86_64",
168 | isValid: false,
169 | binary: binary,
170 | },
171 | {
172 | name: "InvalidArch",
173 | downloadedChecksum: checksum,
174 | os: "linux",
175 | arch: "not_suppported",
176 | isValid: false,
177 | binary: binary,
178 | },
179 | {
180 | name: "InvalidBinary",
181 | downloadedChecksum: checksum,
182 | os: "linux",
183 | arch: "x86_64",
184 | isValid: false,
185 | binary: "invalid_binary",
186 | },
187 | }
188 |
189 | for _, tc := range testCases {
190 | t.Run(tc.name, func(t *testing.T) {
191 | csv := NewCheckSumValidator(WithArch(tc.arch), WithOS(tc.os))
192 | isValid := csv.IsCheckSumValid(context.Background(), tc.binary, checksumInfo, tc.downloadedChecksum)
193 | assert.Equal(t, tc.isValid, isValid)
194 | })
195 | }
196 | }
197 |
--------------------------------------------------------------------------------